[BugFix][DataProcessor] Force top_k=1 for greedy decoding when temperature=0 (#6748)

* [BugFix] Force top_k=1 for greedy decoding when temperature=0

When temperature is set to 0 (greedy decoding), only setting temperature
to a small epsilon is insufficient — the sampling kernel may still pick
non-top-1 tokens. Explicitly set top_k=1 in all processors to guarantee
argmax behavior.

Additionally, add argmax fast-path in top_k_top_p_sampling() under
FD_DETERMINISTIC_MODE to handle non-rejection sampling backends that
ignore top_k parameter.

* Extract greedy decoding from FD_DETERMINISTIC_MODE guard

top_k=1 → argmax is a correctness optimization, not deterministic-specific.
Remove the FD_DETERMINISTIC_MODE guard so all-greedy fast-path and
mixed-batch override work unconditionally.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Update test_torch_model.py

---------

Co-authored-by: gongweibao <gognweibao@baidu.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
gongweibao
2026-03-18 17:36:43 +08:00
committed by GitHub
parent 9b117aafac
commit fb6c56dfd5
10 changed files with 143 additions and 7 deletions
@@ -51,6 +51,8 @@ def top_k_top_p_sampling(
top_k(Tensor|None, optional): A 1-D Tensor with type int64,
used to specify the top_k corresponding to each query.
Only used when FD_SAMPLING_CLASS is `rejection`.
top_k_list(list|None, optional): CPU-side mirror of top_k as a Python list,
used for fast host-side checks (e.g. all-greedy detection) without GPU sync.
threshold(Tensor|None, optional): A 1-D Tensor with type float32, float16 and bfloat16,
used to avoid sampling low score tokens.
topp_seed(Tensor|None, optional): A 1-D Tensor with type int64,
@@ -78,6 +80,19 @@ def top_k_top_p_sampling(
if envs.FD_DETERMINISTIC_MODE:
_reset_cuda_generator_for_determinism()
# Greedy decoding fast-path: top_k=1 is equivalent to argmax.
# In non-rejection sampling modes, top_k is ignored by the backend,
# so we must handle it explicitly.
all_greedy = False
if top_k_list is not None:
all_greedy = all(k == 1 for k in top_k_list)
elif top_k is not None:
all_greedy = bool(paddle.all(top_k == 1))
if all_greedy:
ids = paddle.argmax(x, axis=-1, keepdim=True)
return None, ids
if top_p_class == "air":
_, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode)
elif top_p_class == "rejection":
@@ -116,6 +131,18 @@ def top_k_top_p_sampling(
k=k,
mode="truncated",
)
# Mixed batch: override top_k=1 rows with argmax.
# Shape guard: in overlap/speculative paths, x may be padded (e.g. [8,V])
# while top_k remains per-request (e.g. [2,1]). Skip when shapes disagree.
if not all_greedy and top_k is not None and top_k.shape[0] == ids.shape[0]:
has_greedy = (top_k_list is not None and any(k == 1 for k in top_k_list)) or (
top_k_list is None and bool(paddle.any(top_k == 1))
)
if has_greedy:
argmax_ids = paddle.argmax(x, axis=-1, keepdim=True)
greedy_mask = top_k == 1
ids = paddle.where(greedy_mask, argmax_ids, ids)
return _, ids