mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
[Feature][Sampling] Extend top-k_top-p sampling to all backends and unify greedy decoding with top_k=1 (#6894)
* update sampling * fix * fix * fix mtp * fix test
This commit is contained in:
@@ -23,18 +23,60 @@ Sampling strategies are used to determine how to select the next token from the
|
||||
|
||||
During deployment, you can choose the sampling algorithm by setting the environment variable `FD_SAMPLING_CLASS`. Available values are `base`, `base_non_truncated`, `air`, or `rejection`.
|
||||
|
||||
**Algorithms Supporting Only Top-p Sampling**
|
||||
|
||||
* `base` (default): Directly normalizes using the `top_p` value, favoring tokens with greater probabilities.
|
||||
* `base_non_truncated`: Strictly follows the Top-p sampling logic, first selecting the smallest set that reaches the cumulative probability of `top_p`, then normalizing these selected elements.
|
||||
* `air`: This algorithm is inspired by [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) and supports Top-p sampling.
|
||||
|
||||
**Algorithms Supporting Top-p and Top-k_Top-p Sampling**
|
||||
|
||||
* `rejection`: This algorithm is inspired by [flashinfer](https://github.com/flashinfer-ai/flashinfer) and allows flexible settings for `top_k` and `top_p` parameters for Top-p or Top-k_Top-p sampling.
|
||||
|
||||
## Configuration Method
|
||||
|
||||
### Greedy Sampling
|
||||
|
||||
1. During deployment, set the environment variable to select the sampling algorithm, default is base:
|
||||
|
||||
```bash
|
||||
export FD_SAMPLING_CLASS=rejection # base, base_non_truncated, or air
|
||||
```
|
||||
|
||||
2. When sending a request, specify the following parameters:
|
||||
|
||||
* Example request with curl:
|
||||
|
||||
```bash
|
||||
|
||||
curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "How old are you"}
|
||||
],
|
||||
"top_k": 1
|
||||
}'
|
||||
# or "top_p": 0.0
|
||||
```
|
||||
|
||||
* Example request with Python:
|
||||
|
||||
```python
|
||||
import openai
|
||||
host = "0.0.0.0"
|
||||
port = "8170"
|
||||
client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="null",
|
||||
messages=[
|
||||
{"role": "system", "content": "I'm a helpful AI assistant."},
|
||||
],
|
||||
stream=True,
|
||||
top_k=1 # or "top_p": 0.0
|
||||
)
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta:
|
||||
print(chunk.choices[0].delta.content, end='')
|
||||
print('\n')
|
||||
```
|
||||
|
||||
### Top-p Sampling
|
||||
|
||||
1. During deployment, set the environment variable to select the sampling algorithm, default is base:
|
||||
|
||||
@@ -21,18 +21,61 @@
|
||||
|
||||
在部署时,可以通过设置环境变量 `FD_SAMPLING_CLASS` 来选择采样算法。可选择的值有 `base`, `base_non_truncated`, `air`或 `rejection`。
|
||||
|
||||
**仅支持 Top-p Sampling 的算法**
|
||||
|
||||
* `base`(default):直接使用 `top_p` 的值进行归一化,倾向于采样概率更大的token。
|
||||
* `base_non_truncated`:严格按照 Top-p 采样的逻辑执行,首先选择使累积概率达到 `top_p` 的最小集合,然后对这些选择的元素进行归一化。
|
||||
* `air`:该算法参考 [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM)的实现,支持 Top-p 采样。
|
||||
|
||||
**支持 Top-p 和 Top-k_top-p 采样的算法**
|
||||
|
||||
* `rejection`:该算法参考 [flashinfer](https://github.com/flashinfer-ai/flashinfer) 的实现,支持灵活设置 `top_k` 和 `top_p` 参数进行 Top-p 或 Top-k_top-p 采样。
|
||||
|
||||
## 配置方式
|
||||
|
||||
### Greedy 采样
|
||||
|
||||
1. 在部署时,设置环境变量以选择采样算法,默认为base:
|
||||
|
||||
```bash
|
||||
export FD_SAMPLING_CLASS=rejection # base, base_non_truncated, or air
|
||||
```
|
||||
|
||||
2. 在发送请求时,指定top_p参数:
|
||||
|
||||
* 使用 curl 命令发送用户请求示例如下:
|
||||
|
||||
```bash
|
||||
|
||||
curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "How old are you"}
|
||||
],
|
||||
"top_k": 1
|
||||
}'
|
||||
# 或 "top_p": 0.0
|
||||
```
|
||||
|
||||
* 使用 python 脚本发送用户请求示例如下:
|
||||
|
||||
```python
|
||||
import openai
|
||||
host = "0.0.0.0"
|
||||
port = "8170"
|
||||
client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="null",
|
||||
messages=[
|
||||
{"role": "system", "content": "I'm a helpful AI assistant."},
|
||||
{"role": "user", "content": "把李白的静夜思改写为现代诗"},
|
||||
],
|
||||
stream=True,
|
||||
top_k=1 # 或 "top_p": 0.0
|
||||
)
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta:
|
||||
print(chunk.choices[0].delta.content, end='')
|
||||
print('\n')
|
||||
```
|
||||
|
||||
### Top-p 采样
|
||||
|
||||
1. 在部署时,设置环境变量以选择采样算法,默认为base:
|
||||
|
||||
@@ -148,6 +148,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
request.set("top_k", 1)
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request.set("top_p", _SAMPLING_EPS)
|
||||
request.set("top_k", 1)
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
|
||||
parts = request.request_id.split("_")
|
||||
@@ -237,6 +238,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
request["top_k"] = 1
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
request["top_k"] = 1
|
||||
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
|
||||
|
||||
@@ -291,6 +291,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
request["enable_thinking"] = model_status == "think_start"
|
||||
if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
request["top_k"] = 1
|
||||
if request.get("response_max_tokens") is not None and request.get("enable_thinking") is False:
|
||||
request["max_tokens"] = min(request["response_max_tokens"], request["max_tokens"])
|
||||
|
||||
|
||||
@@ -256,6 +256,7 @@ class PaddleOCRVLProcessor(TextProcessor):
|
||||
|
||||
if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
request["top_k"] = 1
|
||||
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
|
||||
|
||||
@@ -382,6 +382,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
request.set("top_k", 1)
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request.set("top_p", _SAMPLING_EPS)
|
||||
request.set("top_k", 1)
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
|
||||
parts = request.request_id.split("_")
|
||||
@@ -482,6 +483,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
request["top_k"] = 1
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
request["top_k"] = 1
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
|
||||
parts = request["request_id"].split("_")
|
||||
|
||||
@@ -156,6 +156,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
request.set("top_k", 1)
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request.set("top_p", _SAMPLING_EPS)
|
||||
request.set("top_k", 1)
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
|
||||
parts = request.request_id.split("_")
|
||||
@@ -251,6 +252,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
request.sampling_params.top_k = 1
|
||||
if request.sampling_params.top_p < _SAMPLING_EPS:
|
||||
request.sampling_params.top_p = _SAMPLING_EPS
|
||||
request.sampling_params.top_k = 1
|
||||
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
|
||||
|
||||
@@ -292,6 +292,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
request.enable_thinking = model_status == "think_start"
|
||||
if request.sampling_params.top_p is not None and request.sampling_params.top_p < _SAMPLING_EPS:
|
||||
request.sampling_params.top_p = _SAMPLING_EPS
|
||||
request.sampling_params.top_k = 1
|
||||
if request.sampling_params.response_max_tokens is not None and request.enable_thinking is False:
|
||||
request.sampling_params.max_tokens = min(
|
||||
request.sampling_params.response_max_tokens, request.sampling_params.max_tokens
|
||||
|
||||
@@ -256,6 +256,7 @@ class PaddleOCRVLProcessor(TextProcessor):
|
||||
|
||||
if request.sampling_params.top_p is not None and request.sampling_params.top_p < _SAMPLING_EPS:
|
||||
request.sampling_params.top_p = _SAMPLING_EPS
|
||||
request.sampling_params.top_k = 1
|
||||
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
|
||||
|
||||
@@ -368,6 +368,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
request.set("top_k", 1)
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request.set("top_p", _SAMPLING_EPS)
|
||||
request.set("top_k", 1)
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
|
||||
parts = request.request_id.split("_")
|
||||
@@ -475,6 +476,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
request.sampling_params.top_k = 1
|
||||
if request.sampling_params.top_p < _SAMPLING_EPS:
|
||||
request.sampling_params.top_p = _SAMPLING_EPS
|
||||
request.sampling_params.top_k = 1
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
|
||||
parts = request.request_id.split("_")
|
||||
|
||||
@@ -24,6 +24,8 @@ from fastdeploy.platforms import current_platform
|
||||
if current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import top_p_sampling as gcu_top_p_sampling
|
||||
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
_DETERMINISTIC_RNG_SEED = 42
|
||||
|
||||
|
||||
@@ -50,7 +52,6 @@ def top_k_top_p_sampling(
|
||||
used to specify the top_p corresponding to each query.
|
||||
top_k(Tensor|None, optional): A 1-D Tensor with type int64,
|
||||
used to specify the top_k corresponding to each query.
|
||||
Only used when FD_SAMPLING_CLASS is `rejection`.
|
||||
top_k_list(list|None, optional): CPU-side mirror of top_k as a Python list,
|
||||
used for fast host-side checks (e.g. all-greedy detection) without GPU sync.
|
||||
threshold(Tensor|None, optional): A 1-D Tensor with type float32, float16 and bfloat16,
|
||||
@@ -59,7 +60,6 @@ def top_k_top_p_sampling(
|
||||
used to specify the random seed for each query.
|
||||
seed(int, optional): the random seed. Default is -1,
|
||||
k(int): the number of top_k scores/ids to be returned. Default is 0.
|
||||
Only used when FD_SAMPLING_CLASS is `air`.
|
||||
mode(str): The mode to choose sampling strategy. If the mode is `truncated`, sampling will truncate the probability at top_p_value.
|
||||
If the mode is `non-truncated`, it will not be truncated. Default is `truncated`.
|
||||
Only used when FD_SAMPLING_CLASS is `air` or `base`.
|
||||
@@ -80,45 +80,26 @@ def top_k_top_p_sampling(
|
||||
if envs.FD_DETERMINISTIC_MODE:
|
||||
_reset_cuda_generator_for_determinism()
|
||||
|
||||
# Greedy decoding fast-path: top_k=1 is equivalent to argmax.
|
||||
# In non-rejection sampling modes, top_k is ignored by the backend,
|
||||
# so we must handle it explicitly.
|
||||
all_greedy = False
|
||||
if top_k_list is not None:
|
||||
all_greedy = all(k == 1 for k in top_k_list)
|
||||
elif top_k is not None:
|
||||
all_greedy = bool(paddle.all(top_k == 1))
|
||||
|
||||
if all_greedy:
|
||||
ids = paddle.argmax(x, axis=-1, keepdim=True)
|
||||
return None, ids
|
||||
|
||||
if top_p_class == "air":
|
||||
_, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode)
|
||||
elif top_p_class == "rejection":
|
||||
if top_p_class == "rejection":
|
||||
ids = rejection_top_p_sampling(x, top_p, top_k, top_k_list, seed, order)
|
||||
_ = None
|
||||
elif top_p_class == "base_non_truncated":
|
||||
if topp_seed is not None:
|
||||
topp_seed_device = paddle.empty(shape=topp_seed.shape, dtype=topp_seed.dtype)
|
||||
topp_seed_device.copy_(topp_seed, False)
|
||||
_, ids = paddle.tensor.top_p_sampling(
|
||||
x,
|
||||
top_p,
|
||||
threshold=threshold,
|
||||
topp_seed=topp_seed_device,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode="non-truncated",
|
||||
)
|
||||
else:
|
||||
if current_platform.is_gcu():
|
||||
_, ids = gcu_top_p_sampling(x, top_p)
|
||||
elif current_platform.is_dcu():
|
||||
from fastdeploy.model_executor.layers.backends import native_top_p_sampling
|
||||
if top_k_list and any(x > 0 for x in top_k_list):
|
||||
try:
|
||||
if current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.ops.iluvatar import (
|
||||
top_k_renorm_probs,
|
||||
)
|
||||
else:
|
||||
from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs
|
||||
x = top_k_renorm_probs(x, top_k)
|
||||
except ImportError:
|
||||
logger.warning("top_k sampling is not supported on current platform, skipping top_k filtering.")
|
||||
|
||||
_, ids = native_top_p_sampling(x, top_p)
|
||||
else:
|
||||
if top_p_class == "air":
|
||||
_, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode)
|
||||
|
||||
elif top_p_class == "base_non_truncated":
|
||||
if topp_seed is not None:
|
||||
topp_seed_device = paddle.empty(shape=topp_seed.shape, dtype=topp_seed.dtype)
|
||||
topp_seed_device.copy_(topp_seed, False)
|
||||
@@ -129,20 +110,30 @@ def top_k_top_p_sampling(
|
||||
topp_seed=topp_seed_device,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode="truncated",
|
||||
mode="non-truncated",
|
||||
)
|
||||
# Mixed batch: override top_k=1 rows with argmax.
|
||||
# Shape guard: in overlap/speculative paths, x may be padded (e.g. [8,V])
|
||||
# while top_k remains per-request (e.g. [2,1]). Skip when shapes disagree.
|
||||
if not all_greedy and top_k is not None and top_k.shape[0] == ids.shape[0]:
|
||||
has_greedy = (top_k_list is not None and any(k == 1 for k in top_k_list)) or (
|
||||
top_k_list is None and bool(paddle.any(top_k == 1))
|
||||
)
|
||||
if has_greedy:
|
||||
argmax_ids = paddle.argmax(x, axis=-1, keepdim=True)
|
||||
greedy_mask = top_k == 1
|
||||
ids = paddle.where(greedy_mask, argmax_ids, ids)
|
||||
else:
|
||||
if current_platform.is_gcu():
|
||||
_, ids = gcu_top_p_sampling(x, top_p)
|
||||
elif current_platform.is_dcu():
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
native_top_p_sampling,
|
||||
)
|
||||
|
||||
_, ids = native_top_p_sampling(x, top_p)
|
||||
else:
|
||||
if topp_seed is not None:
|
||||
topp_seed_device = paddle.empty(shape=topp_seed.shape, dtype=topp_seed.dtype)
|
||||
topp_seed_device.copy_(topp_seed, False)
|
||||
_, ids = paddle.tensor.top_p_sampling(
|
||||
x,
|
||||
top_p,
|
||||
threshold=threshold,
|
||||
topp_seed=topp_seed_device,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode="truncated",
|
||||
)
|
||||
return _, ids
|
||||
|
||||
|
||||
|
||||
@@ -820,7 +820,9 @@ class SpeculativeSampler(nn.Layer):
|
||||
token_num_output_cpu,
|
||||
increment_value,
|
||||
)
|
||||
_, target_tokens = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)
|
||||
_, target_tokens = top_k_top_p_sampling(
|
||||
probs, top_p=top_p, top_k=top_k, top_k_list=sampling_metadata.top_k_list, topp_seed=topp_seed
|
||||
)
|
||||
elif self.verify_strategy == VerifyStrategy.GREEDY:
|
||||
# GREEDY: deterministic argmax in target_tokens, no candidates needed
|
||||
target_tokens = paddle.argmax(probs, axis=-1)
|
||||
@@ -1071,7 +1073,9 @@ class SpeculativeSampler(nn.Layer):
|
||||
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
|
||||
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
|
||||
)
|
||||
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)
|
||||
_, sampled_token_ids = top_k_top_p_sampling(
|
||||
probs, top_p=top_p, top_k=top_k, top_k_list=sampling_metadata.top_k_list, topp_seed=topp_seed
|
||||
)
|
||||
|
||||
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
|
||||
probs,
|
||||
|
||||
@@ -448,6 +448,15 @@ def test_top_k_top_p_sampling_resets_cuda_rng_in_deterministic_mode(monkeypatch)
|
||||
mock_reset.assert_called_once()
|
||||
|
||||
|
||||
def mixed_mock(probs, *a, **k):
|
||||
ids = paddle.argmax(probs, axis=-1, keepdim=True)
|
||||
# 1 left non_zero token after renorm → greedy, or → renturn 99
|
||||
non_zero_count = (probs > 0).sum(axis=-1, keepdim=True)
|
||||
sampled = non_zero_count > 1
|
||||
ids = paddle.where(sampled, paddle.to_tensor([[99]], dtype="int64"), ids)
|
||||
return None, ids
|
||||
|
||||
|
||||
def test_top_k_1_returns_argmax(monkeypatch):
|
||||
"""top_k=1 should produce argmax results regardless of FD_DETERMINISTIC_MODE."""
|
||||
import sys
|
||||
@@ -480,7 +489,7 @@ def test_top_k_1_returns_argmax(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
sampling_mod.paddle.tensor,
|
||||
"top_p_sampling",
|
||||
lambda *a, **k: (None, paddle.to_tensor([[99], [99]], dtype="int64")),
|
||||
mixed_mock,
|
||||
)
|
||||
|
||||
_, ids_mixed = sampling_mod.top_k_top_p_sampling(probs, top_p, top_k_mixed, top_k_list_mixed)
|
||||
@@ -519,7 +528,7 @@ def test_top_k_1_returns_argmax_without_deterministic_mode(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
sampling_mod.paddle.tensor,
|
||||
"top_p_sampling",
|
||||
lambda *a, **k: (None, paddle.to_tensor([[99], [99]], dtype="int64")),
|
||||
mixed_mock,
|
||||
)
|
||||
|
||||
_, ids_mixed = sampling_mod.top_k_top_p_sampling(probs, top_p, top_k_mixed, top_k_list_mixed)
|
||||
|
||||
@@ -146,7 +146,7 @@ def test_model_against_baseline(
|
||||
|
||||
# Get baseline suffix from config
|
||||
model_config = hugging_face_model_param_map.get(model_name_or_path, {})
|
||||
baseline_suffix = model_config.get("baseline_suffix", "tp2-dev-0317")
|
||||
baseline_suffix = model_config.get("baseline_suffix", "tp2-dev-0311")
|
||||
baseline_filename = f"{model_name_or_path}-{baseline_suffix}"
|
||||
|
||||
if base_path:
|
||||
|
||||
Reference in New Issue
Block a user