[BugFix][DataProcessor] Force top_k=1 for greedy decoding when temperature=0 (#6748)

* [BugFix] Force top_k=1 for greedy decoding when temperature=0

When temperature is set to 0 (greedy decoding), only setting temperature
to a small epsilon is insufficient — the sampling kernel may still pick
non-top-1 tokens. Explicitly set top_k=1 in all processors to guarantee
argmax behavior.

Additionally, add argmax fast-path in top_k_top_p_sampling() under
FD_DETERMINISTIC_MODE to handle non-rejection sampling backends that
ignore top_k parameter.

* Extract greedy decoding from FD_DETERMINISTIC_MODE guard

top_k=1 → argmax is a correctness optimization, not deterministic-specific.
Remove the FD_DETERMINISTIC_MODE guard so all-greedy fast-path and
mixed-batch override work unconditionally.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Update test_torch_model.py

---------

Co-authored-by: gongweibao <gognweibao@baidu.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
gongweibao
2026-03-18 17:36:43 +08:00
committed by GitHub
parent 9b117aafac
commit fb6c56dfd5
10 changed files with 143 additions and 7 deletions
+4 -2
View File
@@ -143,8 +143,9 @@ class Ernie4_5Processor(BaseDataProcessor):
else:
request.set("max_tokens", min(max_tokens, request.get("max_tokens")))
if request.get("temperature") < _SAMPLING_EPS:
# zero temperature is equivalent to greedy sampling
# zero temperature means greedy decoding: set top_k=1 to force argmax
request.set("temperature", 1)
request.set("top_k", 1)
if request.get("top_p") < _SAMPLING_EPS:
request.set("top_p", _SAMPLING_EPS)
if self.reasoning_parser:
@@ -231,8 +232,9 @@ class Ernie4_5Processor(BaseDataProcessor):
else:
request["max_tokens"] = min(max_tokens, request["max_tokens"])
if request.get("temperature") < _SAMPLING_EPS:
# zero temperature is equivalent to greedy sampling
# zero temperature means greedy decoding: set top_k=1 to force argmax
request["temperature"] = 1
request["top_k"] = 1
if request.get("top_p") < _SAMPLING_EPS:
request["top_p"] = _SAMPLING_EPS