mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix][DataProcessor] Force top_k=1 for greedy decoding when temperature=0 (#6748)
* [BugFix] Force top_k=1 for greedy decoding when temperature=0 When temperature is set to 0 (greedy decoding), only setting temperature to a small epsilon is insufficient — the sampling kernel may still pick non-top-1 tokens. Explicitly set top_k=1 in all processors to guarantee argmax behavior. Additionally, add argmax fast-path in top_k_top_p_sampling() under FD_DETERMINISTIC_MODE to handle non-rejection sampling backends that ignore top_k parameter. * Extract greedy decoding from FD_DETERMINISTIC_MODE guard top_k=1 → argmax is a correctness optimization, not deterministic-specific. Remove the FD_DETERMINISTIC_MODE guard so all-greedy fast-path and mixed-batch override work unconditionally. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Update test_torch_model.py --------- Co-authored-by: gongweibao <gognweibao@baidu.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -143,8 +143,9 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
else:
|
||||
request.set("max_tokens", min(max_tokens, request.get("max_tokens")))
|
||||
if request.get("temperature") < _SAMPLING_EPS:
|
||||
# zero temperature is equivalent to greedy sampling
|
||||
# zero temperature means greedy decoding: set top_k=1 to force argmax
|
||||
request.set("temperature", 1)
|
||||
request.set("top_k", 1)
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request.set("top_p", _SAMPLING_EPS)
|
||||
if self.reasoning_parser:
|
||||
@@ -231,8 +232,9 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
else:
|
||||
request["max_tokens"] = min(max_tokens, request["max_tokens"])
|
||||
if request.get("temperature") < _SAMPLING_EPS:
|
||||
# zero temperature is equivalent to greedy sampling
|
||||
# zero temperature means greedy decoding: set top_k=1 to force argmax
|
||||
request["temperature"] = 1
|
||||
request["top_k"] = 1
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
|
||||
|
||||
Reference in New Issue
Block a user