[Feature] Support ThinkingBudget Logits processor to control thinking content length (#6367)

* feat: add thinking budget logits processor

* add unittest

* fix pre-commit

* add unittest

* docs: clarify operator-level vs logits processor usage and conflict guidance

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
jackyYang6
2026-02-25 14:17:09 +08:00
committed by GitHub
parent 1405d7d5d7
commit a29ee57e15
12 changed files with 1861 additions and 2 deletions
+24 -1
View File
@@ -1950,13 +1950,35 @@ class EngineService:
else len(self.data_processor.tokenizer.vocab)
)
think_start_id = self.data_processor.tokenizer.get_vocab().get("<think>", -1)
if think_start_id >= 0:
self.llm_logger.info(f"Get think_start_id {think_start_id} from vocab.")
else:
self.llm_logger.info("No <think> token found in vocabulary, the model can not do reasoning.")
think_end_id = self.data_processor.tokenizer.get_vocab().get("</think>", -1)
if think_end_id > 0:
if think_end_id >= 0:
self.llm_logger.info(f"Get think_end_id {think_end_id} from vocab.")
else:
self.llm_logger.info("No </think> token found in vocabulary, the model can not do reasoning.")
image_patch_id = self.data_processor.tokenizer.get_vocab().get("<|IMAGE_PLACEHOLDER|>", -1)
line_break_id = self.data_processor.tokenizer.get_vocab().get("\n", -1)
if line_break_id < 0:
line_break_ids = self.data_processor.tokenizer.encode("\n", add_special_tokens=False)
if isinstance(line_break_ids, dict):
line_break_ids = line_break_ids.get("input_ids")
elif hasattr(line_break_ids, "input_ids"):
line_break_ids = line_break_ids.input_ids
if line_break_ids:
if isinstance(line_break_ids, (list, tuple)):
first = line_break_ids[0]
if isinstance(first, (list, tuple)):
line_break_id = int(first[0]) if first else -1
else:
line_break_id = int(first)
else:
line_break_id = int(line_break_ids)
if line_break_id >= 0:
self.llm_logger.info(f"Get line_break_id {line_break_id} from tokenizer.")
ports = ",".join(map(str, self.cfg.parallel_config.engine_worker_queue_port))
ips = None
@@ -1984,6 +2006,7 @@ class EngineService:
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'"
f" --ori_vocab_size {ori_vocab_size}"
f" --think_start_id {think_start_id}"
f" --think_end_id {think_end_id}"
f" --image_patch_id {image_patch_id}"
f" --line_break_id {line_break_id}"