mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Bugfix] Align thinking_budget behavior with ERNIE reasoning flow (#6934)
* [Bugfix] Align thinking_budget behavior with ERNIE reasoning flow * [Docs] Fix thinking_budget markdown formatting * [Test] Align ernie thinking budget test with process_request_dict
This commit is contained in:
@@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
`ThinkingBudgetLogitsProcessor` limits the number of tokens generated inside the `<think> ... </think>` segment. When the budget is reached, it forces a line break token and then the `</think>` token to terminate the thinking section.
|
`ThinkingBudgetLogitsProcessor` limits the number of tokens generated inside the `<think> ... </think>`
|
||||||
|
segment. When the budget is reached, it terminates thinking by forcing `</think>`. If
|
||||||
|
`think_stop_sentence` is configured, it forces the custom sentence first and then `</think>`.
|
||||||
|
|
||||||
## When to Use
|
## When to Use
|
||||||
|
|
||||||
@@ -11,19 +13,22 @@
|
|||||||
|
|
||||||
## How It Works
|
## How It Works
|
||||||
|
|
||||||
1. **CPU precompute (DataProcessor)**: when a request includes `thinking_budget`, the prompt token ids are scanned to determine whether thinking has started, whether it already ended, and how many tokens are already inside the thinking section.
|
1. **Request-side precompute (DataProcessor)**: when a request includes `thinking_budget`, the prompt token ids are scanned to determine whether thinking has started, whether it already ended, and how many tokens are already inside the thinking section.
|
||||||
2. **Per-step update**: during decoding, the processor tracks `last_token_id` and `tokens_after_start`.
|
2. **Per-step update**: during decoding, the processor tracks `last_token_id` and `tokens_after_start`.
|
||||||
3. **Budget enforcement**: once the budget is reached, it forces a line break and then the thinking end token.
|
3. **Budget enforcement**: once the budget is reached, it forces `</think>` directly. If `think_stop_sentence`
|
||||||
|
is configured, it forces that sentence first and then `</think>`.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- The model must provide valid token ids for `think_start_id`, `think_end_id`, and `line_break_id` (via `ModelConfig`).
|
- The model must provide valid token ids for `think_start_id` and `think_end_id` (via `ModelConfig`).
|
||||||
- If any of these ids are invalid, the processor is disabled and `thinking_budget` will not take effect.
|
- If either of these ids is invalid, the processor is disabled and `thinking_budget` will not take effect.
|
||||||
|
|
||||||
## Request Parameters
|
## Request Parameters
|
||||||
|
|
||||||
- `thinking_budget` (int, required to enable): maximum number of tokens after `<think>` before forced termination.
|
- `thinking_budget` (int, required to enable): maximum number of decode-time tokens after `<think>` before forced
|
||||||
- `think_stop_sentence` (string, optional): a stop sentence that will be tokenized on the CPU side and enforced near the budget boundary.
|
termination.
|
||||||
|
- `think_stop_sentence` (string, optional): a literal custom sentence that will be tokenized on the request side
|
||||||
|
and enforced near the budget boundary.
|
||||||
|
|
||||||
## Operator-Level vs LogitsProcessor
|
## Operator-Level vs LogitsProcessor
|
||||||
|
|
||||||
@@ -41,16 +46,25 @@ FastDeploy has two ways to limit thinking length:
|
|||||||
In short:
|
In short:
|
||||||
|
|
||||||
- If you only need a hard cap on thinking length, prefer `reasoning_max_tokens`.
|
- If you only need a hard cap on thinking length, prefer `reasoning_max_tokens`.
|
||||||
- If you need custom behavior (for example, injecting custom sentence tokens), use `ThinkingBudgetLogitsProcessor`.
|
- If you need custom behavior (for example, inserting a custom sentence before `</think>`), use
|
||||||
|
`ThinkingBudgetLogitsProcessor`.
|
||||||
|
|
||||||
## Practical guidance
|
## Practical guidance
|
||||||
|
|
||||||
`reasoning_max_tokens` and `thinking_budget` are not mutually exclusive in current implementation.
|
`reasoning_max_tokens` and `thinking_budget` are not mutually exclusive in current implementation.
|
||||||
If both are configured for the same request, both constraints can take effect, and whichever triggers first will end the thinking phase.
|
If both are configured for the same request, both constraints can take effect, and whichever triggers first will end the thinking phase.
|
||||||
|
|
||||||
- To use **operator-level-only** behavior: this is request-level config only. Set `enable_thinking=true` and `reasoning_max_tokens` in request, and do not set `thinking_budget`.
|
- To use **operator-level-only** behavior: this is request-level config only. Set
|
||||||
- To use **logits-processor-only** behavior (especially with `think_stop_sentence`): this requires service-level + request-level config. Start service with `--logits-processors ThinkingBudgetLogitsProcessor`, and set `thinking_budget` (and optional `think_stop_sentence`) in `logits_processors_args`; leave `reasoning_max_tokens` unset.
|
`enable_thinking=true` and `reasoning_max_tokens` in request, and do not set `thinking_budget`.
|
||||||
- Avoid enabling both for strict custom sentence insertion requirements, because operator-level termination may cut the custom sentence path earlier.
|
- To use **logits-processor-only** behavior (especially with `think_stop_sentence`): this requires
|
||||||
|
service-level + request-level config. Start service with `--logits-processors ThinkingBudgetLogitsProcessor`,
|
||||||
|
and set `thinking_budget` (and optional `think_stop_sentence`) in `logits_processors_args`; leave
|
||||||
|
`reasoning_max_tokens` unset.
|
||||||
|
- `thinking_budget` itself does not require `enable_thinking=true`.
|
||||||
|
- If an ERNIE chat template already appends `<think>` in the prompt, `thinking_budget` should still take effect; it
|
||||||
|
does not require the model to emit another `<think>` during decoding.
|
||||||
|
- Avoid enabling both for strict custom sentence insertion requirements, because operator-level
|
||||||
|
termination may cut the custom sentence path earlier.
|
||||||
|
|
||||||
## Online Usage
|
## Online Usage
|
||||||
|
|
||||||
@@ -120,4 +134,6 @@ print(outputs[0].outputs.text)
|
|||||||
|
|
||||||
## Performance Note
|
## Performance Note
|
||||||
|
|
||||||
This processor runs `update_state` and `apply` on every decode step. If you only need a hard thinking-length cap and care most about throughput, consider the operator-level reasoning-length controls instead of per-step logits processing.
|
This processor runs `update_state` and `apply` on every decode step. If you only need a hard
|
||||||
|
thinking-length cap and care most about throughput, consider the operator-level reasoning-length
|
||||||
|
controls instead of per-step logits processing.
|
||||||
|
|||||||
@@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
## 概述
|
## 概述
|
||||||
|
|
||||||
`ThinkingBudgetLogitsProcessor` 用于限制 `<think> ... </think>` 区间的生成长度。当预算达到阈值时,会强制生成换行符 token,再强制生成 `</think>`,从而结束思考段。
|
`ThinkingBudgetLogitsProcessor` 用于限制 `<think> ... </think>` 区间的生成长度。当预算达到阈值时,
|
||||||
|
会直接强制生成 `</think>` 来结束思考段;如果配置了 `think_stop_sentence`,则会先强制输出该自定义
|
||||||
|
文案,再输出 `</think>`。
|
||||||
|
|
||||||
## 适用场景
|
## 适用场景
|
||||||
|
|
||||||
@@ -11,19 +13,20 @@
|
|||||||
|
|
||||||
## 工作原理
|
## 工作原理
|
||||||
|
|
||||||
1. **CPU 侧预计算(DataProcessor)**:当请求中包含 `thinking_budget`,会基于 prompt 的 token ids 计算是否已进入思考段、是否已结束,以及已有的思考长度。
|
1. **请求侧预计算(DataProcessor)**:当请求中包含 `thinking_budget`,会基于 prompt 的 token ids 计算是否已进入思考段、是否已结束,以及已有的思考长度。
|
||||||
2. **每步更新**:解码过程中跟踪 `last_token_id` 与 `tokens_after_start`。
|
2. **每步更新**:解码过程中跟踪 `last_token_id` 与 `tokens_after_start`。
|
||||||
3. **预算约束**:达到预算后,依次强制换行符与思考结束 token。
|
3. **预算约束**:达到预算后,默认直接强制 `</think>`;如果配置了 `think_stop_sentence`,则先逐 token
|
||||||
|
强制输出该文案,再输出 `</think>`。
|
||||||
|
|
||||||
## 前置要求
|
## 前置要求
|
||||||
|
|
||||||
- 模型需提供有效的 `think_start_id`、`think_end_id`、`line_break_id`(来自 `ModelConfig`)。
|
- 模型需提供有效的 `think_start_id`、`think_end_id`(来自 `ModelConfig`)。
|
||||||
- 若任意 id 无效,处理器会禁用,`thinking_budget` 不生效。
|
- 若其中任意 id 无效,处理器会禁用,`thinking_budget` 不生效。
|
||||||
|
|
||||||
## 请求参数
|
## 请求参数
|
||||||
|
|
||||||
- `thinking_budget`(int,启用所需):`<think>` 之后允许的最大 token 数。
|
- `thinking_budget`(int,启用所需):`<think>` 之后允许的最大 decode 阶段 token 数。
|
||||||
- `think_stop_sentence`(string,可选):CPU 侧会将该字符串编码为 token ids,并在预算边界附近强制输出。
|
- `think_stop_sentence`(string,可选):按字面串编码的自定义终止文案,并在预算边界附近强制输出。
|
||||||
|
|
||||||
## 算子级限制 vs LogitsProcessor
|
## 算子级限制 vs LogitsProcessor
|
||||||
|
|
||||||
@@ -35,21 +38,27 @@ FastDeploy 当前有两种思考长度控制方式:
|
|||||||
- 适合“只限制思考长度”的简单场景。
|
- 适合“只限制思考长度”的简单场景。
|
||||||
- **`ThinkingBudgetLogitsProcessor`**(`logits_processors_args.thinking_budget`):
|
- **`ThinkingBudgetLogitsProcessor`**(`logits_processors_args.thinking_budget`):
|
||||||
- 由每步 Python 侧 logits 处理实现。
|
- 由每步 Python 侧 logits 处理实现。
|
||||||
- 支持更灵活的行为,例如 `think_stop_sentence`(在结束前插入自定义话术)。
|
- 支持更灵活的行为,例如 `think_stop_sentence`。
|
||||||
- 相比算子级限制,在高并发下通常有更高开销。
|
- 相比算子级限制,在高并发下通常有更高开销。
|
||||||
|
|
||||||
可按以下原则选择:
|
可按以下原则选择:
|
||||||
|
|
||||||
- 仅需限制思考长度:优先用 `reasoning_max_tokens`。
|
- 仅需限制思考长度:优先用 `reasoning_max_tokens`。
|
||||||
- 需要更灵活控制(如插入自定义话术):使用 `ThinkingBudgetLogitsProcessor`。
|
- 需要更灵活控制(如在 `</think>` 前插入自定义话术):使用 `ThinkingBudgetLogitsProcessor`。
|
||||||
|
|
||||||
## 建议实践
|
## 建议实践
|
||||||
|
|
||||||
当前实现中,`reasoning_max_tokens` 与 `thinking_budget` 不是互斥关系。
|
当前实现中,`reasoning_max_tokens` 与 `thinking_budget` 不是互斥关系。
|
||||||
同一请求如果同时配置,两套约束都可能生效,谁先触发就先结束思考段。
|
同一请求如果同时配置,两套约束都可能生效,谁先触发就先结束思考段。
|
||||||
|
|
||||||
- **只用算子级限制**:这是请求级配置。仅在请求中设置 `enable_thinking=true` + `reasoning_max_tokens`,不要传 `thinking_budget`。
|
- **只用算子级限制**:这是请求级配置。仅在请求中设置 `enable_thinking=true` + `reasoning_max_tokens`,
|
||||||
- **只用 LogitsProcessor**(尤其要用 `think_stop_sentence`):这是“服务启动 + 请求参数”两级配置。服务启动时必须加 `--logits-processors ThinkingBudgetLogitsProcessor`,并在请求里通过 `logits_processors_args` 传 `thinking_budget`(以及可选的 `think_stop_sentence`);同时不要设置 `reasoning_max_tokens`。
|
不要传 `thinking_budget`。
|
||||||
|
- **只用 LogitsProcessor**(尤其要用 `think_stop_sentence`):这是“服务启动 + 请求参数”两级配置。
|
||||||
|
服务启动时必须加 `--logits-processors ThinkingBudgetLogitsProcessor`,并在请求里通过
|
||||||
|
`logits_processors_args` 传 `thinking_budget`(以及可选的 `think_stop_sentence`);同时不要设置
|
||||||
|
`reasoning_max_tokens`。
|
||||||
|
- `thinking_budget` 本身不依赖 `enable_thinking=true`。
|
||||||
|
- 如果 ERNIE 的 chat template 已经在 prompt 里拼入 `<think>`,`thinking_budget` 也应正常生效,不要求模型在 decode 阶段再次输出 `<think>`。
|
||||||
- 如果业务要求“必须完整插入自定义话术”,不建议与算子级限制同时开启,否则可能被算子级提前截断。
|
- 如果业务要求“必须完整插入自定义话术”,不建议与算子级限制同时开启,否则可能被算子级提前截断。
|
||||||
|
|
||||||
## 在线使用
|
## 在线使用
|
||||||
|
|||||||
@@ -107,6 +107,11 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||||
request["bad_words_token_ids"] = bad_words_token_ids
|
request["bad_words_token_ids"] = bad_words_token_ids
|
||||||
|
|
||||||
|
logits_processors_args = self._prepare_think_stop_sentence(
|
||||||
|
request.get("logits_processors_args") or {}, max_model_len
|
||||||
|
)
|
||||||
|
request["logits_processors_args"] = logits_processors_args
|
||||||
|
|
||||||
# processing prompt_token_ids
|
# processing prompt_token_ids
|
||||||
if not request.get("prompt_token_ids"):
|
if not request.get("prompt_token_ids"):
|
||||||
if request.get("prompt"):
|
if request.get("prompt"):
|
||||||
@@ -143,6 +148,10 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
# truncate prompts that exceed the length limit
|
# truncate prompts that exceed the length limit
|
||||||
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
|
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
|
||||||
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
|
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
|
||||||
|
logits_processors_args = self._update_thinking_prompt_state(
|
||||||
|
request["prompt_token_ids"], request.get("logits_processors_args") or {}
|
||||||
|
)
|
||||||
|
request["logits_processors_args"] = logits_processors_args
|
||||||
max_tokens = max_model_len - len(request["prompt_token_ids"])
|
max_tokens = max_model_len - len(request["prompt_token_ids"])
|
||||||
if request.get("max_tokens") is None:
|
if request.get("max_tokens") is None:
|
||||||
request["max_tokens"] = max(1, max_tokens)
|
request["max_tokens"] = max(1, max_tokens)
|
||||||
|
|||||||
@@ -210,6 +210,11 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
|||||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||||
request["bad_words_token_ids"] = bad_words_token_ids
|
request["bad_words_token_ids"] = bad_words_token_ids
|
||||||
|
|
||||||
|
logits_processors_args = self._prepare_think_stop_sentence(
|
||||||
|
request.get("logits_processors_args") or {}, max_model_len
|
||||||
|
)
|
||||||
|
request["logits_processors_args"] = logits_processors_args
|
||||||
|
|
||||||
if request.get("prompt_token_ids"):
|
if request.get("prompt_token_ids"):
|
||||||
messages = request.get("messages")
|
messages = request.get("messages")
|
||||||
if messages:
|
if messages:
|
||||||
@@ -257,6 +262,10 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
|||||||
# 截断超过长度限制的prompt
|
# 截断超过长度限制的prompt
|
||||||
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
|
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
|
||||||
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
|
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
|
||||||
|
logits_processors_args = self._update_thinking_prompt_state(
|
||||||
|
request["prompt_token_ids"], request.get("logits_processors_args") or {}
|
||||||
|
)
|
||||||
|
request["logits_processors_args"] = logits_processors_args
|
||||||
|
|
||||||
max_tokens = max_model_len - len(request["prompt_token_ids"])
|
max_tokens = max_model_len - len(request["prompt_token_ids"])
|
||||||
if request.get("max_tokens") is None:
|
if request.get("max_tokens") is None:
|
||||||
|
|||||||
@@ -92,6 +92,9 @@ class BaseDataProcessor(ABC):
|
|||||||
"""
|
"""
|
||||||
Encode text into token ids with a small LRU cache.
|
Encode text into token ids with a small LRU cache.
|
||||||
"""
|
"""
|
||||||
|
if not hasattr(self, "_tokenize_cache"):
|
||||||
|
self._tokenize_cache = OrderedDict()
|
||||||
|
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
|
||||||
key = (text, bool(add_special_tokens))
|
key = (text, bool(add_special_tokens))
|
||||||
cached = self._tokenize_cache.get(key)
|
cached = self._tokenize_cache.get(key)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
@@ -107,6 +110,25 @@ class BaseDataProcessor(ABC):
|
|||||||
self._tokenize_cache.popitem(last=False)
|
self._tokenize_cache.popitem(last=False)
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
|
def _encode_literal_text_with_cache(self, text):
|
||||||
|
if not hasattr(self, "_tokenize_cache"):
|
||||||
|
self._tokenize_cache = OrderedDict()
|
||||||
|
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
|
||||||
|
key = ("literal_text", text)
|
||||||
|
cached = self._tokenize_cache.get(key)
|
||||||
|
if cached is not None:
|
||||||
|
self._tokenize_cache.move_to_end(key)
|
||||||
|
return cached
|
||||||
|
token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
|
||||||
|
if hasattr(token_ids, "tolist"):
|
||||||
|
token_ids = token_ids.tolist()
|
||||||
|
elif not isinstance(token_ids, list):
|
||||||
|
token_ids = list(token_ids)
|
||||||
|
self._tokenize_cache[key] = token_ids
|
||||||
|
if len(self._tokenize_cache) > self._tokenize_cache_capacity:
|
||||||
|
self._tokenize_cache.popitem(last=False)
|
||||||
|
return token_ids
|
||||||
|
|
||||||
def messages2ids(self, messages):
|
def messages2ids(self, messages):
|
||||||
"""
|
"""
|
||||||
Convert multi-turn messages into ID sequences.
|
Convert multi-turn messages into ID sequences.
|
||||||
@@ -119,6 +141,77 @@ class BaseDataProcessor(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_think_token_ids(self):
|
||||||
|
think_token_ids = getattr(self, "_think_token_ids", None)
|
||||||
|
if think_token_ids is not None:
|
||||||
|
return think_token_ids
|
||||||
|
tokenizer = getattr(self, "tokenizer", None)
|
||||||
|
vocab = tokenizer.get_vocab() if tokenizer is not None else {}
|
||||||
|
think_start_id = vocab.get("<think>", -1)
|
||||||
|
think_end_id = vocab.get("</think>", -1)
|
||||||
|
self._think_token_ids = (think_start_id, think_end_id)
|
||||||
|
return self._think_token_ids
|
||||||
|
|
||||||
|
def _prepare_think_stop_sentence(self, logits_processors_args, max_model_len=None):
|
||||||
|
if not isinstance(logits_processors_args, dict):
|
||||||
|
return logits_processors_args
|
||||||
|
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
|
||||||
|
if isinstance(think_stop_sentence, str) and think_stop_sentence:
|
||||||
|
sentence_token_ids = self._encode_literal_text_with_cache(think_stop_sentence)
|
||||||
|
logits_processors_args["think_stop_sentence_token_ids"] = sentence_token_ids
|
||||||
|
logits_processors_args.pop("think_stop_sentence", None)
|
||||||
|
return logits_processors_args
|
||||||
|
|
||||||
|
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
|
||||||
|
if not isinstance(logits_processors_args, dict):
|
||||||
|
return logits_processors_args
|
||||||
|
thinking_budget = logits_processors_args.get("thinking_budget")
|
||||||
|
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
|
||||||
|
return logits_processors_args
|
||||||
|
if logits_processors_args.get("think_prompt_checked"):
|
||||||
|
return logits_processors_args
|
||||||
|
if prompt_token_ids is None:
|
||||||
|
return logits_processors_args
|
||||||
|
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
|
||||||
|
if token_len == 0:
|
||||||
|
return logits_processors_args
|
||||||
|
think_start_id, think_end_id = self._get_think_token_ids()
|
||||||
|
if think_start_id < 0 or think_end_id < 0:
|
||||||
|
return logits_processors_args
|
||||||
|
|
||||||
|
if hasattr(prompt_token_ids, "tolist"):
|
||||||
|
token_list = prompt_token_ids.tolist()
|
||||||
|
else:
|
||||||
|
token_list = list(prompt_token_ids)
|
||||||
|
|
||||||
|
started = False
|
||||||
|
ended = False
|
||||||
|
tokens_after_start = 0
|
||||||
|
last_token_id = None
|
||||||
|
in_thinking = False
|
||||||
|
for token_id in token_list:
|
||||||
|
if token_id == think_start_id:
|
||||||
|
started = True
|
||||||
|
ended = False
|
||||||
|
in_thinking = True
|
||||||
|
elif token_id == think_end_id and in_thinking:
|
||||||
|
ended = True
|
||||||
|
in_thinking = False
|
||||||
|
if started and token_list:
|
||||||
|
# Align with operator-level reasoning_max_tokens: prompt-side tokens
|
||||||
|
# inside <think> do not consume thinking budget.
|
||||||
|
last_token_id = int(token_list[-1])
|
||||||
|
|
||||||
|
logits_processors_args["think_prompt_checked"] = True
|
||||||
|
logits_processors_args["think_prompt_started"] = started
|
||||||
|
logits_processors_args["think_prompt_ended"] = ended
|
||||||
|
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
|
||||||
|
if last_token_id is not None:
|
||||||
|
logits_processors_args["think_prompt_last_token_id"] = last_token_id
|
||||||
|
else:
|
||||||
|
logits_processors_args.pop("think_prompt_last_token_id", None)
|
||||||
|
return logits_processors_args
|
||||||
|
|
||||||
def ids2tokens(self, token_id, task_id=None):
|
def ids2tokens(self, token_id, task_id=None):
|
||||||
"""
|
"""
|
||||||
token ids to strings
|
token ids to strings
|
||||||
@@ -207,65 +300,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
|
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
|
||||||
self.tokenizer.pad_token_id = self.pad_token_id
|
self.tokenizer.pad_token_id = self.pad_token_id
|
||||||
|
|
||||||
self._think_token_ids = None
|
|
||||||
|
|
||||||
def _get_think_token_ids(self):
|
|
||||||
if self._think_token_ids is not None:
|
|
||||||
return self._think_token_ids
|
|
||||||
vocab = self.tokenizer.get_vocab()
|
|
||||||
think_start_id = vocab.get("<think>", -1)
|
|
||||||
think_end_id = vocab.get("</think>", -1)
|
|
||||||
self._think_token_ids = (think_start_id, think_end_id)
|
|
||||||
return self._think_token_ids
|
|
||||||
|
|
||||||
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
|
|
||||||
if not isinstance(logits_processors_args, dict):
|
|
||||||
return logits_processors_args
|
|
||||||
thinking_budget = logits_processors_args.get("thinking_budget")
|
|
||||||
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
|
|
||||||
return logits_processors_args
|
|
||||||
if logits_processors_args.get("think_prompt_checked"):
|
|
||||||
return logits_processors_args
|
|
||||||
if prompt_token_ids is None:
|
|
||||||
return logits_processors_args
|
|
||||||
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
|
|
||||||
if token_len == 0:
|
|
||||||
return logits_processors_args
|
|
||||||
think_start_id, think_end_id = self._get_think_token_ids()
|
|
||||||
if think_start_id < 0 or think_end_id < 0:
|
|
||||||
return logits_processors_args
|
|
||||||
|
|
||||||
if hasattr(prompt_token_ids, "tolist"):
|
|
||||||
token_list = prompt_token_ids.tolist()
|
|
||||||
else:
|
|
||||||
token_list = list(prompt_token_ids)
|
|
||||||
|
|
||||||
started = think_start_id in token_list
|
|
||||||
ended = False
|
|
||||||
tokens_after_start = 0
|
|
||||||
last_token_id = None
|
|
||||||
if started:
|
|
||||||
start_pos = token_list.index(think_start_id)
|
|
||||||
tokens_after = token_list[start_pos + 1 :]
|
|
||||||
if think_end_id in tokens_after:
|
|
||||||
end_pos = tokens_after.index(think_end_id)
|
|
||||||
tokens_after_start = end_pos + 1
|
|
||||||
ended = True
|
|
||||||
else:
|
|
||||||
tokens_after_start = len(tokens_after)
|
|
||||||
if token_list:
|
|
||||||
last_token_id = int(token_list[-1])
|
|
||||||
|
|
||||||
logits_processors_args["think_prompt_checked"] = True
|
|
||||||
logits_processors_args["think_prompt_started"] = started
|
|
||||||
logits_processors_args["think_prompt_ended"] = ended
|
|
||||||
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
|
|
||||||
if last_token_id is not None:
|
|
||||||
logits_processors_args["think_prompt_last_token_id"] = last_token_id
|
|
||||||
else:
|
|
||||||
logits_processors_args.pop("think_prompt_last_token_id", None)
|
|
||||||
return logits_processors_args
|
|
||||||
|
|
||||||
def process_request_dict(self, request, max_model_len=None, **kwargs):
|
def process_request_dict(self, request, max_model_len=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Preprocess the request
|
Preprocess the request
|
||||||
@@ -292,14 +326,10 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||||
request["bad_words_token_ids"] = bad_words_token_ids
|
request["bad_words_token_ids"] = bad_words_token_ids
|
||||||
|
|
||||||
logits_processors_args = request.get("logits_processors_args") or {}
|
logits_processors_args = self._prepare_think_stop_sentence(
|
||||||
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
|
request.get("logits_processors_args") or {}, max_model_len
|
||||||
if isinstance(think_stop_sentence, str) and think_stop_sentence:
|
)
|
||||||
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
|
request["logits_processors_args"] = logits_processors_args
|
||||||
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
|
|
||||||
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
|
|
||||||
logits_processors_args.pop("think_stop_sentence", None)
|
|
||||||
request["logits_processors_args"] = logits_processors_args
|
|
||||||
|
|
||||||
# processing prompt_token_ids
|
# processing prompt_token_ids
|
||||||
if not request.get("prompt_token_ids"):
|
if not request.get("prompt_token_ids"):
|
||||||
|
|||||||
@@ -107,6 +107,11 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||||
request["bad_words_token_ids"] = bad_words_token_ids
|
request["bad_words_token_ids"] = bad_words_token_ids
|
||||||
|
|
||||||
|
logits_processors_args = self._prepare_think_stop_sentence(
|
||||||
|
request.get("logits_processors_args") or {}, max_model_len
|
||||||
|
)
|
||||||
|
request["logits_processors_args"] = logits_processors_args
|
||||||
|
|
||||||
# processing prompt_token_ids
|
# processing prompt_token_ids
|
||||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
||||||
if request.prompt is not None:
|
if request.prompt is not None:
|
||||||
@@ -123,7 +128,8 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||||
request.prompt_token_ids = token_ids
|
request.prompt_token_ids = token_ids
|
||||||
data_processor_logger.debug(
|
data_processor_logger.debug(
|
||||||
f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
|
f"request_ids: {request.request_id}, prompt: {prompt}, "
|
||||||
|
f"tokens: {tokens}, token_ids: {token_ids}"
|
||||||
)
|
)
|
||||||
elif request.messages is not None:
|
elif request.messages is not None:
|
||||||
task = request.to_dict()
|
task = request.to_dict()
|
||||||
@@ -145,6 +151,10 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
# truncate prompts that exceed the length limit
|
# truncate prompts that exceed the length limit
|
||||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
||||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
||||||
|
logits_processors_args = self._update_thinking_prompt_state(
|
||||||
|
request.prompt_token_ids, request.get("logits_processors_args") or {}
|
||||||
|
)
|
||||||
|
request["logits_processors_args"] = logits_processors_args
|
||||||
max_tokens = max_model_len - len(request.prompt_token_ids)
|
max_tokens = max_model_len - len(request.prompt_token_ids)
|
||||||
if request.get("max_tokens") is None:
|
if request.get("max_tokens") is None:
|
||||||
request.set("max_tokens", max(1, max_tokens))
|
request.set("max_tokens", max(1, max_tokens))
|
||||||
@@ -201,6 +211,11 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||||
request.sampling_params.bad_words_token_ids = bad_words_token_ids
|
request.sampling_params.bad_words_token_ids = bad_words_token_ids
|
||||||
|
|
||||||
|
logits_processors_args = self._prepare_think_stop_sentence(
|
||||||
|
getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
|
||||||
|
)
|
||||||
|
request.sampling_params.logits_processors_args = logits_processors_args
|
||||||
|
|
||||||
# processing prompt_token_ids
|
# processing prompt_token_ids
|
||||||
if not request.prompt_token_ids:
|
if not request.prompt_token_ids:
|
||||||
if request.prompt:
|
if request.prompt:
|
||||||
@@ -241,6 +256,10 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
# truncate prompts that exceed the length limit
|
# truncate prompts that exceed the length limit
|
||||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
||||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
||||||
|
logits_processors_args = self._update_thinking_prompt_state(
|
||||||
|
request.prompt_token_ids, getattr(request.sampling_params, "logits_processors_args", None) or {}
|
||||||
|
)
|
||||||
|
request.sampling_params.logits_processors_args = logits_processors_args
|
||||||
max_tokens = max_model_len - len(request.prompt_token_ids)
|
max_tokens = max_model_len - len(request.prompt_token_ids)
|
||||||
if getattr(request.sampling_params, "max_tokens", None) is None:
|
if getattr(request.sampling_params, "max_tokens", None) is None:
|
||||||
request.sampling_params.max_tokens = max(1, max_tokens)
|
request.sampling_params.max_tokens = max(1, max_tokens)
|
||||||
|
|||||||
@@ -216,6 +216,11 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
|||||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||||
request.sampling_params.bad_words_token_ids = bad_words_token_ids
|
request.sampling_params.bad_words_token_ids = bad_words_token_ids
|
||||||
|
|
||||||
|
logits_processors_args = self._prepare_think_stop_sentence(
|
||||||
|
getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
|
||||||
|
)
|
||||||
|
request.sampling_params.logits_processors_args = logits_processors_args
|
||||||
|
|
||||||
if request.prompt_token_ids:
|
if request.prompt_token_ids:
|
||||||
messages = request.messages
|
messages = request.messages
|
||||||
if messages:
|
if messages:
|
||||||
@@ -267,6 +272,10 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
|||||||
# 截断超过长度限制的prompt
|
# 截断超过长度限制的prompt
|
||||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
||||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
||||||
|
logits_processors_args = self._update_thinking_prompt_state(
|
||||||
|
request.prompt_token_ids, getattr(request.sampling_params, "logits_processors_args", None) or {}
|
||||||
|
)
|
||||||
|
request.sampling_params.logits_processors_args = logits_processors_args
|
||||||
|
|
||||||
max_tokens = max_model_len - len(request.prompt_token_ids)
|
max_tokens = max_model_len - len(request.prompt_token_ids)
|
||||||
if getattr(request.sampling_params, "max_tokens", None) is None:
|
if getattr(request.sampling_params, "max_tokens", None) is None:
|
||||||
|
|||||||
@@ -115,6 +115,9 @@ class BaseDataProcessor(ABC):
|
|||||||
"""
|
"""
|
||||||
Encode text into token ids with a small LRU cache.
|
Encode text into token ids with a small LRU cache.
|
||||||
"""
|
"""
|
||||||
|
if not hasattr(self, "_tokenize_cache"):
|
||||||
|
self._tokenize_cache = OrderedDict()
|
||||||
|
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
|
||||||
key = (text, bool(add_special_tokens))
|
key = (text, bool(add_special_tokens))
|
||||||
cached = self._tokenize_cache.get(key)
|
cached = self._tokenize_cache.get(key)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
@@ -130,6 +133,25 @@ class BaseDataProcessor(ABC):
|
|||||||
self._tokenize_cache.popitem(last=False)
|
self._tokenize_cache.popitem(last=False)
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
|
def _encode_literal_text_with_cache(self, text):
|
||||||
|
if not hasattr(self, "_tokenize_cache"):
|
||||||
|
self._tokenize_cache = OrderedDict()
|
||||||
|
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
|
||||||
|
key = ("literal_text", text)
|
||||||
|
cached = self._tokenize_cache.get(key)
|
||||||
|
if cached is not None:
|
||||||
|
self._tokenize_cache.move_to_end(key)
|
||||||
|
return cached
|
||||||
|
token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
|
||||||
|
if hasattr(token_ids, "tolist"):
|
||||||
|
token_ids = token_ids.tolist()
|
||||||
|
elif not isinstance(token_ids, list):
|
||||||
|
token_ids = list(token_ids)
|
||||||
|
self._tokenize_cache[key] = token_ids
|
||||||
|
if len(self._tokenize_cache) > self._tokenize_cache_capacity:
|
||||||
|
self._tokenize_cache.popitem(last=False)
|
||||||
|
return token_ids
|
||||||
|
|
||||||
def messages2ids(self, messages):
|
def messages2ids(self, messages):
|
||||||
"""
|
"""
|
||||||
Convert multi-turn messages into ID sequences.
|
Convert multi-turn messages into ID sequences.
|
||||||
@@ -142,6 +164,77 @@ class BaseDataProcessor(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_think_token_ids(self):
|
||||||
|
think_token_ids = getattr(self, "_think_token_ids", None)
|
||||||
|
if think_token_ids is not None:
|
||||||
|
return think_token_ids
|
||||||
|
tokenizer = getattr(self, "tokenizer", None)
|
||||||
|
vocab = tokenizer.get_vocab() if tokenizer is not None else {}
|
||||||
|
think_start_id = vocab.get("<think>", -1)
|
||||||
|
think_end_id = vocab.get("</think>", -1)
|
||||||
|
self._think_token_ids = (think_start_id, think_end_id)
|
||||||
|
return self._think_token_ids
|
||||||
|
|
||||||
|
def _prepare_think_stop_sentence(self, logits_processors_args, max_model_len=None):
|
||||||
|
if not isinstance(logits_processors_args, dict):
|
||||||
|
return logits_processors_args
|
||||||
|
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
|
||||||
|
if isinstance(think_stop_sentence, str) and think_stop_sentence:
|
||||||
|
sentence_token_ids = self._encode_literal_text_with_cache(think_stop_sentence)
|
||||||
|
logits_processors_args["think_stop_sentence_token_ids"] = sentence_token_ids
|
||||||
|
logits_processors_args.pop("think_stop_sentence", None)
|
||||||
|
return logits_processors_args
|
||||||
|
|
||||||
|
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
|
||||||
|
if not isinstance(logits_processors_args, dict):
|
||||||
|
return logits_processors_args
|
||||||
|
thinking_budget = logits_processors_args.get("thinking_budget")
|
||||||
|
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
|
||||||
|
return logits_processors_args
|
||||||
|
if logits_processors_args.get("think_prompt_checked"):
|
||||||
|
return logits_processors_args
|
||||||
|
if prompt_token_ids is None:
|
||||||
|
return logits_processors_args
|
||||||
|
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
|
||||||
|
if token_len == 0:
|
||||||
|
return logits_processors_args
|
||||||
|
think_start_id, think_end_id = self._get_think_token_ids()
|
||||||
|
if think_start_id < 0 or think_end_id < 0:
|
||||||
|
return logits_processors_args
|
||||||
|
|
||||||
|
if hasattr(prompt_token_ids, "tolist"):
|
||||||
|
token_list = prompt_token_ids.tolist()
|
||||||
|
else:
|
||||||
|
token_list = list(prompt_token_ids)
|
||||||
|
|
||||||
|
started = False
|
||||||
|
ended = False
|
||||||
|
tokens_after_start = 0
|
||||||
|
last_token_id = None
|
||||||
|
in_thinking = False
|
||||||
|
for token_id in token_list:
|
||||||
|
if token_id == think_start_id:
|
||||||
|
started = True
|
||||||
|
ended = False
|
||||||
|
in_thinking = True
|
||||||
|
elif token_id == think_end_id and in_thinking:
|
||||||
|
ended = True
|
||||||
|
in_thinking = False
|
||||||
|
if started and token_list:
|
||||||
|
# Align with operator-level reasoning_max_tokens: prompt-side tokens
|
||||||
|
# inside <think> do not consume thinking budget.
|
||||||
|
last_token_id = int(token_list[-1])
|
||||||
|
|
||||||
|
logits_processors_args["think_prompt_checked"] = True
|
||||||
|
logits_processors_args["think_prompt_started"] = started
|
||||||
|
logits_processors_args["think_prompt_ended"] = ended
|
||||||
|
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
|
||||||
|
if last_token_id is not None:
|
||||||
|
logits_processors_args["think_prompt_last_token_id"] = last_token_id
|
||||||
|
else:
|
||||||
|
logits_processors_args.pop("think_prompt_last_token_id", None)
|
||||||
|
return logits_processors_args
|
||||||
|
|
||||||
def ids2tokens(self, token_id, task_id=None):
|
def ids2tokens(self, token_id, task_id=None):
|
||||||
"""
|
"""
|
||||||
token ids to strings
|
token ids to strings
|
||||||
@@ -221,65 +314,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
|
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
|
||||||
self.tokenizer.pad_token_id = self.pad_token_id
|
self.tokenizer.pad_token_id = self.pad_token_id
|
||||||
|
|
||||||
self._think_token_ids = None
|
|
||||||
|
|
||||||
def _get_think_token_ids(self):
|
|
||||||
if self._think_token_ids is not None:
|
|
||||||
return self._think_token_ids
|
|
||||||
vocab = self.tokenizer.get_vocab()
|
|
||||||
think_start_id = vocab.get("<think>", -1)
|
|
||||||
think_end_id = vocab.get("</think>", -1)
|
|
||||||
self._think_token_ids = (think_start_id, think_end_id)
|
|
||||||
return self._think_token_ids
|
|
||||||
|
|
||||||
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
|
|
||||||
if not isinstance(logits_processors_args, dict):
|
|
||||||
return logits_processors_args
|
|
||||||
thinking_budget = logits_processors_args.get("thinking_budget")
|
|
||||||
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
|
|
||||||
return logits_processors_args
|
|
||||||
if logits_processors_args.get("think_prompt_checked"):
|
|
||||||
return logits_processors_args
|
|
||||||
if prompt_token_ids is None:
|
|
||||||
return logits_processors_args
|
|
||||||
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
|
|
||||||
if token_len == 0:
|
|
||||||
return logits_processors_args
|
|
||||||
think_start_id, think_end_id = self._get_think_token_ids()
|
|
||||||
if think_start_id < 0 or think_end_id < 0:
|
|
||||||
return logits_processors_args
|
|
||||||
|
|
||||||
if hasattr(prompt_token_ids, "tolist"):
|
|
||||||
token_list = prompt_token_ids.tolist()
|
|
||||||
else:
|
|
||||||
token_list = list(prompt_token_ids)
|
|
||||||
|
|
||||||
started = think_start_id in token_list
|
|
||||||
ended = False
|
|
||||||
tokens_after_start = 0
|
|
||||||
last_token_id = None
|
|
||||||
if started:
|
|
||||||
start_pos = token_list.index(think_start_id)
|
|
||||||
tokens_after = token_list[start_pos + 1 :]
|
|
||||||
if think_end_id in tokens_after:
|
|
||||||
end_pos = tokens_after.index(think_end_id)
|
|
||||||
tokens_after_start = end_pos + 1
|
|
||||||
ended = True
|
|
||||||
else:
|
|
||||||
tokens_after_start = len(tokens_after)
|
|
||||||
if token_list:
|
|
||||||
last_token_id = int(token_list[-1])
|
|
||||||
|
|
||||||
logits_processors_args["think_prompt_checked"] = True
|
|
||||||
logits_processors_args["think_prompt_started"] = started
|
|
||||||
logits_processors_args["think_prompt_ended"] = ended
|
|
||||||
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
|
|
||||||
if last_token_id is not None:
|
|
||||||
logits_processors_args["think_prompt_last_token_id"] = last_token_id
|
|
||||||
else:
|
|
||||||
logits_processors_args.pop("think_prompt_last_token_id", None)
|
|
||||||
return logits_processors_args
|
|
||||||
|
|
||||||
def process_request(self, request, max_model_len=None, **kwargs):
|
def process_request(self, request, max_model_len=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Preprocess the request
|
Preprocess the request
|
||||||
@@ -306,14 +340,10 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||||
request["bad_words_token_ids"] = bad_words_token_ids
|
request["bad_words_token_ids"] = bad_words_token_ids
|
||||||
|
|
||||||
logits_processors_args = request.get("logits_processors_args") or {}
|
logits_processors_args = self._prepare_think_stop_sentence(
|
||||||
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
|
request.get("logits_processors_args") or {}, max_model_len
|
||||||
if isinstance(think_stop_sentence, str) and think_stop_sentence:
|
)
|
||||||
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
|
request["logits_processors_args"] = logits_processors_args
|
||||||
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
|
|
||||||
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
|
|
||||||
logits_processors_args.pop("think_stop_sentence", None)
|
|
||||||
request["logits_processors_args"] = logits_processors_args
|
|
||||||
|
|
||||||
# processing prompt_token_ids
|
# processing prompt_token_ids
|
||||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
||||||
@@ -411,14 +441,10 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||||
request.sampling_params.bad_words_token_ids = bad_words_token_ids
|
request.sampling_params.bad_words_token_ids = bad_words_token_ids
|
||||||
|
|
||||||
logits_processors_args = getattr(request.sampling_params, "logits_processors_args", None) or {}
|
logits_processors_args = self._prepare_think_stop_sentence(
|
||||||
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
|
getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
|
||||||
if isinstance(think_stop_sentence, str) and think_stop_sentence:
|
)
|
||||||
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
|
request.sampling_params.logits_processors_args = logits_processors_args
|
||||||
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
|
|
||||||
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
|
|
||||||
logits_processors_args.pop("think_stop_sentence", None)
|
|
||||||
request.sampling_params.logits_processors_args = logits_processors_args
|
|
||||||
|
|
||||||
# processing prompt_token_ids
|
# processing prompt_token_ids
|
||||||
if not request.prompt_token_ids:
|
if not request.prompt_token_ids:
|
||||||
|
|||||||
@@ -42,14 +42,15 @@ class _ThinkingState:
|
|||||||
class ThinkingBudgetLogitsProcessor(LogitsProcessor):
|
class ThinkingBudgetLogitsProcessor(LogitsProcessor):
|
||||||
"""Limit the number of tokens generated in the thinking phase.
|
"""Limit the number of tokens generated in the thinking phase.
|
||||||
|
|
||||||
The processor tracks per-request thinking state and forces a newline token
|
The processor tracks per-request thinking state and forces the thinking end
|
||||||
when the budget is reached, followed by the thinking end token on the next step.
|
token when the budget is reached. If a stop sentence is configured, the
|
||||||
|
processor emits the stop sentence first and then the thinking end token.
|
||||||
Request-specific configuration is provided via logits_processors_args:
|
Request-specific configuration is provided via logits_processors_args:
|
||||||
|
|
||||||
{"thinking_budget": <int>}
|
{"thinking_budget": <int>}
|
||||||
|
|
||||||
Requires model_config to provide think_start_id, think_end_id, and line_break_id.
|
Requires model_config to provide think_start_id and think_end_id. If any of
|
||||||
If any of these are missing or invalid (-1), the processor will be disabled.
|
these are missing or invalid (-1), the processor will be disabled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, fd_config: FDConfig) -> None:
|
def __init__(self, fd_config: FDConfig) -> None:
|
||||||
@@ -61,20 +62,33 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
|
|||||||
self.think_start_token_id = think_start_id if isinstance(think_start_id, int) and think_start_id >= 0 else -1
|
self.think_start_token_id = think_start_id if isinstance(think_start_id, int) and think_start_id >= 0 else -1
|
||||||
self.think_end_token_id = think_end_id if isinstance(think_end_id, int) and think_end_id >= 0 else -1
|
self.think_end_token_id = think_end_id if isinstance(think_end_id, int) and think_end_id >= 0 else -1
|
||||||
self.line_break_token_id = line_break_id if isinstance(line_break_id, int) and line_break_id >= 0 else -1
|
self.line_break_token_id = line_break_id if isinstance(line_break_id, int) and line_break_id >= 0 else -1
|
||||||
self._enabled = (
|
self._enabled = self.think_start_token_id >= 0 and self.think_end_token_id >= 0
|
||||||
self.think_start_token_id >= 0 and self.think_end_token_id >= 0 and self.line_break_token_id >= 0
|
|
||||||
)
|
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"ThinkingBudgetLogitsProcessor disabled: missing token ids "
|
"ThinkingBudgetLogitsProcessor disabled: missing token ids "
|
||||||
f"(think_start={think_start_id}, think_end={think_end_id}, line_break={line_break_id}). "
|
f"(think_start={think_start_id}, think_end={think_end_id}). "
|
||||||
"Ensure model vocab contains <think>, </think> tokens and line_break_id is configured."
|
"Ensure model vocab contains <think> and </think> tokens."
|
||||||
)
|
)
|
||||||
self._states: Dict[str, _ThinkingState] = {}
|
self._states: Dict[str, _ThinkingState] = {}
|
||||||
self._active_req_ids: list[str] = []
|
self._active_req_ids: list[str] = []
|
||||||
self._active_budgets: list[int] = []
|
self._active_budgets: list[int] = []
|
||||||
self._active_slots: list[int] = []
|
self._active_slots: list[int] = []
|
||||||
|
|
||||||
|
def _scan_prompt_state(self, prompt_slice: list[int]) -> tuple[bool, bool, int, Optional[int]]:
|
||||||
|
started = False
|
||||||
|
ended = False
|
||||||
|
in_thinking = False
|
||||||
|
for token_id in prompt_slice:
|
||||||
|
if token_id == self.think_start_token_id:
|
||||||
|
started = True
|
||||||
|
ended = False
|
||||||
|
in_thinking = True
|
||||||
|
elif token_id == self.think_end_token_id and in_thinking:
|
||||||
|
ended = True
|
||||||
|
in_thinking = False
|
||||||
|
last_token_id = int(prompt_slice[-1]) if started and prompt_slice else None
|
||||||
|
return started, ended, 0, last_token_id
|
||||||
|
|
||||||
def update_state(self, share_inputs: dict) -> None:
|
def update_state(self, share_inputs: dict) -> None:
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return
|
return
|
||||||
@@ -82,6 +96,7 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
|
|||||||
req_ids = share_inputs["req_ids"]
|
req_ids = share_inputs["req_ids"]
|
||||||
logits_processors_args = share_inputs["logits_processors_args"]
|
logits_processors_args = share_inputs["logits_processors_args"]
|
||||||
prompt_ids = share_inputs.get("prompt_ids")
|
prompt_ids = share_inputs.get("prompt_ids")
|
||||||
|
token_ids_all = share_inputs.get("token_ids_all")
|
||||||
prompt_lens = share_inputs.get("prompt_lens")
|
prompt_lens = share_inputs.get("prompt_lens")
|
||||||
pre_ids = share_inputs.get("pre_ids")
|
pre_ids = share_inputs.get("pre_ids")
|
||||||
next_tokens = share_inputs.get("next_tokens")
|
next_tokens = share_inputs.get("next_tokens")
|
||||||
@@ -153,6 +168,8 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
|
|||||||
for idx, slot_id in enumerate(candidate_slots):
|
for idx, slot_id in enumerate(candidate_slots):
|
||||||
next_token_by_slot[slot_id] = int(next_sel[idx])
|
next_token_by_slot[slot_id] = int(next_sel[idx])
|
||||||
|
|
||||||
|
prompt_source = prompt_ids if prompt_ids is not None else token_ids_all
|
||||||
|
|
||||||
for idx, slot_id in enumerate(candidate_slots):
|
for idx, slot_id in enumerate(candidate_slots):
|
||||||
req_id = candidate_req_ids[idx]
|
req_id = candidate_req_ids[idx]
|
||||||
logit_proc_args = candidate_args[idx]
|
logit_proc_args = candidate_args[idx]
|
||||||
@@ -190,25 +207,28 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
|
|||||||
state.current_step_idx = current_step_idx
|
state.current_step_idx = current_step_idx
|
||||||
|
|
||||||
if not state.started and not state.prompt_checked:
|
if not state.started and not state.prompt_checked:
|
||||||
if prompt_ids is not None and prompt_lens is not None:
|
if prompt_source is not None and prompt_lens is not None:
|
||||||
if prompt_lens_np is not None:
|
if prompt_lens_np is not None:
|
||||||
prompt_len = int(prompt_lens_np[slot_id])
|
prompt_len = int(prompt_lens_np[slot_id])
|
||||||
else:
|
else:
|
||||||
prompt_len = int(prompt_lens[slot_id])
|
prompt_len = int(prompt_lens[slot_id])
|
||||||
prompt_slice = prompt_ids[slot_id, :prompt_len]
|
prompt_slice = prompt_source[slot_id, :prompt_len]
|
||||||
prompt_slice = prompt_slice.numpy().tolist()
|
if hasattr(prompt_slice, "numpy"):
|
||||||
if self.think_start_token_id in prompt_slice:
|
prompt_slice = prompt_slice.numpy().tolist()
|
||||||
|
elif hasattr(prompt_slice, "tolist"):
|
||||||
|
prompt_slice = prompt_slice.tolist()
|
||||||
|
else:
|
||||||
|
prompt_slice = list(prompt_slice)
|
||||||
|
if prompt_ids is None:
|
||||||
|
prompt_slice = [int(token_id) for token_id in prompt_slice if int(token_id) >= 0]
|
||||||
|
prompt_started, prompt_ended, prompt_tokens_after_start, prompt_last_token_id = (
|
||||||
|
self._scan_prompt_state(prompt_slice)
|
||||||
|
)
|
||||||
|
if prompt_started:
|
||||||
state.started = True
|
state.started = True
|
||||||
start_pos = prompt_slice.index(self.think_start_token_id)
|
state.ended = prompt_ended
|
||||||
tokens_after = prompt_slice[start_pos + 1 :]
|
state.tokens_after_start = prompt_tokens_after_start
|
||||||
if self.think_end_token_id in tokens_after:
|
state.last_token_id = prompt_last_token_id
|
||||||
end_pos = tokens_after.index(self.think_end_token_id)
|
|
||||||
state.tokens_after_start = end_pos + 1
|
|
||||||
state.ended = True
|
|
||||||
else:
|
|
||||||
state.tokens_after_start = len(tokens_after)
|
|
||||||
if prompt_slice:
|
|
||||||
state.last_token_id = int(prompt_slice[-1])
|
|
||||||
if current_step_idx is not None and state.last_step_idx is None:
|
if current_step_idx is not None and state.last_step_idx is None:
|
||||||
state.last_step_idx = current_step_idx
|
state.last_step_idx = current_step_idx
|
||||||
state.prompt_checked = True
|
state.prompt_checked = True
|
||||||
@@ -304,13 +324,6 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
|
|||||||
if state.tokens_after_start < budget:
|
if state.tokens_after_start < budget:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if state.last_token_id != self.line_break_token_id:
|
|
||||||
logits[slot_id, :] = -float("inf")
|
|
||||||
logits[slot_id, self.line_break_token_id] = 0.0
|
|
||||||
state.last_token_id = self.line_break_token_id
|
|
||||||
state.last_step_idx = state.current_step_idx
|
|
||||||
continue
|
|
||||||
|
|
||||||
logits[slot_id, :] = -float("inf")
|
logits[slot_id, :] = -float("inf")
|
||||||
logits[slot_id, self.think_end_token_id] = 0.0
|
logits[slot_id, self.think_end_token_id] = 0.0
|
||||||
state.last_token_id = self.think_end_token_id
|
state.last_token_id = self.think_end_token_id
|
||||||
|
|||||||
@@ -20,7 +20,19 @@ from fastdeploy.engine import common_engine as common_engine_module
|
|||||||
from fastdeploy.engine import engine as engine_module
|
from fastdeploy.engine import engine as engine_module
|
||||||
from fastdeploy.engine.args_utils import EngineArgs # Import EngineArgs
|
from fastdeploy.engine.args_utils import EngineArgs # Import EngineArgs
|
||||||
from fastdeploy.engine.sampling_params import SamplingParams
|
from fastdeploy.engine.sampling_params import SamplingParams
|
||||||
|
from fastdeploy.input.ernie4_5_processor import (
|
||||||
|
Ernie4_5Processor as ErnieTextDataProcessor,
|
||||||
|
)
|
||||||
|
from fastdeploy.input.ernie4_5_vl_processor import (
|
||||||
|
Ernie4_5_VLProcessor as ErnieVLDataProcessor,
|
||||||
|
)
|
||||||
from fastdeploy.input.text_processor import DataProcessor as TextDataProcessor
|
from fastdeploy.input.text_processor import DataProcessor as TextDataProcessor
|
||||||
|
from fastdeploy.input.v1.ernie4_5_processor import (
|
||||||
|
Ernie4_5Processor as V1ErnieTextDataProcessor,
|
||||||
|
)
|
||||||
|
from fastdeploy.input.v1.ernie4_5_vl_processor.ernie4_5_vl_processor import (
|
||||||
|
Ernie4_5_VLProcessor as V1ErnieVLDataProcessor,
|
||||||
|
)
|
||||||
from fastdeploy.input.v1.text_processor import DataProcessor as V1TextDataProcessor
|
from fastdeploy.input.v1.text_processor import DataProcessor as V1TextDataProcessor
|
||||||
from fastdeploy.model_executor.logits_processor import ThinkingBudgetLogitsProcessor
|
from fastdeploy.model_executor.logits_processor import ThinkingBudgetLogitsProcessor
|
||||||
from fastdeploy.scheduler import SchedulerConfig
|
from fastdeploy.scheduler import SchedulerConfig
|
||||||
@@ -52,6 +64,7 @@ class MockModelRunner:
|
|||||||
"req_ids": [f"req_{i}" for i in range(max_num_seqs)],
|
"req_ids": [f"req_{i}" for i in range(max_num_seqs)],
|
||||||
"logits_processors_args": [{} for _ in range(max_num_seqs)],
|
"logits_processors_args": [{} for _ in range(max_num_seqs)],
|
||||||
"prompt_ids": paddle.to_tensor(np.zeros((max_num_seqs, 10), dtype=np.int64)), # Max prompt len 10
|
"prompt_ids": paddle.to_tensor(np.zeros((max_num_seqs, 10), dtype=np.int64)), # Max prompt len 10
|
||||||
|
"token_ids_all": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)),
|
||||||
"prompt_lens": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)),
|
"prompt_lens": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)),
|
||||||
"pre_ids": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)),
|
"pre_ids": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)),
|
||||||
"step_idx": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)),
|
"step_idx": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)),
|
||||||
@@ -73,6 +86,12 @@ class MockModelRunner:
|
|||||||
self.share_inputs["prompt_ids"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor(
|
self.share_inputs["prompt_ids"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor(
|
||||||
req.prompt_ids, dtype=paddle.int64
|
req.prompt_ids, dtype=paddle.int64
|
||||||
)
|
)
|
||||||
|
self.share_inputs["token_ids_all"][slot_id, :] = paddle.to_tensor(
|
||||||
|
np.full((self.max_model_len,), -1, dtype=np.int64), dtype=paddle.int64
|
||||||
|
)
|
||||||
|
self.share_inputs["token_ids_all"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor(
|
||||||
|
req.prompt_ids, dtype=paddle.int64
|
||||||
|
)
|
||||||
self.share_inputs["prompt_lens"][slot_id] = paddle.to_tensor(len(req.prompt_ids), dtype=paddle.int64)
|
self.share_inputs["prompt_lens"][slot_id] = paddle.to_tensor(len(req.prompt_ids), dtype=paddle.int64)
|
||||||
if req.sampling_params.logits_processors_args:
|
if req.sampling_params.logits_processors_args:
|
||||||
self.share_inputs["logits_processors_args"][slot_id] = req.sampling_params.logits_processors_args
|
self.share_inputs["logits_processors_args"][slot_id] = req.sampling_params.logits_processors_args
|
||||||
@@ -152,7 +171,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
def test_thinking_budget_not_reached(self):
|
def test_thinking_budget_not_reached(self):
|
||||||
# Scenario: Thinking budget is 5, but only 3 tokens are generated in thinking phase
|
# Scenario: Thinking budget is 5, and prompt-side tokens after <think> do not consume budget.
|
||||||
req_id = "test_req_1"
|
req_id = "test_req_1"
|
||||||
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3, 4, 5]
|
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3, 4, 5]
|
||||||
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 5})
|
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 5})
|
||||||
@@ -167,9 +186,9 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
processor.update_state(mock_runner.share_inputs)
|
processor.update_state(mock_runner.share_inputs)
|
||||||
self.assertTrue(processor._states[req_id].started)
|
self.assertTrue(processor._states[req_id].started)
|
||||||
self.assertFalse(processor._states[req_id].ended)
|
self.assertFalse(processor._states[req_id].ended)
|
||||||
self.assertEqual(processor._states[req_id].tokens_after_start, 3) # (3, 4, 5) after THINKING_START
|
self.assertEqual(processor._states[req_id].tokens_after_start, 0)
|
||||||
|
|
||||||
# Step 2: Simulate one generation step (budget 5, generated 3 -> 4)
|
# Step 2: Simulate one generation step (budget 5, generated 0 -> 1)
|
||||||
logits = self._get_initial_logits(1)
|
logits = self._get_initial_logits(1)
|
||||||
processor.update_state(mock_runner.share_inputs) # Update state before apply
|
processor.update_state(mock_runner.share_inputs) # Update state before apply
|
||||||
processed_logits = processor.apply(logits)
|
processed_logits = processor.apply(logits)
|
||||||
@@ -181,9 +200,9 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token
|
mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token
|
||||||
|
|
||||||
processor.update_state(mock_runner.share_inputs) # Update state after generating token
|
processor.update_state(mock_runner.share_inputs) # Update state after generating token
|
||||||
self.assertEqual(processor._states[req_id].tokens_after_start, 4) # (3, 4, 5, 0)
|
self.assertEqual(processor._states[req_id].tokens_after_start, 1)
|
||||||
|
|
||||||
# Step 3: Simulate another generation step (budget 5, generated 4 -> 5)
|
# Step 3: Simulate another generation step (budget 5, generated 1 -> 2)
|
||||||
logits = self._get_initial_logits(1)
|
logits = self._get_initial_logits(1)
|
||||||
processor.update_state(mock_runner.share_inputs) # Update state before apply
|
processor.update_state(mock_runner.share_inputs) # Update state before apply
|
||||||
processed_logits = processor.apply(logits)
|
processed_logits = processor.apply(logits)
|
||||||
@@ -191,14 +210,14 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token
|
mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token
|
||||||
|
|
||||||
processor.update_state(mock_runner.share_inputs) # Update state after generating token
|
processor.update_state(mock_runner.share_inputs) # Update state after generating token
|
||||||
self.assertEqual(processor._states[req_id].tokens_after_start, 5) # (3, 4, 5, 0, 0)
|
self.assertEqual(processor._states[req_id].tokens_after_start, 2)
|
||||||
|
|
||||||
# LogitsProcessor should still not restrict as NEW_LINE_TOKEN is not yet last token
|
# LogitsProcessor should still not restrict as NEW_LINE_TOKEN is not yet last token
|
||||||
|
|
||||||
def test_thinking_budget_reached_forces_newline(self):
|
def test_thinking_budget_reached_forces_think_end(self):
|
||||||
# Scenario: Budget is 3, after 3 tokens, it should force newline, then thinking_end
|
# Scenario: Budget is 3 and only decode-time tokens count toward the budget.
|
||||||
req_id = "test_req_2"
|
req_id = "test_req_2"
|
||||||
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3] # Initial 1 token after start
|
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3]
|
||||||
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 3})
|
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 3})
|
||||||
mock_req = MockRequest(req_id, prompt_ids, sampling_params)
|
mock_req = MockRequest(req_id, prompt_ids, sampling_params)
|
||||||
|
|
||||||
@@ -206,16 +225,26 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1])
|
mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1])
|
||||||
|
|
||||||
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
|
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
|
||||||
line_break_id = processor.line_break_token_id
|
|
||||||
think_end_id = processor.think_end_token_id
|
think_end_id = processor.think_end_token_id
|
||||||
|
|
||||||
# Step 1: Initial state update (1 token after start)
|
# Step 1: Initial state update (prompt-side tokens do not count)
|
||||||
processor.update_state(mock_runner.share_inputs)
|
processor.update_state(mock_runner.share_inputs)
|
||||||
self.assertEqual(processor._states[req_id].tokens_after_start, 1)
|
self.assertEqual(processor._states[req_id].tokens_after_start, 0)
|
||||||
self.assertFalse(processor._states[req_id].ended)
|
self.assertFalse(processor._states[req_id].ended)
|
||||||
self.assertEqual(processor._states[req_id].last_token_id, 3)
|
self.assertEqual(processor._states[req_id].last_token_id, 3)
|
||||||
|
|
||||||
# Step 2: Generate 2nd token (budget 3, generated 1 -> 2)
|
# Step 2: Generate 1st decode token (budget 3, generated 0 -> 1)
|
||||||
|
logits = self._get_initial_logits(1)
|
||||||
|
processor.update_state(mock_runner.share_inputs)
|
||||||
|
processed_logits = processor.apply(logits)
|
||||||
|
next_token = mock_runner.generate_next_token(processed_logits)[0]
|
||||||
|
self.assertEqual(next_token, 0) # Normal generation
|
||||||
|
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
|
||||||
|
processor.update_state(mock_runner.share_inputs)
|
||||||
|
self.assertEqual(processor._states[req_id].tokens_after_start, 1)
|
||||||
|
self.assertEqual(processor._states[req_id].last_token_id, 0)
|
||||||
|
|
||||||
|
# Step 3: Generate 2nd decode token (budget 3, generated 1 -> 2)
|
||||||
logits = self._get_initial_logits(1)
|
logits = self._get_initial_logits(1)
|
||||||
processor.update_state(mock_runner.share_inputs)
|
processor.update_state(mock_runner.share_inputs)
|
||||||
processed_logits = processor.apply(logits)
|
processed_logits = processor.apply(logits)
|
||||||
@@ -226,42 +255,17 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
self.assertEqual(processor._states[req_id].tokens_after_start, 2)
|
self.assertEqual(processor._states[req_id].tokens_after_start, 2)
|
||||||
self.assertEqual(processor._states[req_id].last_token_id, 0)
|
self.assertEqual(processor._states[req_id].last_token_id, 0)
|
||||||
|
|
||||||
# Step 3: Generate 3rd token (budget 3, generated 2 -> 3). Next step should force NEW_LINE.
|
# Step 4: Generate 3rd decode token (budget 3, generated 2 -> 3).
|
||||||
logits = self._get_initial_logits(1)
|
logits = self._get_initial_logits(1)
|
||||||
processor.update_state(mock_runner.share_inputs)
|
processor.update_state(mock_runner.share_inputs)
|
||||||
processed_logits = processor.apply(logits)
|
processed_logits = processor.apply(logits)
|
||||||
next_token = mock_runner.generate_next_token(processed_logits)[0]
|
next_token = mock_runner.generate_next_token(processed_logits)[0]
|
||||||
self.assertEqual(next_token, 0) # Normal generation
|
self.assertEqual(next_token, 0)
|
||||||
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
|
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
|
||||||
processor.update_state(mock_runner.share_inputs)
|
processor.update_state(mock_runner.share_inputs)
|
||||||
self.assertEqual(processor._states[req_id].tokens_after_start, 3) # Budget is now met
|
self.assertEqual(processor._states[req_id].tokens_after_start, 3)
|
||||||
self.assertEqual(processor._states[req_id].last_token_id, 0)
|
|
||||||
|
|
||||||
# Step 4: Budget reached, last token not NEW_LINE. Should force NEW_LINE_TOKEN_ID.
|
# Step 5: Budget reached, should force THINKING_END_TOKEN_ID directly.
|
||||||
logits = self._get_initial_logits(1)
|
|
||||||
processor.update_state(mock_runner.share_inputs)
|
|
||||||
processed_logits = processor.apply(logits)
|
|
||||||
|
|
||||||
# Verify all other logits are -inf, only NEW_LINE_TOKEN_ID is 0.0
|
|
||||||
# Use <= for comparison because paddle.full creates -10.0 and comparison can have precision issues
|
|
||||||
other_logits = paddle.concat(
|
|
||||||
[
|
|
||||||
processed_logits[0, :line_break_id],
|
|
||||||
processed_logits[0, line_break_id + 1 : VOCAB_SIZE],
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
self.assertTrue(paddle.all(other_logits <= -10.0).item() or paddle.all(other_logits == -float("inf")).item())
|
|
||||||
self.assertEqual(processed_logits[0, line_break_id].item(), 0.0)
|
|
||||||
|
|
||||||
next_token = mock_runner.generate_next_token(processed_logits)[0]
|
|
||||||
self.assertEqual(next_token, line_break_id) # Forces NEW_LINE
|
|
||||||
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
|
|
||||||
processor.update_state(mock_runner.share_inputs)
|
|
||||||
self.assertEqual(processor._states[req_id].tokens_after_start, 4) # Still increment
|
|
||||||
self.assertEqual(processor._states[req_id].last_token_id, line_break_id)
|
|
||||||
|
|
||||||
# Step 5: Last token is NEW_LINE. Should force THINKING_END_TOKEN_ID.
|
|
||||||
logits = self._get_initial_logits(1)
|
logits = self._get_initial_logits(1)
|
||||||
processor.update_state(mock_runner.share_inputs)
|
processor.update_state(mock_runner.share_inputs)
|
||||||
processed_logits = processor.apply(logits)
|
processed_logits = processor.apply(logits)
|
||||||
@@ -324,9 +328,9 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
self.assertEqual(next_token, stop_sentence_token_ids[0])
|
self.assertEqual(next_token, stop_sentence_token_ids[0])
|
||||||
|
|
||||||
def test_thinking_budget_no_stop_sentence_defaults(self):
|
def test_thinking_budget_no_stop_sentence_defaults(self):
|
||||||
# Scenario: No stop sentence, budget reached should force newline.
|
# Scenario: No stop sentence, budget reached should force thinking_end directly.
|
||||||
req_id = "test_req_no_stop_sentence"
|
req_id = "test_req_no_stop_sentence"
|
||||||
prompt_ids = [THINKING_START_TOKEN_ID, 42] # 1 token after start
|
prompt_ids = [THINKING_START_TOKEN_ID, 42]
|
||||||
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 1})
|
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 1})
|
||||||
mock_req = MockRequest(req_id, prompt_ids, sampling_params)
|
mock_req = MockRequest(req_id, prompt_ids, sampling_params)
|
||||||
|
|
||||||
@@ -334,35 +338,42 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1])
|
mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1])
|
||||||
|
|
||||||
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
|
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
|
||||||
line_break_id = processor.line_break_token_id
|
|
||||||
|
|
||||||
processor.update_state(mock_runner.share_inputs)
|
processor.update_state(mock_runner.share_inputs)
|
||||||
logits = self._get_initial_logits(1)
|
logits = self._get_initial_logits(1)
|
||||||
processed_logits = processor.apply(logits)
|
processed_logits = processor.apply(logits)
|
||||||
next_token = mock_runner.generate_next_token(processed_logits)[0]
|
next_token = mock_runner.generate_next_token(processed_logits)[0]
|
||||||
self.assertEqual(next_token, line_break_id)
|
self.assertEqual(next_token, 0)
|
||||||
|
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
|
||||||
|
processor.update_state(mock_runner.share_inputs)
|
||||||
|
|
||||||
|
logits = self._get_initial_logits(1)
|
||||||
|
processor.update_state(mock_runner.share_inputs)
|
||||||
|
processed_logits = processor.apply(logits)
|
||||||
|
next_token = mock_runner.generate_next_token(processed_logits)[0]
|
||||||
|
self.assertEqual(next_token, THINKING_END_TOKEN_ID)
|
||||||
|
|
||||||
def test_thinking_budget_uses_config_token_ids(self):
|
def test_thinking_budget_uses_config_token_ids(self):
|
||||||
# Scenario: Processor should use token ids from model config.
|
# Scenario: Processor should use token ids from model config.
|
||||||
self.fd_config.model_config.think_start_id = 123
|
self.fd_config.model_config.think_start_id = 123
|
||||||
self.fd_config.model_config.think_end_id = 124
|
self.fd_config.model_config.think_end_id = 124
|
||||||
self.fd_config.model_config.line_break_id = 125
|
self.fd_config.model_config.line_break_id = -1
|
||||||
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
|
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
|
||||||
self.assertEqual(processor.think_start_token_id, 123)
|
self.assertEqual(processor.think_start_token_id, 123)
|
||||||
self.assertEqual(processor.think_end_token_id, 124)
|
self.assertEqual(processor.think_end_token_id, 124)
|
||||||
self.assertEqual(processor.line_break_token_id, 125)
|
self.assertEqual(processor.line_break_token_id, -1)
|
||||||
self.assertTrue(processor._enabled)
|
self.assertTrue(processor._enabled)
|
||||||
|
|
||||||
def test_thinking_budget_disabled_when_token_ids_missing(self):
|
def test_thinking_budget_disabled_when_token_ids_missing(self):
|
||||||
# Scenario: Processor should be disabled when token ids are not configured.
|
# Scenario: Processor should be disabled when token ids are not configured.
|
||||||
self.fd_config.model_config.think_start_id = -1
|
self.fd_config.model_config.think_start_id = -1
|
||||||
self.fd_config.model_config.think_end_id = -1
|
self.fd_config.model_config.think_end_id = -1
|
||||||
self.fd_config.model_config.line_break_id = -1
|
self.fd_config.model_config.line_break_id = NEW_LINE_TOKEN_ID
|
||||||
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
|
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
|
||||||
self.assertFalse(processor._enabled)
|
self.assertFalse(processor._enabled)
|
||||||
self.assertEqual(processor.think_start_token_id, -1)
|
self.assertEqual(processor.think_start_token_id, -1)
|
||||||
self.assertEqual(processor.think_end_token_id, -1)
|
self.assertEqual(processor.think_end_token_id, -1)
|
||||||
self.assertEqual(processor.line_break_token_id, -1)
|
self.assertEqual(processor.line_break_token_id, NEW_LINE_TOKEN_ID)
|
||||||
|
|
||||||
# update_state and apply should be no-op when disabled
|
# update_state and apply should be no-op when disabled
|
||||||
mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1)
|
mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1)
|
||||||
@@ -493,6 +504,28 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
self.assertEqual(state.tokens_after_start, 2)
|
self.assertEqual(state.tokens_after_start, 2)
|
||||||
self.assertEqual(state.last_token_id, 99)
|
self.assertEqual(state.last_token_id, 99)
|
||||||
|
|
||||||
|
def test_thinking_budget_prompt_state_from_token_ids_all_fallback(self):
|
||||||
|
req_id = "req_gpu_fallback"
|
||||||
|
mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1)
|
||||||
|
mock_runner.share_inputs["req_ids"][0] = req_id
|
||||||
|
mock_runner.share_inputs["logits_processors_args"][0] = {"thinking_budget": 3}
|
||||||
|
mock_runner.share_inputs["prompt_ids"] = None
|
||||||
|
mock_runner.share_inputs["token_ids_all"][0, :4] = paddle.to_tensor(
|
||||||
|
[1, THINKING_START_TOKEN_ID, 2, 3], dtype=paddle.int64
|
||||||
|
)
|
||||||
|
mock_runner.share_inputs["prompt_lens"][0, 0] = paddle.to_tensor(4, dtype=paddle.int64)
|
||||||
|
mock_runner.share_inputs["next_tokens"][0, 0] = paddle.to_tensor(-1, dtype=paddle.int64)
|
||||||
|
|
||||||
|
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
|
||||||
|
processor.update_state(mock_runner.share_inputs)
|
||||||
|
|
||||||
|
state = processor._states[req_id]
|
||||||
|
self.assertTrue(state.prompt_checked)
|
||||||
|
self.assertTrue(state.started)
|
||||||
|
self.assertFalse(state.ended)
|
||||||
|
self.assertEqual(state.tokens_after_start, 0)
|
||||||
|
self.assertEqual(state.last_token_id, 3)
|
||||||
|
|
||||||
def test_thinking_budget_not_configured(self):
|
def test_thinking_budget_not_configured(self):
|
||||||
# Scenario: Processor is active, but request does not provide thinking_budget
|
# Scenario: Processor is active, but request does not provide thinking_budget
|
||||||
req_id = "test_req_3"
|
req_id = "test_req_3"
|
||||||
@@ -532,20 +565,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
self.assertEqual(processor._states[req_id].tokens_after_start, 0) # No tokens after start yet
|
self.assertEqual(processor._states[req_id].tokens_after_start, 0) # No tokens after start yet
|
||||||
self.assertEqual(processor._states[req_id].last_token_id, THINKING_START_TOKEN_ID)
|
self.assertEqual(processor._states[req_id].last_token_id, THINKING_START_TOKEN_ID)
|
||||||
|
|
||||||
# Step 1: Budget 0 reached, last token is THINKING_START. Should force NEW_LINE.
|
# Step 1: Budget 0 reached, last token is THINKING_START. Should force THINKING_END.
|
||||||
logits = self._get_initial_logits(1)
|
|
||||||
processor.update_state(mock_runner.share_inputs)
|
|
||||||
processed_logits = processor.apply(logits)
|
|
||||||
|
|
||||||
self.assertEqual(processed_logits[0, NEW_LINE_TOKEN_ID].item(), 0.0)
|
|
||||||
next_token = mock_runner.generate_next_token(processed_logits)[0]
|
|
||||||
self.assertEqual(next_token, NEW_LINE_TOKEN_ID)
|
|
||||||
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
|
|
||||||
processor.update_state(mock_runner.share_inputs)
|
|
||||||
self.assertEqual(processor._states[req_id].tokens_after_start, 1) # Still increments
|
|
||||||
self.assertEqual(processor._states[req_id].last_token_id, NEW_LINE_TOKEN_ID)
|
|
||||||
|
|
||||||
# Step 2: Last token is NEW_LINE. Should force THINKING_END_TOKEN_ID.
|
|
||||||
logits = self._get_initial_logits(1)
|
logits = self._get_initial_logits(1)
|
||||||
processor.update_state(mock_runner.share_inputs)
|
processor.update_state(mock_runner.share_inputs)
|
||||||
processed_logits = processor.apply(logits)
|
processed_logits = processor.apply(logits)
|
||||||
@@ -572,7 +592,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
processor.update_state(mock_runner.share_inputs)
|
processor.update_state(mock_runner.share_inputs)
|
||||||
self.assertTrue(processor._states[req_id].started)
|
self.assertTrue(processor._states[req_id].started)
|
||||||
self.assertTrue(processor._states[req_id].ended)
|
self.assertTrue(processor._states[req_id].ended)
|
||||||
self.assertEqual(processor._states[req_id].tokens_after_start, 2) # Tokens 2, THINKING_END after start
|
self.assertEqual(processor._states[req_id].tokens_after_start, 0)
|
||||||
|
|
||||||
logits = self._get_initial_logits(1)
|
logits = self._get_initial_logits(1)
|
||||||
processor.update_state(mock_runner.share_inputs)
|
processor.update_state(mock_runner.share_inputs)
|
||||||
@@ -582,7 +602,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
def test_multiple_requests(self):
|
def test_multiple_requests(self):
|
||||||
# Scenario: Multiple requests with different thinking states
|
# Scenario: Multiple requests with different thinking states
|
||||||
req_id_1 = "req_a"
|
req_id_1 = "req_a"
|
||||||
prompt_ids_1 = [THINKING_START_TOKEN_ID, 10, 11] # budget 2, last_token=11, tokens_after_start=2
|
prompt_ids_1 = [THINKING_START_TOKEN_ID, 10, 11]
|
||||||
sampling_params_1 = SamplingParams(logits_processors_args={"thinking_budget": 2})
|
sampling_params_1 = SamplingParams(logits_processors_args={"thinking_budget": 2})
|
||||||
mock_req_1 = MockRequest(req_id_1, prompt_ids_1, sampling_params_1)
|
mock_req_1 = MockRequest(req_id_1, prompt_ids_1, sampling_params_1)
|
||||||
|
|
||||||
@@ -592,7 +612,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
mock_req_2 = MockRequest(req_id_2, prompt_ids_2, sampling_params_2)
|
mock_req_2 = MockRequest(req_id_2, prompt_ids_2, sampling_params_2)
|
||||||
|
|
||||||
req_id_3 = "req_c"
|
req_id_3 = "req_c"
|
||||||
prompt_ids_3 = [THINKING_START_TOKEN_ID, 30] # budget 1, last_token=30, tokens_after_start=1
|
prompt_ids_3 = [THINKING_START_TOKEN_ID, 30]
|
||||||
sampling_params_3 = SamplingParams(logits_processors_args={"thinking_budget": 1})
|
sampling_params_3 = SamplingParams(logits_processors_args={"thinking_budget": 1})
|
||||||
mock_req_3 = MockRequest(req_id_3, prompt_ids_3, sampling_params_3)
|
mock_req_3 = MockRequest(req_id_3, prompt_ids_3, sampling_params_3)
|
||||||
|
|
||||||
@@ -609,14 +629,14 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
# Verify initial states
|
# Verify initial states
|
||||||
self.assertTrue(processor._states[req_id_1].started)
|
self.assertTrue(processor._states[req_id_1].started)
|
||||||
self.assertFalse(processor._states[req_id_1].ended)
|
self.assertFalse(processor._states[req_id_1].ended)
|
||||||
self.assertEqual(processor._states[req_id_1].tokens_after_start, 2)
|
self.assertEqual(processor._states[req_id_1].tokens_after_start, 0)
|
||||||
self.assertEqual(processor._states[req_id_1].last_token_id, 11)
|
self.assertEqual(processor._states[req_id_1].last_token_id, 11)
|
||||||
|
|
||||||
self.assertNotIn(req_id_2, processor._states) # No budget specified for req_2
|
self.assertNotIn(req_id_2, processor._states) # No budget specified for req_2
|
||||||
|
|
||||||
self.assertTrue(processor._states[req_id_3].started)
|
self.assertTrue(processor._states[req_id_3].started)
|
||||||
self.assertFalse(processor._states[req_id_3].ended)
|
self.assertFalse(processor._states[req_id_3].ended)
|
||||||
self.assertEqual(processor._states[req_id_3].tokens_after_start, 1)
|
self.assertEqual(processor._states[req_id_3].tokens_after_start, 0)
|
||||||
self.assertEqual(processor._states[req_id_3].last_token_id, 30)
|
self.assertEqual(processor._states[req_id_3].last_token_id, 30)
|
||||||
|
|
||||||
# Simulate logits for the batch
|
# Simulate logits for the batch
|
||||||
@@ -624,16 +644,14 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
processor.update_state(mock_runner.share_inputs) # Ensure state is updated before apply
|
processor.update_state(mock_runner.share_inputs) # Ensure state is updated before apply
|
||||||
processed_batch_logits = processor.apply(batch_logits)
|
processed_batch_logits = processor.apply(batch_logits)
|
||||||
|
|
||||||
# Req 1: budget 2, tokens_after_start 2. Should force NEW_LINE (last_token_id is 11, not NEW_LINE)
|
# Req 1: prompt-side content does not consume budget, so first step is normal generation.
|
||||||
self.assertEqual(processed_batch_logits[0, NEW_LINE_TOKEN_ID].item(), 0.0)
|
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), 0)
|
||||||
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), NEW_LINE_TOKEN_ID)
|
|
||||||
|
|
||||||
# Req 2: No thinking budget, normal generation
|
# Req 2: No thinking budget, normal generation
|
||||||
self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0)
|
self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0)
|
||||||
|
|
||||||
# Req 3: budget 1, tokens_after_start 1. Should force NEW_LINE (last_token_id is 30, not NEW_LINE)
|
# Req 3: prompt-side content does not consume budget, so first step is normal generation.
|
||||||
self.assertEqual(processed_batch_logits[2, NEW_LINE_TOKEN_ID].item(), 0.0)
|
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), 0)
|
||||||
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), NEW_LINE_TOKEN_ID)
|
|
||||||
|
|
||||||
# Simulate generating next tokens and updating state
|
# Simulate generating next tokens and updating state
|
||||||
next_tokens = mock_runner.generate_next_token(processed_batch_logits)
|
next_tokens = mock_runner.generate_next_token(processed_batch_logits)
|
||||||
@@ -643,20 +661,25 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
processor.update_state(mock_runner.share_inputs)
|
processor.update_state(mock_runner.share_inputs)
|
||||||
|
|
||||||
# Verify updated states for next step
|
# Verify updated states for next step
|
||||||
self.assertEqual(processor._states[req_id_1].last_token_id, NEW_LINE_TOKEN_ID)
|
self.assertEqual(processor._states[req_id_1].last_token_id, 0)
|
||||||
self.assertEqual(processor._states[req_id_3].last_token_id, NEW_LINE_TOKEN_ID)
|
self.assertEqual(processor._states[req_id_3].last_token_id, 0)
|
||||||
|
self.assertEqual(processor._states[req_id_1].tokens_after_start, 1)
|
||||||
|
self.assertEqual(processor._states[req_id_3].tokens_after_start, 1)
|
||||||
|
self.assertFalse(processor._states[req_id_1].ended)
|
||||||
|
self.assertFalse(processor._states[req_id_3].ended)
|
||||||
|
|
||||||
batch_logits = self._get_initial_logits(3)
|
batch_logits = self._get_initial_logits(3)
|
||||||
processor.update_state(mock_runner.share_inputs)
|
processor.update_state(mock_runner.share_inputs)
|
||||||
processed_batch_logits = processor.apply(batch_logits)
|
processed_batch_logits = processor.apply(batch_logits)
|
||||||
|
|
||||||
# Req 1: last token was NEW_LINE. Should force THINKING_END
|
# Req 1: budget 2, tokens_after_start 1. Still normal generation.
|
||||||
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), THINKING_END_TOKEN_ID)
|
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), 0)
|
||||||
|
|
||||||
# Req 2: Still normal generation
|
# Req 2: Still normal generation
|
||||||
self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0)
|
self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0)
|
||||||
|
|
||||||
# Req 3: last token was NEW_LINE. Should force THINKING_END
|
# Req 3: budget 1 reached after one generated token, should now force THINKING_END.
|
||||||
|
self.assertEqual(processed_batch_logits[2, THINKING_END_TOKEN_ID].item(), 0.0)
|
||||||
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), THINKING_END_TOKEN_ID)
|
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), THINKING_END_TOKEN_ID)
|
||||||
|
|
||||||
|
|
||||||
@@ -724,7 +747,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
|||||||
self.assertTrue(updated["think_prompt_checked"])
|
self.assertTrue(updated["think_prompt_checked"])
|
||||||
self.assertTrue(updated["think_prompt_started"])
|
self.assertTrue(updated["think_prompt_started"])
|
||||||
self.assertTrue(updated["think_prompt_ended"])
|
self.assertTrue(updated["think_prompt_ended"])
|
||||||
self.assertEqual(updated["think_prompt_tokens_after_start"], 2)
|
self.assertEqual(updated["think_prompt_tokens_after_start"], 0)
|
||||||
self.assertEqual(updated["think_prompt_last_token_id"], 3)
|
self.assertEqual(updated["think_prompt_last_token_id"], 3)
|
||||||
|
|
||||||
def test_v1_process_request_missing_logits_processors_args(self):
|
def test_v1_process_request_missing_logits_processors_args(self):
|
||||||
@@ -833,6 +856,58 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
|||||||
self.assertEqual(processor.encode_with_cache("iter"), [41, 42])
|
self.assertEqual(processor.encode_with_cache("iter"), [41, 42])
|
||||||
self.assertNotIn(("np", False), processor._tokenize_cache)
|
self.assertNotIn(("np", False), processor._tokenize_cache)
|
||||||
|
|
||||||
|
def test_text_encode_with_cache_lazy_init(self):
|
||||||
|
processor = TextDataProcessor.__new__(TextDataProcessor)
|
||||||
|
call_counter = {"count": 0}
|
||||||
|
|
||||||
|
def _text2ids(text, max_model_len=None, add_special_tokens=False):
|
||||||
|
call_counter["count"] += 1
|
||||||
|
return np.array([51, 52], dtype=np.int64)
|
||||||
|
|
||||||
|
processor.text2ids = _text2ids
|
||||||
|
|
||||||
|
self.assertFalse(hasattr(processor, "_tokenize_cache"))
|
||||||
|
self.assertEqual(processor.encode_with_cache("lazy"), [51, 52])
|
||||||
|
self.assertTrue(hasattr(processor, "_tokenize_cache"))
|
||||||
|
self.assertEqual(processor.encode_with_cache("lazy"), [51, 52])
|
||||||
|
self.assertEqual(call_counter["count"], 1)
|
||||||
|
|
||||||
|
def test_v1_encode_with_cache_lazy_init(self):
|
||||||
|
processor = V1TextDataProcessor.__new__(V1TextDataProcessor)
|
||||||
|
call_counter = {"count": 0}
|
||||||
|
|
||||||
|
def _text2ids(text, max_model_len=None, add_special_tokens=False):
|
||||||
|
call_counter["count"] += 1
|
||||||
|
return np.array([61, 62], dtype=np.int64)
|
||||||
|
|
||||||
|
processor.text2ids = _text2ids
|
||||||
|
|
||||||
|
self.assertFalse(hasattr(processor, "_tokenize_cache"))
|
||||||
|
self.assertEqual(processor.encode_with_cache("lazy"), [61, 62])
|
||||||
|
self.assertTrue(hasattr(processor, "_tokenize_cache"))
|
||||||
|
self.assertEqual(processor.encode_with_cache("lazy"), [61, 62])
|
||||||
|
self.assertEqual(call_counter["count"], 1)
|
||||||
|
|
||||||
|
def test_ernie_encode_literal_text_with_cache(self):
|
||||||
|
processor = ErnieTextDataProcessor.__new__(ErnieTextDataProcessor)
|
||||||
|
processor.tokenizer = SimpleNamespace(
|
||||||
|
tokenize=lambda text: ["token_a", "token_b"],
|
||||||
|
convert_tokens_to_ids=lambda tokens: [71, 72],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [71, 72])
|
||||||
|
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [71, 72])
|
||||||
|
|
||||||
|
def test_v1_ernie_encode_literal_text_with_cache(self):
|
||||||
|
processor = V1ErnieTextDataProcessor.__new__(V1ErnieTextDataProcessor)
|
||||||
|
processor.tokenizer = SimpleNamespace(
|
||||||
|
tokenize=lambda text: ["token_c", "token_d"],
|
||||||
|
convert_tokens_to_ids=lambda tokens: [81, 82],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [81, 82])
|
||||||
|
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [81, 82])
|
||||||
|
|
||||||
def test_text_update_thinking_prompt_state_branches(self):
|
def test_text_update_thinking_prompt_state_branches(self):
|
||||||
processor = TextDataProcessor.__new__(TextDataProcessor)
|
processor = TextDataProcessor.__new__(TextDataProcessor)
|
||||||
processor._think_token_ids = None
|
processor._think_token_ids = None
|
||||||
@@ -868,7 +943,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(with_start_no_end["think_prompt_started"])
|
self.assertTrue(with_start_no_end["think_prompt_started"])
|
||||||
self.assertFalse(with_start_no_end["think_prompt_ended"])
|
self.assertFalse(with_start_no_end["think_prompt_ended"])
|
||||||
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2)
|
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 0)
|
||||||
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
|
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
|
||||||
|
|
||||||
# 命中 _get_think_token_ids 的缓存分支
|
# 命中 _get_think_token_ids 的缓存分支
|
||||||
@@ -891,7 +966,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(with_start_no_end["think_prompt_started"])
|
self.assertTrue(with_start_no_end["think_prompt_started"])
|
||||||
self.assertFalse(with_start_no_end["think_prompt_ended"])
|
self.assertFalse(with_start_no_end["think_prompt_ended"])
|
||||||
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2)
|
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 0)
|
||||||
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
|
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
|
||||||
|
|
||||||
# 命中 _get_think_token_ids 的缓存分支
|
# 命中 _get_think_token_ids 的缓存分支
|
||||||
@@ -903,7 +978,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
|||||||
processor.eos_token_ids = [1]
|
processor.eos_token_ids = [1]
|
||||||
processor.update_stop_seq = lambda *args, **kwargs: None
|
processor.update_stop_seq = lambda *args, **kwargs: None
|
||||||
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
||||||
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [201, 202]
|
processor._encode_literal_text_with_cache = lambda text: [201, 202]
|
||||||
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
|
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
|
||||||
processor.reasoning_parser = None
|
processor.reasoning_parser = None
|
||||||
|
|
||||||
@@ -924,7 +999,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
|||||||
processed = processor.process_request_dict(request, max_model_len=16)
|
processed = processor.process_request_dict(request, max_model_len=16)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
processed["logits_processors_args"].get("think_stop_sentence_token_ids"),
|
processed["logits_processors_args"].get("think_stop_sentence_token_ids"),
|
||||||
[23, 201, 202],
|
[201, 202],
|
||||||
)
|
)
|
||||||
self.assertNotIn("think_stop_sentence", processed["logits_processors_args"])
|
self.assertNotIn("think_stop_sentence", processed["logits_processors_args"])
|
||||||
|
|
||||||
@@ -934,7 +1009,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
|||||||
processor.eos_token_ids = [1]
|
processor.eos_token_ids = [1]
|
||||||
processor.update_stop_seq = lambda *args, **kwargs: None
|
processor.update_stop_seq = lambda *args, **kwargs: None
|
||||||
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
||||||
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [301, 302]
|
processor._encode_literal_text_with_cache = lambda text: [301, 302]
|
||||||
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
|
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
|
||||||
processor.reasoning_parser = None
|
processor.reasoning_parser = None
|
||||||
|
|
||||||
@@ -955,7 +1030,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
|||||||
processed = processor.process_request(request, max_model_len=16)
|
processed = processor.process_request(request, max_model_len=16)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
processed.logits_processors_args.get("think_stop_sentence_token_ids"),
|
processed.logits_processors_args.get("think_stop_sentence_token_ids"),
|
||||||
[23, 301, 302],
|
[301, 302],
|
||||||
)
|
)
|
||||||
self.assertNotIn("think_stop_sentence", processed.logits_processors_args)
|
self.assertNotIn("think_stop_sentence", processed.logits_processors_args)
|
||||||
|
|
||||||
@@ -965,7 +1040,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
|||||||
processor.eos_token_ids = [1]
|
processor.eos_token_ids = [1]
|
||||||
processor.update_stop_seq = lambda *args, **kwargs: None
|
processor.update_stop_seq = lambda *args, **kwargs: None
|
||||||
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
||||||
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [401, 402]
|
processor._encode_literal_text_with_cache = lambda text: [401, 402]
|
||||||
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
|
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
|
||||||
processor.reasoning_parser = None
|
processor.reasoning_parser = None
|
||||||
|
|
||||||
@@ -992,10 +1067,173 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
|||||||
processed = processor.process_request_dict(request, max_model_len=16)
|
processed = processor.process_request_dict(request, max_model_len=16)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
processed.sampling_params.logits_processors_args.get("think_stop_sentence_token_ids"),
|
processed.sampling_params.logits_processors_args.get("think_stop_sentence_token_ids"),
|
||||||
[23, 401, 402],
|
[401, 402],
|
||||||
)
|
)
|
||||||
self.assertNotIn("think_stop_sentence", processed.sampling_params.logits_processors_args)
|
self.assertNotIn("think_stop_sentence", processed.sampling_params.logits_processors_args)
|
||||||
|
|
||||||
|
def test_ernie_process_request_dict_prepares_thinking_budget_args(self):
|
||||||
|
processor = ErnieTextDataProcessor.__new__(ErnieTextDataProcessor)
|
||||||
|
processor._apply_default_parameters = lambda request: request
|
||||||
|
processor.eos_token_ids = [1]
|
||||||
|
processor.update_stop_seq = lambda *args, **kwargs: None
|
||||||
|
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
||||||
|
processor._encode_literal_text_with_cache = lambda text: [501, 502]
|
||||||
|
processor.tokenizer = DummyTokenizerForTextProcessor()
|
||||||
|
processor.reasoning_parser = None
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"request_id": "req_ernie_text",
|
||||||
|
"eos_token_ids": [1],
|
||||||
|
"prompt_token_ids": [1, THINKING_START_TOKEN_ID, 2],
|
||||||
|
"prompt": None,
|
||||||
|
"messages": None,
|
||||||
|
"logits_processors_args": {"thinking_budget": 20, "think_stop_sentence": "done"},
|
||||||
|
"bad_words": None,
|
||||||
|
"bad_words_token_ids": None,
|
||||||
|
"max_tokens": 1,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"response_max_tokens": None,
|
||||||
|
"enable_thinking": True,
|
||||||
|
}
|
||||||
|
with patch("fastdeploy.input.ernie4_5_processor.process_stop_token_ids", lambda *args, **kwargs: None):
|
||||||
|
processed = processor.process_request_dict(request, max_model_len=16)
|
||||||
|
|
||||||
|
self.assertEqual(processed["logits_processors_args"]["think_stop_sentence_token_ids"], [501, 502])
|
||||||
|
self.assertTrue(processed["logits_processors_args"]["think_prompt_started"])
|
||||||
|
self.assertFalse(processed["logits_processors_args"]["think_prompt_ended"])
|
||||||
|
self.assertEqual(processed["logits_processors_args"]["think_prompt_tokens_after_start"], 0)
|
||||||
|
|
||||||
|
def test_v1_ernie_process_request_dict_prepares_thinking_budget_args(self):
|
||||||
|
processor = V1ErnieTextDataProcessor.__new__(V1ErnieTextDataProcessor)
|
||||||
|
processor._apply_default_parameters = lambda request: request
|
||||||
|
processor.eos_token_ids = [1]
|
||||||
|
processor.update_stop_seq = lambda *args, **kwargs: None
|
||||||
|
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
||||||
|
processor._encode_literal_text_with_cache = lambda text: [601, 602]
|
||||||
|
processor.tokenizer = DummyTokenizerForTextProcessor()
|
||||||
|
processor.reasoning_parser = None
|
||||||
|
|
||||||
|
request = DummyRequestV1(
|
||||||
|
request_id="req_v1_ernie_text",
|
||||||
|
eos_token_ids=[1],
|
||||||
|
prompt_token_ids=[1, THINKING_START_TOKEN_ID, 2],
|
||||||
|
prompt=None,
|
||||||
|
messages=None,
|
||||||
|
chat_template_kwargs=None,
|
||||||
|
enable_thinking=True,
|
||||||
|
sampling_params=SimpleNamespace(
|
||||||
|
bad_words=None,
|
||||||
|
bad_words_token_ids=None,
|
||||||
|
max_tokens=1,
|
||||||
|
temperature=1.0,
|
||||||
|
top_p=0.9,
|
||||||
|
repetition_penalty=1.0,
|
||||||
|
frequency_penalty=0.0,
|
||||||
|
presence_penalty=0.0,
|
||||||
|
response_max_tokens=None,
|
||||||
|
n=1,
|
||||||
|
logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with patch("fastdeploy.input.v1.ernie4_5_processor.process_stop_token_ids", lambda *args, **kwargs: None):
|
||||||
|
processed = processor.process_request_dict(request, max_model_len=16)
|
||||||
|
|
||||||
|
self.assertEqual(processed.sampling_params.logits_processors_args["think_stop_sentence_token_ids"], [601, 602])
|
||||||
|
self.assertTrue(processed.sampling_params.logits_processors_args["think_prompt_started"])
|
||||||
|
self.assertFalse(processed.sampling_params.logits_processors_args["think_prompt_ended"])
|
||||||
|
self.assertEqual(processed.sampling_params.logits_processors_args["think_prompt_tokens_after_start"], 0)
|
||||||
|
|
||||||
|
def test_ernie_vl_process_request_dict_prepares_thinking_budget_args(self):
|
||||||
|
processor = ErnieVLDataProcessor.__new__(ErnieVLDataProcessor)
|
||||||
|
processor._apply_default_parameters = lambda request: request
|
||||||
|
processor.eos_token_ids = [1]
|
||||||
|
processor.update_stop_seq = lambda *args, **kwargs: None
|
||||||
|
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
||||||
|
processor._encode_literal_text_with_cache = lambda text: [701, 702]
|
||||||
|
processor.tokenizer = DummyTokenizerForTextProcessor()
|
||||||
|
processor.reasoning_parser = None
|
||||||
|
processor._check_mm_limits = lambda *args, **kwargs: None
|
||||||
|
processor.append_completion_tokens = lambda *args, **kwargs: None
|
||||||
|
processor.pack_outputs = lambda outs: outs
|
||||||
|
processor.ernie4_5_processor = SimpleNamespace(
|
||||||
|
request2ids=lambda request: {"input_ids": np.array([1, THINKING_START_TOKEN_ID, 2], dtype=np.int64)}
|
||||||
|
)
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"request_id": "req_ernie_vl",
|
||||||
|
"eos_token_ids": [1],
|
||||||
|
"messages": [{"role": "user", "content": "hi"}],
|
||||||
|
"bad_words": None,
|
||||||
|
"bad_words_token_ids": None,
|
||||||
|
"logits_processors_args": {"thinking_budget": 20, "think_stop_sentence": "done"},
|
||||||
|
"max_tokens": 1,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"response_max_tokens": None,
|
||||||
|
}
|
||||||
|
with patch(
|
||||||
|
"fastdeploy.input.ernie4_5_vl_processor.ernie4_5_vl_processor.process_stop_token_ids",
|
||||||
|
lambda *args, **kwargs: None,
|
||||||
|
):
|
||||||
|
processed = processor.process_request_dict(request, max_model_len=16)
|
||||||
|
|
||||||
|
self.assertEqual(processed["logits_processors_args"]["think_stop_sentence_token_ids"], [701, 702])
|
||||||
|
self.assertTrue(processed["logits_processors_args"]["think_prompt_started"])
|
||||||
|
self.assertFalse(processed["logits_processors_args"]["think_prompt_ended"])
|
||||||
|
self.assertEqual(processed["logits_processors_args"]["think_prompt_tokens_after_start"], 0)
|
||||||
|
|
||||||
|
def test_v1_ernie_vl_process_request_dict_prepares_thinking_budget_args(self):
|
||||||
|
processor = V1ErnieVLDataProcessor.__new__(V1ErnieVLDataProcessor)
|
||||||
|
processor._apply_default_parameters = lambda request: request
|
||||||
|
processor.eos_token_ids = [1]
|
||||||
|
processor.update_stop_seq = lambda *args, **kwargs: None
|
||||||
|
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
||||||
|
processor._encode_literal_text_with_cache = lambda text: [801, 802]
|
||||||
|
processor.tokenizer = DummyTokenizerForTextProcessor()
|
||||||
|
processor.reasoning_parser = None
|
||||||
|
processor._check_mm_limits = lambda *args, **kwargs: None
|
||||||
|
processor.append_completion_tokens = lambda *args, **kwargs: None
|
||||||
|
processor.pack_outputs = lambda outs: outs
|
||||||
|
processor.ernie4_5_processor = SimpleNamespace(
|
||||||
|
request2ids=lambda request: {"input_ids": np.array([1, THINKING_START_TOKEN_ID, 2], dtype=np.int64)}
|
||||||
|
)
|
||||||
|
|
||||||
|
request = DummyRequestV1(
|
||||||
|
request_id="req_v1_ernie_vl",
|
||||||
|
eos_token_ids=[1],
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt=None,
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
chat_template_kwargs=None,
|
||||||
|
enable_thinking=True,
|
||||||
|
completion_token_ids=None,
|
||||||
|
multimodal_data=None,
|
||||||
|
sampling_params=SimpleNamespace(
|
||||||
|
bad_words=None,
|
||||||
|
bad_words_token_ids=None,
|
||||||
|
max_tokens=1,
|
||||||
|
temperature=1.0,
|
||||||
|
top_p=0.9,
|
||||||
|
repetition_penalty=1.0,
|
||||||
|
frequency_penalty=0.0,
|
||||||
|
presence_penalty=0.0,
|
||||||
|
response_max_tokens=None,
|
||||||
|
reasoning_max_tokens=None,
|
||||||
|
n=1,
|
||||||
|
logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"fastdeploy.input.v1.ernie4_5_vl_processor.ernie4_5_vl_processor.process_stop_token_ids",
|
||||||
|
lambda *args, **kwargs: None,
|
||||||
|
):
|
||||||
|
processed = processor.process_request_dict(request, max_model_len=16)
|
||||||
|
|
||||||
|
self.assertEqual(processed.sampling_params.logits_processors_args["think_stop_sentence_token_ids"], [801, 802])
|
||||||
|
self.assertTrue(processed.sampling_params.logits_processors_args["think_prompt_started"])
|
||||||
|
self.assertFalse(processed.sampling_params.logits_processors_args["think_prompt_ended"])
|
||||||
|
self.assertEqual(processed.sampling_params.logits_processors_args["think_prompt_tokens_after_start"], 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user