mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Optimize draft logprob (#5842)
* optimize draft logprob * fix ut
This commit is contained in:
@@ -122,6 +122,7 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase):
|
||||
cfg.speculative_config.method = "mtp" if speculative_decoding else None
|
||||
cfg.speculative_config.num_speculative_tokens = 1
|
||||
cfg.model_config.enable_logprob = use_logprobs
|
||||
cfg.speculative_config.enable_draft_logprob = True
|
||||
|
||||
processor = TokenProcessor.__new__(TokenProcessor)
|
||||
processor.cfg = cfg
|
||||
@@ -139,6 +140,7 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase):
|
||||
processor.number_of_output_tokens = 0
|
||||
processor.prefill_result_status = {}
|
||||
processor.use_logprobs = use_logprobs
|
||||
processor.enable_draft_logprob = cfg.speculative_config.enable_draft_logprob
|
||||
processor.num_draft_tokens = 0
|
||||
processor.num_accepted_tokens = 0
|
||||
processor.num_emitted_tokens = 0
|
||||
|
||||
@@ -53,6 +53,7 @@ class _DummyCfg:
|
||||
self.speculative_config = types.SimpleNamespace(
|
||||
method=speculative_method,
|
||||
num_speculative_tokens=2,
|
||||
enable_draft_logprob=True,
|
||||
)
|
||||
self.model_config = types.SimpleNamespace(enable_logprob=enable_logprob)
|
||||
self.scheduler_config = types.SimpleNamespace(name="default", splitwise_role="decode")
|
||||
|
||||
Reference in New Issue
Block a user