mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
test_abort (#6743)
This commit is contained in:
@@ -786,7 +786,6 @@ class EngineService:
|
||||
max_num_batched_tokens=self.cfg.scheduler_config.max_num_batched_tokens,
|
||||
batch=num_prefill_batch,
|
||||
)
|
||||
tasks = [task for task in tasks if task.request_id not in self.resource_manager.abort_req_ids_set]
|
||||
for task in tasks:
|
||||
task.metrics.engine_get_req_time = time.time()
|
||||
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
|
||||
@@ -851,7 +850,6 @@ class EngineService:
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
batch=num_prefill_batch,
|
||||
)
|
||||
tasks = [task for task in tasks if task.request_id not in self.resource_manager.abort_req_ids_set]
|
||||
for task in tasks:
|
||||
task.metrics.engine_get_req_time = time.time()
|
||||
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
|
||||
@@ -1178,19 +1176,8 @@ class EngineService:
|
||||
if status_value is not None and status_value == RequestStatus.ABORT.value:
|
||||
req_id = data["request_id"]
|
||||
self.llm_logger.info(f"Receive abort request, req_id: {req_id}")
|
||||
self.resource_manager.abort_req_ids_set.add(req_id)
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
if req_id in self.resource_manager.requests:
|
||||
req = self.resource_manager.requests[req_id]
|
||||
task = self.resource_manager._prepare_preempt_task(req)
|
||||
self.engine_worker_queue.put_tasks(([task], self.resource_manager.real_bsz))
|
||||
self.llm_logger.info(f"put abort task in engine worker queue, req_id: {req_id}")
|
||||
else:
|
||||
self.scheduler._recycle(req_id)
|
||||
self.llm_logger.info(
|
||||
f"req_id:{req_id} has not been allocated any resources, recycled it in scheduler"
|
||||
)
|
||||
self.resource_manager.abort_req_ids_set.remove(req_id)
|
||||
self.resource_manager.add_abort_req_ids(req_id)
|
||||
continue
|
||||
err_msg = None
|
||||
try:
|
||||
|
||||
@@ -61,6 +61,7 @@ class RequestType(Enum):
|
||||
DECODE = 1
|
||||
PREEMPTED = 2
|
||||
EXTEND = 3
|
||||
ABORT = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -87,6 +87,15 @@ class ScheduledExtendBlocksTask:
|
||||
task_type: RequestType = RequestType.EXTEND
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScheduledAbortTask:
|
||||
"""Task for allocating new blocks to skip."""
|
||||
|
||||
idx: int
|
||||
request_id: str
|
||||
task_type: RequestType = RequestType.ABORT
|
||||
|
||||
|
||||
class SignalConsumer:
|
||||
"""
|
||||
A class that consumes a signal value up to a specified limit.
|
||||
@@ -180,6 +189,8 @@ class ResourceManagerV1(ResourceManager):
|
||||
self.using_extend_tables_req_id = set()
|
||||
self.reuse_block_num_map = dict()
|
||||
self.abort_req_ids_set = set()
|
||||
self.waiting_abort_req_id_set = set()
|
||||
self.to_be_aborted_req_id_set = set()
|
||||
|
||||
# need block nums
|
||||
need_block_num_data = np.zeros([max_num_seqs], dtype=np.int32)
|
||||
@@ -246,6 +257,9 @@ class ResourceManagerV1(ResourceManager):
|
||||
def _prepare_preempt_task(self, request):
|
||||
return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id)
|
||||
|
||||
def _prepare_abort_task(self, request):
|
||||
return ScheduledAbortTask(idx=request.idx, request_id=request.request_id)
|
||||
|
||||
def reschedule_preempt_task(self, request_id, process_func=None):
|
||||
with self.lock:
|
||||
llm_logger.debug(f"reschedule {request_id} into waiting queue")
|
||||
@@ -259,6 +273,27 @@ class ResourceManagerV1(ResourceManager):
|
||||
self.waiting.appendleft(request)
|
||||
self.to_be_rescheduled_request_id_set.remove(request_id)
|
||||
|
||||
def recycle_abort_task(self, request_id):
|
||||
with self.lock:
|
||||
if request_id in self.to_be_aborted_req_id_set and request_id in self.requests:
|
||||
request = self.requests[request_id]
|
||||
self.tasks_list[request.idx] = None # 清空slot
|
||||
self.stop_flags[request.idx] = True # 设置停止标志
|
||||
del self.requests[request_id]
|
||||
del self.req_dict[request_id]
|
||||
self.to_be_aborted_req_id_set.remove(request_id)
|
||||
|
||||
def _trigger_abort(self, request_id, scheduled_reqs):
|
||||
if request_id in self.requests:
|
||||
abort_request = self.requests[request_id]
|
||||
abort_request.status = RequestStatus.PREEMPTED
|
||||
abort_request.num_computed_tokens = 0
|
||||
self._free_blocks(abort_request) # 释放KV cache blocks
|
||||
abort_request.cached_block_num = 0
|
||||
scheduled_reqs.append(self._prepare_abort_task(abort_request))
|
||||
self.to_be_aborted_req_id_set.add(request_id)
|
||||
self.waiting_abort_req_id_set.remove(request_id)
|
||||
|
||||
def _info_each_block(self):
|
||||
"""
|
||||
print each req block
|
||||
@@ -693,6 +728,13 @@ class ResourceManagerV1(ResourceManager):
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_abort_req_ids(self, req_ids):
|
||||
with self.lock:
|
||||
if isinstance(req_ids, list):
|
||||
self.waiting_abort_req_id_set.update(req_ids)
|
||||
else:
|
||||
self.waiting_abort_req_id_set.add(req_ids)
|
||||
|
||||
def cache_output_tokens(self, request):
|
||||
if (
|
||||
self.config.cache_config.enable_prefix_caching
|
||||
@@ -720,6 +762,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
preempted_reqs: list[Request] = []
|
||||
error_reqs: list[tuple[str, str]] = []
|
||||
token_budget = self.config.scheduler_config.max_num_batched_tokens
|
||||
need_abort_requests = [] # users trigger abortion
|
||||
|
||||
# First, schedule the RUNNING requests.
|
||||
req_index = 0
|
||||
@@ -739,6 +782,13 @@ class ResourceManagerV1(ResourceManager):
|
||||
continue
|
||||
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
|
||||
request.num_computed_tokens = request.num_total_tokens - 1
|
||||
|
||||
if request.request_id in self.waiting_abort_req_id_set:
|
||||
self._trigger_abort(request.request_id, scheduled_reqs)
|
||||
req_index += 1
|
||||
need_abort_requests.append(request)
|
||||
continue
|
||||
|
||||
if (
|
||||
self.allocated_slots(request) - request.num_total_tokens
|
||||
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
|
||||
@@ -876,6 +926,10 @@ class ResourceManagerV1(ResourceManager):
|
||||
)
|
||||
req_index += 1
|
||||
|
||||
# remove requests to be aborted from running list
|
||||
for request in need_abort_requests:
|
||||
self.running.remove(request)
|
||||
|
||||
# Second, schedule the WAITING requests.
|
||||
if not preempted_reqs:
|
||||
skip_requests: list[Request] = []
|
||||
|
||||
@@ -284,23 +284,13 @@ class TokenProcessor:
|
||||
task_id = task.request_id
|
||||
token_ids = stream_data.tokens # numpy.array
|
||||
if token_ids is not None and token_ids[-1] < 0:
|
||||
if task_id in self.resource_manager.abort_req_ids_set:
|
||||
if (
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER and token_ids[-1] == PREEMPTED_TOKEN_ID
|
||||
) or not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.")
|
||||
self.resource_manager.abort_req_ids_set.remove(task_id)
|
||||
self._recycle_resources(task_id, i, task)
|
||||
llm_logger.info(f"{task_id} received negative token. Recycle end.")
|
||||
abort_res = RequestOutput(
|
||||
request_id=task_id,
|
||||
finished=True,
|
||||
error_code=499,
|
||||
error_msg=f"Your request with request_id:{task_id} is aborted.",
|
||||
)
|
||||
batch_result.append(abort_res)
|
||||
continue
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
if (
|
||||
task_id in self.resource_manager.to_be_aborted_req_id_set
|
||||
and token_ids[-1] == PREEMPTED_TOKEN_ID
|
||||
):
|
||||
llm_logger.info(f"start to recycle abort request_id {task_id}")
|
||||
self.resource_manager.recycle_abort_task(task_id)
|
||||
if (
|
||||
task_id in self.resource_manager.to_be_rescheduled_request_id_set
|
||||
and token_ids[-1] == PREEMPTED_TOKEN_ID
|
||||
@@ -308,6 +298,13 @@ class TokenProcessor:
|
||||
llm_logger.info(f"sync preemption for request_id {task_id} done.")
|
||||
self.resource_manager.reschedule_preempt_task(task_id)
|
||||
continue
|
||||
if self.cfg.scheduler_config.splitwise_role == "decode":
|
||||
# In D instance, if preempted, error has been reported and resource recycled, tokens generated async not need to be handled
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
if task_id in self.resource_manager.to_be_aborted_req_id_set:
|
||||
continue
|
||||
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
|
||||
continue
|
||||
|
||||
current_time = time.time()
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
@@ -760,19 +757,8 @@ class TokenProcessor:
|
||||
if accept_num[i] == PREEMPTED_TOKEN_ID: # in MTP, means preemption has happened in worker
|
||||
llm_logger.info(f"sync preemption for request_id {task_id} done.")
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
if task_id in self.resource_manager.abort_req_ids_set:
|
||||
llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.")
|
||||
self.resource_manager.abort_req_ids_set.remove(task_id)
|
||||
self._recycle_resources(task_id, i, task)
|
||||
llm_logger.info(f"{task_id} received negative token. Recycle end.")
|
||||
abort_res = RequestOutput(
|
||||
request_id=task_id,
|
||||
finished=True,
|
||||
error_code=499,
|
||||
error_msg=f"Your request with request_id:{task_id} is aborted.",
|
||||
)
|
||||
batch_result.append(abort_res)
|
||||
continue
|
||||
if task_id in self.resource_manager.to_be_aborted_req_id_set:
|
||||
self.resource_manager.recycle_abort_task(task_id)
|
||||
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
|
||||
self.resource_manager.reschedule_preempt_task(task_id)
|
||||
continue
|
||||
@@ -801,23 +787,13 @@ class TokenProcessor:
|
||||
if recovery_stop:
|
||||
llm_logger.info(f"recovery stop signal found at task {task_id}")
|
||||
if not recovery_stop and token_id < 0:
|
||||
if task_id in self.resource_manager.abort_req_ids_set:
|
||||
if (
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER and token_id == PREEMPTED_TOKEN_ID
|
||||
) or not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.")
|
||||
self.resource_manager.abort_req_ids_set.remove(task_id)
|
||||
self._recycle_resources(task_id, i, task)
|
||||
llm_logger.info(f"{task_id} received negative token. Recycle end.")
|
||||
abort_res = RequestOutput(
|
||||
request_id=task_id,
|
||||
finished=True,
|
||||
error_code=499,
|
||||
error_msg=f"Your request with request_id:{task_id} is aborted.",
|
||||
)
|
||||
batch_result.append(abort_res)
|
||||
continue
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
if (
|
||||
task_id in self.resource_manager.to_be_aborted_req_id_set
|
||||
and token_id == PREEMPTED_TOKEN_ID
|
||||
):
|
||||
self.resource_manager.recycle_abort_task(task_id)
|
||||
llm_logger.info(f"sync abortion for request_id {task_id} done.")
|
||||
if (
|
||||
task_id in self.resource_manager.to_be_rescheduled_request_id_set
|
||||
and token_id == PREEMPTED_TOKEN_ID
|
||||
@@ -825,6 +801,13 @@ class TokenProcessor:
|
||||
llm_logger.info(f"sync preemption for request_id {task_id} done.")
|
||||
self.resource_manager.reschedule_preempt_task(task_id)
|
||||
continue
|
||||
if self.cfg.scheduler_config.splitwise_role == "decode":
|
||||
# In D instance, if preempted, error has been reported and resource recycled, tokens generated async not need to be handled
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
|
||||
continue
|
||||
if task_id in self.resource_manager.to_be_aborted_req_id_set:
|
||||
continue
|
||||
|
||||
if self.scheduler_metrics_logger and self._is_decode_stage(task):
|
||||
self.scheduler_metrics_logger.on_decode_tokens(len(token_ids))
|
||||
|
||||
@@ -277,6 +277,14 @@ class SplitwiseConnector:
|
||||
"request_id": tasks[i].request_id,
|
||||
"error_msg": tasks[i].get("error_msg"),
|
||||
}
|
||||
if (
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER
|
||||
and tasks[i].request_id in self.resource_manager.waiting_abort_req_id_set
|
||||
):
|
||||
addr = f"{dsg_info['prefill_ip']}:" + f"{dsg_info['prefill_connector_port']}"
|
||||
if addr not in cache_info:
|
||||
cache_info[addr] = []
|
||||
cache_info[addr].append(info)
|
||||
else:
|
||||
addr = f"{dsg_info['prefill_ip']}:" + f"{dsg_info['prefill_connector_port']}"
|
||||
info = {
|
||||
|
||||
@@ -891,7 +891,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["preempted_idx"][idx : idx + 1, :] = 0
|
||||
continue
|
||||
else: # preempted task
|
||||
logger.info(f"Handle preempted request {request} at idx {idx}")
|
||||
if request.task_type.value == RequestType.PREEMPTED.value:
|
||||
logger.info(f"Handle preempted request {request} at idx {idx}")
|
||||
elif request.task_type.value == RequestType.ABORT.value:
|
||||
logger.info(f"Handle abort request {request} at idx {idx}")
|
||||
self.share_inputs["preempted_idx"][idx : idx + 1, :] = 1
|
||||
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.share_inputs["stop_flags"][idx : idx + 1] = True
|
||||
|
||||
@@ -1628,9 +1628,13 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
||||
class DummyRM:
|
||||
def __init__(self):
|
||||
self.abort_req_ids_set = set()
|
||||
self.waiting_abort_req_id_set = set()
|
||||
self.real_bsz = 1
|
||||
self.requests = {"rid": Mock()}
|
||||
|
||||
def add_abort_req_ids(self, req_id):
|
||||
self.waiting_abort_req_id_set.add(req_id)
|
||||
|
||||
def _prepare_preempt_task(self, req):
|
||||
return Request(request_id="rid", prompt_token_ids=[1], prompt_token_ids_len=1)
|
||||
|
||||
@@ -1649,7 +1653,8 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
||||
):
|
||||
eng._insert_zmq_task_to_scheduler()
|
||||
|
||||
eng.engine_worker_queue.put_tasks.assert_called_once()
|
||||
# Verify abort request was handled correctly - added to waiting_abort_req_id_set
|
||||
self.assertIn("rid", eng.resource_manager.waiting_abort_req_id_set)
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_insert_zmq_task_to_scheduler_paused_sends_error(self):
|
||||
|
||||
@@ -53,6 +53,7 @@ class MockConfig:
|
||||
|
||||
class SchedulerConfig:
|
||||
name = "default"
|
||||
splitwise_role = "decode"
|
||||
|
||||
class CacheConfig:
|
||||
enable_prefix_caching = False
|
||||
@@ -104,15 +105,15 @@ class MockResourceManager:
|
||||
self.stop_flags = [False]
|
||||
self.tasks_list = [MockTask()]
|
||||
self.to_be_rescheduled_request_id_set = set()
|
||||
self.to_be_aborted_req_id_set = set()
|
||||
self.abort_req_ids_set = set()
|
||||
self.req_dict = {}
|
||||
self.recycle_abort_task = MagicMock(side_effect=lambda rid: self.to_be_aborted_req_id_set.discard(rid))
|
||||
self.reschedule_preempt_task = MagicMock()
|
||||
|
||||
def info(self):
|
||||
return "Mock resource manager info"
|
||||
|
||||
def reschedule_preempt_task(self, task_id):
|
||||
pass
|
||||
|
||||
|
||||
class MockCachedGeneratedTokens:
|
||||
def __init__(self):
|
||||
@@ -323,7 +324,7 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase):
|
||||
task_id = "test_aborted_request"
|
||||
task = processor.resource_manager.tasks_list[0]
|
||||
task.request_id = task_id
|
||||
processor.resource_manager.abort_req_ids_set = {task_id}
|
||||
processor.resource_manager.to_be_aborted_req_id_set = {task_id}
|
||||
|
||||
# Add the task to req_dict to prevent _recycle_aborted_task from processing it early
|
||||
# batch_id should be < batch_size - 1 to avoid early recycling
|
||||
@@ -339,8 +340,8 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase):
|
||||
# Set up output tokens with negative token
|
||||
# batch = 2 (so batch_id=0 is < batch_size-1=1)
|
||||
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(2))
|
||||
# Set negative token for first task (batch_id=0)
|
||||
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(-1))
|
||||
# Set negative token (PREEMPTED_TOKEN_ID) for first task (batch_id=0)
|
||||
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(-9))
|
||||
# Set positive token for second task (batch_id=1)
|
||||
processor.output_tokens[3, 0].set_tensor(paddle.to_tensor(100))
|
||||
|
||||
@@ -356,15 +357,15 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase):
|
||||
# Mock llm_logger to capture the log message and envs.ENABLE_V1_KVCACHE_SCHEDULER
|
||||
with (
|
||||
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
|
||||
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0),
|
||||
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1),
|
||||
):
|
||||
# Call the method
|
||||
processor._process_batch_output()
|
||||
print(mock_logger)
|
||||
|
||||
# Verify the recycling logic was triggered
|
||||
mock_logger.info.assert_any_call(f"Aborted task {task_id} received negative token. Recycling.")
|
||||
processor._recycle_resources.assert_called_once_with(task_id, 0, task)
|
||||
self.assertNotIn(task_id, processor.resource_manager.abort_req_ids_set)
|
||||
processor.resource_manager.recycle_abort_task.assert_called_once_with(task_id)
|
||||
self.assertNotIn(task_id, processor.resource_manager.to_be_aborted_req_id_set)
|
||||
|
||||
def test_process_batch_output_non_aborted_task_negative_token(self):
|
||||
"""Test non-aborted task receiving negative token does not trigger recycling"""
|
||||
|
||||
@@ -158,11 +158,14 @@ class TestTokenProcessorLogprobs(unittest.TestCase):
|
||||
# Set up task as aborted
|
||||
task_id = "test_aborted_request"
|
||||
self.task_mock.request_id = task_id
|
||||
self.processor.resource_manager.abort_req_ids_set = {task_id}
|
||||
self.processor.resource_manager.to_be_aborted_req_id_set = {task_id}
|
||||
self.processor.resource_manager.recycle_abort_task = MagicMock(
|
||||
side_effect=lambda rid: self.processor.resource_manager.to_be_aborted_req_id_set.discard(rid)
|
||||
)
|
||||
|
||||
# Create stream data with negative token
|
||||
# Create stream data with negative token (PREEMPTED_TOKEN_ID = -9)
|
||||
stream_data = MagicMock()
|
||||
stream_data.tokens = np.array([1, 2, -1]) # Last token is negative
|
||||
stream_data.tokens = np.array([1, 2, -9]) # Last token is PREEMPTED_TOKEN_ID
|
||||
stream_data.batch_id = 0
|
||||
|
||||
# Mock _recycle_resources to track if it's called
|
||||
@@ -171,19 +174,16 @@ class TestTokenProcessorLogprobs(unittest.TestCase):
|
||||
# Mock the llm_logger module and envs.ENABLE_V1_KVCACHE_SCHEDULER
|
||||
with (
|
||||
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
|
||||
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0),
|
||||
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1),
|
||||
):
|
||||
# Call the method
|
||||
result = self.processor._process_batch_output_use_zmq([stream_data])
|
||||
|
||||
# Verify the recycling logic was triggered
|
||||
mock_logger.info.assert_any_call(f"Aborted task {task_id} received negative token. Recycling.")
|
||||
self.processor._recycle_resources.assert_called_once_with(task_id, 0, self.task_mock)
|
||||
self.assertNotIn(task_id, self.processor.resource_manager.abort_req_ids_set)
|
||||
self.assertEqual(len(result), 1) # Should return abort result
|
||||
self.assertEqual(result[0].finished, True)
|
||||
self.assertEqual(result[0].error_code, 499)
|
||||
self.assertIn("aborted", result[0].error_msg.lower())
|
||||
mock_logger.info.assert_any_call(f"start to recycle abort request_id {task_id}")
|
||||
self.processor.resource_manager.recycle_abort_task.assert_called_once_with(task_id)
|
||||
self.assertNotIn(task_id, self.processor.resource_manager.to_be_aborted_req_id_set)
|
||||
self.assertEqual(len(result), 0) # Aborted task is skipped (continue)
|
||||
|
||||
def test_process_batch_output_use_zmq_non_aborted_task_negative_token(self):
|
||||
"""Test non-aborted task receiving negative token does not trigger recycling"""
|
||||
|
||||
@@ -74,6 +74,7 @@ class _DummyResourceManager:
|
||||
self.req_dict = {}
|
||||
self.requests = {}
|
||||
self.to_be_rescheduled_request_id_set = set()
|
||||
self.to_be_aborted_req_id_set = set()
|
||||
self.abort_req_ids_set = set()
|
||||
self.recycled = []
|
||||
self.cached_tasks = []
|
||||
@@ -85,6 +86,10 @@ class _DummyResourceManager:
|
||||
def reschedule_preempt_task(self, request_id):
|
||||
self.recycled.append(f"reschedule-{request_id}")
|
||||
|
||||
def recycle_abort_task(self, request_id):
|
||||
self.recycled.append(f"recycle-abort-{request_id}")
|
||||
self.to_be_aborted_req_id_set.discard(request_id)
|
||||
|
||||
def finish_requests_async(self, request_id):
|
||||
self.recycled.append(f"finish-{request_id}")
|
||||
|
||||
|
||||
@@ -195,6 +195,9 @@ def test_check_decode_allocated_handles_finished_and_error_states():
|
||||
def test_send_cache_info_to_prefill_groups_by_addr_and_skips_error():
|
||||
connector = _build_connector()
|
||||
connector._send_message = Mock()
|
||||
# Add mock resource_manager with waiting_abort_req_id_set
|
||||
connector.resource_manager = Mock()
|
||||
connector.resource_manager.waiting_abort_req_id_set = set()
|
||||
|
||||
tasks = [
|
||||
DummyTask(
|
||||
|
||||
Reference in New Issue
Block a user