[Bugfix] Align thinking_budget behavior with ERNIE reasoning flow (#6934)

* [Bugfix] Align thinking_budget behavior with ERNIE reasoning flow

* [Docs] Fix thinking_budget markdown formatting

* [Test] Align ernie thinking budget test with process_request_dict
This commit is contained in:
jackyYang6
2026-03-23 14:15:55 +08:00
committed by GitHub
parent 7a78001be2
commit 634d23a38a
10 changed files with 663 additions and 285 deletions
+28 -12
View File
@@ -2,7 +2,9 @@
## Overview
`ThinkingBudgetLogitsProcessor` limits the number of tokens generated inside the `<think> ... </think>` segment. When the budget is reached, it forces a line break token and then the `</think>` token to terminate the thinking section.
`ThinkingBudgetLogitsProcessor` limits the number of tokens generated inside the `<think> ... </think>`
segment. When the budget is reached, it terminates thinking by forcing `</think>`. If
`think_stop_sentence` is configured, it forces the custom sentence first and then `</think>`.
## When to Use
@@ -11,19 +13,22 @@
## How It Works
1. **CPU precompute (DataProcessor)**: when a request includes `thinking_budget`, the prompt token ids are scanned to determine whether thinking has started, whether it already ended, and how many tokens are already inside the thinking section.
1. **Request-side precompute (DataProcessor)**: when a request includes `thinking_budget`, the prompt token ids are scanned to determine whether thinking has started, whether it already ended, and how many tokens are already inside the thinking section.
2. **Per-step update**: during decoding, the processor tracks `last_token_id` and `tokens_after_start`.
3. **Budget enforcement**: once the budget is reached, it forces a line break and then the thinking end token.
3. **Budget enforcement**: once the budget is reached, it forces `</think>` directly. If `think_stop_sentence`
is configured, it forces that sentence first and then `</think>`.
## Requirements
- The model must provide valid token ids for `think_start_id`, `think_end_id`, and `line_break_id` (via `ModelConfig`).
- If any of these ids are invalid, the processor is disabled and `thinking_budget` will not take effect.
- The model must provide valid token ids for `think_start_id` and `think_end_id` (via `ModelConfig`).
- If either of these ids is invalid, the processor is disabled and `thinking_budget` will not take effect.
## Request Parameters
- `thinking_budget` (int, required to enable): maximum number of tokens after `<think>` before forced termination.
- `think_stop_sentence` (string, optional): a stop sentence that will be tokenized on the CPU side and enforced near the budget boundary.
- `thinking_budget` (int, required to enable): maximum number of decode-time tokens after `<think>` before forced
termination.
- `think_stop_sentence` (string, optional): a literal custom sentence that will be tokenized on the request side
and enforced near the budget boundary.
## Operator-Level vs LogitsProcessor
@@ -41,16 +46,25 @@ FastDeploy has two ways to limit thinking length:
In short:
- If you only need a hard cap on thinking length, prefer `reasoning_max_tokens`.
- If you need custom behavior (for example, injecting custom sentence tokens), use `ThinkingBudgetLogitsProcessor`.
- If you need custom behavior (for example, inserting a custom sentence before `</think>`), use
`ThinkingBudgetLogitsProcessor`.
## Practical guidance
`reasoning_max_tokens` and `thinking_budget` are not mutually exclusive in current implementation.
If both are configured for the same request, both constraints can take effect, and whichever triggers first will end the thinking phase.
- To use **operator-level-only** behavior: this is request-level config only. Set `enable_thinking=true` and `reasoning_max_tokens` in request, and do not set `thinking_budget`.
- To use **logits-processor-only** behavior (especially with `think_stop_sentence`): this requires service-level + request-level config. Start service with `--logits-processors ThinkingBudgetLogitsProcessor`, and set `thinking_budget` (and optional `think_stop_sentence`) in `logits_processors_args`; leave `reasoning_max_tokens` unset.
- Avoid enabling both for strict custom sentence insertion requirements, because operator-level termination may cut the custom sentence path earlier.
- To use **operator-level-only** behavior: this is request-level config only. Set
`enable_thinking=true` and `reasoning_max_tokens` in request, and do not set `thinking_budget`.
- To use **logits-processor-only** behavior (especially with `think_stop_sentence`): this requires
service-level + request-level config. Start service with `--logits-processors ThinkingBudgetLogitsProcessor`,
and set `thinking_budget` (and optional `think_stop_sentence`) in `logits_processors_args`; leave
`reasoning_max_tokens` unset.
- `thinking_budget` itself does not require `enable_thinking=true`.
- If an ERNIE chat template already appends `<think>` in the prompt, `thinking_budget` should still take effect; it
does not require the model to emit another `<think>` during decoding.
- Avoid enabling both for strict custom sentence insertion requirements, because operator-level
termination may cut the custom sentence path earlier.
## Online Usage
@@ -120,4 +134,6 @@ print(outputs[0].outputs.text)
## Performance Note
This processor runs `update_state` and `apply` on every decode step. If you only need a hard thinking-length cap and care most about throughput, consider the operator-level reasoning-length controls instead of per-step logits processing.
This processor runs `update_state` and `apply` on every decode step. If you only need a hard
thinking-length cap and care most about throughput, consider the operator-level reasoning-length
controls instead of per-step logits processing.
+20 -11
View File
@@ -2,7 +2,9 @@
## 概述
`ThinkingBudgetLogitsProcessor` 用于限制 `<think> ... </think>` 区间的生成长度。当预算达到阈值时,会强制生成换行符 token,再强制生成 `</think>`,从而结束思考段。
`ThinkingBudgetLogitsProcessor` 用于限制 `<think> ... </think>` 区间的生成长度。当预算达到阈值时,
会直接强制生成 `</think>` 来结束思考段;如果配置了 `think_stop_sentence`,则会先强制输出该自定义
文案,再输出 `</think>`
## 适用场景
@@ -11,19 +13,20 @@
## 工作原理
1. **CPU 侧预计算(DataProcessor**:当请求中包含 `thinking_budget`,会基于 prompt 的 token ids 计算是否已进入思考段、是否已结束,以及已有的思考长度。
1. **请求侧预计算(DataProcessor**:当请求中包含 `thinking_budget`,会基于 prompt 的 token ids 计算是否已进入思考段、是否已结束,以及已有的思考长度。
2. **每步更新**:解码过程中跟踪 `last_token_id``tokens_after_start`
3. **预算约束**:达到预算后,依次强制换行符与思考结束 token
3. **预算约束**:达到预算后,默认直接强制 `</think>`;如果配置了 `think_stop_sentence`,则先逐 token
强制输出该文案,再输出 `</think>`
## 前置要求
- 模型需提供有效的 `think_start_id``think_end_id``line_break_id`(来自 `ModelConfig`)。
- 若任意 id 无效,处理器会禁用,`thinking_budget` 不生效。
- 模型需提供有效的 `think_start_id``think_end_id`(来自 `ModelConfig`)。
-其中任意 id 无效,处理器会禁用,`thinking_budget` 不生效。
## 请求参数
- `thinking_budget`int,启用所需):`<think>` 之后允许的最大 token 数。
- `think_stop_sentence`string,可选):CPU 侧会将该字符串编码为 token ids,并在预算边界附近强制输出。
- `thinking_budget`int,启用所需):`<think>` 之后允许的最大 decode 阶段 token 数。
- `think_stop_sentence`string,可选):按字面串编码的自定义终止文案,并在预算边界附近强制输出。
## 算子级限制 vs LogitsProcessor
@@ -35,21 +38,27 @@ FastDeploy 当前有两种思考长度控制方式:
- 适合“只限制思考长度”的简单场景。
- **`ThinkingBudgetLogitsProcessor`**`logits_processors_args.thinking_budget`):
- 由每步 Python 侧 logits 处理实现。
- 支持更灵活的行为,例如 `think_stop_sentence`(在结束前插入自定义话术)
- 支持更灵活的行为,例如 `think_stop_sentence`
- 相比算子级限制,在高并发下通常有更高开销。
可按以下原则选择:
- 仅需限制思考长度:优先用 `reasoning_max_tokens`
- 需要更灵活控制(如插入自定义话术):使用 `ThinkingBudgetLogitsProcessor`
- 需要更灵活控制(如`</think>`插入自定义话术):使用 `ThinkingBudgetLogitsProcessor`
## 建议实践
当前实现中,`reasoning_max_tokens``thinking_budget` 不是互斥关系。
同一请求如果同时配置,两套约束都可能生效,谁先触发就先结束思考段。
- **只用算子级限制**:这是请求级配置。仅在请求中设置 `enable_thinking=true` + `reasoning_max_tokens`不要传 `thinking_budget`
- **只用 LogitsProcessor**(尤其要用 `think_stop_sentence`):这是“服务启动 + 请求参数”两级配置。服务启动时必须加 `--logits-processors ThinkingBudgetLogitsProcessor`,并在请求里通过 `logits_processors_args``thinking_budget`(以及可选的 `think_stop_sentence`);同时不要设置 `reasoning_max_tokens`
- **只用算子级限制**:这是请求级配置。仅在请求中设置 `enable_thinking=true` + `reasoning_max_tokens`
不要传 `thinking_budget`
- **只用 LogitsProcessor**(尤其要用 `think_stop_sentence`):这是“服务启动 + 请求参数”两级配置。
服务启动时必须加 `--logits-processors ThinkingBudgetLogitsProcessor`,并在请求里通过
`logits_processors_args``thinking_budget`(以及可选的 `think_stop_sentence`);同时不要设置
`reasoning_max_tokens`
- `thinking_budget` 本身不依赖 `enable_thinking=true`
- 如果 ERNIE 的 chat template 已经在 prompt 里拼入 `<think>``thinking_budget` 也应正常生效,不要求模型在 decode 阶段再次输出 `<think>`
- 如果业务要求“必须完整插入自定义话术”,不建议与算子级限制同时开启,否则可能被算子级提前截断。
## 在线使用
+9
View File
@@ -107,6 +107,11 @@ class Ernie4_5Processor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
logits_processors_args = self._prepare_think_stop_sentence(
request.get("logits_processors_args") or {}, max_model_len
)
request["logits_processors_args"] = logits_processors_args
# processing prompt_token_ids
if not request.get("prompt_token_ids"):
if request.get("prompt"):
@@ -143,6 +148,10 @@ class Ernie4_5Processor(BaseDataProcessor):
# truncate prompts that exceed the length limit
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
logits_processors_args = self._update_thinking_prompt_state(
request["prompt_token_ids"], request.get("logits_processors_args") or {}
)
request["logits_processors_args"] = logits_processors_args
max_tokens = max_model_len - len(request["prompt_token_ids"])
if request.get("max_tokens") is None:
request["max_tokens"] = max(1, max_tokens)
@@ -210,6 +210,11 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
logits_processors_args = self._prepare_think_stop_sentence(
request.get("logits_processors_args") or {}, max_model_len
)
request["logits_processors_args"] = logits_processors_args
if request.get("prompt_token_ids"):
messages = request.get("messages")
if messages:
@@ -257,6 +262,10 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
# 截断超过长度限制的prompt
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
logits_processors_args = self._update_thinking_prompt_state(
request["prompt_token_ids"], request.get("logits_processors_args") or {}
)
request["logits_processors_args"] = logits_processors_args
max_tokens = max_model_len - len(request["prompt_token_ids"])
if request.get("max_tokens") is None:
+97 -67
View File
@@ -92,6 +92,9 @@ class BaseDataProcessor(ABC):
"""
Encode text into token ids with a small LRU cache.
"""
if not hasattr(self, "_tokenize_cache"):
self._tokenize_cache = OrderedDict()
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
key = (text, bool(add_special_tokens))
cached = self._tokenize_cache.get(key)
if cached is not None:
@@ -107,6 +110,25 @@ class BaseDataProcessor(ABC):
self._tokenize_cache.popitem(last=False)
return token_ids
def _encode_literal_text_with_cache(self, text):
if not hasattr(self, "_tokenize_cache"):
self._tokenize_cache = OrderedDict()
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
key = ("literal_text", text)
cached = self._tokenize_cache.get(key)
if cached is not None:
self._tokenize_cache.move_to_end(key)
return cached
token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
if hasattr(token_ids, "tolist"):
token_ids = token_ids.tolist()
elif not isinstance(token_ids, list):
token_ids = list(token_ids)
self._tokenize_cache[key] = token_ids
if len(self._tokenize_cache) > self._tokenize_cache_capacity:
self._tokenize_cache.popitem(last=False)
return token_ids
def messages2ids(self, messages):
"""
Convert multi-turn messages into ID sequences.
@@ -119,6 +141,77 @@ class BaseDataProcessor(ABC):
"""
raise NotImplementedError
def _get_think_token_ids(self):
think_token_ids = getattr(self, "_think_token_ids", None)
if think_token_ids is not None:
return think_token_ids
tokenizer = getattr(self, "tokenizer", None)
vocab = tokenizer.get_vocab() if tokenizer is not None else {}
think_start_id = vocab.get("<think>", -1)
think_end_id = vocab.get("</think>", -1)
self._think_token_ids = (think_start_id, think_end_id)
return self._think_token_ids
def _prepare_think_stop_sentence(self, logits_processors_args, max_model_len=None):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
if isinstance(think_stop_sentence, str) and think_stop_sentence:
sentence_token_ids = self._encode_literal_text_with_cache(think_stop_sentence)
logits_processors_args["think_stop_sentence_token_ids"] = sentence_token_ids
logits_processors_args.pop("think_stop_sentence", None)
return logits_processors_args
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
thinking_budget = logits_processors_args.get("thinking_budget")
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
return logits_processors_args
if logits_processors_args.get("think_prompt_checked"):
return logits_processors_args
if prompt_token_ids is None:
return logits_processors_args
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
if token_len == 0:
return logits_processors_args
think_start_id, think_end_id = self._get_think_token_ids()
if think_start_id < 0 or think_end_id < 0:
return logits_processors_args
if hasattr(prompt_token_ids, "tolist"):
token_list = prompt_token_ids.tolist()
else:
token_list = list(prompt_token_ids)
started = False
ended = False
tokens_after_start = 0
last_token_id = None
in_thinking = False
for token_id in token_list:
if token_id == think_start_id:
started = True
ended = False
in_thinking = True
elif token_id == think_end_id and in_thinking:
ended = True
in_thinking = False
if started and token_list:
# Align with operator-level reasoning_max_tokens: prompt-side tokens
# inside <think> do not consume thinking budget.
last_token_id = int(token_list[-1])
logits_processors_args["think_prompt_checked"] = True
logits_processors_args["think_prompt_started"] = started
logits_processors_args["think_prompt_ended"] = ended
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
if last_token_id is not None:
logits_processors_args["think_prompt_last_token_id"] = last_token_id
else:
logits_processors_args.pop("think_prompt_last_token_id", None)
return logits_processors_args
def ids2tokens(self, token_id, task_id=None):
"""
token ids to strings
@@ -207,65 +300,6 @@ class DataProcessor(BaseDataProcessor):
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
self.tokenizer.pad_token_id = self.pad_token_id
self._think_token_ids = None
def _get_think_token_ids(self):
if self._think_token_ids is not None:
return self._think_token_ids
vocab = self.tokenizer.get_vocab()
think_start_id = vocab.get("<think>", -1)
think_end_id = vocab.get("</think>", -1)
self._think_token_ids = (think_start_id, think_end_id)
return self._think_token_ids
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
thinking_budget = logits_processors_args.get("thinking_budget")
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
return logits_processors_args
if logits_processors_args.get("think_prompt_checked"):
return logits_processors_args
if prompt_token_ids is None:
return logits_processors_args
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
if token_len == 0:
return logits_processors_args
think_start_id, think_end_id = self._get_think_token_ids()
if think_start_id < 0 or think_end_id < 0:
return logits_processors_args
if hasattr(prompt_token_ids, "tolist"):
token_list = prompt_token_ids.tolist()
else:
token_list = list(prompt_token_ids)
started = think_start_id in token_list
ended = False
tokens_after_start = 0
last_token_id = None
if started:
start_pos = token_list.index(think_start_id)
tokens_after = token_list[start_pos + 1 :]
if think_end_id in tokens_after:
end_pos = tokens_after.index(think_end_id)
tokens_after_start = end_pos + 1
ended = True
else:
tokens_after_start = len(tokens_after)
if token_list:
last_token_id = int(token_list[-1])
logits_processors_args["think_prompt_checked"] = True
logits_processors_args["think_prompt_started"] = started
logits_processors_args["think_prompt_ended"] = ended
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
if last_token_id is not None:
logits_processors_args["think_prompt_last_token_id"] = last_token_id
else:
logits_processors_args.pop("think_prompt_last_token_id", None)
return logits_processors_args
def process_request_dict(self, request, max_model_len=None, **kwargs):
"""
Preprocess the request
@@ -292,14 +326,10 @@ class DataProcessor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
logits_processors_args = request.get("logits_processors_args") or {}
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
if isinstance(think_stop_sentence, str) and think_stop_sentence:
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
logits_processors_args.pop("think_stop_sentence", None)
request["logits_processors_args"] = logits_processors_args
logits_processors_args = self._prepare_think_stop_sentence(
request.get("logits_processors_args") or {}, max_model_len
)
request["logits_processors_args"] = logits_processors_args
# processing prompt_token_ids
if not request.get("prompt_token_ids"):
+20 -1
View File
@@ -107,6 +107,11 @@ class Ernie4_5Processor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
logits_processors_args = self._prepare_think_stop_sentence(
request.get("logits_processors_args") or {}, max_model_len
)
request["logits_processors_args"] = logits_processors_args
# processing prompt_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is not None:
@@ -123,7 +128,8 @@ class Ernie4_5Processor(BaseDataProcessor):
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
request.prompt_token_ids = token_ids
data_processor_logger.debug(
f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
f"request_ids: {request.request_id}, prompt: {prompt}, "
f"tokens: {tokens}, token_ids: {token_ids}"
)
elif request.messages is not None:
task = request.to_dict()
@@ -145,6 +151,10 @@ class Ernie4_5Processor(BaseDataProcessor):
# truncate prompts that exceed the length limit
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
logits_processors_args = self._update_thinking_prompt_state(
request.prompt_token_ids, request.get("logits_processors_args") or {}
)
request["logits_processors_args"] = logits_processors_args
max_tokens = max_model_len - len(request.prompt_token_ids)
if request.get("max_tokens") is None:
request.set("max_tokens", max(1, max_tokens))
@@ -201,6 +211,11 @@ class Ernie4_5Processor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request.sampling_params.bad_words_token_ids = bad_words_token_ids
logits_processors_args = self._prepare_think_stop_sentence(
getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
)
request.sampling_params.logits_processors_args = logits_processors_args
# processing prompt_token_ids
if not request.prompt_token_ids:
if request.prompt:
@@ -241,6 +256,10 @@ class Ernie4_5Processor(BaseDataProcessor):
# truncate prompts that exceed the length limit
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
logits_processors_args = self._update_thinking_prompt_state(
request.prompt_token_ids, getattr(request.sampling_params, "logits_processors_args", None) or {}
)
request.sampling_params.logits_processors_args = logits_processors_args
max_tokens = max_model_len - len(request.prompt_token_ids)
if getattr(request.sampling_params, "max_tokens", None) is None:
request.sampling_params.max_tokens = max(1, max_tokens)
@@ -216,6 +216,11 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request.sampling_params.bad_words_token_ids = bad_words_token_ids
logits_processors_args = self._prepare_think_stop_sentence(
getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
)
request.sampling_params.logits_processors_args = logits_processors_args
if request.prompt_token_ids:
messages = request.messages
if messages:
@@ -267,6 +272,10 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
# 截断超过长度限制的prompt
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
logits_processors_args = self._update_thinking_prompt_state(
request.prompt_token_ids, getattr(request.sampling_params, "logits_processors_args", None) or {}
)
request.sampling_params.logits_processors_args = logits_processors_args
max_tokens = max_model_len - len(request.prompt_token_ids)
if getattr(request.sampling_params, "max_tokens", None) is None:
+101 -75
View File
@@ -115,6 +115,9 @@ class BaseDataProcessor(ABC):
"""
Encode text into token ids with a small LRU cache.
"""
if not hasattr(self, "_tokenize_cache"):
self._tokenize_cache = OrderedDict()
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
key = (text, bool(add_special_tokens))
cached = self._tokenize_cache.get(key)
if cached is not None:
@@ -130,6 +133,25 @@ class BaseDataProcessor(ABC):
self._tokenize_cache.popitem(last=False)
return token_ids
def _encode_literal_text_with_cache(self, text):
if not hasattr(self, "_tokenize_cache"):
self._tokenize_cache = OrderedDict()
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
key = ("literal_text", text)
cached = self._tokenize_cache.get(key)
if cached is not None:
self._tokenize_cache.move_to_end(key)
return cached
token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
if hasattr(token_ids, "tolist"):
token_ids = token_ids.tolist()
elif not isinstance(token_ids, list):
token_ids = list(token_ids)
self._tokenize_cache[key] = token_ids
if len(self._tokenize_cache) > self._tokenize_cache_capacity:
self._tokenize_cache.popitem(last=False)
return token_ids
def messages2ids(self, messages):
"""
Convert multi-turn messages into ID sequences.
@@ -142,6 +164,77 @@ class BaseDataProcessor(ABC):
"""
raise NotImplementedError
def _get_think_token_ids(self):
think_token_ids = getattr(self, "_think_token_ids", None)
if think_token_ids is not None:
return think_token_ids
tokenizer = getattr(self, "tokenizer", None)
vocab = tokenizer.get_vocab() if tokenizer is not None else {}
think_start_id = vocab.get("<think>", -1)
think_end_id = vocab.get("</think>", -1)
self._think_token_ids = (think_start_id, think_end_id)
return self._think_token_ids
def _prepare_think_stop_sentence(self, logits_processors_args, max_model_len=None):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
if isinstance(think_stop_sentence, str) and think_stop_sentence:
sentence_token_ids = self._encode_literal_text_with_cache(think_stop_sentence)
logits_processors_args["think_stop_sentence_token_ids"] = sentence_token_ids
logits_processors_args.pop("think_stop_sentence", None)
return logits_processors_args
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
thinking_budget = logits_processors_args.get("thinking_budget")
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
return logits_processors_args
if logits_processors_args.get("think_prompt_checked"):
return logits_processors_args
if prompt_token_ids is None:
return logits_processors_args
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
if token_len == 0:
return logits_processors_args
think_start_id, think_end_id = self._get_think_token_ids()
if think_start_id < 0 or think_end_id < 0:
return logits_processors_args
if hasattr(prompt_token_ids, "tolist"):
token_list = prompt_token_ids.tolist()
else:
token_list = list(prompt_token_ids)
started = False
ended = False
tokens_after_start = 0
last_token_id = None
in_thinking = False
for token_id in token_list:
if token_id == think_start_id:
started = True
ended = False
in_thinking = True
elif token_id == think_end_id and in_thinking:
ended = True
in_thinking = False
if started and token_list:
# Align with operator-level reasoning_max_tokens: prompt-side tokens
# inside <think> do not consume thinking budget.
last_token_id = int(token_list[-1])
logits_processors_args["think_prompt_checked"] = True
logits_processors_args["think_prompt_started"] = started
logits_processors_args["think_prompt_ended"] = ended
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
if last_token_id is not None:
logits_processors_args["think_prompt_last_token_id"] = last_token_id
else:
logits_processors_args.pop("think_prompt_last_token_id", None)
return logits_processors_args
def ids2tokens(self, token_id, task_id=None):
"""
token ids to strings
@@ -221,65 +314,6 @@ class DataProcessor(BaseDataProcessor):
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
self.tokenizer.pad_token_id = self.pad_token_id
self._think_token_ids = None
def _get_think_token_ids(self):
if self._think_token_ids is not None:
return self._think_token_ids
vocab = self.tokenizer.get_vocab()
think_start_id = vocab.get("<think>", -1)
think_end_id = vocab.get("</think>", -1)
self._think_token_ids = (think_start_id, think_end_id)
return self._think_token_ids
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
thinking_budget = logits_processors_args.get("thinking_budget")
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
return logits_processors_args
if logits_processors_args.get("think_prompt_checked"):
return logits_processors_args
if prompt_token_ids is None:
return logits_processors_args
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
if token_len == 0:
return logits_processors_args
think_start_id, think_end_id = self._get_think_token_ids()
if think_start_id < 0 or think_end_id < 0:
return logits_processors_args
if hasattr(prompt_token_ids, "tolist"):
token_list = prompt_token_ids.tolist()
else:
token_list = list(prompt_token_ids)
started = think_start_id in token_list
ended = False
tokens_after_start = 0
last_token_id = None
if started:
start_pos = token_list.index(think_start_id)
tokens_after = token_list[start_pos + 1 :]
if think_end_id in tokens_after:
end_pos = tokens_after.index(think_end_id)
tokens_after_start = end_pos + 1
ended = True
else:
tokens_after_start = len(tokens_after)
if token_list:
last_token_id = int(token_list[-1])
logits_processors_args["think_prompt_checked"] = True
logits_processors_args["think_prompt_started"] = started
logits_processors_args["think_prompt_ended"] = ended
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
if last_token_id is not None:
logits_processors_args["think_prompt_last_token_id"] = last_token_id
else:
logits_processors_args.pop("think_prompt_last_token_id", None)
return logits_processors_args
def process_request(self, request, max_model_len=None, **kwargs):
"""
Preprocess the request
@@ -306,14 +340,10 @@ class DataProcessor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
logits_processors_args = request.get("logits_processors_args") or {}
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
if isinstance(think_stop_sentence, str) and think_stop_sentence:
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
logits_processors_args.pop("think_stop_sentence", None)
request["logits_processors_args"] = logits_processors_args
logits_processors_args = self._prepare_think_stop_sentence(
request.get("logits_processors_args") or {}, max_model_len
)
request["logits_processors_args"] = logits_processors_args
# processing prompt_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
@@ -411,14 +441,10 @@ class DataProcessor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request.sampling_params.bad_words_token_ids = bad_words_token_ids
logits_processors_args = getattr(request.sampling_params, "logits_processors_args", None) or {}
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
if isinstance(think_stop_sentence, str) and think_stop_sentence:
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
logits_processors_args.pop("think_stop_sentence", None)
request.sampling_params.logits_processors_args = logits_processors_args
logits_processors_args = self._prepare_think_stop_sentence(
getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
)
request.sampling_params.logits_processors_args = logits_processors_args
# processing prompt_token_ids
if not request.prompt_token_ids:
@@ -42,14 +42,15 @@ class _ThinkingState:
class ThinkingBudgetLogitsProcessor(LogitsProcessor):
"""Limit the number of tokens generated in the thinking phase.
The processor tracks per-request thinking state and forces a newline token
when the budget is reached, followed by the thinking end token on the next step.
The processor tracks per-request thinking state and forces the thinking end
token when the budget is reached. If a stop sentence is configured, the
processor emits the stop sentence first and then the thinking end token.
Request-specific configuration is provided via logits_processors_args:
{"thinking_budget": <int>}
Requires model_config to provide think_start_id, think_end_id, and line_break_id.
If any of these are missing or invalid (-1), the processor will be disabled.
Requires model_config to provide think_start_id and think_end_id. If any of
these are missing or invalid (-1), the processor will be disabled.
"""
def __init__(self, fd_config: FDConfig) -> None:
@@ -61,20 +62,33 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
self.think_start_token_id = think_start_id if isinstance(think_start_id, int) and think_start_id >= 0 else -1
self.think_end_token_id = think_end_id if isinstance(think_end_id, int) and think_end_id >= 0 else -1
self.line_break_token_id = line_break_id if isinstance(line_break_id, int) and line_break_id >= 0 else -1
self._enabled = (
self.think_start_token_id >= 0 and self.think_end_token_id >= 0 and self.line_break_token_id >= 0
)
self._enabled = self.think_start_token_id >= 0 and self.think_end_token_id >= 0
if not self._enabled:
logger.warning(
"ThinkingBudgetLogitsProcessor disabled: missing token ids "
f"(think_start={think_start_id}, think_end={think_end_id}, line_break={line_break_id}). "
"Ensure model vocab contains <think>, </think> tokens and line_break_id is configured."
f"(think_start={think_start_id}, think_end={think_end_id}). "
"Ensure model vocab contains <think> and </think> tokens."
)
self._states: Dict[str, _ThinkingState] = {}
self._active_req_ids: list[str] = []
self._active_budgets: list[int] = []
self._active_slots: list[int] = []
def _scan_prompt_state(self, prompt_slice: list[int]) -> tuple[bool, bool, int, Optional[int]]:
started = False
ended = False
in_thinking = False
for token_id in prompt_slice:
if token_id == self.think_start_token_id:
started = True
ended = False
in_thinking = True
elif token_id == self.think_end_token_id and in_thinking:
ended = True
in_thinking = False
last_token_id = int(prompt_slice[-1]) if started and prompt_slice else None
return started, ended, 0, last_token_id
def update_state(self, share_inputs: dict) -> None:
if not self._enabled:
return
@@ -82,6 +96,7 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
req_ids = share_inputs["req_ids"]
logits_processors_args = share_inputs["logits_processors_args"]
prompt_ids = share_inputs.get("prompt_ids")
token_ids_all = share_inputs.get("token_ids_all")
prompt_lens = share_inputs.get("prompt_lens")
pre_ids = share_inputs.get("pre_ids")
next_tokens = share_inputs.get("next_tokens")
@@ -153,6 +168,8 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
for idx, slot_id in enumerate(candidate_slots):
next_token_by_slot[slot_id] = int(next_sel[idx])
prompt_source = prompt_ids if prompt_ids is not None else token_ids_all
for idx, slot_id in enumerate(candidate_slots):
req_id = candidate_req_ids[idx]
logit_proc_args = candidate_args[idx]
@@ -190,25 +207,28 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
state.current_step_idx = current_step_idx
if not state.started and not state.prompt_checked:
if prompt_ids is not None and prompt_lens is not None:
if prompt_source is not None and prompt_lens is not None:
if prompt_lens_np is not None:
prompt_len = int(prompt_lens_np[slot_id])
else:
prompt_len = int(prompt_lens[slot_id])
prompt_slice = prompt_ids[slot_id, :prompt_len]
prompt_slice = prompt_slice.numpy().tolist()
if self.think_start_token_id in prompt_slice:
prompt_slice = prompt_source[slot_id, :prompt_len]
if hasattr(prompt_slice, "numpy"):
prompt_slice = prompt_slice.numpy().tolist()
elif hasattr(prompt_slice, "tolist"):
prompt_slice = prompt_slice.tolist()
else:
prompt_slice = list(prompt_slice)
if prompt_ids is None:
prompt_slice = [int(token_id) for token_id in prompt_slice if int(token_id) >= 0]
prompt_started, prompt_ended, prompt_tokens_after_start, prompt_last_token_id = (
self._scan_prompt_state(prompt_slice)
)
if prompt_started:
state.started = True
start_pos = prompt_slice.index(self.think_start_token_id)
tokens_after = prompt_slice[start_pos + 1 :]
if self.think_end_token_id in tokens_after:
end_pos = tokens_after.index(self.think_end_token_id)
state.tokens_after_start = end_pos + 1
state.ended = True
else:
state.tokens_after_start = len(tokens_after)
if prompt_slice:
state.last_token_id = int(prompt_slice[-1])
state.ended = prompt_ended
state.tokens_after_start = prompt_tokens_after_start
state.last_token_id = prompt_last_token_id
if current_step_idx is not None and state.last_step_idx is None:
state.last_step_idx = current_step_idx
state.prompt_checked = True
@@ -304,13 +324,6 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
if state.tokens_after_start < budget:
continue
if state.last_token_id != self.line_break_token_id:
logits[slot_id, :] = -float("inf")
logits[slot_id, self.line_break_token_id] = 0.0
state.last_token_id = self.line_break_token_id
state.last_step_idx = state.current_step_idx
continue
logits[slot_id, :] = -float("inf")
logits[slot_id, self.think_end_token_id] = 0.0
state.last_token_id = self.think_end_token_id
+327 -89
View File
@@ -20,7 +20,19 @@ from fastdeploy.engine import common_engine as common_engine_module
from fastdeploy.engine import engine as engine_module
from fastdeploy.engine.args_utils import EngineArgs # Import EngineArgs
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.input.ernie4_5_processor import (
Ernie4_5Processor as ErnieTextDataProcessor,
)
from fastdeploy.input.ernie4_5_vl_processor import (
Ernie4_5_VLProcessor as ErnieVLDataProcessor,
)
from fastdeploy.input.text_processor import DataProcessor as TextDataProcessor
from fastdeploy.input.v1.ernie4_5_processor import (
Ernie4_5Processor as V1ErnieTextDataProcessor,
)
from fastdeploy.input.v1.ernie4_5_vl_processor.ernie4_5_vl_processor import (
Ernie4_5_VLProcessor as V1ErnieVLDataProcessor,
)
from fastdeploy.input.v1.text_processor import DataProcessor as V1TextDataProcessor
from fastdeploy.model_executor.logits_processor import ThinkingBudgetLogitsProcessor
from fastdeploy.scheduler import SchedulerConfig
@@ -52,6 +64,7 @@ class MockModelRunner:
"req_ids": [f"req_{i}" for i in range(max_num_seqs)],
"logits_processors_args": [{} for _ in range(max_num_seqs)],
"prompt_ids": paddle.to_tensor(np.zeros((max_num_seqs, 10), dtype=np.int64)), # Max prompt len 10
"token_ids_all": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)),
"prompt_lens": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)),
"pre_ids": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)),
"step_idx": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)),
@@ -73,6 +86,12 @@ class MockModelRunner:
self.share_inputs["prompt_ids"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor(
req.prompt_ids, dtype=paddle.int64
)
self.share_inputs["token_ids_all"][slot_id, :] = paddle.to_tensor(
np.full((self.max_model_len,), -1, dtype=np.int64), dtype=paddle.int64
)
self.share_inputs["token_ids_all"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor(
req.prompt_ids, dtype=paddle.int64
)
self.share_inputs["prompt_lens"][slot_id] = paddle.to_tensor(len(req.prompt_ids), dtype=paddle.int64)
if req.sampling_params.logits_processors_args:
self.share_inputs["logits_processors_args"][slot_id] = req.sampling_params.logits_processors_args
@@ -152,7 +171,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
return logits
def test_thinking_budget_not_reached(self):
# Scenario: Thinking budget is 5, but only 3 tokens are generated in thinking phase
# Scenario: Thinking budget is 5, and prompt-side tokens after <think> do not consume budget.
req_id = "test_req_1"
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3, 4, 5]
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 5})
@@ -167,9 +186,9 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
processor.update_state(mock_runner.share_inputs)
self.assertTrue(processor._states[req_id].started)
self.assertFalse(processor._states[req_id].ended)
self.assertEqual(processor._states[req_id].tokens_after_start, 3) # (3, 4, 5) after THINKING_START
self.assertEqual(processor._states[req_id].tokens_after_start, 0)
# Step 2: Simulate one generation step (budget 5, generated 3 -> 4)
# Step 2: Simulate one generation step (budget 5, generated 0 -> 1)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs) # Update state before apply
processed_logits = processor.apply(logits)
@@ -181,9 +200,9 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token
processor.update_state(mock_runner.share_inputs) # Update state after generating token
self.assertEqual(processor._states[req_id].tokens_after_start, 4) # (3, 4, 5, 0)
self.assertEqual(processor._states[req_id].tokens_after_start, 1)
# Step 3: Simulate another generation step (budget 5, generated 4 -> 5)
# Step 3: Simulate another generation step (budget 5, generated 1 -> 2)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs) # Update state before apply
processed_logits = processor.apply(logits)
@@ -191,14 +210,14 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token
processor.update_state(mock_runner.share_inputs) # Update state after generating token
self.assertEqual(processor._states[req_id].tokens_after_start, 5) # (3, 4, 5, 0, 0)
self.assertEqual(processor._states[req_id].tokens_after_start, 2)
# LogitsProcessor should still not restrict as NEW_LINE_TOKEN is not yet last token
def test_thinking_budget_reached_forces_newline(self):
# Scenario: Budget is 3, after 3 tokens, it should force newline, then thinking_end
def test_thinking_budget_reached_forces_think_end(self):
# Scenario: Budget is 3 and only decode-time tokens count toward the budget.
req_id = "test_req_2"
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3] # Initial 1 token after start
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3]
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 3})
mock_req = MockRequest(req_id, prompt_ids, sampling_params)
@@ -206,16 +225,26 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1])
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
line_break_id = processor.line_break_token_id
think_end_id = processor.think_end_token_id
# Step 1: Initial state update (1 token after start)
# Step 1: Initial state update (prompt-side tokens do not count)
processor.update_state(mock_runner.share_inputs)
self.assertEqual(processor._states[req_id].tokens_after_start, 1)
self.assertEqual(processor._states[req_id].tokens_after_start, 0)
self.assertFalse(processor._states[req_id].ended)
self.assertEqual(processor._states[req_id].last_token_id, 3)
# Step 2: Generate 2nd token (budget 3, generated 1 -> 2)
# Step 2: Generate 1st decode token (budget 3, generated 0 -> 1)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, 0) # Normal generation
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
self.assertEqual(processor._states[req_id].tokens_after_start, 1)
self.assertEqual(processor._states[req_id].last_token_id, 0)
# Step 3: Generate 2nd decode token (budget 3, generated 1 -> 2)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
@@ -226,42 +255,17 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
self.assertEqual(processor._states[req_id].tokens_after_start, 2)
self.assertEqual(processor._states[req_id].last_token_id, 0)
# Step 3: Generate 3rd token (budget 3, generated 2 -> 3). Next step should force NEW_LINE.
# Step 4: Generate 3rd decode token (budget 3, generated 2 -> 3).
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, 0) # Normal generation
self.assertEqual(next_token, 0)
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
self.assertEqual(processor._states[req_id].tokens_after_start, 3) # Budget is now met
self.assertEqual(processor._states[req_id].last_token_id, 0)
self.assertEqual(processor._states[req_id].tokens_after_start, 3)
# Step 4: Budget reached, last token not NEW_LINE. Should force NEW_LINE_TOKEN_ID.
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
# Verify all other logits are -inf, only NEW_LINE_TOKEN_ID is 0.0
# Use <= for comparison because paddle.full creates -10.0 and comparison can have precision issues
other_logits = paddle.concat(
[
processed_logits[0, :line_break_id],
processed_logits[0, line_break_id + 1 : VOCAB_SIZE],
],
axis=0,
)
self.assertTrue(paddle.all(other_logits <= -10.0).item() or paddle.all(other_logits == -float("inf")).item())
self.assertEqual(processed_logits[0, line_break_id].item(), 0.0)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, line_break_id) # Forces NEW_LINE
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
self.assertEqual(processor._states[req_id].tokens_after_start, 4) # Still increment
self.assertEqual(processor._states[req_id].last_token_id, line_break_id)
# Step 5: Last token is NEW_LINE. Should force THINKING_END_TOKEN_ID.
# Step 5: Budget reached, should force THINKING_END_TOKEN_ID directly.
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
@@ -324,9 +328,9 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
self.assertEqual(next_token, stop_sentence_token_ids[0])
def test_thinking_budget_no_stop_sentence_defaults(self):
# Scenario: No stop sentence, budget reached should force newline.
# Scenario: No stop sentence, budget reached should force thinking_end directly.
req_id = "test_req_no_stop_sentence"
prompt_ids = [THINKING_START_TOKEN_ID, 42] # 1 token after start
prompt_ids = [THINKING_START_TOKEN_ID, 42]
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 1})
mock_req = MockRequest(req_id, prompt_ids, sampling_params)
@@ -334,35 +338,42 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1])
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
line_break_id = processor.line_break_token_id
processor.update_state(mock_runner.share_inputs)
logits = self._get_initial_logits(1)
processed_logits = processor.apply(logits)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, line_break_id)
self.assertEqual(next_token, 0)
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, THINKING_END_TOKEN_ID)
def test_thinking_budget_uses_config_token_ids(self):
# Scenario: Processor should use token ids from model config.
self.fd_config.model_config.think_start_id = 123
self.fd_config.model_config.think_end_id = 124
self.fd_config.model_config.line_break_id = 125
self.fd_config.model_config.line_break_id = -1
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
self.assertEqual(processor.think_start_token_id, 123)
self.assertEqual(processor.think_end_token_id, 124)
self.assertEqual(processor.line_break_token_id, 125)
self.assertEqual(processor.line_break_token_id, -1)
self.assertTrue(processor._enabled)
def test_thinking_budget_disabled_when_token_ids_missing(self):
# Scenario: Processor should be disabled when token ids are not configured.
self.fd_config.model_config.think_start_id = -1
self.fd_config.model_config.think_end_id = -1
self.fd_config.model_config.line_break_id = -1
self.fd_config.model_config.line_break_id = NEW_LINE_TOKEN_ID
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
self.assertFalse(processor._enabled)
self.assertEqual(processor.think_start_token_id, -1)
self.assertEqual(processor.think_end_token_id, -1)
self.assertEqual(processor.line_break_token_id, -1)
self.assertEqual(processor.line_break_token_id, NEW_LINE_TOKEN_ID)
# update_state and apply should be no-op when disabled
mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1)
@@ -493,6 +504,28 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
self.assertEqual(state.tokens_after_start, 2)
self.assertEqual(state.last_token_id, 99)
def test_thinking_budget_prompt_state_from_token_ids_all_fallback(self):
req_id = "req_gpu_fallback"
mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1)
mock_runner.share_inputs["req_ids"][0] = req_id
mock_runner.share_inputs["logits_processors_args"][0] = {"thinking_budget": 3}
mock_runner.share_inputs["prompt_ids"] = None
mock_runner.share_inputs["token_ids_all"][0, :4] = paddle.to_tensor(
[1, THINKING_START_TOKEN_ID, 2, 3], dtype=paddle.int64
)
mock_runner.share_inputs["prompt_lens"][0, 0] = paddle.to_tensor(4, dtype=paddle.int64)
mock_runner.share_inputs["next_tokens"][0, 0] = paddle.to_tensor(-1, dtype=paddle.int64)
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
processor.update_state(mock_runner.share_inputs)
state = processor._states[req_id]
self.assertTrue(state.prompt_checked)
self.assertTrue(state.started)
self.assertFalse(state.ended)
self.assertEqual(state.tokens_after_start, 0)
self.assertEqual(state.last_token_id, 3)
def test_thinking_budget_not_configured(self):
# Scenario: Processor is active, but request does not provide thinking_budget
req_id = "test_req_3"
@@ -532,20 +565,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
self.assertEqual(processor._states[req_id].tokens_after_start, 0) # No tokens after start yet
self.assertEqual(processor._states[req_id].last_token_id, THINKING_START_TOKEN_ID)
# Step 1: Budget 0 reached, last token is THINKING_START. Should force NEW_LINE.
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
self.assertEqual(processed_logits[0, NEW_LINE_TOKEN_ID].item(), 0.0)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, NEW_LINE_TOKEN_ID)
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
self.assertEqual(processor._states[req_id].tokens_after_start, 1) # Still increments
self.assertEqual(processor._states[req_id].last_token_id, NEW_LINE_TOKEN_ID)
# Step 2: Last token is NEW_LINE. Should force THINKING_END_TOKEN_ID.
# Step 1: Budget 0 reached, last token is THINKING_START. Should force THINKING_END.
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
@@ -572,7 +592,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
processor.update_state(mock_runner.share_inputs)
self.assertTrue(processor._states[req_id].started)
self.assertTrue(processor._states[req_id].ended)
self.assertEqual(processor._states[req_id].tokens_after_start, 2) # Tokens 2, THINKING_END after start
self.assertEqual(processor._states[req_id].tokens_after_start, 0)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
@@ -582,7 +602,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
def test_multiple_requests(self):
# Scenario: Multiple requests with different thinking states
req_id_1 = "req_a"
prompt_ids_1 = [THINKING_START_TOKEN_ID, 10, 11] # budget 2, last_token=11, tokens_after_start=2
prompt_ids_1 = [THINKING_START_TOKEN_ID, 10, 11]
sampling_params_1 = SamplingParams(logits_processors_args={"thinking_budget": 2})
mock_req_1 = MockRequest(req_id_1, prompt_ids_1, sampling_params_1)
@@ -592,7 +612,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
mock_req_2 = MockRequest(req_id_2, prompt_ids_2, sampling_params_2)
req_id_3 = "req_c"
prompt_ids_3 = [THINKING_START_TOKEN_ID, 30] # budget 1, last_token=30, tokens_after_start=1
prompt_ids_3 = [THINKING_START_TOKEN_ID, 30]
sampling_params_3 = SamplingParams(logits_processors_args={"thinking_budget": 1})
mock_req_3 = MockRequest(req_id_3, prompt_ids_3, sampling_params_3)
@@ -609,14 +629,14 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
# Verify initial states
self.assertTrue(processor._states[req_id_1].started)
self.assertFalse(processor._states[req_id_1].ended)
self.assertEqual(processor._states[req_id_1].tokens_after_start, 2)
self.assertEqual(processor._states[req_id_1].tokens_after_start, 0)
self.assertEqual(processor._states[req_id_1].last_token_id, 11)
self.assertNotIn(req_id_2, processor._states) # No budget specified for req_2
self.assertTrue(processor._states[req_id_3].started)
self.assertFalse(processor._states[req_id_3].ended)
self.assertEqual(processor._states[req_id_3].tokens_after_start, 1)
self.assertEqual(processor._states[req_id_3].tokens_after_start, 0)
self.assertEqual(processor._states[req_id_3].last_token_id, 30)
# Simulate logits for the batch
@@ -624,16 +644,14 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
processor.update_state(mock_runner.share_inputs) # Ensure state is updated before apply
processed_batch_logits = processor.apply(batch_logits)
# Req 1: budget 2, tokens_after_start 2. Should force NEW_LINE (last_token_id is 11, not NEW_LINE)
self.assertEqual(processed_batch_logits[0, NEW_LINE_TOKEN_ID].item(), 0.0)
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), NEW_LINE_TOKEN_ID)
# Req 1: prompt-side content does not consume budget, so first step is normal generation.
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), 0)
# Req 2: No thinking budget, normal generation
self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0)
# Req 3: budget 1, tokens_after_start 1. Should force NEW_LINE (last_token_id is 30, not NEW_LINE)
self.assertEqual(processed_batch_logits[2, NEW_LINE_TOKEN_ID].item(), 0.0)
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), NEW_LINE_TOKEN_ID)
# Req 3: prompt-side content does not consume budget, so first step is normal generation.
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), 0)
# Simulate generating next tokens and updating state
next_tokens = mock_runner.generate_next_token(processed_batch_logits)
@@ -643,20 +661,25 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
processor.update_state(mock_runner.share_inputs)
# Verify updated states for next step
self.assertEqual(processor._states[req_id_1].last_token_id, NEW_LINE_TOKEN_ID)
self.assertEqual(processor._states[req_id_3].last_token_id, NEW_LINE_TOKEN_ID)
self.assertEqual(processor._states[req_id_1].last_token_id, 0)
self.assertEqual(processor._states[req_id_3].last_token_id, 0)
self.assertEqual(processor._states[req_id_1].tokens_after_start, 1)
self.assertEqual(processor._states[req_id_3].tokens_after_start, 1)
self.assertFalse(processor._states[req_id_1].ended)
self.assertFalse(processor._states[req_id_3].ended)
batch_logits = self._get_initial_logits(3)
processor.update_state(mock_runner.share_inputs)
processed_batch_logits = processor.apply(batch_logits)
# Req 1: last token was NEW_LINE. Should force THINKING_END
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), THINKING_END_TOKEN_ID)
# Req 1: budget 2, tokens_after_start 1. Still normal generation.
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), 0)
# Req 2: Still normal generation
self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0)
# Req 3: last token was NEW_LINE. Should force THINKING_END
# Req 3: budget 1 reached after one generated token, should now force THINKING_END.
self.assertEqual(processed_batch_logits[2, THINKING_END_TOKEN_ID].item(), 0.0)
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), THINKING_END_TOKEN_ID)
@@ -724,7 +747,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
self.assertTrue(updated["think_prompt_checked"])
self.assertTrue(updated["think_prompt_started"])
self.assertTrue(updated["think_prompt_ended"])
self.assertEqual(updated["think_prompt_tokens_after_start"], 2)
self.assertEqual(updated["think_prompt_tokens_after_start"], 0)
self.assertEqual(updated["think_prompt_last_token_id"], 3)
def test_v1_process_request_missing_logits_processors_args(self):
@@ -833,6 +856,58 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
self.assertEqual(processor.encode_with_cache("iter"), [41, 42])
self.assertNotIn(("np", False), processor._tokenize_cache)
def test_text_encode_with_cache_lazy_init(self):
processor = TextDataProcessor.__new__(TextDataProcessor)
call_counter = {"count": 0}
def _text2ids(text, max_model_len=None, add_special_tokens=False):
call_counter["count"] += 1
return np.array([51, 52], dtype=np.int64)
processor.text2ids = _text2ids
self.assertFalse(hasattr(processor, "_tokenize_cache"))
self.assertEqual(processor.encode_with_cache("lazy"), [51, 52])
self.assertTrue(hasattr(processor, "_tokenize_cache"))
self.assertEqual(processor.encode_with_cache("lazy"), [51, 52])
self.assertEqual(call_counter["count"], 1)
def test_v1_encode_with_cache_lazy_init(self):
processor = V1TextDataProcessor.__new__(V1TextDataProcessor)
call_counter = {"count": 0}
def _text2ids(text, max_model_len=None, add_special_tokens=False):
call_counter["count"] += 1
return np.array([61, 62], dtype=np.int64)
processor.text2ids = _text2ids
self.assertFalse(hasattr(processor, "_tokenize_cache"))
self.assertEqual(processor.encode_with_cache("lazy"), [61, 62])
self.assertTrue(hasattr(processor, "_tokenize_cache"))
self.assertEqual(processor.encode_with_cache("lazy"), [61, 62])
self.assertEqual(call_counter["count"], 1)
def test_ernie_encode_literal_text_with_cache(self):
processor = ErnieTextDataProcessor.__new__(ErnieTextDataProcessor)
processor.tokenizer = SimpleNamespace(
tokenize=lambda text: ["token_a", "token_b"],
convert_tokens_to_ids=lambda tokens: [71, 72],
)
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [71, 72])
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [71, 72])
def test_v1_ernie_encode_literal_text_with_cache(self):
processor = V1ErnieTextDataProcessor.__new__(V1ErnieTextDataProcessor)
processor.tokenizer = SimpleNamespace(
tokenize=lambda text: ["token_c", "token_d"],
convert_tokens_to_ids=lambda tokens: [81, 82],
)
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [81, 82])
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [81, 82])
def test_text_update_thinking_prompt_state_branches(self):
processor = TextDataProcessor.__new__(TextDataProcessor)
processor._think_token_ids = None
@@ -868,7 +943,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
)
self.assertTrue(with_start_no_end["think_prompt_started"])
self.assertFalse(with_start_no_end["think_prompt_ended"])
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2)
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 0)
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
# 命中 _get_think_token_ids 的缓存分支
@@ -891,7 +966,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
)
self.assertTrue(with_start_no_end["think_prompt_started"])
self.assertFalse(with_start_no_end["think_prompt_ended"])
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2)
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 0)
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
# 命中 _get_think_token_ids 的缓存分支
@@ -903,7 +978,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [201, 202]
processor._encode_literal_text_with_cache = lambda text: [201, 202]
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
processor.reasoning_parser = None
@@ -924,7 +999,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(
processed["logits_processors_args"].get("think_stop_sentence_token_ids"),
[23, 201, 202],
[201, 202],
)
self.assertNotIn("think_stop_sentence", processed["logits_processors_args"])
@@ -934,7 +1009,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [301, 302]
processor._encode_literal_text_with_cache = lambda text: [301, 302]
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
processor.reasoning_parser = None
@@ -955,7 +1030,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processed = processor.process_request(request, max_model_len=16)
self.assertEqual(
processed.logits_processors_args.get("think_stop_sentence_token_ids"),
[23, 301, 302],
[301, 302],
)
self.assertNotIn("think_stop_sentence", processed.logits_processors_args)
@@ -965,7 +1040,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [401, 402]
processor._encode_literal_text_with_cache = lambda text: [401, 402]
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
processor.reasoning_parser = None
@@ -992,10 +1067,173 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(
processed.sampling_params.logits_processors_args.get("think_stop_sentence_token_ids"),
[23, 401, 402],
[401, 402],
)
self.assertNotIn("think_stop_sentence", processed.sampling_params.logits_processors_args)
def test_ernie_process_request_dict_prepares_thinking_budget_args(self):
processor = ErnieTextDataProcessor.__new__(ErnieTextDataProcessor)
processor._apply_default_parameters = lambda request: request
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor._encode_literal_text_with_cache = lambda text: [501, 502]
processor.tokenizer = DummyTokenizerForTextProcessor()
processor.reasoning_parser = None
request = {
"request_id": "req_ernie_text",
"eos_token_ids": [1],
"prompt_token_ids": [1, THINKING_START_TOKEN_ID, 2],
"prompt": None,
"messages": None,
"logits_processors_args": {"thinking_budget": 20, "think_stop_sentence": "done"},
"bad_words": None,
"bad_words_token_ids": None,
"max_tokens": 1,
"temperature": 1.0,
"top_p": 0.9,
"response_max_tokens": None,
"enable_thinking": True,
}
with patch("fastdeploy.input.ernie4_5_processor.process_stop_token_ids", lambda *args, **kwargs: None):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(processed["logits_processors_args"]["think_stop_sentence_token_ids"], [501, 502])
self.assertTrue(processed["logits_processors_args"]["think_prompt_started"])
self.assertFalse(processed["logits_processors_args"]["think_prompt_ended"])
self.assertEqual(processed["logits_processors_args"]["think_prompt_tokens_after_start"], 0)
def test_v1_ernie_process_request_dict_prepares_thinking_budget_args(self):
processor = V1ErnieTextDataProcessor.__new__(V1ErnieTextDataProcessor)
processor._apply_default_parameters = lambda request: request
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor._encode_literal_text_with_cache = lambda text: [601, 602]
processor.tokenizer = DummyTokenizerForTextProcessor()
processor.reasoning_parser = None
request = DummyRequestV1(
request_id="req_v1_ernie_text",
eos_token_ids=[1],
prompt_token_ids=[1, THINKING_START_TOKEN_ID, 2],
prompt=None,
messages=None,
chat_template_kwargs=None,
enable_thinking=True,
sampling_params=SimpleNamespace(
bad_words=None,
bad_words_token_ids=None,
max_tokens=1,
temperature=1.0,
top_p=0.9,
repetition_penalty=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
response_max_tokens=None,
n=1,
logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
),
)
with patch("fastdeploy.input.v1.ernie4_5_processor.process_stop_token_ids", lambda *args, **kwargs: None):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(processed.sampling_params.logits_processors_args["think_stop_sentence_token_ids"], [601, 602])
self.assertTrue(processed.sampling_params.logits_processors_args["think_prompt_started"])
self.assertFalse(processed.sampling_params.logits_processors_args["think_prompt_ended"])
self.assertEqual(processed.sampling_params.logits_processors_args["think_prompt_tokens_after_start"], 0)
def test_ernie_vl_process_request_dict_prepares_thinking_budget_args(self):
processor = ErnieVLDataProcessor.__new__(ErnieVLDataProcessor)
processor._apply_default_parameters = lambda request: request
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor._encode_literal_text_with_cache = lambda text: [701, 702]
processor.tokenizer = DummyTokenizerForTextProcessor()
processor.reasoning_parser = None
processor._check_mm_limits = lambda *args, **kwargs: None
processor.append_completion_tokens = lambda *args, **kwargs: None
processor.pack_outputs = lambda outs: outs
processor.ernie4_5_processor = SimpleNamespace(
request2ids=lambda request: {"input_ids": np.array([1, THINKING_START_TOKEN_ID, 2], dtype=np.int64)}
)
request = {
"request_id": "req_ernie_vl",
"eos_token_ids": [1],
"messages": [{"role": "user", "content": "hi"}],
"bad_words": None,
"bad_words_token_ids": None,
"logits_processors_args": {"thinking_budget": 20, "think_stop_sentence": "done"},
"max_tokens": 1,
"top_p": 0.9,
"response_max_tokens": None,
}
with patch(
"fastdeploy.input.ernie4_5_vl_processor.ernie4_5_vl_processor.process_stop_token_ids",
lambda *args, **kwargs: None,
):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(processed["logits_processors_args"]["think_stop_sentence_token_ids"], [701, 702])
self.assertTrue(processed["logits_processors_args"]["think_prompt_started"])
self.assertFalse(processed["logits_processors_args"]["think_prompt_ended"])
self.assertEqual(processed["logits_processors_args"]["think_prompt_tokens_after_start"], 0)
def test_v1_ernie_vl_process_request_dict_prepares_thinking_budget_args(self):
processor = V1ErnieVLDataProcessor.__new__(V1ErnieVLDataProcessor)
processor._apply_default_parameters = lambda request: request
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor._encode_literal_text_with_cache = lambda text: [801, 802]
processor.tokenizer = DummyTokenizerForTextProcessor()
processor.reasoning_parser = None
processor._check_mm_limits = lambda *args, **kwargs: None
processor.append_completion_tokens = lambda *args, **kwargs: None
processor.pack_outputs = lambda outs: outs
processor.ernie4_5_processor = SimpleNamespace(
request2ids=lambda request: {"input_ids": np.array([1, THINKING_START_TOKEN_ID, 2], dtype=np.int64)}
)
request = DummyRequestV1(
request_id="req_v1_ernie_vl",
eos_token_ids=[1],
prompt_token_ids=None,
prompt=None,
messages=[{"role": "user", "content": "hi"}],
chat_template_kwargs=None,
enable_thinking=True,
completion_token_ids=None,
multimodal_data=None,
sampling_params=SimpleNamespace(
bad_words=None,
bad_words_token_ids=None,
max_tokens=1,
temperature=1.0,
top_p=0.9,
repetition_penalty=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
response_max_tokens=None,
reasoning_max_tokens=None,
n=1,
logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
),
)
with patch(
"fastdeploy.input.v1.ernie4_5_vl_processor.ernie4_5_vl_processor.process_stop_token_ids",
lambda *args, **kwargs: None,
):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(processed.sampling_params.logits_processors_args["think_stop_sentence_token_ids"], [801, 802])
self.assertTrue(processed.sampling_params.logits_processors_args["think_prompt_started"])
self.assertFalse(processed.sampling_params.logits_processors_args["think_prompt_ended"])
self.assertEqual(processed.sampling_params.logits_processors_args["think_prompt_tokens_after_start"], 0)
if __name__ == "__main__":
unittest.main()