[Feature] support logits processors (#4515)

* [feat] provide an interface for logits processors and a builtin LogitBiasLogitsProcessor

* [chore] fix code style

* [fix] add unit test & fix existing bugs

* [feat] add engine/worker arg --logits-processors

* [fix] redefine user args as logits_processors_args and fix some bugs

* [fix] fix test_sampler

* Update fastdeploy/model_executor/logits_processor/builtin.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/model_executor/logits_processor/__init__.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/model_executor/test_logits_processor.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* [fix] fix typo

* Update fastdeploy/engine/sampling_params.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* [fix] fix bracelet

* [chore] redefine logits processor interface: pass the entire share_inputs into LP, do not copy share_inputs and logits

* [doc] add docs

* [fix] fix logit bias processor not applied when decoding is too fast & add docs and tests

* [fix] fix redundant code

* [feat] skip apply() if no bias is specified

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
李泳桦
2025-10-29 00:08:53 +08:00
committed by GitHub
parent 24b9505971
commit a012e3608b
18 changed files with 882 additions and 14 deletions
+10
View File
@@ -970,6 +970,9 @@ class PlasAttentionConfig:
"""
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
def __str__(self) -> str:
return json.dumps({key: value for key, value in self.__dict__.items()})
class EarlyStopConfig:
def __init__(
@@ -1071,6 +1074,9 @@ class LoadConfig:
if hasattr(self, key):
setattr(self, key, value)
def __str__(self) -> str:
return json.dumps({key: value for key, value in self.__dict__.items()})
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""
@@ -1339,11 +1345,15 @@ class StructuredOutputsConfig:
self.guided_decoding_backend: Optional[str] = None
# disable any whitespace for guided decoding
self.disable_any_whitespace: bool = True
self.logits_processors: Optional[list[str]] = None
for key, value in args.items():
if hasattr(self, key) and value != "None":
setattr(self, key, value)
def __str__(self) -> str:
return json.dumps({key: value for key, value in self.__dict__.items()})
class FDConfig:
"""
+17
View File
@@ -422,6 +422,16 @@ class EngineArgs:
Flag to specify the dtype of lm_head as FP32. Default is False (Using model default dtype).
"""
logits_processors: Optional[List[str]] = None
"""
A list of FQCNs (Fully Qualified Class Names) of logits processors supported by the service.
A fully qualified class name (FQCN) is a string that uniquely identifies a class within a Python module.
- To enable builtin logits processors, add builtin module paths and class names to the list. Currently support:
- fastdeploy.model_executor.logits_processor:LogitBiasLogitsProcessor
- To enable custom logits processors, add your dotted paths to module and class names to the list.
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -687,6 +697,13 @@ class EngineArgs:
default=EngineArgs.lm_head_fp32,
help="Specify the dtype of lm_head weight as float32.",
)
model_group.add_argument(
"--logits-processors",
type=str,
nargs="+",
default=EngineArgs.logits_processors,
help="FQCNs (Fully Qualified Class Names) of logits processors supported by the service.",
)
# Parallel processing parameters group
parallel_group = parser.add_argument_group("Parallel Configuration")
+2
View File
@@ -535,6 +535,8 @@ class LLMEngine:
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
f" --logprobs_mode {self.cfg.model_config.logprobs_mode}"
)
if self.cfg.structured_outputs_config.logits_processors is not None:
arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}"
worker_append_flag = {
"enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel,
+21
View File
@@ -103,6 +103,7 @@ class SamplingParams:
bad_words: Optional[List[str]] = None
guided_decoding: Optional[GuidedDecodingParams] = None
bad_words_token_ids: Optional[List[int]] = None
logits_processors_args: Optional[dict[str, Any]] = None
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
@@ -136,6 +137,7 @@ class SamplingParams:
bad_words=None,
guided_decoding=None,
bad_words_token_ids=None,
logits_processors_args=None,
) -> SamplingParams:
"""Create instance from command line arguments"""
return cls(
@@ -158,6 +160,7 @@ class SamplingParams:
bad_words=bad_words,
guided_decoding=guided_decoding,
bad_words_token_ids=bad_words_token_ids,
logits_processors_args=logits_processors_args,
)
def __post_init__(self):
@@ -208,6 +211,24 @@ class SamplingParams:
if not 0 <= self.seed <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
# Verify logits processors arguments
if self.logits_processors_args is not None:
if self.logits_processors_args.get("logit_bias") is not None:
logit_bias = self.logits_processors_args.get("logit_bias")
if not isinstance(logit_bias, dict):
raise TypeError(f"logit_bias must be a dict, but got {type(logit_bias)}")
elif not all(isinstance(k, int) and isinstance(v, float) for k, v in logit_bias.items()):
# try to cast the dict to the correct type first
try:
cast_logit_bias = {}
for k, v in logit_bias.items():
cast_logit_bias[int(k)] = float(v)
self.logits_processors_args["logit_bias"] = cast_logit_bias
except:
raise TypeError(
"failed to cast logit_bias to the correct {key -> value} type, expected {int -> float}"
)
@dataclass
class BeamSearchParams:
@@ -266,14 +266,11 @@ class ResourceManagerV1(ResourceManager):
del self.req_dict[preempted_req.request_id]
self._free_blocks(preempted_req)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
main_process_metrics.num_requests_running.dec(1)
else:
self._free_blocks(preempted_req)
preempted_req.cached_block_num = 0
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
main_process_metrics.num_requests_waiting.inc(1)
main_process_metrics.num_requests_running.dec(1)
preempted_reqs.append(preempted_req)
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
@@ -651,8 +648,6 @@ class ResourceManagerV1(ResourceManager):
request, self.config.cache_config.block_size, request.num_computed_tokens
)
request.status = RequestStatus.RUNNING
main_process_metrics.num_requests_waiting.dec(1)
main_process_metrics.num_requests_running.inc(1)
if self.config.scheduler_config.splitwise_role == "mixed":
allocated_position = self.get_available_position()
request.idx = allocated_position
@@ -461,6 +461,7 @@ class CompletionRequest(BaseModel):
include_stop_str_in_output: Optional[bool] = False
bad_words: Optional[List[str]] = None
bad_words_token_ids: Optional[List[int]] = None
logits_processors_args: Optional[Dict] = None
# doc: end-completion-sampling-params
# doc: start-completion-extra-params
@@ -613,6 +614,7 @@ class ChatCompletionRequest(BaseModel):
bad_words_token_ids: Optional[List[int]] = None
repetition_penalty: Optional[float] = None
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
logits_processors_args: Optional[Dict] = None
# doc: end-chat-completion-sampling-params
# doc: start-chat-completion-extra-params
@@ -19,6 +19,8 @@ from typing import Dict, Optional
import paddle
from fastdeploy.model_executor.logits_processor import LogitsProcessor
@dataclass
class SamplingMetadata:
@@ -54,6 +56,7 @@ class SamplingMetadata:
temp_scaled_logprobs: Optional[paddle.Tensor] = None
top_p_normalized_logprobs: Optional[paddle.Tensor] = None
share_inputs: Optional[Dict[str, paddle.Tensor]] = None
logits_processors: Optional[list[LogitsProcessor]] = None
# Add for HPU post-processing
seq_lens_encoder: Optional[paddle.Tensor] = None
seq_lens_decoder: Optional[paddle.Tensor] = None
@@ -53,9 +53,9 @@ def top_p_normalize_probs_paddle(
return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1)
class SamplerProcessor:
class GuidedDecoding:
"""
SamplingProcessor for guided decoding.
processor for guided decoding.
"""
def __init__(self):
@@ -75,7 +75,7 @@ class SamplerProcessor:
future: Optional[Any] = None,
prefill_tokens: List[int] = [],
):
"""add logits processor to SamplerProcessor"""
"""add logits processor to GuidedDecoding"""
with self.logits_lock:
if future is None:
if ids in self.logits_processor:
@@ -216,7 +216,7 @@ class Sampler(nn.Layer):
else:
raise NotImplementedError
self.processor = SamplerProcessor()
self.guided_decoding = GuidedDecoding()
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
if (
@@ -230,19 +230,19 @@ class Sampler(nn.Layer):
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
"""set reasoning parser"""
self.processor.apply_reasoning_parser(reasoning_parser)
self.guided_decoding.apply_reasoning_parser(reasoning_parser)
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
"""apply logits processor to sampler"""
self.processor.add_logits_processor(ids, future, prefill_tokens)
self.guided_decoding.add_logits_processor(ids, future, prefill_tokens)
def pre_process(self, skip_idx_list: List[int] = []):
"""pre process before running"""
self.processor.pre_process(skip_idx_list)
self.guided_decoding.pre_process(skip_idx_list)
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
"""post process after running"""
self.processor.update_output_tokens(next_tokens, skip_idx_list)
self.guided_decoding.update_output_tokens(next_tokens, skip_idx_list)
def compute_logprobs(
self,
@@ -332,7 +332,7 @@ class Sampler(nn.Layer):
skip_idx_list: List[int] = [],
) -> SamplerOutput:
""" """
logits = self.processor.apply_token_mask(logits, skip_idx_list)
logits = self.guided_decoding.apply_token_mask(logits, skip_idx_list)
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
@@ -341,6 +341,9 @@ class Sampler(nn.Layer):
elif self.logprobs_mode == "raw_logits":
raw_logprobs = logits.clone()
for proc in sampling_metadata.logits_processors or []:
logits = proc.apply(logits)
logits = apply_penalty_multi_scores(
sampling_metadata.pre_token_ids,
sampling_metadata.prompt_ids,
@@ -0,0 +1,70 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from importlib import import_module
from .base import LogitsProcessor
from .builtin import LogitBiasLogitsProcessor
def load_class(spec: str):
"""
Load a class from a string spec.
If the spec is in the form 'package.module:ClassName', loads ClassName from the specified module.
If the spec does not contain a colon, it is treated as the name of a builtin class from
'fastdeploy.model_executor.logits_processor'.
Args:
spec (str): The class specifier string.
Returns:
type: The loaded class object.
Raises:
ValueError: If the spec is invalid.
ImportError: If the module cannot be imported.
AttributeError: If the class cannot be found in the module.
"""
try:
if ":" in spec:
module_path, class_name = spec.split(":", 1)
else:
module_path = "fastdeploy.model_executor.logits_processor"
class_name = spec
module = import_module(module_path)
obj = getattr(module, class_name)
return obj
except ValueError as e:
raise ValueError(f"Invalid spec {spec!r}; expected 'module:ClassName'.") from e
except ImportError as e:
raise ImportError(f"Failed to import module {module_path}") from e
except AttributeError as e:
raise AttributeError(f"Module {module_path} does not have attribute {class_name}") from e
def build_logits_processors(fd_config):
logit_procs = []
for fqcn in fd_config.structured_outputs_config.logits_processors or []:
logit_procs.append(load_class(fqcn)(fd_config))
return logit_procs
__all__ = [
"build_logits_processors",
"LogitsProcessor",
"LogitBiasLogitsProcessor",
]
@@ -0,0 +1,46 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from abc import ABC, abstractmethod
from paddle import Tensor
from fastdeploy.config import FDConfig
class LogitsProcessor(ABC):
@abstractmethod
def __init__(self, fd_config: FDConfig) -> None:
raise NotImplementedError
@abstractmethod
def update_state(self, share_inputs: dict) -> None:
"""Called when there are new output tokens, prior to each forward pass.
Each field in the `share_inputs` dict typically stores information for all request
slots. It has a `stop_flags` array that indicates whether a slot currently has a
running request (`False` means the slot is active). Therefore, it is recommended to
filter entries by `stop_flags` to keep only data for the current batch.
"""
raise NotImplementedError
@abstractmethod
def apply(self, logits: Tensor) -> Tensor:
"""Apply LogitsProcessor to batch logits tensor.
The updated tensor must be returned but may be modified in-place.
"""
raise NotImplementedError
@@ -0,0 +1,68 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.logits_processor.base import LogitsProcessor
class LogitBiasLogitsProcessor(LogitsProcessor):
"""
Maintains per-request logit biases and applies them to logits.
"""
def __init__(self, fd_config: FDConfig):
self.device = paddle.device.get_device()
self.dtype = fd_config.model_config.dtype
self.batch_ids: list[int] = []
self.token_ids: list[int] = []
self.biases: list[float] = []
def update_state(self, share_inputs: dict):
"""Build per-step logit-bias state from request slots and move it to device."""
# Retrive inference states from share_inputs
stop_flags = share_inputs["stop_flags"]
logits_processors_args = share_inputs["logits_processors_args"]
logits_processors_args = [a for a, f in zip(logits_processors_args, stop_flags) if not f]
# Get bias states for each request
self.batch_ids = []
self.token_ids: list[int] = []
self.biases: list[float] = []
for batch_id, logit_proc_args in enumerate(logits_processors_args):
tok_id_bias_map = logit_proc_args.get("logit_bias") or {}
self.batch_ids.extend([batch_id] * len(tok_id_bias_map))
self.token_ids.extend(tok_id_bias_map.keys())
self.biases.extend(tok_id_bias_map.values())
return
def apply(self, logits: paddle.Tensor) -> paddle.Tensor:
"""Apply logit bias to logits: [batch_size, vocab_size]"""
# Skip if no bias is applied
if len(self.biases) == 0:
return logits
# Make bias indices and bias tensor
bias_indices = (
paddle.tensor(self.batch_ids, dtype="int32").to(self.device),
paddle.tensor(self.token_ids, dtype="int32").to(self.device),
)
bias_tensor = paddle.tensor(self.biases, device=self.device, dtype=self.dtype)
logits[bias_indices] += bias_tensor
return logits
+14
View File
@@ -87,6 +87,7 @@ from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata
from fastdeploy.model_executor.logits_processor import build_logits_processors
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling
from fastdeploy.output.pooler import PoolerOutput
@@ -624,6 +625,9 @@ class GPUModelRunner(ModelRunnerBase):
else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
# For logits processors
self.share_inputs["logits_processors_args"][idx] = request.get("logits_processors_args") or {}
if len(multi_vision_inputs["images_lst"]) > 0:
self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs)
@@ -1188,6 +1192,11 @@ class GPUModelRunner(ModelRunnerBase):
)
self.share_inputs["image_features"] = None
# For logits processors
self.share_inputs["logits_processors"] = build_logits_processors(self.fd_config)
self.share_inputs["logits_processors_args"] = [{} for _ in range(max_num_seqs)]
logger.info(f"Enabled logits processors: {self.share_inputs['logits_processors']}")
def _prepare_inputs(self) -> None:
"""Prepare the model inputs"""
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
@@ -1265,6 +1274,7 @@ class GPUModelRunner(ModelRunnerBase):
stop_flags=self.share_inputs["stop_flags"],
temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"],
top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"],
logits_processors=self.share_inputs["logits_processors"],
share_inputs=self.share_inputs,
)
@@ -1993,6 +2003,10 @@ class GPUModelRunner(ModelRunnerBase):
self._prepare_inputs()
self.sampler.pre_process(skip_idx_list)
# 1.1 Update state of logits processor
for proc in self.sampling_metadata.logits_processors:
proc.update_state(self.share_inputs)
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model.
+8
View File
@@ -686,6 +686,14 @@ def parse_args():
help="Override configuration for the pooler.",
)
parser.add_argument(
"--logits-processors",
type=str,
nargs="+",
default=[],
help="FQCNs (Fully Qualified Class Names) of logits processors supported by the service.",
)
args = parser.parse_args()
return args