diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 0ad2f79643..e985841374 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -786,7 +786,6 @@ class EngineService: max_num_batched_tokens=self.cfg.scheduler_config.max_num_batched_tokens, batch=num_prefill_batch, ) - tasks = [task for task in tasks if task.request_id not in self.resource_manager.abort_req_ids_set] for task in tasks: task.metrics.engine_get_req_time = time.time() trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) @@ -851,7 +850,6 @@ class EngineService: max_num_batched_tokens=max_num_batched_tokens, batch=num_prefill_batch, ) - tasks = [task for task in tasks if task.request_id not in self.resource_manager.abort_req_ids_set] for task in tasks: task.metrics.engine_get_req_time = time.time() trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) @@ -1178,19 +1176,8 @@ class EngineService: if status_value is not None and status_value == RequestStatus.ABORT.value: req_id = data["request_id"] self.llm_logger.info(f"Receive abort request, req_id: {req_id}") - self.resource_manager.abort_req_ids_set.add(req_id) if envs.ENABLE_V1_KVCACHE_SCHEDULER: - if req_id in self.resource_manager.requests: - req = self.resource_manager.requests[req_id] - task = self.resource_manager._prepare_preempt_task(req) - self.engine_worker_queue.put_tasks(([task], self.resource_manager.real_bsz)) - self.llm_logger.info(f"put abort task in engine worker queue, req_id: {req_id}") - else: - self.scheduler._recycle(req_id) - self.llm_logger.info( - f"req_id:{req_id} has not been allocated any resources, recycled it in scheduler" - ) - self.resource_manager.abort_req_ids_set.remove(req_id) + self.resource_manager.add_abort_req_ids(req_id) continue err_msg = None try: diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 1e2a53ed20..1ecb39a4a4 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -61,6 +61,7 @@ class RequestType(Enum): DECODE = 1 PREEMPTED = 2 EXTEND = 3 + ABORT = 4 @dataclass diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 9960d4afaf..b0425d779d 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -87,6 +87,15 @@ class ScheduledExtendBlocksTask: task_type: RequestType = RequestType.EXTEND +@dataclass +class ScheduledAbortTask: + """Task for allocating new blocks to skip.""" + + idx: int + request_id: str + task_type: RequestType = RequestType.ABORT + + class SignalConsumer: """ A class that consumes a signal value up to a specified limit. @@ -180,6 +189,8 @@ class ResourceManagerV1(ResourceManager): self.using_extend_tables_req_id = set() self.reuse_block_num_map = dict() self.abort_req_ids_set = set() + self.waiting_abort_req_id_set = set() + self.to_be_aborted_req_id_set = set() # need block nums need_block_num_data = np.zeros([max_num_seqs], dtype=np.int32) @@ -246,6 +257,9 @@ class ResourceManagerV1(ResourceManager): def _prepare_preempt_task(self, request): return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id) + def _prepare_abort_task(self, request): + return ScheduledAbortTask(idx=request.idx, request_id=request.request_id) + def reschedule_preempt_task(self, request_id, process_func=None): with self.lock: llm_logger.debug(f"reschedule {request_id} into waiting queue") @@ -259,6 +273,27 @@ class ResourceManagerV1(ResourceManager): self.waiting.appendleft(request) self.to_be_rescheduled_request_id_set.remove(request_id) + def recycle_abort_task(self, request_id): + with self.lock: + if request_id in self.to_be_aborted_req_id_set and request_id in self.requests: + request = self.requests[request_id] + self.tasks_list[request.idx] = None # 清空slot + self.stop_flags[request.idx] = True # 设置停止标志 + del self.requests[request_id] + del self.req_dict[request_id] + self.to_be_aborted_req_id_set.remove(request_id) + + def _trigger_abort(self, request_id, scheduled_reqs): + if request_id in self.requests: + abort_request = self.requests[request_id] + abort_request.status = RequestStatus.PREEMPTED + abort_request.num_computed_tokens = 0 + self._free_blocks(abort_request) # 释放KV cache blocks + abort_request.cached_block_num = 0 + scheduled_reqs.append(self._prepare_abort_task(abort_request)) + self.to_be_aborted_req_id_set.add(request_id) + self.waiting_abort_req_id_set.remove(request_id) + def _info_each_block(self): """ print each req block @@ -693,6 +728,13 @@ class ResourceManagerV1(ResourceManager): return True return False + def add_abort_req_ids(self, req_ids): + with self.lock: + if isinstance(req_ids, list): + self.waiting_abort_req_id_set.update(req_ids) + else: + self.waiting_abort_req_id_set.add(req_ids) + def cache_output_tokens(self, request): if ( self.config.cache_config.enable_prefix_caching @@ -720,6 +762,7 @@ class ResourceManagerV1(ResourceManager): preempted_reqs: list[Request] = [] error_reqs: list[tuple[str, str]] = [] token_budget = self.config.scheduler_config.max_num_batched_tokens + need_abort_requests = [] # users trigger abortion # First, schedule the RUNNING requests. req_index = 0 @@ -739,6 +782,13 @@ class ResourceManagerV1(ResourceManager): continue if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens request.num_computed_tokens = request.num_total_tokens - 1 + + if request.request_id in self.waiting_abort_req_id_set: + self._trigger_abort(request.request_id, scheduled_reqs) + req_index += 1 + need_abort_requests.append(request) + continue + if ( self.allocated_slots(request) - request.num_total_tokens <= self.config.cache_config.prealloc_dec_block_slot_num_threshold @@ -876,6 +926,10 @@ class ResourceManagerV1(ResourceManager): ) req_index += 1 + # remove requests to be aborted from running list + for request in need_abort_requests: + self.running.remove(request) + # Second, schedule the WAITING requests. if not preempted_reqs: skip_requests: list[Request] = [] diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 662b1cfc8a..1ab0b48f35 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -284,23 +284,13 @@ class TokenProcessor: task_id = task.request_id token_ids = stream_data.tokens # numpy.array if token_ids is not None and token_ids[-1] < 0: - if task_id in self.resource_manager.abort_req_ids_set: - if ( - envs.ENABLE_V1_KVCACHE_SCHEDULER and token_ids[-1] == PREEMPTED_TOKEN_ID - ) or not envs.ENABLE_V1_KVCACHE_SCHEDULER: - llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.") - self.resource_manager.abort_req_ids_set.remove(task_id) - self._recycle_resources(task_id, i, task) - llm_logger.info(f"{task_id} received negative token. Recycle end.") - abort_res = RequestOutput( - request_id=task_id, - finished=True, - error_code=499, - error_msg=f"Your request with request_id:{task_id} is aborted.", - ) - batch_result.append(abort_res) - continue if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if ( + task_id in self.resource_manager.to_be_aborted_req_id_set + and token_ids[-1] == PREEMPTED_TOKEN_ID + ): + llm_logger.info(f"start to recycle abort request_id {task_id}") + self.resource_manager.recycle_abort_task(task_id) if ( task_id in self.resource_manager.to_be_rescheduled_request_id_set and token_ids[-1] == PREEMPTED_TOKEN_ID @@ -308,6 +298,13 @@ class TokenProcessor: llm_logger.info(f"sync preemption for request_id {task_id} done.") self.resource_manager.reschedule_preempt_task(task_id) continue + if self.cfg.scheduler_config.splitwise_role == "decode": + # In D instance, if preempted, error has been reported and resource recycled, tokens generated async not need to be handled + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if task_id in self.resource_manager.to_be_aborted_req_id_set: + continue + if task_id in self.resource_manager.to_be_rescheduled_request_id_set: + continue current_time = time.time() if self.tokens_counter[task_id] == 0: @@ -760,19 +757,8 @@ class TokenProcessor: if accept_num[i] == PREEMPTED_TOKEN_ID: # in MTP, means preemption has happened in worker llm_logger.info(f"sync preemption for request_id {task_id} done.") if envs.ENABLE_V1_KVCACHE_SCHEDULER: - if task_id in self.resource_manager.abort_req_ids_set: - llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.") - self.resource_manager.abort_req_ids_set.remove(task_id) - self._recycle_resources(task_id, i, task) - llm_logger.info(f"{task_id} received negative token. Recycle end.") - abort_res = RequestOutput( - request_id=task_id, - finished=True, - error_code=499, - error_msg=f"Your request with request_id:{task_id} is aborted.", - ) - batch_result.append(abort_res) - continue + if task_id in self.resource_manager.to_be_aborted_req_id_set: + self.resource_manager.recycle_abort_task(task_id) if task_id in self.resource_manager.to_be_rescheduled_request_id_set: self.resource_manager.reschedule_preempt_task(task_id) continue @@ -801,23 +787,13 @@ class TokenProcessor: if recovery_stop: llm_logger.info(f"recovery stop signal found at task {task_id}") if not recovery_stop and token_id < 0: - if task_id in self.resource_manager.abort_req_ids_set: - if ( - envs.ENABLE_V1_KVCACHE_SCHEDULER and token_id == PREEMPTED_TOKEN_ID - ) or not envs.ENABLE_V1_KVCACHE_SCHEDULER: - llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.") - self.resource_manager.abort_req_ids_set.remove(task_id) - self._recycle_resources(task_id, i, task) - llm_logger.info(f"{task_id} received negative token. Recycle end.") - abort_res = RequestOutput( - request_id=task_id, - finished=True, - error_code=499, - error_msg=f"Your request with request_id:{task_id} is aborted.", - ) - batch_result.append(abort_res) - continue if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if ( + task_id in self.resource_manager.to_be_aborted_req_id_set + and token_id == PREEMPTED_TOKEN_ID + ): + self.resource_manager.recycle_abort_task(task_id) + llm_logger.info(f"sync abortion for request_id {task_id} done.") if ( task_id in self.resource_manager.to_be_rescheduled_request_id_set and token_id == PREEMPTED_TOKEN_ID @@ -825,6 +801,13 @@ class TokenProcessor: llm_logger.info(f"sync preemption for request_id {task_id} done.") self.resource_manager.reschedule_preempt_task(task_id) continue + if self.cfg.scheduler_config.splitwise_role == "decode": + # In D instance, if preempted, error has been reported and resource recycled, tokens generated async not need to be handled + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if task_id in self.resource_manager.to_be_rescheduled_request_id_set: + continue + if task_id in self.resource_manager.to_be_aborted_req_id_set: + continue if self.scheduler_metrics_logger and self._is_decode_stage(task): self.scheduler_metrics_logger.on_decode_tokens(len(token_ids)) diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index 3854970a96..7200c99ed9 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -277,6 +277,14 @@ class SplitwiseConnector: "request_id": tasks[i].request_id, "error_msg": tasks[i].get("error_msg"), } + if ( + envs.ENABLE_V1_KVCACHE_SCHEDULER + and tasks[i].request_id in self.resource_manager.waiting_abort_req_id_set + ): + addr = f"{dsg_info['prefill_ip']}:" + f"{dsg_info['prefill_connector_port']}" + if addr not in cache_info: + cache_info[addr] = [] + cache_info[addr].append(info) else: addr = f"{dsg_info['prefill_ip']}:" + f"{dsg_info['prefill_connector_port']}" info = { diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 17c58ce34a..24f72f5e26 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -891,7 +891,10 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["preempted_idx"][idx : idx + 1, :] = 0 continue else: # preempted task - logger.info(f"Handle preempted request {request} at idx {idx}") + if request.task_type.value == RequestType.PREEMPTED.value: + logger.info(f"Handle preempted request {request} at idx {idx}") + elif request.task_type.value == RequestType.ABORT.value: + logger.info(f"Handle abort request {request} at idx {idx}") self.share_inputs["preempted_idx"][idx : idx + 1, :] = 1 self.share_inputs["block_tables"][idx : idx + 1, :] = -1 self.share_inputs["stop_flags"][idx : idx + 1] = True diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 5638c8a569..996e148112 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -1628,9 +1628,13 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase): class DummyRM: def __init__(self): self.abort_req_ids_set = set() + self.waiting_abort_req_id_set = set() self.real_bsz = 1 self.requests = {"rid": Mock()} + def add_abort_req_ids(self, req_id): + self.waiting_abort_req_id_set.add(req_id) + def _prepare_preempt_task(self, req): return Request(request_id="rid", prompt_token_ids=[1], prompt_token_ids_len=1) @@ -1649,7 +1653,8 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase): ): eng._insert_zmq_task_to_scheduler() - eng.engine_worker_queue.put_tasks.assert_called_once() + # Verify abort request was handled correctly - added to waiting_abort_req_id_set + self.assertIn("rid", eng.resource_manager.waiting_abort_req_id_set) self._detach_finalizer(eng) def test_insert_zmq_task_to_scheduler_paused_sends_error(self): diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index 6859259387..46282cd386 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -53,6 +53,7 @@ class MockConfig: class SchedulerConfig: name = "default" + splitwise_role = "decode" class CacheConfig: enable_prefix_caching = False @@ -104,15 +105,15 @@ class MockResourceManager: self.stop_flags = [False] self.tasks_list = [MockTask()] self.to_be_rescheduled_request_id_set = set() + self.to_be_aborted_req_id_set = set() self.abort_req_ids_set = set() self.req_dict = {} + self.recycle_abort_task = MagicMock(side_effect=lambda rid: self.to_be_aborted_req_id_set.discard(rid)) + self.reschedule_preempt_task = MagicMock() def info(self): return "Mock resource manager info" - def reschedule_preempt_task(self, task_id): - pass - class MockCachedGeneratedTokens: def __init__(self): @@ -323,7 +324,7 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase): task_id = "test_aborted_request" task = processor.resource_manager.tasks_list[0] task.request_id = task_id - processor.resource_manager.abort_req_ids_set = {task_id} + processor.resource_manager.to_be_aborted_req_id_set = {task_id} # Add the task to req_dict to prevent _recycle_aborted_task from processing it early # batch_id should be < batch_size - 1 to avoid early recycling @@ -339,8 +340,8 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase): # Set up output tokens with negative token # batch = 2 (so batch_id=0 is < batch_size-1=1) processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(2)) - # Set negative token for first task (batch_id=0) - processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(-1)) + # Set negative token (PREEMPTED_TOKEN_ID) for first task (batch_id=0) + processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(-9)) # Set positive token for second task (batch_id=1) processor.output_tokens[3, 0].set_tensor(paddle.to_tensor(100)) @@ -356,15 +357,15 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase): # Mock llm_logger to capture the log message and envs.ENABLE_V1_KVCACHE_SCHEDULER with ( patch("fastdeploy.output.token_processor.llm_logger") as mock_logger, - patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0), + patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1), ): # Call the method processor._process_batch_output() + print(mock_logger) # Verify the recycling logic was triggered - mock_logger.info.assert_any_call(f"Aborted task {task_id} received negative token. Recycling.") - processor._recycle_resources.assert_called_once_with(task_id, 0, task) - self.assertNotIn(task_id, processor.resource_manager.abort_req_ids_set) + processor.resource_manager.recycle_abort_task.assert_called_once_with(task_id) + self.assertNotIn(task_id, processor.resource_manager.to_be_aborted_req_id_set) def test_process_batch_output_non_aborted_task_negative_token(self): """Test non-aborted task receiving negative token does not trigger recycling""" diff --git a/tests/output/test_process_batch_output_use_zmq.py b/tests/output/test_process_batch_output_use_zmq.py index eb9ea9fc11..07826e6f0e 100644 --- a/tests/output/test_process_batch_output_use_zmq.py +++ b/tests/output/test_process_batch_output_use_zmq.py @@ -158,11 +158,14 @@ class TestTokenProcessorLogprobs(unittest.TestCase): # Set up task as aborted task_id = "test_aborted_request" self.task_mock.request_id = task_id - self.processor.resource_manager.abort_req_ids_set = {task_id} + self.processor.resource_manager.to_be_aborted_req_id_set = {task_id} + self.processor.resource_manager.recycle_abort_task = MagicMock( + side_effect=lambda rid: self.processor.resource_manager.to_be_aborted_req_id_set.discard(rid) + ) - # Create stream data with negative token + # Create stream data with negative token (PREEMPTED_TOKEN_ID = -9) stream_data = MagicMock() - stream_data.tokens = np.array([1, 2, -1]) # Last token is negative + stream_data.tokens = np.array([1, 2, -9]) # Last token is PREEMPTED_TOKEN_ID stream_data.batch_id = 0 # Mock _recycle_resources to track if it's called @@ -171,19 +174,16 @@ class TestTokenProcessorLogprobs(unittest.TestCase): # Mock the llm_logger module and envs.ENABLE_V1_KVCACHE_SCHEDULER with ( patch("fastdeploy.output.token_processor.llm_logger") as mock_logger, - patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0), + patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1), ): # Call the method result = self.processor._process_batch_output_use_zmq([stream_data]) # Verify the recycling logic was triggered - mock_logger.info.assert_any_call(f"Aborted task {task_id} received negative token. Recycling.") - self.processor._recycle_resources.assert_called_once_with(task_id, 0, self.task_mock) - self.assertNotIn(task_id, self.processor.resource_manager.abort_req_ids_set) - self.assertEqual(len(result), 1) # Should return abort result - self.assertEqual(result[0].finished, True) - self.assertEqual(result[0].error_code, 499) - self.assertIn("aborted", result[0].error_msg.lower()) + mock_logger.info.assert_any_call(f"start to recycle abort request_id {task_id}") + self.processor.resource_manager.recycle_abort_task.assert_called_once_with(task_id) + self.assertNotIn(task_id, self.processor.resource_manager.to_be_aborted_req_id_set) + self.assertEqual(len(result), 0) # Aborted task is skipped (continue) def test_process_batch_output_use_zmq_non_aborted_task_negative_token(self): """Test non-aborted task receiving negative token does not trigger recycling""" diff --git a/tests/output/test_token_processor.py b/tests/output/test_token_processor.py index e8ff821a26..c0609094a2 100644 --- a/tests/output/test_token_processor.py +++ b/tests/output/test_token_processor.py @@ -74,6 +74,7 @@ class _DummyResourceManager: self.req_dict = {} self.requests = {} self.to_be_rescheduled_request_id_set = set() + self.to_be_aborted_req_id_set = set() self.abort_req_ids_set = set() self.recycled = [] self.cached_tasks = [] @@ -85,6 +86,10 @@ class _DummyResourceManager: def reschedule_preempt_task(self, request_id): self.recycled.append(f"reschedule-{request_id}") + def recycle_abort_task(self, request_id): + self.recycled.append(f"recycle-abort-{request_id}") + self.to_be_aborted_req_id_set.discard(request_id) + def finish_requests_async(self, request_id): self.recycled.append(f"finish-{request_id}") diff --git a/tests/splitwise/test_splitwise_connector.py b/tests/splitwise/test_splitwise_connector.py index 712bebdeb1..610cfd9246 100644 --- a/tests/splitwise/test_splitwise_connector.py +++ b/tests/splitwise/test_splitwise_connector.py @@ -195,6 +195,9 @@ def test_check_decode_allocated_handles_finished_and_error_states(): def test_send_cache_info_to_prefill_groups_by_addr_and_skips_error(): connector = _build_connector() connector._send_message = Mock() + # Add mock resource_manager with waiting_abort_req_id_set + connector.resource_manager = Mock() + connector.resource_manager.waiting_abort_req_id_set = set() tasks = [ DummyTask(