test_abort (#6743)

This commit is contained in:
qwes5s5
2026-03-17 14:06:40 +08:00
committed by GitHub
parent eab429d05e
commit 3b7507a4c2
11 changed files with 132 additions and 82 deletions
@@ -87,6 +87,15 @@ class ScheduledExtendBlocksTask:
task_type: RequestType = RequestType.EXTEND
@dataclass
class ScheduledAbortTask:
"""Task for allocating new blocks to skip."""
idx: int
request_id: str
task_type: RequestType = RequestType.ABORT
class SignalConsumer:
"""
A class that consumes a signal value up to a specified limit.
@@ -180,6 +189,8 @@ class ResourceManagerV1(ResourceManager):
self.using_extend_tables_req_id = set()
self.reuse_block_num_map = dict()
self.abort_req_ids_set = set()
self.waiting_abort_req_id_set = set()
self.to_be_aborted_req_id_set = set()
# need block nums
need_block_num_data = np.zeros([max_num_seqs], dtype=np.int32)
@@ -246,6 +257,9 @@ class ResourceManagerV1(ResourceManager):
def _prepare_preempt_task(self, request):
return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id)
def _prepare_abort_task(self, request):
return ScheduledAbortTask(idx=request.idx, request_id=request.request_id)
def reschedule_preempt_task(self, request_id, process_func=None):
with self.lock:
llm_logger.debug(f"reschedule {request_id} into waiting queue")
@@ -259,6 +273,27 @@ class ResourceManagerV1(ResourceManager):
self.waiting.appendleft(request)
self.to_be_rescheduled_request_id_set.remove(request_id)
def recycle_abort_task(self, request_id):
with self.lock:
if request_id in self.to_be_aborted_req_id_set and request_id in self.requests:
request = self.requests[request_id]
self.tasks_list[request.idx] = None # 清空slot
self.stop_flags[request.idx] = True # 设置停止标志
del self.requests[request_id]
del self.req_dict[request_id]
self.to_be_aborted_req_id_set.remove(request_id)
def _trigger_abort(self, request_id, scheduled_reqs):
if request_id in self.requests:
abort_request = self.requests[request_id]
abort_request.status = RequestStatus.PREEMPTED
abort_request.num_computed_tokens = 0
self._free_blocks(abort_request) # 释放KV cache blocks
abort_request.cached_block_num = 0
scheduled_reqs.append(self._prepare_abort_task(abort_request))
self.to_be_aborted_req_id_set.add(request_id)
self.waiting_abort_req_id_set.remove(request_id)
def _info_each_block(self):
"""
print each req block
@@ -693,6 +728,13 @@ class ResourceManagerV1(ResourceManager):
return True
return False
def add_abort_req_ids(self, req_ids):
with self.lock:
if isinstance(req_ids, list):
self.waiting_abort_req_id_set.update(req_ids)
else:
self.waiting_abort_req_id_set.add(req_ids)
def cache_output_tokens(self, request):
if (
self.config.cache_config.enable_prefix_caching
@@ -720,6 +762,7 @@ class ResourceManagerV1(ResourceManager):
preempted_reqs: list[Request] = []
error_reqs: list[tuple[str, str]] = []
token_budget = self.config.scheduler_config.max_num_batched_tokens
need_abort_requests = [] # users trigger abortion
# First, schedule the RUNNING requests.
req_index = 0
@@ -739,6 +782,13 @@ class ResourceManagerV1(ResourceManager):
continue
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
request.num_computed_tokens = request.num_total_tokens - 1
if request.request_id in self.waiting_abort_req_id_set:
self._trigger_abort(request.request_id, scheduled_reqs)
req_index += 1
need_abort_requests.append(request)
continue
if (
self.allocated_slots(request) - request.num_total_tokens
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
@@ -876,6 +926,10 @@ class ResourceManagerV1(ResourceManager):
)
req_index += 1
# remove requests to be aborted from running list
for request in need_abort_requests:
self.running.remove(request)
# Second, schedule the WAITING requests.
if not preempted_reqs:
skip_requests: list[Request] = []