mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Support mtp super ultra overlap in pd-split mode with insert_task overlap (#7323)
* support mtp overlap in pd-split mode with insert_task overlap
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -591,5 +591,162 @@ class TestSleepWakeupBehavior(unittest.TestCase):
|
||||
mock_print_memory.assert_not_called()
|
||||
|
||||
|
||||
def _sync_async_set_value(tgt, src):
|
||||
"""Synchronous stand-in for async_set_value used in tests (no CUDA required).
|
||||
|
||||
Writes to real numpy arrays; silently skips Mock objects (untracked share_inputs
|
||||
fields whose values we do not assert on).
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
|
||||
if isinstance(tgt, MagicMock):
|
||||
return # untracked field — nothing to write
|
||||
if isinstance(src, (int, float, bool)):
|
||||
tgt[:] = src
|
||||
elif isinstance(src, (list, np.ndarray)):
|
||||
tgt[:] = np.array(src).reshape(tgt.shape)
|
||||
elif hasattr(src, "numpy"):
|
||||
tgt[:] = src.numpy()
|
||||
else:
|
||||
tgt[:] = src
|
||||
|
||||
|
||||
class TestInsertTasksV1SplitwiseSuffix(unittest.TestCase):
|
||||
"""Tests for insert_tasks_v1 splitwise_role=\'decode\' + SpecMethod.SUFFIX branch."""
|
||||
|
||||
def _make_share_inputs(self, bsz=4, max_draft=6):
|
||||
"""Mock-backed share_inputs; only keys we assert on hold real numpy arrays."""
|
||||
import numpy as np
|
||||
|
||||
# Keys whose values we want to inspect after the call
|
||||
tracked = {
|
||||
"seq_lens_encoder": np.zeros((bsz, 1), dtype=np.int32),
|
||||
"draft_tokens": np.zeros((bsz, max_draft), dtype=np.int64),
|
||||
"seq_lens_this_time_buffer": np.zeros((bsz, 1), dtype=np.int32),
|
||||
"req_ids": [""] * bsz,
|
||||
"preempted_idx": np.zeros((bsz, 1), dtype=np.int32),
|
||||
"num_running_requests": 0,
|
||||
"running_requests_ids": [],
|
||||
}
|
||||
|
||||
class _SI:
|
||||
def get_index_by_batch_id(self, batch_id):
|
||||
return batch_id
|
||||
|
||||
def __getitem__(self, key):
|
||||
# Return real array for tracked keys; Mock for everything else
|
||||
if key in tracked:
|
||||
return tracked[key]
|
||||
return MagicMock()
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
tracked[key] = value
|
||||
|
||||
return _SI()
|
||||
|
||||
def _make_runner(self, bsz=4, num_spec_tokens=3):
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastdeploy.spec_decode import SpecMethod
|
||||
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
runner = GPUModelRunner.__new__(GPUModelRunner)
|
||||
runner.enable_mm = False
|
||||
runner.is_pooling_model = False
|
||||
runner.speculative_decoding = True
|
||||
runner.spec_method = SpecMethod.SUFFIX
|
||||
runner.speculative_config = Mock(num_speculative_tokens=num_spec_tokens)
|
||||
runner.deterministic_logger = None
|
||||
runner.routing_replay_manager = Mock()
|
||||
runner.prompt_logprobs_reqs = {}
|
||||
runner.in_progress_prompt_logprobs = {}
|
||||
runner.forward_batch_reqs_list = [None] * bsz
|
||||
runner._cached_launch_token_num = -1
|
||||
runner._cached_real_bsz = 0
|
||||
runner.exist_prefill_flag = True
|
||||
runner.proposer = Mock()
|
||||
runner.sampler = Mock()
|
||||
runner.model_config = Mock(eos_tokens_lens=1)
|
||||
runner.share_inputs = self._make_share_inputs(bsz=bsz, max_draft=num_spec_tokens + 2)
|
||||
|
||||
fd_config = Mock()
|
||||
fd_config.scheduler_config.splitwise_role = "decode"
|
||||
fd_config.routing_replay_config.enable_routing_replay = False
|
||||
runner.fd_config = fd_config
|
||||
runner.scheduler_config = fd_config.scheduler_config
|
||||
return runner
|
||||
|
||||
def _make_prefill_request(self, idx, draft_token_ids):
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastdeploy.engine.request import RequestType
|
||||
|
||||
req = Mock()
|
||||
req.task_type = Mock(value=RequestType.PREFILL.value)
|
||||
req.idx = idx
|
||||
req.request_id = f"req_{idx}"
|
||||
req.prompt_token_ids = [10, 20, 30]
|
||||
req.output_token_ids = [99]
|
||||
req.draft_token_ids = draft_token_ids
|
||||
req.pooling_params = None
|
||||
req.guided_json = None
|
||||
req.guided_regex = None
|
||||
req.structural_tag = None
|
||||
req.guided_grammar = None
|
||||
req.prefill_start_index = 0
|
||||
req.prefill_end_index = 3
|
||||
req.multimodal_inputs = None
|
||||
req.get = Mock(return_value=None)
|
||||
req.eos_token_ids = [2]
|
||||
req.block_tables = []
|
||||
return req
|
||||
|
||||
@patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value)
|
||||
def test_draft_tokens_and_seq_lens_written(self, _mock_asv):
|
||||
"""draft_tokens[0:2] and seq_lens_this_time_buffer=2 are written."""
|
||||
runner = self._make_runner(num_spec_tokens=3)
|
||||
req = self._make_prefill_request(idx=0, draft_token_ids=[101, 202, 303])
|
||||
runner.insert_tasks_v1([req], num_running_requests=1)
|
||||
|
||||
self.assertEqual(runner.share_inputs["draft_tokens"][0, 0], 101)
|
||||
self.assertEqual(runner.share_inputs["draft_tokens"][0, 1], 202)
|
||||
self.assertEqual(runner.share_inputs["seq_lens_this_time_buffer"][0, 0], 2)
|
||||
|
||||
@patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value)
|
||||
def test_exist_prefill_flag_cleared(self, _mock_asv):
|
||||
runner = self._make_runner()
|
||||
req = self._make_prefill_request(idx=0, draft_token_ids=[1, 2])
|
||||
runner.insert_tasks_v1([req], num_running_requests=1)
|
||||
self.assertFalse(runner.exist_prefill_flag)
|
||||
|
||||
@patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value)
|
||||
def test_cached_launch_token_num_incremented(self, _mock_asv):
|
||||
runner = self._make_runner(num_spec_tokens=3)
|
||||
runner._cached_launch_token_num = 10
|
||||
runner._cached_real_bsz = 2
|
||||
req = self._make_prefill_request(idx=0, draft_token_ids=[1, 2])
|
||||
runner.insert_tasks_v1([req], num_running_requests=1)
|
||||
# token_num_one_step = num_speculative_tokens + 1 = 4
|
||||
self.assertEqual(runner._cached_launch_token_num, 14)
|
||||
self.assertEqual(runner._cached_real_bsz, 3)
|
||||
|
||||
@patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value)
|
||||
def test_cached_launch_token_num_skipped_when_negative_one(self, _mock_asv):
|
||||
runner = self._make_runner(num_spec_tokens=3)
|
||||
runner._cached_launch_token_num = -1
|
||||
req = self._make_prefill_request(idx=0, draft_token_ids=[1, 2])
|
||||
runner.insert_tasks_v1([req], num_running_requests=1)
|
||||
self.assertEqual(runner._cached_launch_token_num, -1)
|
||||
|
||||
@patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value)
|
||||
def test_raises_when_fewer_than_two_draft_tokens(self, _mock_asv):
|
||||
runner = self._make_runner()
|
||||
req = self._make_prefill_request(idx=0, draft_token_ids=[42])
|
||||
with self.assertRaises(ValueError):
|
||||
runner.insert_tasks_v1([req], num_running_requests=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user