[Feature][Sampling] Extend top-k_top-p sampling to all backends and unify greedy decoding with top_k=1 (#6894)

* update sampling

* fix

* fix

* fix mtp

* fix test
This commit is contained in:
sunxin
2026-03-19 16:43:10 +08:00
committed by GitHub
parent 2b84a4276e
commit 33e01f22a8
14 changed files with 165 additions and 64 deletions
+47 -5
View File
@@ -23,18 +23,60 @@ Sampling strategies are used to determine how to select the next token from the
During deployment, you can choose the sampling algorithm by setting the environment variable `FD_SAMPLING_CLASS`. Available values are `base`, `base_non_truncated`, `air`, or `rejection`.
**Algorithms Supporting Only Top-p Sampling**
* `base` (default): Directly normalizes using the `top_p` value, favoring tokens with greater probabilities.
* `base_non_truncated`: Strictly follows the Top-p sampling logic, first selecting the smallest set that reaches the cumulative probability of `top_p`, then normalizing these selected elements.
* `air`: This algorithm is inspired by [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) and supports Top-p sampling.
**Algorithms Supporting Top-p and Top-k_Top-p Sampling**
* `rejection`: This algorithm is inspired by [flashinfer](https://github.com/flashinfer-ai/flashinfer) and allows flexible settings for `top_k` and `top_p` parameters for Top-p or Top-k_Top-p sampling.
## Configuration Method
### Greedy Sampling
1. During deployment, set the environment variable to select the sampling algorithm, default is base:
```bash
export FD_SAMPLING_CLASS=rejection # base, base_non_truncated, or air
```
2. When sending a request, specify the following parameters:
* Example request with curl:
```bash
curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "How old are you"}
],
"top_k": 1
}'
# or "top_p": 0.0
```
* Example request with Python:
```python
import openai
host = "0.0.0.0"
port = "8170"
client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null")
response = client.chat.completions.create(
model="null",
messages=[
{"role": "system", "content": "I'm a helpful AI assistant."},
],
stream=True,
top_k=1 # or "top_p": 0.0
)
for chunk in response:
if chunk.choices[0].delta:
print(chunk.choices[0].delta.content, end='')
print('\n')
```
### Top-p Sampling
1. During deployment, set the environment variable to select the sampling algorithm, default is base:
+48 -5
View File
@@ -21,18 +21,61 @@
在部署时,可以通过设置环境变量 `FD_SAMPLING_CLASS` 来选择采样算法。可选择的值有 `base`, `base_non_truncated`, `air``rejection`
**仅支持 Top-p Sampling 的算法**
* `base`(default):直接使用 `top_p` 的值进行归一化,倾向于采样概率更大的token。
* `base_non_truncated`:严格按照 Top-p 采样的逻辑执行,首先选择使累积概率达到 `top_p` 的最小集合,然后对这些选择的元素进行归一化。
* `air`:该算法参考 [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM)的实现,支持 Top-p 采样。
**支持 Top-p 和 Top-k_top-p 采样的算法**
* `rejection`:该算法参考 [flashinfer](https://github.com/flashinfer-ai/flashinfer) 的实现,支持灵活设置 `top_k``top_p` 参数进行 Top-p 或 Top-k_top-p 采样。
## 配置方式
### Greedy 采样
1. 在部署时,设置环境变量以选择采样算法,默认为base:
```bash
export FD_SAMPLING_CLASS=rejection # base, base_non_truncated, or air
```
2. 在发送请求时,指定top_p参数:
* 使用 curl 命令发送用户请求示例如下:
```bash
curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "How old are you"}
],
"top_k": 1
}'
# 或 "top_p": 0.0
```
* 使用 python 脚本发送用户请求示例如下:
```python
import openai
host = "0.0.0.0"
port = "8170"
client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null")
response = client.chat.completions.create(
model="null",
messages=[
{"role": "system", "content": "I'm a helpful AI assistant."},
{"role": "user", "content": "把李白的静夜思改写为现代诗"},
],
stream=True,
top_k=1 # 或 "top_p": 0.0
)
for chunk in response:
if chunk.choices[0].delta:
print(chunk.choices[0].delta.content, end='')
print('\n')
```
### Top-p 采样
1. 在部署时,设置环境变量以选择采样算法,默认为base:
+2
View File
@@ -148,6 +148,7 @@ class Ernie4_5Processor(BaseDataProcessor):
request.set("top_k", 1)
if request.get("top_p") < _SAMPLING_EPS:
request.set("top_p", _SAMPLING_EPS)
request.set("top_k", 1)
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
parts = request.request_id.split("_")
@@ -237,6 +238,7 @@ class Ernie4_5Processor(BaseDataProcessor):
request["top_k"] = 1
if request.get("top_p") < _SAMPLING_EPS:
request["top_p"] = _SAMPLING_EPS
request["top_k"] = 1
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
@@ -291,6 +291,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
request["enable_thinking"] = model_status == "think_start"
if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS:
request["top_p"] = _SAMPLING_EPS
request["top_k"] = 1
if request.get("response_max_tokens") is not None and request.get("enable_thinking") is False:
request["max_tokens"] = min(request["response_max_tokens"], request["max_tokens"])
@@ -256,6 +256,7 @@ class PaddleOCRVLProcessor(TextProcessor):
if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS:
request["top_p"] = _SAMPLING_EPS
request["top_k"] = 1
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
+2
View File
@@ -382,6 +382,7 @@ class DataProcessor(BaseDataProcessor):
request.set("top_k", 1)
if request.get("top_p") < _SAMPLING_EPS:
request.set("top_p", _SAMPLING_EPS)
request.set("top_k", 1)
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
parts = request.request_id.split("_")
@@ -482,6 +483,7 @@ class DataProcessor(BaseDataProcessor):
request["top_k"] = 1
if request.get("top_p") < _SAMPLING_EPS:
request["top_p"] = _SAMPLING_EPS
request["top_k"] = 1
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
parts = request["request_id"].split("_")
@@ -156,6 +156,7 @@ class Ernie4_5Processor(BaseDataProcessor):
request.set("top_k", 1)
if request.get("top_p") < _SAMPLING_EPS:
request.set("top_p", _SAMPLING_EPS)
request.set("top_k", 1)
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
parts = request.request_id.split("_")
@@ -251,6 +252,7 @@ class Ernie4_5Processor(BaseDataProcessor):
request.sampling_params.top_k = 1
if request.sampling_params.top_p < _SAMPLING_EPS:
request.sampling_params.top_p = _SAMPLING_EPS
request.sampling_params.top_k = 1
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
@@ -292,6 +292,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
request.enable_thinking = model_status == "think_start"
if request.sampling_params.top_p is not None and request.sampling_params.top_p < _SAMPLING_EPS:
request.sampling_params.top_p = _SAMPLING_EPS
request.sampling_params.top_k = 1
if request.sampling_params.response_max_tokens is not None and request.enable_thinking is False:
request.sampling_params.max_tokens = min(
request.sampling_params.response_max_tokens, request.sampling_params.max_tokens
@@ -256,6 +256,7 @@ class PaddleOCRVLProcessor(TextProcessor):
if request.sampling_params.top_p is not None and request.sampling_params.top_p < _SAMPLING_EPS:
request.sampling_params.top_p = _SAMPLING_EPS
request.sampling_params.top_k = 1
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
+2
View File
@@ -368,6 +368,7 @@ class DataProcessor(BaseDataProcessor):
request.set("top_k", 1)
if request.get("top_p") < _SAMPLING_EPS:
request.set("top_p", _SAMPLING_EPS)
request.set("top_k", 1)
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
parts = request.request_id.split("_")
@@ -475,6 +476,7 @@ class DataProcessor(BaseDataProcessor):
request.sampling_params.top_k = 1
if request.sampling_params.top_p < _SAMPLING_EPS:
request.sampling_params.top_p = _SAMPLING_EPS
request.sampling_params.top_k = 1
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
parts = request.request_id.split("_")
@@ -24,6 +24,8 @@ from fastdeploy.platforms import current_platform
if current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import top_p_sampling as gcu_top_p_sampling
from paddleformers.utils.log import logger
_DETERMINISTIC_RNG_SEED = 42
@@ -50,7 +52,6 @@ def top_k_top_p_sampling(
used to specify the top_p corresponding to each query.
top_k(Tensor|None, optional): A 1-D Tensor with type int64,
used to specify the top_k corresponding to each query.
Only used when FD_SAMPLING_CLASS is `rejection`.
top_k_list(list|None, optional): CPU-side mirror of top_k as a Python list,
used for fast host-side checks (e.g. all-greedy detection) without GPU sync.
threshold(Tensor|None, optional): A 1-D Tensor with type float32, float16 and bfloat16,
@@ -59,7 +60,6 @@ def top_k_top_p_sampling(
used to specify the random seed for each query.
seed(int, optional): the random seed. Default is -1,
k(int): the number of top_k scores/ids to be returned. Default is 0.
Only used when FD_SAMPLING_CLASS is `air`.
mode(str): The mode to choose sampling strategy. If the mode is `truncated`, sampling will truncate the probability at top_p_value.
If the mode is `non-truncated`, it will not be truncated. Default is `truncated`.
Only used when FD_SAMPLING_CLASS is `air` or `base`.
@@ -80,45 +80,26 @@ def top_k_top_p_sampling(
if envs.FD_DETERMINISTIC_MODE:
_reset_cuda_generator_for_determinism()
# Greedy decoding fast-path: top_k=1 is equivalent to argmax.
# In non-rejection sampling modes, top_k is ignored by the backend,
# so we must handle it explicitly.
all_greedy = False
if top_k_list is not None:
all_greedy = all(k == 1 for k in top_k_list)
elif top_k is not None:
all_greedy = bool(paddle.all(top_k == 1))
if all_greedy:
ids = paddle.argmax(x, axis=-1, keepdim=True)
return None, ids
if top_p_class == "air":
_, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode)
elif top_p_class == "rejection":
if top_p_class == "rejection":
ids = rejection_top_p_sampling(x, top_p, top_k, top_k_list, seed, order)
_ = None
elif top_p_class == "base_non_truncated":
if topp_seed is not None:
topp_seed_device = paddle.empty(shape=topp_seed.shape, dtype=topp_seed.dtype)
topp_seed_device.copy_(topp_seed, False)
_, ids = paddle.tensor.top_p_sampling(
x,
top_p,
threshold=threshold,
topp_seed=topp_seed_device,
seed=seed,
k=k,
mode="non-truncated",
)
else:
if current_platform.is_gcu():
_, ids = gcu_top_p_sampling(x, top_p)
elif current_platform.is_dcu():
from fastdeploy.model_executor.layers.backends import native_top_p_sampling
if top_k_list and any(x > 0 for x in top_k_list):
try:
if current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import (
top_k_renorm_probs,
)
else:
from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs
x = top_k_renorm_probs(x, top_k)
except ImportError:
logger.warning("top_k sampling is not supported on current platform, skipping top_k filtering.")
_, ids = native_top_p_sampling(x, top_p)
else:
if top_p_class == "air":
_, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode)
elif top_p_class == "base_non_truncated":
if topp_seed is not None:
topp_seed_device = paddle.empty(shape=topp_seed.shape, dtype=topp_seed.dtype)
topp_seed_device.copy_(topp_seed, False)
@@ -129,20 +110,30 @@ def top_k_top_p_sampling(
topp_seed=topp_seed_device,
seed=seed,
k=k,
mode="truncated",
mode="non-truncated",
)
# Mixed batch: override top_k=1 rows with argmax.
# Shape guard: in overlap/speculative paths, x may be padded (e.g. [8,V])
# while top_k remains per-request (e.g. [2,1]). Skip when shapes disagree.
if not all_greedy and top_k is not None and top_k.shape[0] == ids.shape[0]:
has_greedy = (top_k_list is not None and any(k == 1 for k in top_k_list)) or (
top_k_list is None and bool(paddle.any(top_k == 1))
)
if has_greedy:
argmax_ids = paddle.argmax(x, axis=-1, keepdim=True)
greedy_mask = top_k == 1
ids = paddle.where(greedy_mask, argmax_ids, ids)
else:
if current_platform.is_gcu():
_, ids = gcu_top_p_sampling(x, top_p)
elif current_platform.is_dcu():
from fastdeploy.model_executor.layers.backends import (
native_top_p_sampling,
)
_, ids = native_top_p_sampling(x, top_p)
else:
if topp_seed is not None:
topp_seed_device = paddle.empty(shape=topp_seed.shape, dtype=topp_seed.dtype)
topp_seed_device.copy_(topp_seed, False)
_, ids = paddle.tensor.top_p_sampling(
x,
top_p,
threshold=threshold,
topp_seed=topp_seed_device,
seed=seed,
k=k,
mode="truncated",
)
return _, ids
@@ -820,7 +820,9 @@ class SpeculativeSampler(nn.Layer):
token_num_output_cpu,
increment_value,
)
_, target_tokens = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)
_, target_tokens = top_k_top_p_sampling(
probs, top_p=top_p, top_k=top_k, top_k_list=sampling_metadata.top_k_list, topp_seed=topp_seed
)
elif self.verify_strategy == VerifyStrategy.GREEDY:
# GREEDY: deterministic argmax in target_tokens, no candidates needed
target_tokens = paddle.argmax(probs, axis=-1)
@@ -1071,7 +1073,9 @@ class SpeculativeSampler(nn.Layer):
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
)
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)
_, sampled_token_ids = top_k_top_p_sampling(
probs, top_p=top_p, top_k=top_k, top_k_list=sampling_metadata.top_k_list, topp_seed=topp_seed
)
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
probs,
+11 -2
View File
@@ -448,6 +448,15 @@ def test_top_k_top_p_sampling_resets_cuda_rng_in_deterministic_mode(monkeypatch)
mock_reset.assert_called_once()
def mixed_mock(probs, *a, **k):
ids = paddle.argmax(probs, axis=-1, keepdim=True)
# 1 left non_zero token after renorm → greedy, or → renturn 99
non_zero_count = (probs > 0).sum(axis=-1, keepdim=True)
sampled = non_zero_count > 1
ids = paddle.where(sampled, paddle.to_tensor([[99]], dtype="int64"), ids)
return None, ids
def test_top_k_1_returns_argmax(monkeypatch):
"""top_k=1 should produce argmax results regardless of FD_DETERMINISTIC_MODE."""
import sys
@@ -480,7 +489,7 @@ def test_top_k_1_returns_argmax(monkeypatch):
monkeypatch.setattr(
sampling_mod.paddle.tensor,
"top_p_sampling",
lambda *a, **k: (None, paddle.to_tensor([[99], [99]], dtype="int64")),
mixed_mock,
)
_, ids_mixed = sampling_mod.top_k_top_p_sampling(probs, top_p, top_k_mixed, top_k_list_mixed)
@@ -519,7 +528,7 @@ def test_top_k_1_returns_argmax_without_deterministic_mode(monkeypatch):
monkeypatch.setattr(
sampling_mod.paddle.tensor,
"top_p_sampling",
lambda *a, **k: (None, paddle.to_tensor([[99], [99]], dtype="int64")),
mixed_mock,
)
_, ids_mixed = sampling_mod.top_k_top_p_sampling(probs, top_p, top_k_mixed, top_k_list_mixed)
+1 -1
View File
@@ -146,7 +146,7 @@ def test_model_against_baseline(
# Get baseline suffix from config
model_config = hugging_face_model_param_map.get(model_name_or_path, {})
baseline_suffix = model_config.get("baseline_suffix", "tp2-dev-0317")
baseline_suffix = model_config.get("baseline_suffix", "tp2-dev-0311")
baseline_filename = f"{model_name_or_path}-{baseline_suffix}"
if base_path: