[Speculative Decoding] Support mtp super ultra overlap in pd-split mode with insert_task overlap (#7323)

* support mtp overlap in pd-split mode with insert_task overlap
This commit is contained in:
freeliuzc
2026-04-13 19:41:17 +08:00
committed by GitHub
parent 5ddd1af756
commit 31e2a8bbad
6 changed files with 351 additions and 122 deletions
+158 -1
View File
@@ -14,7 +14,7 @@
import unittest
from dataclasses import dataclass
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import numpy as np
import paddle
@@ -591,5 +591,162 @@ class TestSleepWakeupBehavior(unittest.TestCase):
mock_print_memory.assert_not_called()
def _sync_async_set_value(tgt, src):
"""Synchronous stand-in for async_set_value used in tests (no CUDA required).
Writes to real numpy arrays; silently skips Mock objects (untracked share_inputs
fields whose values we do not assert on).
"""
from unittest.mock import MagicMock
import numpy as np
if isinstance(tgt, MagicMock):
return # untracked field — nothing to write
if isinstance(src, (int, float, bool)):
tgt[:] = src
elif isinstance(src, (list, np.ndarray)):
tgt[:] = np.array(src).reshape(tgt.shape)
elif hasattr(src, "numpy"):
tgt[:] = src.numpy()
else:
tgt[:] = src
class TestInsertTasksV1SplitwiseSuffix(unittest.TestCase):
"""Tests for insert_tasks_v1 splitwise_role=\'decode\' + SpecMethod.SUFFIX branch."""
def _make_share_inputs(self, bsz=4, max_draft=6):
"""Mock-backed share_inputs; only keys we assert on hold real numpy arrays."""
import numpy as np
# Keys whose values we want to inspect after the call
tracked = {
"seq_lens_encoder": np.zeros((bsz, 1), dtype=np.int32),
"draft_tokens": np.zeros((bsz, max_draft), dtype=np.int64),
"seq_lens_this_time_buffer": np.zeros((bsz, 1), dtype=np.int32),
"req_ids": [""] * bsz,
"preempted_idx": np.zeros((bsz, 1), dtype=np.int32),
"num_running_requests": 0,
"running_requests_ids": [],
}
class _SI:
def get_index_by_batch_id(self, batch_id):
return batch_id
def __getitem__(self, key):
# Return real array for tracked keys; Mock for everything else
if key in tracked:
return tracked[key]
return MagicMock()
def __setitem__(self, key, value):
tracked[key] = value
return _SI()
def _make_runner(self, bsz=4, num_spec_tokens=3):
from unittest.mock import Mock
from fastdeploy.spec_decode import SpecMethod
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
runner = GPUModelRunner.__new__(GPUModelRunner)
runner.enable_mm = False
runner.is_pooling_model = False
runner.speculative_decoding = True
runner.spec_method = SpecMethod.SUFFIX
runner.speculative_config = Mock(num_speculative_tokens=num_spec_tokens)
runner.deterministic_logger = None
runner.routing_replay_manager = Mock()
runner.prompt_logprobs_reqs = {}
runner.in_progress_prompt_logprobs = {}
runner.forward_batch_reqs_list = [None] * bsz
runner._cached_launch_token_num = -1
runner._cached_real_bsz = 0
runner.exist_prefill_flag = True
runner.proposer = Mock()
runner.sampler = Mock()
runner.model_config = Mock(eos_tokens_lens=1)
runner.share_inputs = self._make_share_inputs(bsz=bsz, max_draft=num_spec_tokens + 2)
fd_config = Mock()
fd_config.scheduler_config.splitwise_role = "decode"
fd_config.routing_replay_config.enable_routing_replay = False
runner.fd_config = fd_config
runner.scheduler_config = fd_config.scheduler_config
return runner
def _make_prefill_request(self, idx, draft_token_ids):
from unittest.mock import Mock
from fastdeploy.engine.request import RequestType
req = Mock()
req.task_type = Mock(value=RequestType.PREFILL.value)
req.idx = idx
req.request_id = f"req_{idx}"
req.prompt_token_ids = [10, 20, 30]
req.output_token_ids = [99]
req.draft_token_ids = draft_token_ids
req.pooling_params = None
req.guided_json = None
req.guided_regex = None
req.structural_tag = None
req.guided_grammar = None
req.prefill_start_index = 0
req.prefill_end_index = 3
req.multimodal_inputs = None
req.get = Mock(return_value=None)
req.eos_token_ids = [2]
req.block_tables = []
return req
@patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value)
def test_draft_tokens_and_seq_lens_written(self, _mock_asv):
"""draft_tokens[0:2] and seq_lens_this_time_buffer=2 are written."""
runner = self._make_runner(num_spec_tokens=3)
req = self._make_prefill_request(idx=0, draft_token_ids=[101, 202, 303])
runner.insert_tasks_v1([req], num_running_requests=1)
self.assertEqual(runner.share_inputs["draft_tokens"][0, 0], 101)
self.assertEqual(runner.share_inputs["draft_tokens"][0, 1], 202)
self.assertEqual(runner.share_inputs["seq_lens_this_time_buffer"][0, 0], 2)
@patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value)
def test_exist_prefill_flag_cleared(self, _mock_asv):
runner = self._make_runner()
req = self._make_prefill_request(idx=0, draft_token_ids=[1, 2])
runner.insert_tasks_v1([req], num_running_requests=1)
self.assertFalse(runner.exist_prefill_flag)
@patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value)
def test_cached_launch_token_num_incremented(self, _mock_asv):
runner = self._make_runner(num_spec_tokens=3)
runner._cached_launch_token_num = 10
runner._cached_real_bsz = 2
req = self._make_prefill_request(idx=0, draft_token_ids=[1, 2])
runner.insert_tasks_v1([req], num_running_requests=1)
# token_num_one_step = num_speculative_tokens + 1 = 4
self.assertEqual(runner._cached_launch_token_num, 14)
self.assertEqual(runner._cached_real_bsz, 3)
@patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value)
def test_cached_launch_token_num_skipped_when_negative_one(self, _mock_asv):
runner = self._make_runner(num_spec_tokens=3)
runner._cached_launch_token_num = -1
req = self._make_prefill_request(idx=0, draft_token_ids=[1, 2])
runner.insert_tasks_v1([req], num_running_requests=1)
self.assertEqual(runner._cached_launch_token_num, -1)
@patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value)
def test_raises_when_fewer_than_two_draft_tokens(self, _mock_asv):
runner = self._make_runner()
req = self._make_prefill_request(idx=0, draft_token_ids=[42])
with self.assertRaises(ValueError):
runner.insert_tasks_v1([req], num_running_requests=1)
if __name__ == "__main__":
unittest.main()