mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] support logits processors (#4515)
* [feat] provide an interface for logits processors and a builtin LogitBiasLogitsProcessor * [chore] fix code style * [fix] add unit test & fix existing bugs * [feat] add engine/worker arg --logits-processors * [fix] redefine user args as logits_processors_args and fix some bugs * [fix] fix test_sampler * Update fastdeploy/model_executor/logits_processor/builtin.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/model_executor/logits_processor/__init__.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/model_executor/test_logits_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * [fix] fix typo * Update fastdeploy/engine/sampling_params.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * [fix] fix bracelet * [chore] redefine logits processor interface: pass the entire share_inputs into LP, do not copy share_inputs and logits * [doc] add docs * [fix] fix logit bias processor not applied when decoding is too fast & add docs and tests * [fix] fix redundant code * [feat] skip apply() if no bias is specified --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,200 @@
|
||||
# Logits Processors
|
||||
|
||||
## Overview
|
||||
|
||||
A **Logits Processor (LP)** sits between *model output logits* and the *sampler* (top-k/top-p/temperature…). It applies pluggable transformations to logits **before** sampling (e.g., weighting, masking, penalties, biases).
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Server-level registration**: declare available processors at startup via `--logits-processors`. The declaration order is the execution order.
|
||||
- **Per-request control**: enable and configure processors via the `logits_processors_args` field in the request body.
|
||||
- **Built-in processor**: commonly used processors are provided, e.g., `LogitBiasLogitsProcessor`, which can be loaded directly by class name.
|
||||
- **Extensible interface**: a standard `LogitsProcessor` interface is provided for user-defined processors, which can be loaded by FQCN `module.path:ClassName`.
|
||||
|
||||
## Usage
|
||||
|
||||
### Online Service
|
||||
|
||||
#### 1. Start the service (register logits processors)
|
||||
|
||||
Register processors with `--logits-processors` when starting the service. For a built-in processor like `LogitBiasLogitsProcessor`, pass the class name directly:
|
||||
|
||||
```bash
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model /path/to/model \
|
||||
--port 8180 --metrics-port 8181 --engine-worker-queue-port 8182 --cache-queue-port 8183 \
|
||||
--logits-processors LogitBiasLogitsProcessor
|
||||
```
|
||||
|
||||
#### 2. Send a request (enable and configure as needed)
|
||||
|
||||
Use the `logits_processors_args` field in the REST request body to enable and configure processors. Example with `LogitBiasLogitsProcessor`, which adds a bias to specified tokens. It accepts a `logit_bias` dictionary mapping *token_id* → *bias value*:
|
||||
|
||||
```bash
|
||||
curl -X POST "http://0.0.0.0:8180/v1/chat/completions" -H "Content-Type: application/json" -d '{
|
||||
"messages": [{"role":"user", "content":"Who is Lu Xun?"}],
|
||||
"logits_processors_args": {
|
||||
"logit_bias": {"128": 5.0, "50256": -10.0}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
When using the OpenAI Python SDK, pass `logits_processors_args` through `extra_body`:
|
||||
|
||||
```python
|
||||
import openai
|
||||
|
||||
client = openai.Client(base_url="http://0.0.0.0:8180/v1", api_key="EMPTY_API_KEY")
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Who is Lu Xun?"}],
|
||||
extra_body={
|
||||
"logits_processors_args": {
|
||||
"logit_bias": {"128": 5.0, "50256": -10.0}
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Offline Inference
|
||||
|
||||
For offline inference, pass the `logits_processors` argument (type `list[str]`) when initializing the `LLM` instance. When generating text via the offline `chat()` or `generate()` APIs, provide the logits-processor parameters through `sampling_params`.`logits_processors_args` to enable and pass arguments to the corresponding processors.
|
||||
|
||||
```python
|
||||
from fastdeploy import LLM, SamplingParams
|
||||
|
||||
llm = LLM(
|
||||
model="path/to/model",
|
||||
engine_worker_queue_port=8282,
|
||||
cache_queue_port=8383,
|
||||
logits_processors=['LogitBiasLogitsProcessor'],
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Who is Lu Xun?"}]
|
||||
sampling_params = SamplingParams(
|
||||
top_p=0.95,
|
||||
max_tokens=128,
|
||||
logits_processors_args={"logit_bias": {128: 5.0, 50256: -10.0}},
|
||||
)
|
||||
outputs = llm.chat(messages, sampling_params)
|
||||
print(outputs[0].outputs.text)
|
||||
```
|
||||
|
||||
## Custom Logits Processor
|
||||
|
||||
### 1. Define your own LogitsProcessor class
|
||||
Inherit the `fastdeploy.openai.logits_processor.LogitsProcessor` class and implement the `update_state()` and `apply()` methods.
|
||||
|
||||
- **`update_state()` is used to update the logits processor state.** The input is the inference backend’s runtime state `share_inputs`, and it returns nothing. You need to extract useful information from the runtime state to update the logits processor’s internal state.
|
||||
- For example, in the following example, we retrieve the current batch’s `logits_processors_args` from `share_inputs`, and then bulk-modify the enablement status of the logits processor for the current batch;
|
||||
- When writing your class, you should predefine the parameter names for your logits processor, such as adding a request parameter `enable_your_logits_processor` to control whether your logits processor is enabled for a request;
|
||||
- **`apply()` is used to actually modify the logits tensor.** Before `apply()` runs, the model will call `update_state()` to refresh the logits processor state. Therefore, ensure your `update_state()` correctly updates the state variables used by the logits processor.
|
||||
- In the following example, we use `self.enabled`to determine whether each request in the current batch enables your logits processor, and adjust the logits tensor dynamically.
|
||||
```python
|
||||
from paddle import Tensor
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.openai.logits_processor import LogitsProcessor
|
||||
|
||||
class YourLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, fd_config: FDConfig) -> None:
|
||||
# Initialize your state variables here, for example obtain dtype, device, etc.
|
||||
# from fd_config. You can freely set the state variables you need, and update them
|
||||
# during each step of inference update_state()
|
||||
self.enabled = None
|
||||
return
|
||||
|
||||
def update_state(self, share_inputs: dict) -> None:
|
||||
"""Called when there are new output tokens, prior to each forward pass.
|
||||
|
||||
Each field in the `share_inputs` dict typically stores information for all request
|
||||
slots. It has a `stop_flags` array that indicates whether a slot currently has a
|
||||
running request (`False` means the slot is active). Therefore, it is recommended to
|
||||
filter entries by `stop_flags` to keep only data for the current batch.
|
||||
"""
|
||||
stop_flags = share_inputs["stop_flags"]
|
||||
logits_processors_args = share_inputs["logits_processors_args"]
|
||||
logits_processors_args = [a for a, f in zip(logits_processors_args, stop_flags) if not f]
|
||||
# Update your state variables here to facilitate dynamically
|
||||
# adjusting your logits processor behavior at each step of inference.
|
||||
# The latest state should be read and used in the apply() method
|
||||
self.enabled = [a.enable_your_logits_processor for a in logits_processors_args]
|
||||
return
|
||||
|
||||
def apply(self, logits: Tensor) -> Tensor:
|
||||
"""Apply LogitsProcessor to batch logits tensor.
|
||||
|
||||
The updated tensor must be returned but may be modified in-place.
|
||||
"""
|
||||
for i, e in enumerate(self.enabled):
|
||||
# Implement your core logits transformation here, and return the modified logits tensor
|
||||
logits[i] = ...
|
||||
return logits
|
||||
```
|
||||
|
||||
### 2. Use your logits processor via online service
|
||||
|
||||
#### 2.2. Start the service (register your logits processor)
|
||||
|
||||
When registering a custom processor, pass its **FQCN** (`module.path:ClassName`) to `--logits-processors`:
|
||||
|
||||
```bash
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model /path/to/model \
|
||||
--port 8180 --metrics-port 8181 --engine-worker-queue-port 8182 --cache-queue-port 8183 \
|
||||
--logits-processors your.dotted.path.to.module:YourLogitsProcessor
|
||||
```
|
||||
|
||||
#### 2.2. Send a request (enable and configure as needed)
|
||||
|
||||
Enable your processor per request via `logits_processors_args`:
|
||||
|
||||
```bash
|
||||
curl -X POST "http://0.0.0.0:8180/v1/chat/completions" -H "Content-Type: application/json" -d '{
|
||||
"messages": [{"role":"user", "content":"Who is Lu Xun?"}],
|
||||
"logits_processors_args": {
|
||||
"enable_your_logits_processor": true
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
Using the OpenAI Python SDK:
|
||||
|
||||
```python
|
||||
import openai
|
||||
|
||||
client = openai.Client(base_url="http://0.0.0.0:8180/v1", api_key="EMPTY_API_KEY")
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Who is Lu Xun?"}],
|
||||
extra_body={
|
||||
"logits_processors_args": {
|
||||
"enable_your_logits_processor": True
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Use your logits processor via offline inference
|
||||
|
||||
For offline inference, pass the `logits_processors` argument (type `list[str]`) when initializing the `LLM` instance. To specify your custom logits processor, pass its FQCN (`module.path:ClassName`). When generating text via the offline `chat()` or `generate()` APIs, provide the logits-processor parameters through `sampling_params`.`logits_processors_args` to enable and pass arguments to the corresponding processors.
|
||||
|
||||
```python
|
||||
from fastdeploy import LLM, SamplingParams
|
||||
|
||||
llm = LLM(
|
||||
model="path/to/model",
|
||||
engine_worker_queue_port=8282,
|
||||
cache_queue_port=8383,
|
||||
logits_processors=['your.dotted.path.to.module:YourLogitsProcessor'],
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Who is Lu Xun?"}]
|
||||
sampling_params = SamplingParams(
|
||||
top_p=0.95,
|
||||
max_tokens=128,
|
||||
logits_processors_args={"enable_your_logits_processor": True},
|
||||
)
|
||||
outputs = llm.chat(messages, sampling_params)
|
||||
print(outputs[0].outputs.text)
|
||||
```
|
||||
@@ -0,0 +1,200 @@
|
||||
# Logits Processors
|
||||
|
||||
## 概述
|
||||
|
||||
Logits Processor(LP)位于“模型输出 logits → 采样器(top-k/top-p/temperature…)” 之间,用于在采样前对 logits 做可插拔的变换(加权、屏蔽、惩罚、偏置等)。
|
||||
|
||||
## 关键特性
|
||||
|
||||
- **服务级注册**:启动时用 `--logits-processors` 声明可用处理器,其中的声明顺序即 logits 处理器的执行顺序
|
||||
- **请求级控制**:请求体通过 `logits_processors_args` 字段按需启用并传参
|
||||
- **内置处理器**:提供常用处理器,如:`LogitBiasLogitsProcessor`,可直接按类名加载
|
||||
- **可扩展接口**:提供 `LogitsProcessor` 类的标准接口,支持用户基于此接口编写自定义处理器,并按 FQCN 加载:`module.path:ClassName`
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 在线服务
|
||||
|
||||
#### 1. 启动服务(注册 logits 处理器)
|
||||
|
||||
在启动服务时,通过 `--logits-processors` 参数注册处理器。如果应用内置的 logits 处理器,以 `LogitBiasLogitsProcessor` 为例,直接传入类名即可:
|
||||
|
||||
```bash
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model /path/to/model \
|
||||
--port 8180 --metrics-port 8181 --engine-worker-queue-port 8182 --cache-queue-port 8183 \
|
||||
--logits-processors LogitBiasLogitsProcessor
|
||||
```
|
||||
|
||||
#### 2. 发送请求(按需启用并传参)
|
||||
|
||||
通过 RESTful API 发送请求时,通过 `logits_processors_args` 字段启用并传参,**不同的 logits 处理器需要不同的参数**。以 `LogitBiasLogitsProcessor` 为例,该处理器用于对指定 token 添加偏置。它接收 `logit_bias` 参数,为一个 dict 字典,表示 token id 到偏置值的映射。
|
||||
```bash
|
||||
curl -X POST "http://0.0.0.0:8180/v1/chat/completions" -H "Content-Type: application/json" -d \
|
||||
'{
|
||||
"messages": [{"role":"user", "content":"今天天气真好"}],
|
||||
"logits_processors_args": {
|
||||
"logit_bias": {"128": 5.0, "50256": -10.0},
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
通过 OpenAI Python SDK 发送请求时,通过 `extra_body` 参数传入 `logits_processor_args` 字段启用并传参。
|
||||
```python
|
||||
import openai
|
||||
|
||||
client = openai.Client(base_url=f"http://0.0.0.0:8180/v1", api_key="EMPTY_API_KEY")
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "今天天气真好"}],
|
||||
extra_body={
|
||||
"logits_processors_args": {
|
||||
"logit_bias": {"128": 5.0, "50256": -10.0},
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 离线推理
|
||||
|
||||
离线调用场景,在初始化 `LLM` 实例时传入 `logits_processors` 参数,类型为 `list[str]`。在调用离线 `chat()` 或 `generate()` 接口生成文本时,通过 `sampling_params`.`logits_processors_args` 传入 logits 处理器参数,启用并传参给对应处理器。
|
||||
|
||||
```python
|
||||
from fastdeploy import LLM, SamplingParams
|
||||
|
||||
llm = LLM(
|
||||
model="path/to/model",
|
||||
engine_worker_queue_port=8282,
|
||||
cache_queue_port=8383,
|
||||
logits_processors=['LogitBiasLogitsProcessor'],
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "鲁迅是谁"}]
|
||||
sampling_params = SamplingParams(
|
||||
top_p=0.95,
|
||||
max_tokens=128,
|
||||
logits_processors_args={"logit_bias": {128: 5.0, 50256: -10.0}},
|
||||
)
|
||||
outputs = llm.chat(messages, sampling_params)
|
||||
print(outputs[0].outputs.text)
|
||||
```
|
||||
|
||||
## 自定义 Logits Processor
|
||||
|
||||
### 1. 定义自己的 LogitsProcessor 类
|
||||
|
||||
继承 `fastdeploy.openai.logits_processor.LogitsProcessor` 类,实现 `update_state()` 和 `apply()` 方法。
|
||||
- **`update_state()` 用于更新 logits 处理器状态。** 输入为推理后端的推理状态 `share_inputs`,无需返回值。你需要从推理状态中提取对 logits 处理器状态更新的有用信息。
|
||||
- 例如,在下面的示例中,我们从 `share_inputs` 中取出当前 batch 的 `logits_processors_args`,然后批量修改当前 batch 的 logits 处理器启用状态;
|
||||
- 你需要在编写类时事先约定好你的 logits 处理器参数名,例如添加请求参数 `enable_your_logits_processor`,用于控制请求是否启用你的 logits 处理器;
|
||||
- **`apply()` 用于实际修改 logits 张量。** 在 apply() 执行前,模型会调用 update_state() 方法,更新 logits 处理器状态。因此,请确保你的 update_state() 实现正确更新了 logits 处理器状态变量。
|
||||
- 在下面的示例中,我们通过 `self.enabled` 判断当前 batch 各请求是否启用你的 logits 处理器,并动态调整 logits 张量。
|
||||
|
||||
```python
|
||||
|
||||
from paddle import Tensor
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.openai.logits_processor import LogitsProcessor
|
||||
|
||||
class YourLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, fd_config: FDConfig) -> None:
|
||||
# 在这里初始化你的状态变量,例如从 fd_config 取出 dtype, device 等信息
|
||||
# 你可以自由设定需要存储的状态变量,并在每一步推理中通过 update_state() 方法更新
|
||||
self.enabled = None
|
||||
return
|
||||
|
||||
def update_state(self, share_inputs: dict) -> None:
|
||||
"""Called when there are new output tokens, prior to each forward pass.
|
||||
|
||||
Each field in the `share_inputs` dict typically stores information for all request
|
||||
slots. It has a `stop_flags` array that indicates whether a slot currently has a
|
||||
running request (`False` means the slot is active). Therefore, it is recommended to
|
||||
filter entries by `stop_flags` to keep only data for the current batch.
|
||||
"""
|
||||
stop_flags = share_inputs["stop_flags"]
|
||||
logits_processors_args = share_inputs["logits_processors_args"]
|
||||
logits_processors_args = [a for a, f in zip(logits_processors_args, stop_flags) if not f]
|
||||
# 在这里更新你的状态变量,便于在每一步推理中动态调整你的 logits 处理器行为
|
||||
# 最新的状态应该在 apply() 方法中读取并使用
|
||||
self.enabled = [a.enable_your_logits_processor for a in logits_processors_args]
|
||||
return
|
||||
|
||||
def apply(self, logits: Tensor) -> Tensor:
|
||||
"""Apply LogitsProcessor to batch logits tensor.
|
||||
|
||||
The updated tensor must be returned but may be modified in-place.
|
||||
"""
|
||||
for i, e in enumerate(self.enabled):
|
||||
# 在这里实现你的核心 logits 处理逻辑,并返回修改后的 logits 张量
|
||||
logits[i] = ...
|
||||
return logits
|
||||
|
||||
```
|
||||
|
||||
### 2. 通过在线服务使用自己的 logits 处理器
|
||||
|
||||
#### 2.1. 启动服务(注册自己的 logits 处理器)
|
||||
|
||||
在启动服务时,通过 `--logits-processors` 参数注册你的处理器。在传入自定义处理器时,需要传入 FQCN(Fully Qualified Class Name),即 `module.path:ClassName`。
|
||||
|
||||
```bash
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model /path/to/model \
|
||||
--port 8180 --metrics-port 8181 --engine-worker-queue-port 8182 --cache-queue-port 8183 \
|
||||
--logits-processors your.dotted.path.to.module:YourLogitsProcessor
|
||||
```
|
||||
|
||||
#### 2.2. 发送请求(按需启用并传参)
|
||||
|
||||
通过 RESTful API 发送请求时,通过 `logits_processors_args` 字段启用并传参:
|
||||
```bash
|
||||
curl -X POST "http://0.0.0.0:8180/v1/chat/completions" -H "Content-Type: application/json" -d \
|
||||
'{
|
||||
"messages": [{"role":"user", "content":"今天天气真好"}],
|
||||
"logits_processors_args": {
|
||||
"enable_your_logits_processor": true
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
通过 OpenAI Python SDK 发送请求时,通过 `extra_body` 参数传入 `logits_processor_args` 字段启用并传参:
|
||||
|
||||
```python
|
||||
import openai
|
||||
|
||||
client = openai.Client(base_url=f"http://0.0.0.0:8180/v1", api_key="EMPTY_API_KEY")
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "今天天气真好"}],
|
||||
extra_body={
|
||||
"logits_processors_args": {
|
||||
"enable_your_logits_processor": True
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 3. 通过离线调用使用自己的 logits 处理器
|
||||
|
||||
在初始化 `LLM` 实例时传入 `logits_processors` 参数,类型为 `list[str]`。在传入自定义处理器时,需要传入 FQCN(Fully Qualified Class Name),即 `module.path:ClassName`。在调用离线 `chat()` 或 `generate()` 接口生成文本时,通过 `sampling_params`.`logits_processors_args` 传入 logits 处理器参数,启用并传参给对应处理器。
|
||||
|
||||
```python
|
||||
from fastdeploy import LLM, SamplingParams
|
||||
|
||||
llm = LLM(
|
||||
model="path/to/model",
|
||||
engine_worker_queue_port=8282,
|
||||
cache_queue_port=8383,
|
||||
logits_processors=['your.dotted.path.to.module:YourLogitsProcessor'],
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "鲁迅是谁"}]
|
||||
sampling_params = SamplingParams(
|
||||
top_p=0.95,
|
||||
max_tokens=128,
|
||||
logits_processors_args={"enable_your_logits_processor": True},
|
||||
)
|
||||
outputs = llm.chat(messages, sampling_params)
|
||||
print(outputs[0].outputs.text)
|
||||
```
|
||||
@@ -970,6 +970,9 @@ class PlasAttentionConfig:
|
||||
"""
|
||||
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
|
||||
|
||||
def __str__(self) -> str:
|
||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||
|
||||
|
||||
class EarlyStopConfig:
|
||||
def __init__(
|
||||
@@ -1071,6 +1074,9 @@ class LoadConfig:
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||
|
||||
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of output pooling in pooling models."""
|
||||
@@ -1339,11 +1345,15 @@ class StructuredOutputsConfig:
|
||||
self.guided_decoding_backend: Optional[str] = None
|
||||
# disable any whitespace for guided decoding
|
||||
self.disable_any_whitespace: bool = True
|
||||
self.logits_processors: Optional[list[str]] = None
|
||||
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key) and value != "None":
|
||||
setattr(self, key, value)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||
|
||||
|
||||
class FDConfig:
|
||||
"""
|
||||
|
||||
@@ -422,6 +422,16 @@ class EngineArgs:
|
||||
Flag to specify the dtype of lm_head as FP32. Default is False (Using model default dtype).
|
||||
"""
|
||||
|
||||
logits_processors: Optional[List[str]] = None
|
||||
"""
|
||||
A list of FQCNs (Fully Qualified Class Names) of logits processors supported by the service.
|
||||
A fully qualified class name (FQCN) is a string that uniquely identifies a class within a Python module.
|
||||
|
||||
- To enable builtin logits processors, add builtin module paths and class names to the list. Currently support:
|
||||
- fastdeploy.model_executor.logits_processor:LogitBiasLogitsProcessor
|
||||
- To enable custom logits processors, add your dotted paths to module and class names to the list.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to set default tokenizer if not provided.
|
||||
@@ -687,6 +697,13 @@ class EngineArgs:
|
||||
default=EngineArgs.lm_head_fp32,
|
||||
help="Specify the dtype of lm_head weight as float32.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--logits-processors",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=EngineArgs.logits_processors,
|
||||
help="FQCNs (Fully Qualified Class Names) of logits processors supported by the service.",
|
||||
)
|
||||
|
||||
# Parallel processing parameters group
|
||||
parallel_group = parser.add_argument_group("Parallel Configuration")
|
||||
|
||||
@@ -535,6 +535,8 @@ class LLMEngine:
|
||||
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
|
||||
f" --logprobs_mode {self.cfg.model_config.logprobs_mode}"
|
||||
)
|
||||
if self.cfg.structured_outputs_config.logits_processors is not None:
|
||||
arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}"
|
||||
|
||||
worker_append_flag = {
|
||||
"enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel,
|
||||
|
||||
@@ -103,6 +103,7 @@ class SamplingParams:
|
||||
bad_words: Optional[List[str]] = None
|
||||
guided_decoding: Optional[GuidedDecodingParams] = None
|
||||
bad_words_token_ids: Optional[List[int]] = None
|
||||
logits_processors_args: Optional[dict[str, Any]] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
|
||||
@@ -136,6 +137,7 @@ class SamplingParams:
|
||||
bad_words=None,
|
||||
guided_decoding=None,
|
||||
bad_words_token_ids=None,
|
||||
logits_processors_args=None,
|
||||
) -> SamplingParams:
|
||||
"""Create instance from command line arguments"""
|
||||
return cls(
|
||||
@@ -158,6 +160,7 @@ class SamplingParams:
|
||||
bad_words=bad_words,
|
||||
guided_decoding=guided_decoding,
|
||||
bad_words_token_ids=bad_words_token_ids,
|
||||
logits_processors_args=logits_processors_args,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -208,6 +211,24 @@ class SamplingParams:
|
||||
if not 0 <= self.seed <= 922337203685477580:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
|
||||
|
||||
# Verify logits processors arguments
|
||||
if self.logits_processors_args is not None:
|
||||
if self.logits_processors_args.get("logit_bias") is not None:
|
||||
logit_bias = self.logits_processors_args.get("logit_bias")
|
||||
if not isinstance(logit_bias, dict):
|
||||
raise TypeError(f"logit_bias must be a dict, but got {type(logit_bias)}")
|
||||
elif not all(isinstance(k, int) and isinstance(v, float) for k, v in logit_bias.items()):
|
||||
# try to cast the dict to the correct type first
|
||||
try:
|
||||
cast_logit_bias = {}
|
||||
for k, v in logit_bias.items():
|
||||
cast_logit_bias[int(k)] = float(v)
|
||||
self.logits_processors_args["logit_bias"] = cast_logit_bias
|
||||
except:
|
||||
raise TypeError(
|
||||
"failed to cast logit_bias to the correct {key -> value} type, expected {int -> float}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchParams:
|
||||
|
||||
@@ -266,14 +266,11 @@ class ResourceManagerV1(ResourceManager):
|
||||
del self.req_dict[preempted_req.request_id]
|
||||
self._free_blocks(preempted_req)
|
||||
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
|
||||
main_process_metrics.num_requests_running.dec(1)
|
||||
else:
|
||||
self._free_blocks(preempted_req)
|
||||
preempted_req.cached_block_num = 0
|
||||
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
|
||||
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
|
||||
main_process_metrics.num_requests_waiting.inc(1)
|
||||
main_process_metrics.num_requests_running.dec(1)
|
||||
preempted_reqs.append(preempted_req)
|
||||
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
|
||||
|
||||
@@ -651,8 +648,6 @@ class ResourceManagerV1(ResourceManager):
|
||||
request, self.config.cache_config.block_size, request.num_computed_tokens
|
||||
)
|
||||
request.status = RequestStatus.RUNNING
|
||||
main_process_metrics.num_requests_waiting.dec(1)
|
||||
main_process_metrics.num_requests_running.inc(1)
|
||||
if self.config.scheduler_config.splitwise_role == "mixed":
|
||||
allocated_position = self.get_available_position()
|
||||
request.idx = allocated_position
|
||||
|
||||
@@ -461,6 +461,7 @@ class CompletionRequest(BaseModel):
|
||||
include_stop_str_in_output: Optional[bool] = False
|
||||
bad_words: Optional[List[str]] = None
|
||||
bad_words_token_ids: Optional[List[int]] = None
|
||||
logits_processors_args: Optional[Dict] = None
|
||||
# doc: end-completion-sampling-params
|
||||
|
||||
# doc: start-completion-extra-params
|
||||
@@ -613,6 +614,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
bad_words_token_ids: Optional[List[int]] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
logits_processors_args: Optional[Dict] = None
|
||||
# doc: end-chat-completion-sampling-params
|
||||
|
||||
# doc: start-chat-completion-extra-params
|
||||
|
||||
@@ -19,6 +19,8 @@ from typing import Dict, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.logits_processor import LogitsProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
@@ -54,6 +56,7 @@ class SamplingMetadata:
|
||||
temp_scaled_logprobs: Optional[paddle.Tensor] = None
|
||||
top_p_normalized_logprobs: Optional[paddle.Tensor] = None
|
||||
share_inputs: Optional[Dict[str, paddle.Tensor]] = None
|
||||
logits_processors: Optional[list[LogitsProcessor]] = None
|
||||
# Add for HPU post-processing
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||
seq_lens_decoder: Optional[paddle.Tensor] = None
|
||||
|
||||
@@ -53,9 +53,9 @@ def top_p_normalize_probs_paddle(
|
||||
return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1)
|
||||
|
||||
|
||||
class SamplerProcessor:
|
||||
class GuidedDecoding:
|
||||
"""
|
||||
SamplingProcessor for guided decoding.
|
||||
processor for guided decoding.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -75,7 +75,7 @@ class SamplerProcessor:
|
||||
future: Optional[Any] = None,
|
||||
prefill_tokens: List[int] = [],
|
||||
):
|
||||
"""add logits processor to SamplerProcessor"""
|
||||
"""add logits processor to GuidedDecoding"""
|
||||
with self.logits_lock:
|
||||
if future is None:
|
||||
if ids in self.logits_processor:
|
||||
@@ -216,7 +216,7 @@ class Sampler(nn.Layer):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.processor = SamplerProcessor()
|
||||
self.guided_decoding = GuidedDecoding()
|
||||
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
|
||||
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
|
||||
if (
|
||||
@@ -230,19 +230,19 @@ class Sampler(nn.Layer):
|
||||
|
||||
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
|
||||
"""set reasoning parser"""
|
||||
self.processor.apply_reasoning_parser(reasoning_parser)
|
||||
self.guided_decoding.apply_reasoning_parser(reasoning_parser)
|
||||
|
||||
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
|
||||
"""apply logits processor to sampler"""
|
||||
self.processor.add_logits_processor(ids, future, prefill_tokens)
|
||||
self.guided_decoding.add_logits_processor(ids, future, prefill_tokens)
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
"""pre process before running"""
|
||||
self.processor.pre_process(skip_idx_list)
|
||||
self.guided_decoding.pre_process(skip_idx_list)
|
||||
|
||||
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
|
||||
"""post process after running"""
|
||||
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
||||
self.guided_decoding.update_output_tokens(next_tokens, skip_idx_list)
|
||||
|
||||
def compute_logprobs(
|
||||
self,
|
||||
@@ -332,7 +332,7 @@ class Sampler(nn.Layer):
|
||||
skip_idx_list: List[int] = [],
|
||||
) -> SamplerOutput:
|
||||
""" """
|
||||
logits = self.processor.apply_token_mask(logits, skip_idx_list)
|
||||
logits = self.guided_decoding.apply_token_mask(logits, skip_idx_list)
|
||||
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
@@ -341,6 +341,9 @@ class Sampler(nn.Layer):
|
||||
elif self.logprobs_mode == "raw_logits":
|
||||
raw_logprobs = logits.clone()
|
||||
|
||||
for proc in sampling_metadata.logits_processors or []:
|
||||
logits = proc.apply(logits)
|
||||
|
||||
logits = apply_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
sampling_metadata.prompt_ids,
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
from .base import LogitsProcessor
|
||||
from .builtin import LogitBiasLogitsProcessor
|
||||
|
||||
|
||||
def load_class(spec: str):
|
||||
"""
|
||||
Load a class from a string spec.
|
||||
|
||||
If the spec is in the form 'package.module:ClassName', loads ClassName from the specified module.
|
||||
If the spec does not contain a colon, it is treated as the name of a builtin class from
|
||||
'fastdeploy.model_executor.logits_processor'.
|
||||
|
||||
Args:
|
||||
spec (str): The class specifier string.
|
||||
|
||||
Returns:
|
||||
type: The loaded class object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the spec is invalid.
|
||||
ImportError: If the module cannot be imported.
|
||||
AttributeError: If the class cannot be found in the module.
|
||||
"""
|
||||
try:
|
||||
if ":" in spec:
|
||||
module_path, class_name = spec.split(":", 1)
|
||||
else:
|
||||
module_path = "fastdeploy.model_executor.logits_processor"
|
||||
class_name = spec
|
||||
module = import_module(module_path)
|
||||
obj = getattr(module, class_name)
|
||||
return obj
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid spec {spec!r}; expected 'module:ClassName'.") from e
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Failed to import module {module_path}") from e
|
||||
except AttributeError as e:
|
||||
raise AttributeError(f"Module {module_path} does not have attribute {class_name}") from e
|
||||
|
||||
|
||||
def build_logits_processors(fd_config):
|
||||
logit_procs = []
|
||||
for fqcn in fd_config.structured_outputs_config.logits_processors or []:
|
||||
logit_procs.append(load_class(fqcn)(fd_config))
|
||||
return logit_procs
|
||||
|
||||
|
||||
__all__ = [
|
||||
"build_logits_processors",
|
||||
"LogitsProcessor",
|
||||
"LogitBiasLogitsProcessor",
|
||||
]
|
||||
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from paddle import Tensor
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
|
||||
class LogitsProcessor(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, fd_config: FDConfig) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_state(self, share_inputs: dict) -> None:
|
||||
"""Called when there are new output tokens, prior to each forward pass.
|
||||
|
||||
Each field in the `share_inputs` dict typically stores information for all request
|
||||
slots. It has a `stop_flags` array that indicates whether a slot currently has a
|
||||
running request (`False` means the slot is active). Therefore, it is recommended to
|
||||
filter entries by `stop_flags` to keep only data for the current batch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, logits: Tensor) -> Tensor:
|
||||
"""Apply LogitsProcessor to batch logits tensor.
|
||||
|
||||
The updated tensor must be returned but may be modified in-place.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.logits_processor.base import LogitsProcessor
|
||||
|
||||
|
||||
class LogitBiasLogitsProcessor(LogitsProcessor):
|
||||
"""
|
||||
Maintains per-request logit biases and applies them to logits.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
self.device = paddle.device.get_device()
|
||||
self.dtype = fd_config.model_config.dtype
|
||||
self.batch_ids: list[int] = []
|
||||
self.token_ids: list[int] = []
|
||||
self.biases: list[float] = []
|
||||
|
||||
def update_state(self, share_inputs: dict):
|
||||
"""Build per-step logit-bias state from request slots and move it to device."""
|
||||
|
||||
# Retrive inference states from share_inputs
|
||||
stop_flags = share_inputs["stop_flags"]
|
||||
logits_processors_args = share_inputs["logits_processors_args"]
|
||||
logits_processors_args = [a for a, f in zip(logits_processors_args, stop_flags) if not f]
|
||||
|
||||
# Get bias states for each request
|
||||
self.batch_ids = []
|
||||
self.token_ids: list[int] = []
|
||||
self.biases: list[float] = []
|
||||
for batch_id, logit_proc_args in enumerate(logits_processors_args):
|
||||
tok_id_bias_map = logit_proc_args.get("logit_bias") or {}
|
||||
self.batch_ids.extend([batch_id] * len(tok_id_bias_map))
|
||||
self.token_ids.extend(tok_id_bias_map.keys())
|
||||
self.biases.extend(tok_id_bias_map.values())
|
||||
|
||||
return
|
||||
|
||||
def apply(self, logits: paddle.Tensor) -> paddle.Tensor:
|
||||
"""Apply logit bias to logits: [batch_size, vocab_size]"""
|
||||
# Skip if no bias is applied
|
||||
if len(self.biases) == 0:
|
||||
return logits
|
||||
|
||||
# Make bias indices and bias tensor
|
||||
bias_indices = (
|
||||
paddle.tensor(self.batch_ids, dtype="int32").to(self.device),
|
||||
paddle.tensor(self.token_ids, dtype="int32").to(self.device),
|
||||
)
|
||||
bias_tensor = paddle.tensor(self.biases, device=self.device, dtype=self.dtype)
|
||||
logits[bias_indices] += bias_tensor
|
||||
return logits
|
||||
@@ -87,6 +87,7 @@ from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata
|
||||
from fastdeploy.model_executor.logits_processor import build_logits_processors
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
|
||||
from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling
|
||||
from fastdeploy.output.pooler import PoolerOutput
|
||||
@@ -624,6 +625,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
|
||||
|
||||
# For logits processors
|
||||
self.share_inputs["logits_processors_args"][idx] = request.get("logits_processors_args") or {}
|
||||
|
||||
if len(multi_vision_inputs["images_lst"]) > 0:
|
||||
self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs)
|
||||
|
||||
@@ -1188,6 +1192,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
self.share_inputs["image_features"] = None
|
||||
|
||||
# For logits processors
|
||||
self.share_inputs["logits_processors"] = build_logits_processors(self.fd_config)
|
||||
self.share_inputs["logits_processors_args"] = [{} for _ in range(max_num_seqs)]
|
||||
logger.info(f"Enabled logits processors: {self.share_inputs['logits_processors']}")
|
||||
|
||||
def _prepare_inputs(self) -> None:
|
||||
"""Prepare the model inputs"""
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
@@ -1265,6 +1274,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
stop_flags=self.share_inputs["stop_flags"],
|
||||
temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"],
|
||||
top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"],
|
||||
logits_processors=self.share_inputs["logits_processors"],
|
||||
share_inputs=self.share_inputs,
|
||||
)
|
||||
|
||||
@@ -1993,6 +2003,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self._prepare_inputs()
|
||||
self.sampler.pre_process(skip_idx_list)
|
||||
|
||||
# 1.1 Update state of logits processor
|
||||
for proc in self.sampling_metadata.logits_processors:
|
||||
proc.update_state(self.share_inputs)
|
||||
|
||||
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
|
||||
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
||||
# when there is data on other runner, the current runner is required to execute part of the model.
|
||||
|
||||
@@ -686,6 +686,14 @@ def parse_args():
|
||||
help="Override configuration for the pooler.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--logits-processors",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=[],
|
||||
help="FQCNs (Fully Qualified Class Names) of logits processors supported by the service.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
@@ -91,6 +91,7 @@ def llm(model_path):
|
||||
cache_queue_port=FD_CACHE_QUEUE_PORT,
|
||||
max_model_len=32768,
|
||||
quantization="wint8",
|
||||
logits_processors=["LogitBiasLogitsProcessor"],
|
||||
)
|
||||
|
||||
# Wait for the port to be open
|
||||
@@ -284,6 +285,45 @@ def test_seed(llm):
|
||||
pytest.fail("Prompt generation test failed")
|
||||
|
||||
|
||||
def test_logits_processors(llm):
|
||||
"""
|
||||
Test LogitBiasLogitsProcessor: token with extremely large logit bias should always be greedy-sampled
|
||||
"""
|
||||
messages = [{"role": "user", "content": "鲁迅是谁"}]
|
||||
sampling_params = SamplingParams(
|
||||
top_p=0.0,
|
||||
max_tokens=128,
|
||||
)
|
||||
outputs = llm.chat(messages, sampling_params)
|
||||
print("generated text:", outputs[0].outputs.text)
|
||||
original_generated_text = outputs[0].outputs.text
|
||||
|
||||
# test request with logit bias
|
||||
token_id_with_exlarge_bias = 123
|
||||
messages = [{"role": "user", "content": "鲁迅是谁"}]
|
||||
sampling_params = SamplingParams(
|
||||
top_p=0.0,
|
||||
max_tokens=128,
|
||||
logits_processors_args={"logit_bias": {token_id_with_exlarge_bias: 100000}},
|
||||
)
|
||||
outputs = llm.chat(messages, sampling_params)
|
||||
print("generated text:", outputs[0].outputs.text)
|
||||
print("generated token ids:", outputs[0].outputs.token_ids)
|
||||
print("expected token id:", token_id_with_exlarge_bias)
|
||||
assert all(x == token_id_with_exlarge_bias for x in outputs[0].outputs.token_ids[:-1])
|
||||
|
||||
# test request without logit bias
|
||||
messages = [{"role": "user", "content": "鲁迅是谁"}]
|
||||
sampling_params = SamplingParams(
|
||||
top_p=0.0,
|
||||
max_tokens=128,
|
||||
)
|
||||
outputs = llm.chat(messages, sampling_params)
|
||||
print("generated text:", outputs[0].outputs.text)
|
||||
current_generated_text = outputs[0].outputs.text
|
||||
assert current_generated_text == original_generated_text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Main entry point for the test script.
|
||||
|
||||
@@ -60,6 +60,7 @@ def _create_default_sampling_metadata(
|
||||
eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"),
|
||||
min_p=paddle.randn([batch_size]),
|
||||
seed=paddle.to_tensor([[2025]]),
|
||||
logits_processors=None,
|
||||
)
|
||||
if max_num_logprobs is not None:
|
||||
fake_sampling_metadata.max_num_logprobs = max_num_logprobs
|
||||
@@ -89,6 +90,9 @@ def get_baseline_logprobs(logits, sampling_metadata, logprobs_mode, token_ids):
|
||||
apply_penalty_multi_scores,
|
||||
)
|
||||
|
||||
for proc in sampling_metadata.logits_processors or []:
|
||||
logits = proc.apply(logits)
|
||||
|
||||
logits = apply_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
sampling_metadata.prompt_ids,
|
||||
@@ -109,6 +113,9 @@ def get_baseline_logprobs(logits, sampling_metadata, logprobs_mode, token_ids):
|
||||
apply_penalty_multi_scores,
|
||||
)
|
||||
|
||||
for proc in sampling_metadata.logits_processors or []:
|
||||
logits = proc.apply(logits)
|
||||
|
||||
logits = apply_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
sampling_metadata.prompt_ids,
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
import random
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.model_executor.logits_processor.builtin import LogitBiasLogitsProcessor
|
||||
|
||||
|
||||
class TestLogitsProcessor(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.vocab_size = 10
|
||||
self.max_num_seqs = 16
|
||||
self.dtype = "float32"
|
||||
self.share_inputs = {
|
||||
"stop_flags": paddle.tensor([True for _ in range(self.max_num_seqs)]),
|
||||
"logits_processors_args": [{} for _ in range(self.max_num_seqs)],
|
||||
}
|
||||
|
||||
def create_request(self, **kwargs):
|
||||
"""Create a mock request with specified logit bias"""
|
||||
request = Mock(spec=Request)
|
||||
for k, v in kwargs.items():
|
||||
setattr(request, k, v)
|
||||
return request
|
||||
|
||||
def create_logits(self):
|
||||
return paddle.randn([self.get_batch_size(), self.vocab_size], dtype=self.dtype)
|
||||
|
||||
def add_request(self, req):
|
||||
self.share_inputs["stop_flags"][req.idx] = False
|
||||
self.share_inputs["logits_processors_args"][req.idx]["logit_bias"] = req.logit_bias
|
||||
|
||||
def del_request(self, req):
|
||||
self.share_inputs["stop_flags"][req.idx] = True
|
||||
self.share_inputs["logits_processors_args"][req.idx] = {}
|
||||
|
||||
def get_batch_size(self):
|
||||
return self.max_num_seqs - sum(self.share_inputs["stop_flags"])
|
||||
|
||||
def test_logit_bias_logit_processor(self):
|
||||
|
||||
fd_config = Mock()
|
||||
fd_config.model_config.dtype = self.dtype
|
||||
logits_processor = LogitBiasLogitsProcessor(fd_config)
|
||||
|
||||
print("Phase 1: Empty batch")
|
||||
logits = self.create_logits()
|
||||
logits_processor.update_state(self.share_inputs)
|
||||
processed_logits = logits_processor.apply(logits)
|
||||
self.assertTrue(paddle.all(processed_logits == logits), "Logits should remain unchanged with empty batch")
|
||||
|
||||
print("Phase 2: Add first request")
|
||||
request1 = self.create_request(
|
||||
request_id="req1", idx=0, logit_bias={random.randint(0, self.vocab_size - 1): random.random() - 0.5}
|
||||
)
|
||||
self.add_request(request1)
|
||||
logits = self.create_logits()
|
||||
original_logits = logits.clone()
|
||||
expected_logits = logits.clone()
|
||||
logits_processor.update_state(self.share_inputs)
|
||||
processed_logits = logits_processor.apply(logits)
|
||||
batch_id = 0
|
||||
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
|
||||
if not flag:
|
||||
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias", {})
|
||||
for token_id, bias in logit_bias.items():
|
||||
expected_logits[batch_id, token_id] += bias
|
||||
batch_id += 1
|
||||
self.assertTrue(
|
||||
paddle.all(processed_logits == expected_logits),
|
||||
f"Logits should be modified with req1 biases\n"
|
||||
f"original: {original_logits}\n"
|
||||
f"processed: {processed_logits}\n"
|
||||
f"expected: {expected_logits}\n"
|
||||
f"diff: {processed_logits-expected_logits}",
|
||||
)
|
||||
|
||||
print("Phase 3: Add second request with multiple tokens to apply bias")
|
||||
request2 = self.create_request(
|
||||
request_id="req2",
|
||||
idx=1,
|
||||
logit_bias=dict(
|
||||
zip(random.choices(range(0, self.vocab_size), k=3), [random.random() - 0.5 for _ in range(3)])
|
||||
),
|
||||
)
|
||||
self.add_request(request2)
|
||||
logits = self.create_logits()
|
||||
original_logits = logits.clone()
|
||||
expected_logits = logits.clone()
|
||||
logits_processor.update_state(self.share_inputs)
|
||||
processed_logits = logits_processor.apply(logits)
|
||||
batch_id = 0
|
||||
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
|
||||
if not flag:
|
||||
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias") or {}
|
||||
for token_id, bias in logit_bias.items():
|
||||
expected_logits[batch_id, token_id] += bias
|
||||
batch_id += 1
|
||||
self.assertTrue(
|
||||
paddle.all(processed_logits == expected_logits),
|
||||
"Logits should be modified with req1 and req2 biases\n"
|
||||
f"original: {original_logits}\n"
|
||||
f"processed: {processed_logits}\n"
|
||||
f"expected: {expected_logits}\n"
|
||||
f"diff: {processed_logits-expected_logits}",
|
||||
)
|
||||
|
||||
print("Phase 4: Remove first request")
|
||||
self.del_request(request1)
|
||||
logits = self.create_logits()
|
||||
original_logits = logits.clone()
|
||||
expected_logits = logits.clone()
|
||||
logits_processor.update_state(self.share_inputs)
|
||||
processed_logits = logits_processor.apply(logits)
|
||||
batch_id = 0
|
||||
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
|
||||
if not flag:
|
||||
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias") or {}
|
||||
for token_id, bias in logit_bias.items():
|
||||
expected_logits[batch_id, token_id] += bias
|
||||
batch_id += 1
|
||||
self.assertTrue(
|
||||
paddle.all(processed_logits == expected_logits),
|
||||
"Logits should only have biases from request2 after removal\n"
|
||||
f"original: {original_logits}\n"
|
||||
f"processed: {processed_logits}\n"
|
||||
f"expected: {expected_logits}\n"
|
||||
f"diff: {processed_logits-expected_logits}",
|
||||
)
|
||||
|
||||
print("Phase 5: Add third request with no logit bias")
|
||||
request3 = self.create_request(request_id="req3", idx=0, logit_bias=None)
|
||||
self.add_request(request3)
|
||||
logits = self.create_logits()
|
||||
original_logits = logits.clone()
|
||||
expected_logits = logits.clone()
|
||||
logits_processor.update_state(self.share_inputs)
|
||||
processed_logits = logits_processor.apply(logits)
|
||||
batch_id = 0
|
||||
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
|
||||
if not flag:
|
||||
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias") or {}
|
||||
for token_id, bias in logit_bias.items():
|
||||
expected_logits[batch_id, token_id] += bias
|
||||
batch_id += 1
|
||||
self.assertTrue(
|
||||
paddle.all(processed_logits == expected_logits),
|
||||
"Logits should remain unchanged with request having no bias\n"
|
||||
f"original: {original_logits}\n"
|
||||
f"processed: {processed_logits}\n"
|
||||
f"expected: {expected_logits}\n"
|
||||
f"diff: {processed_logits-expected_logits}",
|
||||
)
|
||||
|
||||
print("All test phases completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user