[Feature] support logits processors (#4515)

* [feat] provide an interface for logits processors and a builtin LogitBiasLogitsProcessor

* [chore] fix code style

* [fix] add unit test & fix existing bugs

* [feat] add engine/worker arg --logits-processors

* [fix] redefine user args as logits_processors_args and fix some bugs

* [fix] fix test_sampler

* Update fastdeploy/model_executor/logits_processor/builtin.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/model_executor/logits_processor/__init__.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/model_executor/test_logits_processor.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* [fix] fix typo

* Update fastdeploy/engine/sampling_params.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* [fix] fix bracelet

* [chore] redefine logits processor interface: pass the entire share_inputs into LP, do not copy share_inputs and logits

* [doc] add docs

* [fix] fix logit bias processor not applied when decoding is too fast & add docs and tests

* [fix] fix redundant code

* [feat] skip apply() if no bias is specified

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
李泳桦
2025-10-29 00:08:53 +08:00
committed by GitHub
parent 24b9505971
commit a012e3608b
18 changed files with 882 additions and 14 deletions
+200
View File
@@ -0,0 +1,200 @@
# Logits Processors
## Overview
A **Logits Processor (LP)** sits between *model output logits* and the *sampler* (top-k/top-p/temperature…). It applies pluggable transformations to logits **before** sampling (e.g., weighting, masking, penalties, biases).
## Key Features
- **Server-level registration**: declare available processors at startup via `--logits-processors`. The declaration order is the execution order.
- **Per-request control**: enable and configure processors via the `logits_processors_args` field in the request body.
- **Built-in processor**: commonly used processors are provided, e.g., `LogitBiasLogitsProcessor`, which can be loaded directly by class name.
- **Extensible interface**: a standard `LogitsProcessor` interface is provided for user-defined processors, which can be loaded by FQCN `module.path:ClassName`.
## Usage
### Online Service
#### 1. Start the service (register logits processors)
Register processors with `--logits-processors` when starting the service. For a built-in processor like `LogitBiasLogitsProcessor`, pass the class name directly:
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--model /path/to/model \
--port 8180 --metrics-port 8181 --engine-worker-queue-port 8182 --cache-queue-port 8183 \
--logits-processors LogitBiasLogitsProcessor
```
#### 2. Send a request (enable and configure as needed)
Use the `logits_processors_args` field in the REST request body to enable and configure processors. Example with `LogitBiasLogitsProcessor`, which adds a bias to specified tokens. It accepts a `logit_bias` dictionary mapping *token_id**bias value*:
```bash
curl -X POST "http://0.0.0.0:8180/v1/chat/completions" -H "Content-Type: application/json" -d '{
"messages": [{"role":"user", "content":"Who is Lu Xun?"}],
"logits_processors_args": {
"logit_bias": {"128": 5.0, "50256": -10.0}
}
}'
```
When using the OpenAI Python SDK, pass `logits_processors_args` through `extra_body`:
```python
import openai
client = openai.Client(base_url="http://0.0.0.0:8180/v1", api_key="EMPTY_API_KEY")
response = client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Who is Lu Xun?"}],
extra_body={
"logits_processors_args": {
"logit_bias": {"128": 5.0, "50256": -10.0}
}
}
)
```
### Offline Inference
For offline inference, pass the `logits_processors` argument (type `list[str]`) when initializing the `LLM` instance. When generating text via the offline `chat()` or `generate()` APIs, provide the logits-processor parameters through `sampling_params`.`logits_processors_args` to enable and pass arguments to the corresponding processors.
```python
from fastdeploy import LLM, SamplingParams
llm = LLM(
model="path/to/model",
engine_worker_queue_port=8282,
cache_queue_port=8383,
logits_processors=['LogitBiasLogitsProcessor'],
)
messages = [{"role": "user", "content": "Who is Lu Xun?"}]
sampling_params = SamplingParams(
top_p=0.95,
max_tokens=128,
logits_processors_args={"logit_bias": {128: 5.0, 50256: -10.0}},
)
outputs = llm.chat(messages, sampling_params)
print(outputs[0].outputs.text)
```
## Custom Logits Processor
### 1. Define your own LogitsProcessor class
Inherit the `fastdeploy.openai.logits_processor.LogitsProcessor` class and implement the `update_state()` and `apply()` methods.
- **`update_state()` is used to update the logits processor state.** The input is the inference backends runtime state `share_inputs`, and it returns nothing. You need to extract useful information from the runtime state to update the logits processors internal state.
- For example, in the following example, we retrieve the current batchs `logits_processors_args` from `share_inputs`, and then bulk-modify the enablement status of the logits processor for the current batch;
- When writing your class, you should predefine the parameter names for your logits processor, such as adding a request parameter `enable_your_logits_processor` to control whether your logits processor is enabled for a request;
- **`apply()` is used to actually modify the logits tensor.** Before `apply()` runs, the model will call `update_state()` to refresh the logits processor state. Therefore, ensure your `update_state()` correctly updates the state variables used by the logits processor.
- In the following example, we use `self.enabled`to determine whether each request in the current batch enables your logits processor, and adjust the logits tensor dynamically.
```python
from paddle import Tensor
from fastdeploy.config import FDConfig
from fastdeploy.openai.logits_processor import LogitsProcessor
class YourLogitsProcessor(LogitsProcessor):
def __init__(self, fd_config: FDConfig) -> None:
# Initialize your state variables here, for example obtain dtype, device, etc.
# from fd_config. You can freely set the state variables you need, and update them
# during each step of inference update_state()
self.enabled = None
return
def update_state(self, share_inputs: dict) -> None:
"""Called when there are new output tokens, prior to each forward pass.
Each field in the `share_inputs` dict typically stores information for all request
slots. It has a `stop_flags` array that indicates whether a slot currently has a
running request (`False` means the slot is active). Therefore, it is recommended to
filter entries by `stop_flags` to keep only data for the current batch.
"""
stop_flags = share_inputs["stop_flags"]
logits_processors_args = share_inputs["logits_processors_args"]
logits_processors_args = [a for a, f in zip(logits_processors_args, stop_flags) if not f]
# Update your state variables here to facilitate dynamically
# adjusting your logits processor behavior at each step of inference.
# The latest state should be read and used in the apply() method
self.enabled = [a.enable_your_logits_processor for a in logits_processors_args]
return
def apply(self, logits: Tensor) -> Tensor:
"""Apply LogitsProcessor to batch logits tensor.
The updated tensor must be returned but may be modified in-place.
"""
for i, e in enumerate(self.enabled):
# Implement your core logits transformation here, and return the modified logits tensor
logits[i] = ...
return logits
```
### 2. Use your logits processor via online service
#### 2.2. Start the service (register your logits processor)
When registering a custom processor, pass its **FQCN** (`module.path:ClassName`) to `--logits-processors`:
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--model /path/to/model \
--port 8180 --metrics-port 8181 --engine-worker-queue-port 8182 --cache-queue-port 8183 \
--logits-processors your.dotted.path.to.module:YourLogitsProcessor
```
#### 2.2. Send a request (enable and configure as needed)
Enable your processor per request via `logits_processors_args`:
```bash
curl -X POST "http://0.0.0.0:8180/v1/chat/completions" -H "Content-Type: application/json" -d '{
"messages": [{"role":"user", "content":"Who is Lu Xun?"}],
"logits_processors_args": {
"enable_your_logits_processor": true
}
}'
```
Using the OpenAI Python SDK:
```python
import openai
client = openai.Client(base_url="http://0.0.0.0:8180/v1", api_key="EMPTY_API_KEY")
response = client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Who is Lu Xun?"}],
extra_body={
"logits_processors_args": {
"enable_your_logits_processor": True
}
}
)
```
### 3. Use your logits processor via offline inference
For offline inference, pass the `logits_processors` argument (type `list[str]`) when initializing the `LLM` instance. To specify your custom logits processor, pass its FQCN (`module.path:ClassName`). When generating text via the offline `chat()` or `generate()` APIs, provide the logits-processor parameters through `sampling_params`.`logits_processors_args` to enable and pass arguments to the corresponding processors.
```python
from fastdeploy import LLM, SamplingParams
llm = LLM(
model="path/to/model",
engine_worker_queue_port=8282,
cache_queue_port=8383,
logits_processors=['your.dotted.path.to.module:YourLogitsProcessor'],
)
messages = [{"role": "user", "content": "Who is Lu Xun?"}]
sampling_params = SamplingParams(
top_p=0.95,
max_tokens=128,
logits_processors_args={"enable_your_logits_processor": True},
)
outputs = llm.chat(messages, sampling_params)
print(outputs[0].outputs.text)
```
+200
View File
@@ -0,0 +1,200 @@
# Logits Processors
## 概述
Logits ProcessorLP)位于“模型输出 logits → 采样器(top-k/top-p/temperature…)” 之间,用于在采样前对 logits 做可插拔的变换(加权、屏蔽、惩罚、偏置等)。
## 关键特性
- **服务级注册**:启动时用 `--logits-processors` 声明可用处理器,其中的声明顺序即 logits 处理器的执行顺序
- **请求级控制**:请求体通过 `logits_processors_args` 字段按需启用并传参
- **内置处理器**:提供常用处理器,如:`LogitBiasLogitsProcessor`,可直接按类名加载
- **可扩展接口**:提供 `LogitsProcessor` 类的标准接口,支持用户基于此接口编写自定义处理器,并按 FQCN 加载:`module.path:ClassName`
## 使用方法
### 在线服务
#### 1. 启动服务(注册 logits 处理器)
在启动服务时,通过 `--logits-processors` 参数注册处理器。如果应用内置的 logits 处理器,以 `LogitBiasLogitsProcessor` 为例,直接传入类名即可:
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--model /path/to/model \
--port 8180 --metrics-port 8181 --engine-worker-queue-port 8182 --cache-queue-port 8183 \
--logits-processors LogitBiasLogitsProcessor
```
#### 2. 发送请求(按需启用并传参)
通过 RESTful API 发送请求时,通过 `logits_processors_args` 字段启用并传参,**不同的 logits 处理器需要不同的参数**。以 `LogitBiasLogitsProcessor` 为例,该处理器用于对指定 token 添加偏置。它接收 `logit_bias` 参数,为一个 dict 字典,表示 token id 到偏置值的映射。
```bash
curl -X POST "http://0.0.0.0:8180/v1/chat/completions" -H "Content-Type: application/json" -d \
'{
"messages": [{"role":"user", "content":"今天天气真好"}],
"logits_processors_args": {
"logit_bias": {"128": 5.0, "50256": -10.0},
}
}'
```
通过 OpenAI Python SDK 发送请求时,通过 `extra_body` 参数传入 `logits_processor_args` 字段启用并传参。
```python
import openai
client = openai.Client(base_url=f"http://0.0.0.0:8180/v1", api_key="EMPTY_API_KEY")
response = client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "今天天气真好"}],
extra_body={
"logits_processors_args": {
"logit_bias": {"128": 5.0, "50256": -10.0},
}
}
)
```
### 离线推理
离线调用场景,在初始化 `LLM` 实例时传入 `logits_processors` 参数,类型为 `list[str]`。在调用离线 `chat()``generate()` 接口生成文本时,通过 `sampling_params`.`logits_processors_args` 传入 logits 处理器参数,启用并传参给对应处理器。
```python
from fastdeploy import LLM, SamplingParams
llm = LLM(
model="path/to/model",
engine_worker_queue_port=8282,
cache_queue_port=8383,
logits_processors=['LogitBiasLogitsProcessor'],
)
messages = [{"role": "user", "content": "鲁迅是谁"}]
sampling_params = SamplingParams(
top_p=0.95,
max_tokens=128,
logits_processors_args={"logit_bias": {128: 5.0, 50256: -10.0}},
)
outputs = llm.chat(messages, sampling_params)
print(outputs[0].outputs.text)
```
## 自定义 Logits Processor
### 1. 定义自己的 LogitsProcessor 类
继承 `fastdeploy.openai.logits_processor.LogitsProcessor` 类,实现 `update_state()``apply()` 方法。
- **`update_state()` 用于更新 logits 处理器状态。** 输入为推理后端的推理状态 `share_inputs`,无需返回值。你需要从推理状态中提取对 logits 处理器状态更新的有用信息。
- 例如,在下面的示例中,我们从 `share_inputs` 中取出当前 batch 的 `logits_processors_args`,然后批量修改当前 batch 的 logits 处理器启用状态;
- 你需要在编写类时事先约定好你的 logits 处理器参数名,例如添加请求参数 `enable_your_logits_processor`,用于控制请求是否启用你的 logits 处理器;
- **`apply()` 用于实际修改 logits 张量。** 在 apply() 执行前,模型会调用 update_state() 方法,更新 logits 处理器状态。因此,请确保你的 update_state() 实现正确更新了 logits 处理器状态变量。
- 在下面的示例中,我们通过 `self.enabled` 判断当前 batch 各请求是否启用你的 logits 处理器,并动态调整 logits 张量。
```python
from paddle import Tensor
from fastdeploy.config import FDConfig
from fastdeploy.openai.logits_processor import LogitsProcessor
class YourLogitsProcessor(LogitsProcessor):
def __init__(self, fd_config: FDConfig) -> None:
# 在这里初始化你的状态变量,例如从 fd_config 取出 dtype, device 等信息
# 你可以自由设定需要存储的状态变量,并在每一步推理中通过 update_state() 方法更新
self.enabled = None
return
def update_state(self, share_inputs: dict) -> None:
"""Called when there are new output tokens, prior to each forward pass.
Each field in the `share_inputs` dict typically stores information for all request
slots. It has a `stop_flags` array that indicates whether a slot currently has a
running request (`False` means the slot is active). Therefore, it is recommended to
filter entries by `stop_flags` to keep only data for the current batch.
"""
stop_flags = share_inputs["stop_flags"]
logits_processors_args = share_inputs["logits_processors_args"]
logits_processors_args = [a for a, f in zip(logits_processors_args, stop_flags) if not f]
# 在这里更新你的状态变量,便于在每一步推理中动态调整你的 logits 处理器行为
# 最新的状态应该在 apply() 方法中读取并使用
self.enabled = [a.enable_your_logits_processor for a in logits_processors_args]
return
def apply(self, logits: Tensor) -> Tensor:
"""Apply LogitsProcessor to batch logits tensor.
The updated tensor must be returned but may be modified in-place.
"""
for i, e in enumerate(self.enabled):
# 在这里实现你的核心 logits 处理逻辑,并返回修改后的 logits 张量
logits[i] = ...
return logits
```
### 2. 通过在线服务使用自己的 logits 处理器
#### 2.1. 启动服务(注册自己的 logits 处理器)
在启动服务时,通过 `--logits-processors` 参数注册你的处理器。在传入自定义处理器时,需要传入 FQCNFully Qualified Class Name),即 `module.path:ClassName`
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--model /path/to/model \
--port 8180 --metrics-port 8181 --engine-worker-queue-port 8182 --cache-queue-port 8183 \
--logits-processors your.dotted.path.to.module:YourLogitsProcessor
```
#### 2.2. 发送请求(按需启用并传参)
通过 RESTful API 发送请求时,通过 `logits_processors_args` 字段启用并传参:
```bash
curl -X POST "http://0.0.0.0:8180/v1/chat/completions" -H "Content-Type: application/json" -d \
'{
"messages": [{"role":"user", "content":"今天天气真好"}],
"logits_processors_args": {
"enable_your_logits_processor": true
}
}'
```
通过 OpenAI Python SDK 发送请求时,通过 `extra_body` 参数传入 `logits_processor_args` 字段启用并传参:
```python
import openai
client = openai.Client(base_url=f"http://0.0.0.0:8180/v1", api_key="EMPTY_API_KEY")
response = client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "今天天气真好"}],
extra_body={
"logits_processors_args": {
"enable_your_logits_processor": True
}
}
)
```
### 3. 通过离线调用使用自己的 logits 处理器
在初始化 `LLM` 实例时传入 `logits_processors` 参数,类型为 `list[str]`。在传入自定义处理器时,需要传入 FQCNFully Qualified Class Name),即 `module.path:ClassName`。在调用离线 `chat()``generate()` 接口生成文本时,通过 `sampling_params`.`logits_processors_args` 传入 logits 处理器参数,启用并传参给对应处理器。
```python
from fastdeploy import LLM, SamplingParams
llm = LLM(
model="path/to/model",
engine_worker_queue_port=8282,
cache_queue_port=8383,
logits_processors=['your.dotted.path.to.module:YourLogitsProcessor'],
)
messages = [{"role": "user", "content": "鲁迅是谁"}]
sampling_params = SamplingParams(
top_p=0.95,
max_tokens=128,
logits_processors_args={"enable_your_logits_processor": True},
)
outputs = llm.chat(messages, sampling_params)
print(outputs[0].outputs.text)
```
+10
View File
@@ -970,6 +970,9 @@ class PlasAttentionConfig:
"""
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
def __str__(self) -> str:
return json.dumps({key: value for key, value in self.__dict__.items()})
class EarlyStopConfig:
def __init__(
@@ -1071,6 +1074,9 @@ class LoadConfig:
if hasattr(self, key):
setattr(self, key, value)
def __str__(self) -> str:
return json.dumps({key: value for key, value in self.__dict__.items()})
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""
@@ -1339,11 +1345,15 @@ class StructuredOutputsConfig:
self.guided_decoding_backend: Optional[str] = None
# disable any whitespace for guided decoding
self.disable_any_whitespace: bool = True
self.logits_processors: Optional[list[str]] = None
for key, value in args.items():
if hasattr(self, key) and value != "None":
setattr(self, key, value)
def __str__(self) -> str:
return json.dumps({key: value for key, value in self.__dict__.items()})
class FDConfig:
"""
+17
View File
@@ -422,6 +422,16 @@ class EngineArgs:
Flag to specify the dtype of lm_head as FP32. Default is False (Using model default dtype).
"""
logits_processors: Optional[List[str]] = None
"""
A list of FQCNs (Fully Qualified Class Names) of logits processors supported by the service.
A fully qualified class name (FQCN) is a string that uniquely identifies a class within a Python module.
- To enable builtin logits processors, add builtin module paths and class names to the list. Currently support:
- fastdeploy.model_executor.logits_processor:LogitBiasLogitsProcessor
- To enable custom logits processors, add your dotted paths to module and class names to the list.
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -687,6 +697,13 @@ class EngineArgs:
default=EngineArgs.lm_head_fp32,
help="Specify the dtype of lm_head weight as float32.",
)
model_group.add_argument(
"--logits-processors",
type=str,
nargs="+",
default=EngineArgs.logits_processors,
help="FQCNs (Fully Qualified Class Names) of logits processors supported by the service.",
)
# Parallel processing parameters group
parallel_group = parser.add_argument_group("Parallel Configuration")
+2
View File
@@ -535,6 +535,8 @@ class LLMEngine:
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
f" --logprobs_mode {self.cfg.model_config.logprobs_mode}"
)
if self.cfg.structured_outputs_config.logits_processors is not None:
arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}"
worker_append_flag = {
"enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel,
+21
View File
@@ -103,6 +103,7 @@ class SamplingParams:
bad_words: Optional[List[str]] = None
guided_decoding: Optional[GuidedDecodingParams] = None
bad_words_token_ids: Optional[List[int]] = None
logits_processors_args: Optional[dict[str, Any]] = None
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
@@ -136,6 +137,7 @@ class SamplingParams:
bad_words=None,
guided_decoding=None,
bad_words_token_ids=None,
logits_processors_args=None,
) -> SamplingParams:
"""Create instance from command line arguments"""
return cls(
@@ -158,6 +160,7 @@ class SamplingParams:
bad_words=bad_words,
guided_decoding=guided_decoding,
bad_words_token_ids=bad_words_token_ids,
logits_processors_args=logits_processors_args,
)
def __post_init__(self):
@@ -208,6 +211,24 @@ class SamplingParams:
if not 0 <= self.seed <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
# Verify logits processors arguments
if self.logits_processors_args is not None:
if self.logits_processors_args.get("logit_bias") is not None:
logit_bias = self.logits_processors_args.get("logit_bias")
if not isinstance(logit_bias, dict):
raise TypeError(f"logit_bias must be a dict, but got {type(logit_bias)}")
elif not all(isinstance(k, int) and isinstance(v, float) for k, v in logit_bias.items()):
# try to cast the dict to the correct type first
try:
cast_logit_bias = {}
for k, v in logit_bias.items():
cast_logit_bias[int(k)] = float(v)
self.logits_processors_args["logit_bias"] = cast_logit_bias
except:
raise TypeError(
"failed to cast logit_bias to the correct {key -> value} type, expected {int -> float}"
)
@dataclass
class BeamSearchParams:
@@ -266,14 +266,11 @@ class ResourceManagerV1(ResourceManager):
del self.req_dict[preempted_req.request_id]
self._free_blocks(preempted_req)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
main_process_metrics.num_requests_running.dec(1)
else:
self._free_blocks(preempted_req)
preempted_req.cached_block_num = 0
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
main_process_metrics.num_requests_waiting.inc(1)
main_process_metrics.num_requests_running.dec(1)
preempted_reqs.append(preempted_req)
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
@@ -651,8 +648,6 @@ class ResourceManagerV1(ResourceManager):
request, self.config.cache_config.block_size, request.num_computed_tokens
)
request.status = RequestStatus.RUNNING
main_process_metrics.num_requests_waiting.dec(1)
main_process_metrics.num_requests_running.inc(1)
if self.config.scheduler_config.splitwise_role == "mixed":
allocated_position = self.get_available_position()
request.idx = allocated_position
@@ -461,6 +461,7 @@ class CompletionRequest(BaseModel):
include_stop_str_in_output: Optional[bool] = False
bad_words: Optional[List[str]] = None
bad_words_token_ids: Optional[List[int]] = None
logits_processors_args: Optional[Dict] = None
# doc: end-completion-sampling-params
# doc: start-completion-extra-params
@@ -613,6 +614,7 @@ class ChatCompletionRequest(BaseModel):
bad_words_token_ids: Optional[List[int]] = None
repetition_penalty: Optional[float] = None
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
logits_processors_args: Optional[Dict] = None
# doc: end-chat-completion-sampling-params
# doc: start-chat-completion-extra-params
@@ -19,6 +19,8 @@ from typing import Dict, Optional
import paddle
from fastdeploy.model_executor.logits_processor import LogitsProcessor
@dataclass
class SamplingMetadata:
@@ -54,6 +56,7 @@ class SamplingMetadata:
temp_scaled_logprobs: Optional[paddle.Tensor] = None
top_p_normalized_logprobs: Optional[paddle.Tensor] = None
share_inputs: Optional[Dict[str, paddle.Tensor]] = None
logits_processors: Optional[list[LogitsProcessor]] = None
# Add for HPU post-processing
seq_lens_encoder: Optional[paddle.Tensor] = None
seq_lens_decoder: Optional[paddle.Tensor] = None
@@ -53,9 +53,9 @@ def top_p_normalize_probs_paddle(
return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1)
class SamplerProcessor:
class GuidedDecoding:
"""
SamplingProcessor for guided decoding.
processor for guided decoding.
"""
def __init__(self):
@@ -75,7 +75,7 @@ class SamplerProcessor:
future: Optional[Any] = None,
prefill_tokens: List[int] = [],
):
"""add logits processor to SamplerProcessor"""
"""add logits processor to GuidedDecoding"""
with self.logits_lock:
if future is None:
if ids in self.logits_processor:
@@ -216,7 +216,7 @@ class Sampler(nn.Layer):
else:
raise NotImplementedError
self.processor = SamplerProcessor()
self.guided_decoding = GuidedDecoding()
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
if (
@@ -230,19 +230,19 @@ class Sampler(nn.Layer):
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
"""set reasoning parser"""
self.processor.apply_reasoning_parser(reasoning_parser)
self.guided_decoding.apply_reasoning_parser(reasoning_parser)
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
"""apply logits processor to sampler"""
self.processor.add_logits_processor(ids, future, prefill_tokens)
self.guided_decoding.add_logits_processor(ids, future, prefill_tokens)
def pre_process(self, skip_idx_list: List[int] = []):
"""pre process before running"""
self.processor.pre_process(skip_idx_list)
self.guided_decoding.pre_process(skip_idx_list)
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
"""post process after running"""
self.processor.update_output_tokens(next_tokens, skip_idx_list)
self.guided_decoding.update_output_tokens(next_tokens, skip_idx_list)
def compute_logprobs(
self,
@@ -332,7 +332,7 @@ class Sampler(nn.Layer):
skip_idx_list: List[int] = [],
) -> SamplerOutput:
""" """
logits = self.processor.apply_token_mask(logits, skip_idx_list)
logits = self.guided_decoding.apply_token_mask(logits, skip_idx_list)
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
@@ -341,6 +341,9 @@ class Sampler(nn.Layer):
elif self.logprobs_mode == "raw_logits":
raw_logprobs = logits.clone()
for proc in sampling_metadata.logits_processors or []:
logits = proc.apply(logits)
logits = apply_penalty_multi_scores(
sampling_metadata.pre_token_ids,
sampling_metadata.prompt_ids,
@@ -0,0 +1,70 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from importlib import import_module
from .base import LogitsProcessor
from .builtin import LogitBiasLogitsProcessor
def load_class(spec: str):
"""
Load a class from a string spec.
If the spec is in the form 'package.module:ClassName', loads ClassName from the specified module.
If the spec does not contain a colon, it is treated as the name of a builtin class from
'fastdeploy.model_executor.logits_processor'.
Args:
spec (str): The class specifier string.
Returns:
type: The loaded class object.
Raises:
ValueError: If the spec is invalid.
ImportError: If the module cannot be imported.
AttributeError: If the class cannot be found in the module.
"""
try:
if ":" in spec:
module_path, class_name = spec.split(":", 1)
else:
module_path = "fastdeploy.model_executor.logits_processor"
class_name = spec
module = import_module(module_path)
obj = getattr(module, class_name)
return obj
except ValueError as e:
raise ValueError(f"Invalid spec {spec!r}; expected 'module:ClassName'.") from e
except ImportError as e:
raise ImportError(f"Failed to import module {module_path}") from e
except AttributeError as e:
raise AttributeError(f"Module {module_path} does not have attribute {class_name}") from e
def build_logits_processors(fd_config):
logit_procs = []
for fqcn in fd_config.structured_outputs_config.logits_processors or []:
logit_procs.append(load_class(fqcn)(fd_config))
return logit_procs
__all__ = [
"build_logits_processors",
"LogitsProcessor",
"LogitBiasLogitsProcessor",
]
@@ -0,0 +1,46 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from abc import ABC, abstractmethod
from paddle import Tensor
from fastdeploy.config import FDConfig
class LogitsProcessor(ABC):
@abstractmethod
def __init__(self, fd_config: FDConfig) -> None:
raise NotImplementedError
@abstractmethod
def update_state(self, share_inputs: dict) -> None:
"""Called when there are new output tokens, prior to each forward pass.
Each field in the `share_inputs` dict typically stores information for all request
slots. It has a `stop_flags` array that indicates whether a slot currently has a
running request (`False` means the slot is active). Therefore, it is recommended to
filter entries by `stop_flags` to keep only data for the current batch.
"""
raise NotImplementedError
@abstractmethod
def apply(self, logits: Tensor) -> Tensor:
"""Apply LogitsProcessor to batch logits tensor.
The updated tensor must be returned but may be modified in-place.
"""
raise NotImplementedError
@@ -0,0 +1,68 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.logits_processor.base import LogitsProcessor
class LogitBiasLogitsProcessor(LogitsProcessor):
"""
Maintains per-request logit biases and applies them to logits.
"""
def __init__(self, fd_config: FDConfig):
self.device = paddle.device.get_device()
self.dtype = fd_config.model_config.dtype
self.batch_ids: list[int] = []
self.token_ids: list[int] = []
self.biases: list[float] = []
def update_state(self, share_inputs: dict):
"""Build per-step logit-bias state from request slots and move it to device."""
# Retrive inference states from share_inputs
stop_flags = share_inputs["stop_flags"]
logits_processors_args = share_inputs["logits_processors_args"]
logits_processors_args = [a for a, f in zip(logits_processors_args, stop_flags) if not f]
# Get bias states for each request
self.batch_ids = []
self.token_ids: list[int] = []
self.biases: list[float] = []
for batch_id, logit_proc_args in enumerate(logits_processors_args):
tok_id_bias_map = logit_proc_args.get("logit_bias") or {}
self.batch_ids.extend([batch_id] * len(tok_id_bias_map))
self.token_ids.extend(tok_id_bias_map.keys())
self.biases.extend(tok_id_bias_map.values())
return
def apply(self, logits: paddle.Tensor) -> paddle.Tensor:
"""Apply logit bias to logits: [batch_size, vocab_size]"""
# Skip if no bias is applied
if len(self.biases) == 0:
return logits
# Make bias indices and bias tensor
bias_indices = (
paddle.tensor(self.batch_ids, dtype="int32").to(self.device),
paddle.tensor(self.token_ids, dtype="int32").to(self.device),
)
bias_tensor = paddle.tensor(self.biases, device=self.device, dtype=self.dtype)
logits[bias_indices] += bias_tensor
return logits
+14
View File
@@ -87,6 +87,7 @@ from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata
from fastdeploy.model_executor.logits_processor import build_logits_processors
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling
from fastdeploy.output.pooler import PoolerOutput
@@ -624,6 +625,9 @@ class GPUModelRunner(ModelRunnerBase):
else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
# For logits processors
self.share_inputs["logits_processors_args"][idx] = request.get("logits_processors_args") or {}
if len(multi_vision_inputs["images_lst"]) > 0:
self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs)
@@ -1188,6 +1192,11 @@ class GPUModelRunner(ModelRunnerBase):
)
self.share_inputs["image_features"] = None
# For logits processors
self.share_inputs["logits_processors"] = build_logits_processors(self.fd_config)
self.share_inputs["logits_processors_args"] = [{} for _ in range(max_num_seqs)]
logger.info(f"Enabled logits processors: {self.share_inputs['logits_processors']}")
def _prepare_inputs(self) -> None:
"""Prepare the model inputs"""
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
@@ -1265,6 +1274,7 @@ class GPUModelRunner(ModelRunnerBase):
stop_flags=self.share_inputs["stop_flags"],
temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"],
top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"],
logits_processors=self.share_inputs["logits_processors"],
share_inputs=self.share_inputs,
)
@@ -1993,6 +2003,10 @@ class GPUModelRunner(ModelRunnerBase):
self._prepare_inputs()
self.sampler.pre_process(skip_idx_list)
# 1.1 Update state of logits processor
for proc in self.sampling_metadata.logits_processors:
proc.update_state(self.share_inputs)
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model.
+8
View File
@@ -686,6 +686,14 @@ def parse_args():
help="Override configuration for the pooler.",
)
parser.add_argument(
"--logits-processors",
type=str,
nargs="+",
default=[],
help="FQCNs (Fully Qualified Class Names) of logits processors supported by the service.",
)
args = parser.parse_args()
return args
@@ -91,6 +91,7 @@ def llm(model_path):
cache_queue_port=FD_CACHE_QUEUE_PORT,
max_model_len=32768,
quantization="wint8",
logits_processors=["LogitBiasLogitsProcessor"],
)
# Wait for the port to be open
@@ -284,6 +285,45 @@ def test_seed(llm):
pytest.fail("Prompt generation test failed")
def test_logits_processors(llm):
"""
Test LogitBiasLogitsProcessor: token with extremely large logit bias should always be greedy-sampled
"""
messages = [{"role": "user", "content": "鲁迅是谁"}]
sampling_params = SamplingParams(
top_p=0.0,
max_tokens=128,
)
outputs = llm.chat(messages, sampling_params)
print("generated text:", outputs[0].outputs.text)
original_generated_text = outputs[0].outputs.text
# test request with logit bias
token_id_with_exlarge_bias = 123
messages = [{"role": "user", "content": "鲁迅是谁"}]
sampling_params = SamplingParams(
top_p=0.0,
max_tokens=128,
logits_processors_args={"logit_bias": {token_id_with_exlarge_bias: 100000}},
)
outputs = llm.chat(messages, sampling_params)
print("generated text:", outputs[0].outputs.text)
print("generated token ids:", outputs[0].outputs.token_ids)
print("expected token id:", token_id_with_exlarge_bias)
assert all(x == token_id_with_exlarge_bias for x in outputs[0].outputs.token_ids[:-1])
# test request without logit bias
messages = [{"role": "user", "content": "鲁迅是谁"}]
sampling_params = SamplingParams(
top_p=0.0,
max_tokens=128,
)
outputs = llm.chat(messages, sampling_params)
print("generated text:", outputs[0].outputs.text)
current_generated_text = outputs[0].outputs.text
assert current_generated_text == original_generated_text
if __name__ == "__main__":
"""
Main entry point for the test script.
+7
View File
@@ -60,6 +60,7 @@ def _create_default_sampling_metadata(
eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"),
min_p=paddle.randn([batch_size]),
seed=paddle.to_tensor([[2025]]),
logits_processors=None,
)
if max_num_logprobs is not None:
fake_sampling_metadata.max_num_logprobs = max_num_logprobs
@@ -89,6 +90,9 @@ def get_baseline_logprobs(logits, sampling_metadata, logprobs_mode, token_ids):
apply_penalty_multi_scores,
)
for proc in sampling_metadata.logits_processors or []:
logits = proc.apply(logits)
logits = apply_penalty_multi_scores(
sampling_metadata.pre_token_ids,
sampling_metadata.prompt_ids,
@@ -109,6 +113,9 @@ def get_baseline_logprobs(logits, sampling_metadata, logprobs_mode, token_ids):
apply_penalty_multi_scores,
)
for proc in sampling_metadata.logits_processors or []:
logits = proc.apply(logits)
logits = apply_penalty_multi_scores(
sampling_metadata.pre_token_ids,
sampling_metadata.prompt_ids,
@@ -0,0 +1,162 @@
import random
import unittest
from unittest.mock import Mock
import paddle
from fastdeploy.engine.request import Request
from fastdeploy.model_executor.logits_processor.builtin import LogitBiasLogitsProcessor
class TestLogitsProcessor(unittest.TestCase):
def setUp(self):
self.vocab_size = 10
self.max_num_seqs = 16
self.dtype = "float32"
self.share_inputs = {
"stop_flags": paddle.tensor([True for _ in range(self.max_num_seqs)]),
"logits_processors_args": [{} for _ in range(self.max_num_seqs)],
}
def create_request(self, **kwargs):
"""Create a mock request with specified logit bias"""
request = Mock(spec=Request)
for k, v in kwargs.items():
setattr(request, k, v)
return request
def create_logits(self):
return paddle.randn([self.get_batch_size(), self.vocab_size], dtype=self.dtype)
def add_request(self, req):
self.share_inputs["stop_flags"][req.idx] = False
self.share_inputs["logits_processors_args"][req.idx]["logit_bias"] = req.logit_bias
def del_request(self, req):
self.share_inputs["stop_flags"][req.idx] = True
self.share_inputs["logits_processors_args"][req.idx] = {}
def get_batch_size(self):
return self.max_num_seqs - sum(self.share_inputs["stop_flags"])
def test_logit_bias_logit_processor(self):
fd_config = Mock()
fd_config.model_config.dtype = self.dtype
logits_processor = LogitBiasLogitsProcessor(fd_config)
print("Phase 1: Empty batch")
logits = self.create_logits()
logits_processor.update_state(self.share_inputs)
processed_logits = logits_processor.apply(logits)
self.assertTrue(paddle.all(processed_logits == logits), "Logits should remain unchanged with empty batch")
print("Phase 2: Add first request")
request1 = self.create_request(
request_id="req1", idx=0, logit_bias={random.randint(0, self.vocab_size - 1): random.random() - 0.5}
)
self.add_request(request1)
logits = self.create_logits()
original_logits = logits.clone()
expected_logits = logits.clone()
logits_processor.update_state(self.share_inputs)
processed_logits = logits_processor.apply(logits)
batch_id = 0
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
if not flag:
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias", {})
for token_id, bias in logit_bias.items():
expected_logits[batch_id, token_id] += bias
batch_id += 1
self.assertTrue(
paddle.all(processed_logits == expected_logits),
f"Logits should be modified with req1 biases\n"
f"original: {original_logits}\n"
f"processed: {processed_logits}\n"
f"expected: {expected_logits}\n"
f"diff: {processed_logits-expected_logits}",
)
print("Phase 3: Add second request with multiple tokens to apply bias")
request2 = self.create_request(
request_id="req2",
idx=1,
logit_bias=dict(
zip(random.choices(range(0, self.vocab_size), k=3), [random.random() - 0.5 for _ in range(3)])
),
)
self.add_request(request2)
logits = self.create_logits()
original_logits = logits.clone()
expected_logits = logits.clone()
logits_processor.update_state(self.share_inputs)
processed_logits = logits_processor.apply(logits)
batch_id = 0
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
if not flag:
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias") or {}
for token_id, bias in logit_bias.items():
expected_logits[batch_id, token_id] += bias
batch_id += 1
self.assertTrue(
paddle.all(processed_logits == expected_logits),
"Logits should be modified with req1 and req2 biases\n"
f"original: {original_logits}\n"
f"processed: {processed_logits}\n"
f"expected: {expected_logits}\n"
f"diff: {processed_logits-expected_logits}",
)
print("Phase 4: Remove first request")
self.del_request(request1)
logits = self.create_logits()
original_logits = logits.clone()
expected_logits = logits.clone()
logits_processor.update_state(self.share_inputs)
processed_logits = logits_processor.apply(logits)
batch_id = 0
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
if not flag:
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias") or {}
for token_id, bias in logit_bias.items():
expected_logits[batch_id, token_id] += bias
batch_id += 1
self.assertTrue(
paddle.all(processed_logits == expected_logits),
"Logits should only have biases from request2 after removal\n"
f"original: {original_logits}\n"
f"processed: {processed_logits}\n"
f"expected: {expected_logits}\n"
f"diff: {processed_logits-expected_logits}",
)
print("Phase 5: Add third request with no logit bias")
request3 = self.create_request(request_id="req3", idx=0, logit_bias=None)
self.add_request(request3)
logits = self.create_logits()
original_logits = logits.clone()
expected_logits = logits.clone()
logits_processor.update_state(self.share_inputs)
processed_logits = logits_processor.apply(logits)
batch_id = 0
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
if not flag:
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias") or {}
for token_id, bias in logit_bias.items():
expected_logits[batch_id, token_id] += bias
batch_id += 1
self.assertTrue(
paddle.all(processed_logits == expected_logits),
"Logits should remain unchanged with request having no bias\n"
f"original: {original_logits}\n"
f"processed: {processed_logits}\n"
f"expected: {expected_logits}\n"
f"diff: {processed_logits-expected_logits}",
)
print("All test phases completed successfully!")
if __name__ == "__main__":
unittest.main()