mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[BugFix][DataProcessor] Force top_k=1 for greedy decoding when temperature=0 (#6748)
* [BugFix] Force top_k=1 for greedy decoding when temperature=0 When temperature is set to 0 (greedy decoding), only setting temperature to a small epsilon is insufficient — the sampling kernel may still pick non-top-1 tokens. Explicitly set top_k=1 in all processors to guarantee argmax behavior. Additionally, add argmax fast-path in top_k_top_p_sampling() under FD_DETERMINISTIC_MODE to handle non-rejection sampling backends that ignore top_k parameter. * Extract greedy decoding from FD_DETERMINISTIC_MODE guard top_k=1 → argmax is a correctness optimization, not deterministic-specific. Remove the FD_DETERMINISTIC_MODE guard so all-greedy fast-path and mixed-batch override work unconditionally. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Update test_torch_model.py --------- Co-authored-by: gongweibao <gognweibao@baidu.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -51,6 +51,8 @@ def top_k_top_p_sampling(
|
||||
top_k(Tensor|None, optional): A 1-D Tensor with type int64,
|
||||
used to specify the top_k corresponding to each query.
|
||||
Only used when FD_SAMPLING_CLASS is `rejection`.
|
||||
top_k_list(list|None, optional): CPU-side mirror of top_k as a Python list,
|
||||
used for fast host-side checks (e.g. all-greedy detection) without GPU sync.
|
||||
threshold(Tensor|None, optional): A 1-D Tensor with type float32, float16 and bfloat16,
|
||||
used to avoid sampling low score tokens.
|
||||
topp_seed(Tensor|None, optional): A 1-D Tensor with type int64,
|
||||
@@ -78,6 +80,19 @@ def top_k_top_p_sampling(
|
||||
if envs.FD_DETERMINISTIC_MODE:
|
||||
_reset_cuda_generator_for_determinism()
|
||||
|
||||
# Greedy decoding fast-path: top_k=1 is equivalent to argmax.
|
||||
# In non-rejection sampling modes, top_k is ignored by the backend,
|
||||
# so we must handle it explicitly.
|
||||
all_greedy = False
|
||||
if top_k_list is not None:
|
||||
all_greedy = all(k == 1 for k in top_k_list)
|
||||
elif top_k is not None:
|
||||
all_greedy = bool(paddle.all(top_k == 1))
|
||||
|
||||
if all_greedy:
|
||||
ids = paddle.argmax(x, axis=-1, keepdim=True)
|
||||
return None, ids
|
||||
|
||||
if top_p_class == "air":
|
||||
_, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode)
|
||||
elif top_p_class == "rejection":
|
||||
@@ -116,6 +131,18 @@ def top_k_top_p_sampling(
|
||||
k=k,
|
||||
mode="truncated",
|
||||
)
|
||||
# Mixed batch: override top_k=1 rows with argmax.
|
||||
# Shape guard: in overlap/speculative paths, x may be padded (e.g. [8,V])
|
||||
# while top_k remains per-request (e.g. [2,1]). Skip when shapes disagree.
|
||||
if not all_greedy and top_k is not None and top_k.shape[0] == ids.shape[0]:
|
||||
has_greedy = (top_k_list is not None and any(k == 1 for k in top_k_list)) or (
|
||||
top_k_list is None and bool(paddle.any(top_k == 1))
|
||||
)
|
||||
if has_greedy:
|
||||
argmax_ids = paddle.argmax(x, axis=-1, keepdim=True)
|
||||
greedy_mask = top_k == 1
|
||||
ids = paddle.where(greedy_mask, argmax_ids, ids)
|
||||
|
||||
return _, ids
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user