[Feature] mm and thinking model support structred output (#2749)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* mm support structured output

* update code

* update code

* update format

* update code

* update code

* add enable_thinking default

* update code

* add structured_outputs test case

* add ci install xgrammar

* add ci timeout time

* update test for structured_outputs

* update code

* add error traceback info

* update error msg

* update structred output code

* update code

* update code

* update config

* update torch version

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
kevin
2025-09-02 16:21:09 +08:00
committed by GitHub
parent 0e4df5a6f4
commit 1908465542
17 changed files with 1168 additions and 83 deletions
@@ -15,8 +15,13 @@
"""
# from fastdeploy.config import FDConfig
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
BackendBase,
BaseChecker,
LogitsProcessorBase,
)
__all__ = ["get_guided_backend", "schema_checker"]
__all__ = ["get_guided_backend", "schema_checker", "LogitsProcessorBase", "BackendBase", "BaseChecker"]
def get_guided_backend(
@@ -20,6 +20,7 @@ from concurrent.futures import ThreadPoolExecutor
from fastdeploy.config import ErnieArchitectures, FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.reasoning import ReasoningParserManager
from fastdeploy.utils import llm_logger
@@ -35,8 +36,9 @@ class LogitsProcessorBase:
None (all state should be managed by subclasses)
"""
def __init__(self):
pass
def __init__(self, enable_reasoning):
self.reasoning_ended = False
self.enable_reasoning = enable_reasoning
def fill_token_bitmask(self, token_bitmask, idx):
"""
@@ -137,8 +139,14 @@ class BackendBase:
self.fd_config = fd_config
self.executor = ThreadPoolExecutor()
self.max_cache_size = 2048
self.reasoning_parser = None
self.hf_tokenizer = self._get_tokenizer_hf()
if self.fd_config.model_config.reasoning_parser:
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(
self.fd_config.model_config.reasoning_parser
)
self.reasoning_parser = reasoning_parser_obj(self.hf_tokenizer)
def _create_processor(self):
"""
@@ -149,70 +157,88 @@ class BackendBase:
"""
raise NotImplementedError
def _json_processor(self, schemata):
def _json_processor(self, schemata, enable_thinking=False):
"""
Process JSON schemata.
Args:
schemata (str): The schemata string.
enable_thinking (bool): Whether to enable thinking mode.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError
def _regex_processor(self, schemata):
def _regex_processor(self, schemata, enable_thinking=False):
"""
Process regular expression schemata.
Args:
schemata (str): The schemata string.
enable_thinking (bool): Whether to enable thinking mode.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError
def _grammar_processor(self, schemata):
def _grammar_processor(self, schemata, enable_thinking=False):
"""
Process grammar schemata.
Args:
schemata (str): The schemata string.
enable_thinking (bool): Whether to enable thinking mode.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError
def _structural_tag_processor(self, schemata):
def _structural_tag_processor(self, schemata, enable_thinking=False):
"""
Process structural tag schemata.
Args:
schemata (str): The schemata string.
enable_thinking (bool): Whether to enable thinking mode.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError
def _unsupported_processor_type(self, key_type, schemata):
def _unsupported_processor_type(self, key_type, schemata, enable_thinking=False):
"""
Process unsupported type.
Args:
key_type (str): The key type string.
schemata (str): The schemata string.
enable_thinking (bool): Whether to enable thinking mode.
"""
raise Exception(f"Unsupported processor type {key_type}.")
def _init_logits_processor(self, schemata_key: tuple[str, str]) -> LogitsProcessorBase:
def get_reasoning_parser(self):
"""
Get reasoning parser object.
Returns:
ReasoningParser: Reasoning parser object or None
"""
return self.reasoning_parser
def _init_logits_processor(
self,
schemata_key: tuple[str, str],
enable_thinking: bool = False,
) -> LogitsProcessorBase:
"""
init logits processor by type and schemata.
Args:
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
enable_thinking (bool): Whether to enable thinking step
Returns:
LogitsProcessorBase: Initialized logits processor instance
@@ -222,18 +248,22 @@ class BackendBase:
"""
key_type, schemata = schemata_key
if key_type == "json":
return self._json_processor(schemata)
return self._json_processor(schemata, enable_thinking)
elif key_type == "regex":
return self._regex_processor(schemata)
return self._regex_processor(schemata, enable_thinking)
elif key_type == "grammar":
return self._grammar_processor(schemata)
return self._grammar_processor(schemata, enable_thinking)
elif key_type == "structural_tag":
return self._structural_tag_processor(schemata)
return self._structural_tag_processor(schemata, enable_thinking)
else:
llm_logger.error(f"Unsupported processor type {key_type}.")
return None
def get_logits_processor(self, schemata_key: tuple[str, str]) -> tuple[LogitsProcessorBase, bool]:
def get_logits_processor(
self,
schemata_key: tuple[str, str],
enable_thinking: bool = False,
) -> tuple[LogitsProcessorBase, bool]:
"""
get logits processor by key from cache or create new one.
@@ -247,8 +277,10 @@ class BackendBase:
"""
value = self.cache.get(schemata_key, None)
if value:
return value.copy(), True
value = self.executor.submit(self._init_logits_processor, schemata_key)
value_copy = value.copy()
value_copy.enable_reasoning = enable_thinking
return value_copy, True
value = self.executor.submit(self._init_logits_processor, schemata_key, enable_thinking)
return value, False
def _get_tokenizer_hf(self):
@@ -267,7 +299,6 @@ class BackendBase:
try:
architectures = self.fd_config.model_config.architectures
if not ErnieArchitectures.contains_ernie_arch(architectures):
from transformers import AutoTokenizer, PreTrainedTokenizerFast
tokenizer = AutoTokenizer.from_pretrained(
@@ -0,0 +1,118 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# refer to https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
from typing import List, Optional
import paddle
try:
import triton
import triton.language as tl
except ImportError as err:
raise ImportError("Triton is not installed") from err
@triton.jit
def apply_token_bitmask_inplace_kernel(
logits_ptr,
bitmask_ptr,
indices_ptr,
num_rows,
vocab_size,
logits_strides,
bitmask_strides,
NUM_SMS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Triton kernel for in-place logits masking using bitwise compression.
Processes logits tensor in blocks, applying bitmask to restrict vocabulary access.
Masked positions are set to -inf to ensure zero probability during sampling.
Note:
- Bitmask uses 32:1 compression (1 bit per vocabulary token)
- Optimized for GPU parallel processing with configurable block size
"""
pid = tl.program_id(0)
num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
row_id = work_id // num_blocks
block_offset = (work_id % num_blocks) * BLOCK_SIZE
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
vocab_mask = offsets < vocab_size
packed_bitmask_mask = bitmask_offsets < bitmask_strides
packed_bitmask = tl.load(bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets, packed_bitmask_mask)
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
bitmask = bitmask.reshape(BLOCK_SIZE)
tl.store(logits_ptr + batch_id * logits_strides + offsets, -float("inf"), vocab_mask & bitmask)
def apply_token_bitmask_inplace_triton(
logits: paddle.Tensor,
bitmask: paddle.Tensor,
vocab_size: Optional[int] = None,
indices: Optional[List[int]] = None,
):
"""Applies vocabulary mask to logits tensor using Triton GPU kernel.
Args:
logits: Input logits tensor of shape [batch_size, vocab_size]
bitmask: Compressed mask tensor (int32) where each bit represents a token
vocab_size: Optional explicit vocabulary size (defaults to auto-detected)
indices: Optional list of batch indices to apply mask to
Note:
Requires CUDA GPU with Triton support
Bitmask must be int32 tensor with shape [batch_size, ceil(vocab_size/32)]
"""
NUM_SMS = paddle.device.cuda.get_device_properties().multi_processor_count
BLOCK_SIZE = 4096
assert bitmask.dtype == paddle.int32, "bitmask must be of type int32"
detected_vocab_size = min(logits.shape[-1], bitmask.shape[-1] * 32)
if vocab_size is None:
vocab_size = detected_vocab_size
else:
assert (
vocab_size <= detected_vocab_size
), f"vocab_size {vocab_size} is larger than the detected vocab_size {detected_vocab_size}"
num_rows = len(indices) if indices is not None else logits.shape[0] if logits.ndim == 2 else 1
if indices is not None:
indices = paddle.to_tensor(indices, dtype=paddle.int32, place=logits.place)
grid = (NUM_SMS,)
apply_token_bitmask_inplace_kernel[grid](
logits,
bitmask,
indices,
num_rows,
vocab_size,
logits.shape[-1],
bitmask.shape[-1],
NUM_SMS,
BLOCK_SIZE,
num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
num_stages=3,
)
@@ -24,7 +24,7 @@ import torch
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
from fastdeploy.model_executor.guided_decoding import (
BackendBase,
BaseChecker,
LogitsProcessorBase,
@@ -57,7 +57,6 @@ class XGrammarProcessor(LogitsProcessorBase):
max_rollback_tokens (int): Maximum number of tokens to rollback on mismatch
vocab_size (int): Size of the vocabulary
batch_size (int): Batch size for processing
splitwise_role (str): Role for splitwise processing
compiled_grammar (CompiledGrammar): Compiled grammar rules
terminate_without_stop_token (bool): Whether to terminate without stop token
override_stop_tokens (Optional[List[int]]): Custom stop tokens
@@ -71,13 +70,12 @@ class XGrammarProcessor(LogitsProcessorBase):
override_stop_tokens: Optional[List[int]] = None,
vocab_size: Optional[int] = None,
batch_size: Optional[int] = None,
splitwise_role: str = "mixed",
enable_thinking: bool = False,
):
super().__init__()
super().__init__(enable_reasoning=enable_thinking)
self.max_rollback_tokens = 200
self.vocab_size = vocab_size
self.batch_size = batch_size
self.splitwise_role = splitwise_role
self.compiled_grammar = compiled_grammar
self.terminate_without_stop_token = terminate_without_stop_token
self.override_stop_tokens = override_stop_tokens
@@ -188,7 +186,6 @@ class XGrammarProcessor(LogitsProcessorBase):
override_stop_tokens=self.override_stop_tokens,
vocab_size=self.vocab_size,
batch_size=self.batch_size,
splitwise_role=self.splitwise_role,
)
@@ -203,7 +200,6 @@ class XGrammarBackend(BackendBase):
vocab_size (int): Size of the vocabulary from config
batch_size (int): Maximum batch size from config
any_whitespace (bool): Whether to allow any whitespace in JSON
splitwise_role (str): Role for splitwise processing
grammar_compiler (GrammarCompiler): Grammar compilation engine
"""
@@ -217,7 +213,6 @@ class XGrammarBackend(BackendBase):
self.batch_size = fd_config.parallel_config.max_num_seqs
self.any_whitespace = not fd_config.parallel_config.disable_any_whitespace
self.splitwise_role = fd_config.parallel_config.splitwise_role
try:
tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size)
@@ -230,6 +225,7 @@ class XGrammarBackend(BackendBase):
compiled_grammar: CompiledGrammar,
terminate_without_stop_token: bool = False,
override_stop_tokens: Optional[List[int]] = None,
enable_thinking: bool = False,
) -> XGrammarProcessor:
"""
Create a logits processor instance for the given compiled grammar.
@@ -238,6 +234,7 @@ class XGrammarBackend(BackendBase):
compiled_grammar (CompiledGrammar): Compiled grammar rules
terminate_without_stop_token (bool): Whether to terminate without stop token
override_stop_tokens (Optional[List[int]]): Custom stop tokens to override defaults
enable_thinking (bool): Whether to enable thinking mode
Returns:
XGrammarProcessor: Configured grammar processor instance
@@ -248,15 +245,16 @@ class XGrammarBackend(BackendBase):
override_stop_tokens=override_stop_tokens,
vocab_size=self.vocab_size,
batch_size=self.batch_size,
splitwise_role=self.splitwise_role,
enable_thinking=enable_thinking,
)
def _json_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
def _json_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
"""
Compile JSON schema into a grammar processor.
Args:
schemata (str): JSON schema string to compile
enable_thinking (bool): Whether to enable thinking mode
Returns:
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
@@ -266,14 +264,15 @@ class XGrammarBackend(BackendBase):
except Exception as e:
llm_logger.error(f"Failed to compile json schema: {e}, {str(traceback.format_exc())}")
return None
return self._create_processor(compiled_grammar)
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
def _regex_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
def _regex_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
"""
Compile regex pattern into a grammar processor.
Args:
schemata (str): Regex pattern string to compile
enable_thinking (bool): Whether to enable thinking mode
Returns:
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
@@ -283,14 +282,15 @@ class XGrammarBackend(BackendBase):
except Exception as e:
llm_logger.error(f"Failed to compile regex schema: {e}, {str(traceback.format_exc())}")
return None
return self._create_processor(compiled_grammar)
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
def _grammar_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
def _grammar_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
"""
Compile grammar (EBNF) into a grammar processor.
Args:
schemata (str): Grammar string in EBNF format
enable_thinking (bool): Whether to enable thinking mode
Returns:
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
@@ -300,9 +300,9 @@ class XGrammarBackend(BackendBase):
except Exception as e:
llm_logger.error(f"Failed to compile ebnf schema: {e}, {str(traceback.format_exc())}")
return None
return self._create_processor(compiled_grammar)
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
def _structural_tag_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
def _structural_tag_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
"""
Compile structural tags into a grammar processor.
@@ -327,7 +327,7 @@ class XGrammarBackend(BackendBase):
except Exception as e:
llm_logger.error(f"Failed to compile structural tags schema: {e}, {str(traceback.format_exc())}")
return None
return self._create_processor(compiled_grammar)
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
class XGrammarChecker(BaseChecker):
@@ -23,9 +23,7 @@ import paddle.nn.functional as F
from paddle import nn
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
LogitsProcessorBase,
)
from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
from fastdeploy.model_executor.layers.sample.early_stopper import (
get_early_stopper_cls_from_stragegy,
)
@@ -37,6 +35,7 @@ from fastdeploy.model_executor.layers.sample.ops import (
top_k_top_p_sampling,
)
from fastdeploy.platforms import current_platform
from fastdeploy.reasoning import ReasoningParser
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
@@ -63,6 +62,10 @@ class SamplerProcessor:
self.logits_processor: Dict[int, Optional[Any]] = dict()
self.executor = ThreadPoolExecutor()
self.logits_lock = threading.Lock()
self.reasoning_parser = None
def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
self.reasoning_parser = reasoning_parser
def add_logits_processor(
self,
@@ -139,9 +142,14 @@ class SamplerProcessor:
if available_processors is None:
return logits
indices = list(self.logits_processor.keys())
mask_idx = [i for i in indices if i not in skip_idx_list]
return available_processors.apply_token_mask(logits, self.token_bitmask, indices=mask_idx)
indices = []
for idx, processor in self.logits_processor.items():
if processor is None or idx in skip_idx_list:
continue
if self.reasoning_parser is None or not processor.enable_reasoning or processor.reasoning_ended:
indices.append(idx)
return available_processors.apply_token_mask(logits, self.token_bitmask, indices=indices)
def _accept_token(self, idx: int, token: int):
"""accept token"""
@@ -151,6 +159,15 @@ class SamplerProcessor:
if self.logits_processor[idx].is_terminated():
return
if (
self.reasoning_parser is not None
and self.logits_processor[idx].enable_reasoning
and not self.logits_processor[idx].reasoning_ended
):
reasoning_ended = self.reasoning_parser.is_reasoning_end([token])
self.logits_processor[idx].reasoning_ended = reasoning_ended
return
self.logits_processor[idx].accept_token(token)
def update_output_tokens(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
@@ -206,12 +223,11 @@ class Sampler(nn.Layer):
self.early_stopper = early_stopper_cls()
self.early_stopper.initialize(fd_config.parallel_config.max_num_seqs, fd_config.early_stop_config)
def apply_logits_processor(
self,
ids: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = [],
):
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
"""set reasoning parser"""
self.processor.apply_reasoning_parser(reasoning_parser)
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
"""apply logits processor to sampler"""
self.processor.add_logits_processor(ids, future, prefill_tokens)
@@ -219,6 +235,10 @@ class Sampler(nn.Layer):
"""pre process before running"""
self.processor.pre_process(skip_idx_list)
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
"""post process after running"""
self.processor.update_output_tokens(next_tokens, skip_idx_list)
def compute_logprobs(
self,
logits: paddle.Tensor,
@@ -307,12 +327,12 @@ class Sampler(nn.Layer):
skip_idx_list: List[int] = [],
) -> SamplerOutput:
""" """
logits = self.processor.apply_token_mask(logits, skip_idx_list)
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
logits = self.processor.apply_token_mask(logits, skip_idx_list)
logits = apply_penalty_multi_scores(
sampling_metadata.pre_token_ids,
sampling_metadata.prompt_ids,
@@ -347,8 +367,6 @@ class Sampler(nn.Layer):
assert sampling_metadata.stop_flags is not None, "need stop_flags for eary stop"
self.early_stopper.process(probs, next_tokens, sampling_metadata.stop_flags)
self.processor.update_output_tokens(next_tokens, skip_idx_list)
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
@@ -380,12 +398,15 @@ class SpeculativeSampler(nn.Layer):
"""pre process before running"""
pass
def apply_logits_processor(
self,
ids: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = [],
):
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
"""set reasoning parser"""
pass
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
"""post process after running"""
pass
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
"""apply logits processor to sampler"""
pass
@@ -480,6 +501,14 @@ class MTPSampler(nn.Layer):
"""apply logits processor to sampler"""
pass
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
"""set reasoning parser"""
pass
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
"""post process after running"""
pass
def forward_cuda(
self,
logits: paddle.Tensor,