test_abort (#6743)

This commit is contained in:
qwes5s5
2026-03-17 14:06:40 +08:00
committed by GitHub
parent eab429d05e
commit 3b7507a4c2
11 changed files with 132 additions and 82 deletions
+1 -14
View File
@@ -786,7 +786,6 @@ class EngineService:
max_num_batched_tokens=self.cfg.scheduler_config.max_num_batched_tokens,
batch=num_prefill_batch,
)
tasks = [task for task in tasks if task.request_id not in self.resource_manager.abort_req_ids_set]
for task in tasks:
task.metrics.engine_get_req_time = time.time()
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
@@ -851,7 +850,6 @@ class EngineService:
max_num_batched_tokens=max_num_batched_tokens,
batch=num_prefill_batch,
)
tasks = [task for task in tasks if task.request_id not in self.resource_manager.abort_req_ids_set]
for task in tasks:
task.metrics.engine_get_req_time = time.time()
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
@@ -1178,19 +1176,8 @@ class EngineService:
if status_value is not None and status_value == RequestStatus.ABORT.value:
req_id = data["request_id"]
self.llm_logger.info(f"Receive abort request, req_id: {req_id}")
self.resource_manager.abort_req_ids_set.add(req_id)
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if req_id in self.resource_manager.requests:
req = self.resource_manager.requests[req_id]
task = self.resource_manager._prepare_preempt_task(req)
self.engine_worker_queue.put_tasks(([task], self.resource_manager.real_bsz))
self.llm_logger.info(f"put abort task in engine worker queue, req_id: {req_id}")
else:
self.scheduler._recycle(req_id)
self.llm_logger.info(
f"req_id:{req_id} has not been allocated any resources, recycled it in scheduler"
)
self.resource_manager.abort_req_ids_set.remove(req_id)
self.resource_manager.add_abort_req_ids(req_id)
continue
err_msg = None
try:
+1
View File
@@ -61,6 +61,7 @@ class RequestType(Enum):
DECODE = 1
PREEMPTED = 2
EXTEND = 3
ABORT = 4
@dataclass
@@ -87,6 +87,15 @@ class ScheduledExtendBlocksTask:
task_type: RequestType = RequestType.EXTEND
@dataclass
class ScheduledAbortTask:
"""Task for allocating new blocks to skip."""
idx: int
request_id: str
task_type: RequestType = RequestType.ABORT
class SignalConsumer:
"""
A class that consumes a signal value up to a specified limit.
@@ -180,6 +189,8 @@ class ResourceManagerV1(ResourceManager):
self.using_extend_tables_req_id = set()
self.reuse_block_num_map = dict()
self.abort_req_ids_set = set()
self.waiting_abort_req_id_set = set()
self.to_be_aborted_req_id_set = set()
# need block nums
need_block_num_data = np.zeros([max_num_seqs], dtype=np.int32)
@@ -246,6 +257,9 @@ class ResourceManagerV1(ResourceManager):
def _prepare_preempt_task(self, request):
return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id)
def _prepare_abort_task(self, request):
return ScheduledAbortTask(idx=request.idx, request_id=request.request_id)
def reschedule_preempt_task(self, request_id, process_func=None):
with self.lock:
llm_logger.debug(f"reschedule {request_id} into waiting queue")
@@ -259,6 +273,27 @@ class ResourceManagerV1(ResourceManager):
self.waiting.appendleft(request)
self.to_be_rescheduled_request_id_set.remove(request_id)
def recycle_abort_task(self, request_id):
with self.lock:
if request_id in self.to_be_aborted_req_id_set and request_id in self.requests:
request = self.requests[request_id]
self.tasks_list[request.idx] = None # 清空slot
self.stop_flags[request.idx] = True # 设置停止标志
del self.requests[request_id]
del self.req_dict[request_id]
self.to_be_aborted_req_id_set.remove(request_id)
def _trigger_abort(self, request_id, scheduled_reqs):
if request_id in self.requests:
abort_request = self.requests[request_id]
abort_request.status = RequestStatus.PREEMPTED
abort_request.num_computed_tokens = 0
self._free_blocks(abort_request) # 释放KV cache blocks
abort_request.cached_block_num = 0
scheduled_reqs.append(self._prepare_abort_task(abort_request))
self.to_be_aborted_req_id_set.add(request_id)
self.waiting_abort_req_id_set.remove(request_id)
def _info_each_block(self):
"""
print each req block
@@ -693,6 +728,13 @@ class ResourceManagerV1(ResourceManager):
return True
return False
def add_abort_req_ids(self, req_ids):
with self.lock:
if isinstance(req_ids, list):
self.waiting_abort_req_id_set.update(req_ids)
else:
self.waiting_abort_req_id_set.add(req_ids)
def cache_output_tokens(self, request):
if (
self.config.cache_config.enable_prefix_caching
@@ -720,6 +762,7 @@ class ResourceManagerV1(ResourceManager):
preempted_reqs: list[Request] = []
error_reqs: list[tuple[str, str]] = []
token_budget = self.config.scheduler_config.max_num_batched_tokens
need_abort_requests = [] # users trigger abortion
# First, schedule the RUNNING requests.
req_index = 0
@@ -739,6 +782,13 @@ class ResourceManagerV1(ResourceManager):
continue
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
request.num_computed_tokens = request.num_total_tokens - 1
if request.request_id in self.waiting_abort_req_id_set:
self._trigger_abort(request.request_id, scheduled_reqs)
req_index += 1
need_abort_requests.append(request)
continue
if (
self.allocated_slots(request) - request.num_total_tokens
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
@@ -876,6 +926,10 @@ class ResourceManagerV1(ResourceManager):
)
req_index += 1
# remove requests to be aborted from running list
for request in need_abort_requests:
self.running.remove(request)
# Second, schedule the WAITING requests.
if not preempted_reqs:
skip_requests: list[Request] = []
+28 -45
View File
@@ -284,23 +284,13 @@ class TokenProcessor:
task_id = task.request_id
token_ids = stream_data.tokens # numpy.array
if token_ids is not None and token_ids[-1] < 0:
if task_id in self.resource_manager.abort_req_ids_set:
if (
envs.ENABLE_V1_KVCACHE_SCHEDULER and token_ids[-1] == PREEMPTED_TOKEN_ID
) or not envs.ENABLE_V1_KVCACHE_SCHEDULER:
llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.")
self.resource_manager.abort_req_ids_set.remove(task_id)
self._recycle_resources(task_id, i, task)
llm_logger.info(f"{task_id} received negative token. Recycle end.")
abort_res = RequestOutput(
request_id=task_id,
finished=True,
error_code=499,
error_msg=f"Your request with request_id:{task_id} is aborted.",
)
batch_result.append(abort_res)
continue
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if (
task_id in self.resource_manager.to_be_aborted_req_id_set
and token_ids[-1] == PREEMPTED_TOKEN_ID
):
llm_logger.info(f"start to recycle abort request_id {task_id}")
self.resource_manager.recycle_abort_task(task_id)
if (
task_id in self.resource_manager.to_be_rescheduled_request_id_set
and token_ids[-1] == PREEMPTED_TOKEN_ID
@@ -308,6 +298,13 @@ class TokenProcessor:
llm_logger.info(f"sync preemption for request_id {task_id} done.")
self.resource_manager.reschedule_preempt_task(task_id)
continue
if self.cfg.scheduler_config.splitwise_role == "decode":
# In D instance, if preempted, error has been reported and resource recycled, tokens generated async not need to be handled
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if task_id in self.resource_manager.to_be_aborted_req_id_set:
continue
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
continue
current_time = time.time()
if self.tokens_counter[task_id] == 0:
@@ -760,19 +757,8 @@ class TokenProcessor:
if accept_num[i] == PREEMPTED_TOKEN_ID: # in MTP, means preemption has happened in worker
llm_logger.info(f"sync preemption for request_id {task_id} done.")
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if task_id in self.resource_manager.abort_req_ids_set:
llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.")
self.resource_manager.abort_req_ids_set.remove(task_id)
self._recycle_resources(task_id, i, task)
llm_logger.info(f"{task_id} received negative token. Recycle end.")
abort_res = RequestOutput(
request_id=task_id,
finished=True,
error_code=499,
error_msg=f"Your request with request_id:{task_id} is aborted.",
)
batch_result.append(abort_res)
continue
if task_id in self.resource_manager.to_be_aborted_req_id_set:
self.resource_manager.recycle_abort_task(task_id)
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
self.resource_manager.reschedule_preempt_task(task_id)
continue
@@ -801,23 +787,13 @@ class TokenProcessor:
if recovery_stop:
llm_logger.info(f"recovery stop signal found at task {task_id}")
if not recovery_stop and token_id < 0:
if task_id in self.resource_manager.abort_req_ids_set:
if (
envs.ENABLE_V1_KVCACHE_SCHEDULER and token_id == PREEMPTED_TOKEN_ID
) or not envs.ENABLE_V1_KVCACHE_SCHEDULER:
llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.")
self.resource_manager.abort_req_ids_set.remove(task_id)
self._recycle_resources(task_id, i, task)
llm_logger.info(f"{task_id} received negative token. Recycle end.")
abort_res = RequestOutput(
request_id=task_id,
finished=True,
error_code=499,
error_msg=f"Your request with request_id:{task_id} is aborted.",
)
batch_result.append(abort_res)
continue
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if (
task_id in self.resource_manager.to_be_aborted_req_id_set
and token_id == PREEMPTED_TOKEN_ID
):
self.resource_manager.recycle_abort_task(task_id)
llm_logger.info(f"sync abortion for request_id {task_id} done.")
if (
task_id in self.resource_manager.to_be_rescheduled_request_id_set
and token_id == PREEMPTED_TOKEN_ID
@@ -825,6 +801,13 @@ class TokenProcessor:
llm_logger.info(f"sync preemption for request_id {task_id} done.")
self.resource_manager.reschedule_preempt_task(task_id)
continue
if self.cfg.scheduler_config.splitwise_role == "decode":
# In D instance, if preempted, error has been reported and resource recycled, tokens generated async not need to be handled
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
continue
if task_id in self.resource_manager.to_be_aborted_req_id_set:
continue
if self.scheduler_metrics_logger and self._is_decode_stage(task):
self.scheduler_metrics_logger.on_decode_tokens(len(token_ids))
@@ -277,6 +277,14 @@ class SplitwiseConnector:
"request_id": tasks[i].request_id,
"error_msg": tasks[i].get("error_msg"),
}
if (
envs.ENABLE_V1_KVCACHE_SCHEDULER
and tasks[i].request_id in self.resource_manager.waiting_abort_req_id_set
):
addr = f"{dsg_info['prefill_ip']}:" + f"{dsg_info['prefill_connector_port']}"
if addr not in cache_info:
cache_info[addr] = []
cache_info[addr].append(info)
else:
addr = f"{dsg_info['prefill_ip']}:" + f"{dsg_info['prefill_connector_port']}"
info = {
+4 -1
View File
@@ -891,7 +891,10 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["preempted_idx"][idx : idx + 1, :] = 0
continue
else: # preempted task
logger.info(f"Handle preempted request {request} at idx {idx}")
if request.task_type.value == RequestType.PREEMPTED.value:
logger.info(f"Handle preempted request {request} at idx {idx}")
elif request.task_type.value == RequestType.ABORT.value:
logger.info(f"Handle abort request {request} at idx {idx}")
self.share_inputs["preempted_idx"][idx : idx + 1, :] = 1
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
self.share_inputs["stop_flags"][idx : idx + 1] = True
+6 -1
View File
@@ -1628,9 +1628,13 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
class DummyRM:
def __init__(self):
self.abort_req_ids_set = set()
self.waiting_abort_req_id_set = set()
self.real_bsz = 1
self.requests = {"rid": Mock()}
def add_abort_req_ids(self, req_id):
self.waiting_abort_req_id_set.add(req_id)
def _prepare_preempt_task(self, req):
return Request(request_id="rid", prompt_token_ids=[1], prompt_token_ids_len=1)
@@ -1649,7 +1653,8 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
):
eng._insert_zmq_task_to_scheduler()
eng.engine_worker_queue.put_tasks.assert_called_once()
# Verify abort request was handled correctly - added to waiting_abort_req_id_set
self.assertIn("rid", eng.resource_manager.waiting_abort_req_id_set)
self._detach_finalizer(eng)
def test_insert_zmq_task_to_scheduler_paused_sends_error(self):
+11 -10
View File
@@ -53,6 +53,7 @@ class MockConfig:
class SchedulerConfig:
name = "default"
splitwise_role = "decode"
class CacheConfig:
enable_prefix_caching = False
@@ -104,15 +105,15 @@ class MockResourceManager:
self.stop_flags = [False]
self.tasks_list = [MockTask()]
self.to_be_rescheduled_request_id_set = set()
self.to_be_aborted_req_id_set = set()
self.abort_req_ids_set = set()
self.req_dict = {}
self.recycle_abort_task = MagicMock(side_effect=lambda rid: self.to_be_aborted_req_id_set.discard(rid))
self.reschedule_preempt_task = MagicMock()
def info(self):
return "Mock resource manager info"
def reschedule_preempt_task(self, task_id):
pass
class MockCachedGeneratedTokens:
def __init__(self):
@@ -323,7 +324,7 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase):
task_id = "test_aborted_request"
task = processor.resource_manager.tasks_list[0]
task.request_id = task_id
processor.resource_manager.abort_req_ids_set = {task_id}
processor.resource_manager.to_be_aborted_req_id_set = {task_id}
# Add the task to req_dict to prevent _recycle_aborted_task from processing it early
# batch_id should be < batch_size - 1 to avoid early recycling
@@ -339,8 +340,8 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase):
# Set up output tokens with negative token
# batch = 2 (so batch_id=0 is < batch_size-1=1)
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(2))
# Set negative token for first task (batch_id=0)
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(-1))
# Set negative token (PREEMPTED_TOKEN_ID) for first task (batch_id=0)
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(-9))
# Set positive token for second task (batch_id=1)
processor.output_tokens[3, 0].set_tensor(paddle.to_tensor(100))
@@ -356,15 +357,15 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase):
# Mock llm_logger to capture the log message and envs.ENABLE_V1_KVCACHE_SCHEDULER
with (
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0),
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1),
):
# Call the method
processor._process_batch_output()
print(mock_logger)
# Verify the recycling logic was triggered
mock_logger.info.assert_any_call(f"Aborted task {task_id} received negative token. Recycling.")
processor._recycle_resources.assert_called_once_with(task_id, 0, task)
self.assertNotIn(task_id, processor.resource_manager.abort_req_ids_set)
processor.resource_manager.recycle_abort_task.assert_called_once_with(task_id)
self.assertNotIn(task_id, processor.resource_manager.to_be_aborted_req_id_set)
def test_process_batch_output_non_aborted_task_negative_token(self):
"""Test non-aborted task receiving negative token does not trigger recycling"""
@@ -158,11 +158,14 @@ class TestTokenProcessorLogprobs(unittest.TestCase):
# Set up task as aborted
task_id = "test_aborted_request"
self.task_mock.request_id = task_id
self.processor.resource_manager.abort_req_ids_set = {task_id}
self.processor.resource_manager.to_be_aborted_req_id_set = {task_id}
self.processor.resource_manager.recycle_abort_task = MagicMock(
side_effect=lambda rid: self.processor.resource_manager.to_be_aborted_req_id_set.discard(rid)
)
# Create stream data with negative token
# Create stream data with negative token (PREEMPTED_TOKEN_ID = -9)
stream_data = MagicMock()
stream_data.tokens = np.array([1, 2, -1]) # Last token is negative
stream_data.tokens = np.array([1, 2, -9]) # Last token is PREEMPTED_TOKEN_ID
stream_data.batch_id = 0
# Mock _recycle_resources to track if it's called
@@ -171,19 +174,16 @@ class TestTokenProcessorLogprobs(unittest.TestCase):
# Mock the llm_logger module and envs.ENABLE_V1_KVCACHE_SCHEDULER
with (
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0),
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1),
):
# Call the method
result = self.processor._process_batch_output_use_zmq([stream_data])
# Verify the recycling logic was triggered
mock_logger.info.assert_any_call(f"Aborted task {task_id} received negative token. Recycling.")
self.processor._recycle_resources.assert_called_once_with(task_id, 0, self.task_mock)
self.assertNotIn(task_id, self.processor.resource_manager.abort_req_ids_set)
self.assertEqual(len(result), 1) # Should return abort result
self.assertEqual(result[0].finished, True)
self.assertEqual(result[0].error_code, 499)
self.assertIn("aborted", result[0].error_msg.lower())
mock_logger.info.assert_any_call(f"start to recycle abort request_id {task_id}")
self.processor.resource_manager.recycle_abort_task.assert_called_once_with(task_id)
self.assertNotIn(task_id, self.processor.resource_manager.to_be_aborted_req_id_set)
self.assertEqual(len(result), 0) # Aborted task is skipped (continue)
def test_process_batch_output_use_zmq_non_aborted_task_negative_token(self):
"""Test non-aborted task receiving negative token does not trigger recycling"""
+5
View File
@@ -74,6 +74,7 @@ class _DummyResourceManager:
self.req_dict = {}
self.requests = {}
self.to_be_rescheduled_request_id_set = set()
self.to_be_aborted_req_id_set = set()
self.abort_req_ids_set = set()
self.recycled = []
self.cached_tasks = []
@@ -85,6 +86,10 @@ class _DummyResourceManager:
def reschedule_preempt_task(self, request_id):
self.recycled.append(f"reschedule-{request_id}")
def recycle_abort_task(self, request_id):
self.recycled.append(f"recycle-abort-{request_id}")
self.to_be_aborted_req_id_set.discard(request_id)
def finish_requests_async(self, request_id):
self.recycled.append(f"finish-{request_id}")
@@ -195,6 +195,9 @@ def test_check_decode_allocated_handles_finished_and_error_states():
def test_send_cache_info_to_prefill_groups_by_addr_and_skips_error():
connector = _build_connector()
connector._send_message = Mock()
# Add mock resource_manager with waiting_abort_req_id_set
connector.resource_manager = Mock()
connector.resource_manager.waiting_abort_req_id_set = set()
tasks = [
DummyTask(