mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support stopping the inference for the corresponding request in the online service after a disconnection request. (#5320)
* request disconnect * request disconnect * fix bug * fix bug--amend --------- Co-authored-by: root <root@yq01-sys-rpm26xc1knu.yq01.baidu.com>
This commit is contained in:
@@ -444,33 +444,35 @@ class PrefixCacheManager:
|
|||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def allocate_gpu_blocks(self, num_blocks):
|
def allocate_gpu_blocks(self, num_blocks, req_id=None):
|
||||||
"""
|
"""
|
||||||
allocate gpu blocks.
|
allocate gpu blocks.
|
||||||
"""
|
"""
|
||||||
assert num_blocks <= len(
|
assert num_blocks <= len(
|
||||||
self.gpu_free_block_list
|
self.gpu_free_block_list
|
||||||
), f"gpu free block num: {len(self.gpu_free_block_list)} < needed number {num_blocks}"
|
), f"gpu free block num: {len(self.gpu_free_block_list)} < needed number {num_blocks}"
|
||||||
|
logger.debug(f"{req_id} start allocate...")
|
||||||
allocated_block_ids = [heapq.heappop(self.gpu_free_block_list) for i in range(num_blocks)]
|
allocated_block_ids = [heapq.heappop(self.gpu_free_block_list) for i in range(num_blocks)]
|
||||||
logger.info(
|
logger.info(
|
||||||
f"allocate_gpu_blocks: {allocated_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
|
f"req_id:{req_id} allocate_gpu_blocks: {allocated_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
|
||||||
)
|
)
|
||||||
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
|
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
|
||||||
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
|
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
|
||||||
return allocated_block_ids
|
return allocated_block_ids
|
||||||
|
|
||||||
def recycle_gpu_blocks(self, gpu_block_ids):
|
def recycle_gpu_blocks(self, gpu_block_ids, req_id=None):
|
||||||
"""
|
"""
|
||||||
recycle gpu blocks.
|
recycle gpu blocks.
|
||||||
"""
|
"""
|
||||||
logger.info(
|
logger.info(
|
||||||
f"recycle_gpu_blocks: {gpu_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
|
f"req_id:{req_id} recycle_gpu_blocks: {gpu_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
|
||||||
)
|
)
|
||||||
if isinstance(gpu_block_ids, list):
|
if isinstance(gpu_block_ids, list):
|
||||||
for gpu_block_id in gpu_block_ids:
|
for gpu_block_id in gpu_block_ids:
|
||||||
heapq.heappush(self.gpu_free_block_list, gpu_block_id)
|
heapq.heappush(self.gpu_free_block_list, gpu_block_id)
|
||||||
else:
|
else:
|
||||||
heapq.heappush(self.gpu_free_block_list, gpu_block_ids)
|
heapq.heappush(self.gpu_free_block_list, gpu_block_ids)
|
||||||
|
logger.debug(f"req_id:{req_id} recycle blocks end")
|
||||||
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
|
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
|
||||||
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
|
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
|
||||||
|
|
||||||
@@ -978,7 +980,7 @@ class PrefixCacheManager:
|
|||||||
logger.info(f"release_block_ids: req_id {req_id} leaf_node {leaf_node}")
|
logger.info(f"release_block_ids: req_id {req_id} leaf_node {leaf_node}")
|
||||||
|
|
||||||
if leaf_node == self.radix_tree_root:
|
if leaf_node == self.radix_tree_root:
|
||||||
self.recycle_gpu_blocks(self.unfilled_req_block_map[req_id])
|
self.recycle_gpu_blocks(self.unfilled_req_block_map[req_id], req_id)
|
||||||
del self.unfilled_req_block_map[req_id]
|
del self.unfilled_req_block_map[req_id]
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ import zmq
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import fastdeploy.metrics.trace as tracing
|
import fastdeploy.metrics.trace as tracing
|
||||||
from fastdeploy.engine.request import Request, RequestOutput, RequestType
|
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
|
||||||
from fastdeploy.engine.resource_manager import ResourceManager
|
from fastdeploy.engine.resource_manager import ResourceManager
|
||||||
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
||||||
from fastdeploy.eplb.utils import init_eplb_signals
|
from fastdeploy.eplb.utils import init_eplb_signals
|
||||||
@@ -721,6 +721,7 @@ class EngineService:
|
|||||||
max_num_batched_tokens=self.cfg.scheduler_config.max_num_batched_tokens,
|
max_num_batched_tokens=self.cfg.scheduler_config.max_num_batched_tokens,
|
||||||
batch=num_prefill_batch,
|
batch=num_prefill_batch,
|
||||||
)
|
)
|
||||||
|
tasks = [task for task in tasks if task.request_id not in self.resource_manager.abort_req_ids_set]
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.metrics.engine_get_req_time = time.time()
|
task.metrics.engine_get_req_time = time.time()
|
||||||
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
|
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
|
||||||
@@ -787,6 +788,7 @@ class EngineService:
|
|||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
batch=num_prefill_batch,
|
batch=num_prefill_batch,
|
||||||
)
|
)
|
||||||
|
tasks = [task for task in tasks if task.request_id not in self.resource_manager.abort_req_ids_set]
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.metrics.engine_get_req_time = time.time()
|
task.metrics.engine_get_req_time = time.time()
|
||||||
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
|
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
|
||||||
@@ -1074,6 +1076,24 @@ class EngineService:
|
|||||||
request, insert_task = None, []
|
request, insert_task = None, []
|
||||||
results: List[Tuple[str, Optional[str]]] = list()
|
results: List[Tuple[str, Optional[str]]] = list()
|
||||||
if data:
|
if data:
|
||||||
|
status_value = data.get("status", None)
|
||||||
|
if status_value is not None and status_value == RequestStatus.ABORT.value:
|
||||||
|
req_id = data["request_id"]
|
||||||
|
self.llm_logger.info(f"Receive abort request, req_id: {req_id}")
|
||||||
|
self.resource_manager.abort_req_ids_set.add(req_id)
|
||||||
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
|
if req_id in self.resource_manager.requests:
|
||||||
|
req = self.resource_manager.requests[req_id]
|
||||||
|
task = self.resource_manager._prepare_preempt_task(req)
|
||||||
|
self.engine_worker_queue.put_tasks(([task], self.resource_manager.real_bsz))
|
||||||
|
self.llm_logger.info(f"put abort task in engine worker queue, req_id: {req_id}")
|
||||||
|
else:
|
||||||
|
self.scheduler._recycle(req_id)
|
||||||
|
self.llm_logger.info(
|
||||||
|
f"req_id:{req_id} has not been allocated any resources, recycled it in scheduler"
|
||||||
|
)
|
||||||
|
self.resource_manager.abort_req_ids_set.remove(req_id)
|
||||||
|
continue
|
||||||
err_msg = None
|
err_msg = None
|
||||||
try:
|
try:
|
||||||
request = Request.from_dict(data)
|
request = Request.from_dict(data)
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ class RequestStatus(Enum):
|
|||||||
RUNNING = 1
|
RUNNING = 1
|
||||||
PREEMPTED = 2
|
PREEMPTED = 2
|
||||||
FINISHED = 3
|
FINISHED = 3
|
||||||
|
ABORT = 4
|
||||||
|
|
||||||
|
|
||||||
class RequestType(Enum):
|
class RequestType(Enum):
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ class ResourceManager:
|
|||||||
self.req_dict = dict()
|
self.req_dict = dict()
|
||||||
# current batch status of the engine
|
# current batch status of the engine
|
||||||
self.real_bsz = 0
|
self.real_bsz = 0
|
||||||
|
self.abort_req_ids_set = set()
|
||||||
llm_logger.info(f"{self.info()}")
|
llm_logger.info(f"{self.info()}")
|
||||||
main_process_metrics.max_batch_size.set(max_num_seqs)
|
main_process_metrics.max_batch_size.set(max_num_seqs)
|
||||||
|
|
||||||
|
|||||||
@@ -175,6 +175,7 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
|
|
||||||
self.using_extend_tables_req_id = set()
|
self.using_extend_tables_req_id = set()
|
||||||
self.reuse_block_num_map = dict()
|
self.reuse_block_num_map = dict()
|
||||||
|
self.abort_req_ids_set = set()
|
||||||
|
|
||||||
# need block nums
|
# need block nums
|
||||||
need_block_num_data = np.zeros([max_num_seqs], dtype=np.int32)
|
need_block_num_data = np.zeros([max_num_seqs], dtype=np.int32)
|
||||||
@@ -243,6 +244,7 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
request = self.requests[request_id]
|
request = self.requests[request_id]
|
||||||
if process_func is not None:
|
if process_func is not None:
|
||||||
process_func(request)
|
process_func(request)
|
||||||
|
llm_logger.debug(f"self.waiting append request:{request.request_id},req.type:{request.status}")
|
||||||
self.waiting.appendleft(request)
|
self.waiting.appendleft(request)
|
||||||
self.to_be_rescheduled_request_id_set.remove(request_id)
|
self.to_be_rescheduled_request_id_set.remove(request_id)
|
||||||
|
|
||||||
@@ -640,7 +642,9 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
f"schedule decoding task: {request} request.num_total_tokens {request.num_total_tokens} request.num_computed_tokens {request.num_computed_tokens}"
|
f"schedule decoding task: {request} request.num_total_tokens {request.num_total_tokens} request.num_computed_tokens {request.num_computed_tokens}"
|
||||||
)
|
)
|
||||||
request.block_tables.extend(
|
request.block_tables.extend(
|
||||||
self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num)
|
self.cache_manager.allocate_gpu_blocks(
|
||||||
|
self.config.cache_config.enc_dec_block_num, request.request_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# Prepare decoding task
|
# Prepare decoding task
|
||||||
scheduled_reqs.append(self._prepare_decode_task(request))
|
scheduled_reqs.append(self._prepare_decode_task(request))
|
||||||
@@ -653,7 +657,9 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
break
|
break
|
||||||
# Allocation for next decoding blocks
|
# Allocation for next decoding blocks
|
||||||
request.block_tables.extend(
|
request.block_tables.extend(
|
||||||
self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num)
|
self.cache_manager.allocate_gpu_blocks(
|
||||||
|
self.config.cache_config.enc_dec_block_num, request.request_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# Prepare decoding task
|
# Prepare decoding task
|
||||||
scheduled_reqs.append(self._prepare_decode_task(request))
|
scheduled_reqs.append(self._prepare_decode_task(request))
|
||||||
@@ -668,7 +674,9 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
def _allocate_decode_and_extend():
|
def _allocate_decode_and_extend():
|
||||||
allocate_block_num = self.need_block_num_map[request.request_id].consume()
|
allocate_block_num = self.need_block_num_map[request.request_id].consume()
|
||||||
# Prepare decoding task
|
# Prepare decoding task
|
||||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(allocate_block_num))
|
request.block_tables.extend(
|
||||||
|
self.cache_manager.allocate_gpu_blocks(allocate_block_num, request.request_id)
|
||||||
|
)
|
||||||
scheduled_reqs.append(self._prepare_decode_task(request))
|
scheduled_reqs.append(self._prepare_decode_task(request))
|
||||||
|
|
||||||
# Prepare extend task
|
# Prepare extend task
|
||||||
@@ -682,7 +690,7 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
|
|
||||||
request.extend_block_tables = request.block_tables[:reuse_block_num] # copy prompt cache
|
request.extend_block_tables = request.block_tables[:reuse_block_num] # copy prompt cache
|
||||||
request.extend_block_tables.extend(
|
request.extend_block_tables.extend(
|
||||||
self.cache_manager.allocate_gpu_blocks(allocate_block_num)
|
self.cache_manager.allocate_gpu_blocks(allocate_block_num, request.request_id)
|
||||||
)
|
)
|
||||||
scheduled_reqs.append(
|
scheduled_reqs.append(
|
||||||
ScheduledExtendBlocksTask(
|
ScheduledExtendBlocksTask(
|
||||||
@@ -733,14 +741,18 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
num_new_block = self.get_new_block_nums(request, num_new_tokens)
|
num_new_block = self.get_new_block_nums(request, num_new_tokens)
|
||||||
# Allocate blocks to prefill
|
# Allocate blocks to prefill
|
||||||
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
|
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
|
||||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
|
request.block_tables.extend(
|
||||||
|
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
|
||||||
|
)
|
||||||
# Prepare prefill task
|
# Prepare prefill task
|
||||||
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
||||||
else: # Not enough blocks to allocate, trigger preemption
|
else: # Not enough blocks to allocate, trigger preemption
|
||||||
can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs)
|
can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs)
|
||||||
if not can_schedule:
|
if not can_schedule:
|
||||||
break
|
break
|
||||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
|
request.block_tables.extend(
|
||||||
|
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
|
||||||
|
)
|
||||||
# Prepare prefill task
|
# Prepare prefill task
|
||||||
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
||||||
token_budget -= num_new_tokens
|
token_budget -= num_new_tokens
|
||||||
@@ -806,7 +818,9 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
# Allocate blocks to prefill
|
# Allocate blocks to prefill
|
||||||
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
|
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
|
||||||
if not request.get("skip_allocate", False):
|
if not request.get("skip_allocate", False):
|
||||||
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block)
|
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(
|
||||||
|
num_new_block, request.request_id
|
||||||
|
)
|
||||||
request.block_tables.extend(extra_gpu_block_ids)
|
request.block_tables.extend(extra_gpu_block_ids)
|
||||||
self.waiting.popleft()
|
self.waiting.popleft()
|
||||||
self.running.append(request)
|
self.running.append(request)
|
||||||
@@ -824,6 +838,7 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
self.tasks_list[allocated_position] = request
|
self.tasks_list[allocated_position] = request
|
||||||
self.stop_flags[allocated_position] = False
|
self.stop_flags[allocated_position] = False
|
||||||
self.req_dict[request.request_id] = allocated_position
|
self.req_dict[request.request_id] = allocated_position
|
||||||
|
llm_logger.debug(f"req_id:{request.request_id} allocate pos end")
|
||||||
else:
|
else:
|
||||||
if self.config.cache_config.enable_prefix_caching:
|
if self.config.cache_config.enable_prefix_caching:
|
||||||
self._free_blocks(request)
|
self._free_blocks(request)
|
||||||
@@ -856,7 +871,9 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
# Allocate blocks to prefill
|
# Allocate blocks to prefill
|
||||||
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
|
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
|
||||||
if not request.get("skip_allocate", False):
|
if not request.get("skip_allocate", False):
|
||||||
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block)
|
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(
|
||||||
|
num_new_block, request.request_id
|
||||||
|
)
|
||||||
request.block_tables.extend(extra_gpu_block_ids)
|
request.block_tables.extend(extra_gpu_block_ids)
|
||||||
self.waiting.popleft()
|
self.waiting.popleft()
|
||||||
self.running.append(request)
|
self.running.append(request)
|
||||||
@@ -873,7 +890,7 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
self._free_blocks(request)
|
self._free_blocks(request)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
llm_logger.error("Unknown request status type")
|
llm_logger.info(f"Unknown request status type:{request.status}, req_id:{request.request_id}")
|
||||||
|
|
||||||
for req in skip_requests:
|
for req in skip_requests:
|
||||||
# move waiting request to end of the deque
|
# move waiting request to end of the deque
|
||||||
@@ -1069,6 +1086,7 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
def add_request(self, request: Request) -> None:
|
def add_request(self, request: Request) -> None:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.apply_async_preprocess(request)
|
self.apply_async_preprocess(request)
|
||||||
|
llm_logger.debug(f"self.waiting append request:{request.request_id},req.type:{request.status}")
|
||||||
self.waiting.append(request)
|
self.waiting.append(request)
|
||||||
self.requests[request.request_id] = request
|
self.requests[request.request_id] = request
|
||||||
|
|
||||||
@@ -1120,7 +1138,9 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
|
|
||||||
need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0]
|
need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0]
|
||||||
if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks):
|
if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks):
|
||||||
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks)
|
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(
|
||||||
|
need_extra_prefill_blocks, request.request_id
|
||||||
|
)
|
||||||
request.block_tables.extend(extra_gpu_block_ids)
|
request.block_tables.extend(extra_gpu_block_ids)
|
||||||
allocated_position = self.get_available_position()
|
allocated_position = self.get_available_position()
|
||||||
request.idx = allocated_position
|
request.idx = allocated_position
|
||||||
@@ -1135,7 +1155,9 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
|
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
|
||||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks))
|
request.block_tables.extend(
|
||||||
|
self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks, request.request_id)
|
||||||
|
)
|
||||||
request.num_computed_tokens = 0
|
request.num_computed_tokens = 0
|
||||||
allocated_position = self.get_available_position()
|
allocated_position = self.get_available_position()
|
||||||
request.idx = allocated_position
|
request.idx = allocated_position
|
||||||
@@ -1169,7 +1191,9 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
if not self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
|
if not self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
request.block_tables = self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks)
|
request.block_tables = self.cache_manager.allocate_gpu_blocks(
|
||||||
|
need_prealloc_prefill_blocks, request.request_id
|
||||||
|
)
|
||||||
request.num_computed_tokens = request.need_prefill_tokens
|
request.num_computed_tokens = request.need_prefill_tokens
|
||||||
request.disaggregate_info["block_tables"] = request.block_tables
|
request.disaggregate_info["block_tables"] = request.block_tables
|
||||||
allocated_position = self.get_available_position()
|
allocated_position = self.get_available_position()
|
||||||
@@ -1221,16 +1245,18 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
def _free_blocks(self, request: Request):
|
def _free_blocks(self, request: Request):
|
||||||
if self.config.cache_config.enable_prefix_caching:
|
if self.config.cache_config.enable_prefix_caching:
|
||||||
self.cache_manager.release_block_ids(request)
|
self.cache_manager.release_block_ids(request)
|
||||||
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.num_cached_blocks :])
|
self.cache_manager.recycle_gpu_blocks(
|
||||||
|
request.block_tables[request.num_cached_blocks :], request.request_id
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.cache_manager.recycle_gpu_blocks(request.block_tables)
|
self.cache_manager.recycle_gpu_blocks(request.block_tables, request.request_id)
|
||||||
request.block_tables = []
|
request.block_tables = []
|
||||||
|
|
||||||
if request.request_id in self.using_extend_tables_req_id:
|
if request.request_id in self.using_extend_tables_req_id:
|
||||||
reuse_block_num = self.reuse_block_num_map[request.request_id]
|
reuse_block_num = self.reuse_block_num_map[request.request_id]
|
||||||
|
|
||||||
self.using_extend_tables_req_id.remove(request.request_id)
|
self.using_extend_tables_req_id.remove(request.request_id)
|
||||||
self.cache_manager.recycle_gpu_blocks(request.extend_block_tables[reuse_block_num:])
|
self.cache_manager.recycle_gpu_blocks(request.extend_block_tables[reuse_block_num:], request.request_id)
|
||||||
llm_logger.info(
|
llm_logger.info(
|
||||||
f"req {request.request_id} recycle extend blocks {request.extend_block_tables[reuse_block_num:]}"
|
f"req {request.request_id} recycle extend blocks {request.extend_block_tables[reuse_block_num:]}"
|
||||||
)
|
)
|
||||||
@@ -1280,9 +1306,9 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
for req in need_postprocess_reqs:
|
for req in need_postprocess_reqs:
|
||||||
try:
|
try:
|
||||||
self._free_blocks(req)
|
self._free_blocks(req)
|
||||||
|
llm_logger.debug(f"req_id:{req.request_id} free pos:{req.idx}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
llm_logger.warning(f"release block failed {req.request_id}: {e}")
|
llm_logger.warning(f"release block failed {req.request_id}: {e}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}")
|
llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}")
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
@@ -28,6 +29,7 @@ from filelock import FileLock
|
|||||||
import fastdeploy.metrics.trace as tracing
|
import fastdeploy.metrics.trace as tracing
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
|
from fastdeploy.engine.request import RequestStatus
|
||||||
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
|
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
|
||||||
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
||||||
from fastdeploy.eplb.utils import RedundantExpertWorkload
|
from fastdeploy.eplb.utils import RedundantExpertWorkload
|
||||||
@@ -844,3 +846,27 @@ class EngineClient:
|
|||||||
content = {"code": 0, "msg": "ok", "data": update_weight_from_disk_list}
|
content = {"code": 0, "msg": "ok", "data": update_weight_from_disk_list}
|
||||||
status_code = HTTPStatus.OK
|
status_code = HTTPStatus.OK
|
||||||
return content, status_code
|
return content, status_code
|
||||||
|
|
||||||
|
async def abort(self, request_id, n=1) -> None:
|
||||||
|
if envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE:
|
||||||
|
api_server_logger.info(f"abort request_id:{request_id}")
|
||||||
|
if n <= 0:
|
||||||
|
api_server_logger.warning("Abort function called with non-positive n: %d. No requests aborted.", n)
|
||||||
|
return
|
||||||
|
match = re.search(r"_\d+$", request_id)
|
||||||
|
if match:
|
||||||
|
prefix = request_id[: match.start()]
|
||||||
|
else:
|
||||||
|
api_server_logger.warning(
|
||||||
|
"request_id format error: %s does not end with _<number>. Using it as prefix.", request_id
|
||||||
|
)
|
||||||
|
prefix = request_id
|
||||||
|
request_ids = [f"{prefix}_{i}" for i in range(n)]
|
||||||
|
for req_id in request_ids:
|
||||||
|
data = {
|
||||||
|
"request_id": req_id,
|
||||||
|
"status": RequestStatus.ABORT.value,
|
||||||
|
}
|
||||||
|
self._send_task(data)
|
||||||
|
|
||||||
|
api_server_logger.info("Aborted request(s) %s.", ",".join(request_ids))
|
||||||
|
|||||||
@@ -59,7 +59,11 @@ from fastdeploy.entrypoints.openai.serving_embedding import OpenAIServingEmbeddi
|
|||||||
from fastdeploy.entrypoints.openai.serving_models import ModelPath, OpenAIServingModels
|
from fastdeploy.entrypoints.openai.serving_models import ModelPath, OpenAIServingModels
|
||||||
from fastdeploy.entrypoints.openai.serving_reward import OpenAIServingReward
|
from fastdeploy.entrypoints.openai.serving_reward import OpenAIServingReward
|
||||||
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
|
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
from fastdeploy.entrypoints.openai.utils import UVICORN_CONFIG, make_arg_parser
|
from fastdeploy.entrypoints.openai.utils import (
|
||||||
|
UVICORN_CONFIG,
|
||||||
|
make_arg_parser,
|
||||||
|
with_cancellation,
|
||||||
|
)
|
||||||
from fastdeploy.entrypoints.openai.v1.serving_chat import (
|
from fastdeploy.entrypoints.openai.v1.serving_chat import (
|
||||||
OpenAIServingChat as OpenAIServingChatV1,
|
OpenAIServingChat as OpenAIServingChatV1,
|
||||||
)
|
)
|
||||||
@@ -410,6 +414,7 @@ def wrap_streaming_generator(original_generator: AsyncGenerator):
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
|
@with_cancellation
|
||||||
async def create_chat_completion(request: ChatCompletionRequest, req: Request):
|
async def create_chat_completion(request: ChatCompletionRequest, req: Request):
|
||||||
"""
|
"""
|
||||||
Create a chat completion for the provided prompt and parameters.
|
Create a chat completion for the provided prompt and parameters.
|
||||||
@@ -446,6 +451,7 @@ async def create_chat_completion(request: ChatCompletionRequest, req: Request):
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
|
@with_cancellation
|
||||||
async def create_completion(request: CompletionRequest, req: Request):
|
async def create_completion(request: CompletionRequest, req: Request):
|
||||||
"""
|
"""
|
||||||
Create a completion for the provided prompt and parameters.
|
Create a completion for the provided prompt and parameters.
|
||||||
|
|||||||
@@ -180,6 +180,13 @@ class OpenAIServingChat:
|
|||||||
error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}"
|
error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}"
|
||||||
api_server_logger.error(error_msg)
|
api_server_logger.error(error_msg)
|
||||||
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR))
|
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR))
|
||||||
|
except asyncio.CancelledError as e:
|
||||||
|
await self.engine_client.abort(f"{request_id}_0", 1 if request.n is None else request.n)
|
||||||
|
error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}"
|
||||||
|
api_server_logger.error(error_msg)
|
||||||
|
return ErrorResponse(
|
||||||
|
error=ErrorInfo(message=error_msg, type=ErrorType.INVALID_REQUEST_ERROR, code=ErrorCode.CLIENT_ABORTED)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"request[{request_id}] waiting error: {str(e)}, {str(traceback.format_exc())}, "
|
f"request[{request_id}] waiting error: {str(e)}, {str(traceback.format_exc())}, "
|
||||||
@@ -267,7 +274,9 @@ class OpenAIServingChat:
|
|||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
current_waiting_time += 10
|
current_waiting_time += 10
|
||||||
if current_waiting_time == 300:
|
if current_waiting_time == 300:
|
||||||
status, msg = self.engine_client.check_health(time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT)
|
status, msg = self.engine_client.check_health(
|
||||||
|
time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT
|
||||||
|
)
|
||||||
if not status:
|
if not status:
|
||||||
if choices:
|
if choices:
|
||||||
chunk.choices = choices
|
chunk.choices = choices
|
||||||
@@ -506,6 +515,10 @@ class OpenAIServingChat:
|
|||||||
)
|
)
|
||||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||||
|
|
||||||
|
except asyncio.CancelledError as e:
|
||||||
|
await self.engine_client.abort(f"{request_id}_0", 1 if request.n is None else request.n)
|
||||||
|
error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}"
|
||||||
|
api_server_logger.error(error_msg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_data = self._create_streaming_error_response(
|
error_data = self._create_streaming_error_response(
|
||||||
f"request[{request_id}] generate stream error: {str(e)}, {str(traceback.format_exc())}"
|
f"request[{request_id}] generate stream error: {str(e)}, {str(traceback.format_exc())}"
|
||||||
@@ -577,7 +590,9 @@ class OpenAIServingChat:
|
|||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
current_waiting_time += 10
|
current_waiting_time += 10
|
||||||
if current_waiting_time == 300:
|
if current_waiting_time == 300:
|
||||||
status, msg = self.engine_client.check_health(time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT)
|
status, msg = self.engine_client.check_health(
|
||||||
|
time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT
|
||||||
|
)
|
||||||
if not status:
|
if not status:
|
||||||
raise ValueError(f"Engine is not healthy: {msg}")
|
raise ValueError(f"Engine is not healthy: {msg}")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -230,7 +230,13 @@ class OpenAIServingCompletion:
|
|||||||
)
|
)
|
||||||
api_server_logger.error(error_msg)
|
api_server_logger.error(error_msg)
|
||||||
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR))
|
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR))
|
||||||
|
except asyncio.CancelledError as e:
|
||||||
|
await self.engine_client.abort(f"{request_id}_0", num_choices)
|
||||||
|
error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}"
|
||||||
|
api_server_logger.error(error_msg)
|
||||||
|
return ErrorResponse(
|
||||||
|
error=ErrorInfo(message=error_msg, type=ErrorType.INVALID_REQUEST_ERROR, code=ErrorCode.CLIENT_ABORTED)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"OpenAIServingCompletion create_completion error: {e}, {str(traceback.format_exc())}"
|
error_msg = f"OpenAIServingCompletion create_completion error: {e}, {str(traceback.format_exc())}"
|
||||||
api_server_logger.error(error_msg)
|
api_server_logger.error(error_msg)
|
||||||
@@ -285,7 +291,9 @@ class OpenAIServingCompletion:
|
|||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
current_waiting_time += 10
|
current_waiting_time += 10
|
||||||
if current_waiting_time == 300:
|
if current_waiting_time == 300:
|
||||||
status, msg = self.engine_client.check_health(time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT)
|
status, msg = self.engine_client.check_health(
|
||||||
|
time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT
|
||||||
|
)
|
||||||
if not status:
|
if not status:
|
||||||
raise ValueError(f"Engine is not healthy: {msg}")
|
raise ValueError(f"Engine is not healthy: {msg}")
|
||||||
else:
|
else:
|
||||||
@@ -455,7 +463,9 @@ class OpenAIServingCompletion:
|
|||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
current_waiting_time += 10
|
current_waiting_time += 10
|
||||||
if current_waiting_time == 300:
|
if current_waiting_time == 300:
|
||||||
status, msg = self.engine_client.check_health(time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT)
|
status, msg = self.engine_client.check_health(
|
||||||
|
time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT
|
||||||
|
)
|
||||||
if not status:
|
if not status:
|
||||||
raise ValueError(f"Engine is not healthy: {msg}")
|
raise ValueError(f"Engine is not healthy: {msg}")
|
||||||
else:
|
else:
|
||||||
@@ -634,6 +644,10 @@ class OpenAIServingCompletion:
|
|||||||
yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
|
yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||||
api_server_logger.info(f"Completion Streaming response last send: {chunk.model_dump_json()}")
|
api_server_logger.info(f"Completion Streaming response last send: {chunk.model_dump_json()}")
|
||||||
|
|
||||||
|
except asyncio.CancelledError as e:
|
||||||
|
await self.engine_client.abort(f"{request_id}_0", num_choices)
|
||||||
|
error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}"
|
||||||
|
api_server_logger.error(error_msg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_server_logger.error(f"Error in completion_stream_generator: {e}, {str(traceback.format_exc())}")
|
api_server_logger.error(f"Error in completion_stream_generator: {e}, {str(traceback.format_exc())}")
|
||||||
yield f"data: {ErrorResponse(error=ErrorInfo(message=str(e), code='400', type=ErrorType.INTERNAL_ERROR)).model_dump_json(exclude_unset=True)}\n\n"
|
yield f"data: {ErrorResponse(error=ErrorInfo(message=str(e), code='400', type=ErrorType.INTERNAL_ERROR)).model_dump_json(exclude_unset=True)}\n\n"
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import functools
|
||||||
import heapq
|
import heapq
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
@@ -22,6 +23,7 @@ from multiprocessing.reduction import ForkingPickler
|
|||||||
|
|
||||||
import aiozmq
|
import aiozmq
|
||||||
import zmq
|
import zmq
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
from fastdeploy.engine.args_utils import EngineArgs
|
from fastdeploy.engine.args_utils import EngineArgs
|
||||||
from fastdeploy.metrics.metrics import main_process_metrics
|
from fastdeploy.metrics.metrics import main_process_metrics
|
||||||
@@ -253,3 +255,55 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
|
|
||||||
parser = EngineArgs.add_cli_args(parser)
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
async def listen_for_disconnect(request: Request) -> None:
|
||||||
|
"""Returns if a disconnect message is received"""
|
||||||
|
while True:
|
||||||
|
message = await request.receive()
|
||||||
|
if message["type"] == "http.disconnect":
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def with_cancellation(handler_func):
|
||||||
|
"""Decorator that allows a route handler to be cancelled by client
|
||||||
|
disconnections.
|
||||||
|
|
||||||
|
This does _not_ use request.is_disconnected, which does not work with
|
||||||
|
middleware. Instead this follows the pattern from
|
||||||
|
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
|
||||||
|
to wait for an http disconnect message, and the other to do the work that we
|
||||||
|
want done. When the first task finishes, the other is cancelled.
|
||||||
|
|
||||||
|
A core assumption of this method is that the body of the request has already
|
||||||
|
been read. This is a safe assumption to make for fastapi handlers that have
|
||||||
|
already parsed the body of the request into a pydantic model for us.
|
||||||
|
This decorator is unsafe to use elsewhere, as it will consume and throw away
|
||||||
|
all incoming messages for the request while it looks for a disconnect
|
||||||
|
message.
|
||||||
|
|
||||||
|
In the case where a `StreamingResponse` is returned by the handler, this
|
||||||
|
wrapper will stop listening for disconnects and instead the response object
|
||||||
|
will start listening for disconnects.The response object will only correctly
|
||||||
|
listen when the ASGI protocol version used by Uvicorn is less than 2.4(Excluding 2.4).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Functools.wraps is required for this wrapper to appear to fastapi as a
|
||||||
|
# normal route handler, with the correct request type hinting.
|
||||||
|
@functools.wraps(handler_func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
# The request is either the second positional arg or `raw_request`
|
||||||
|
request = args[1] if len(args) > 1 else kwargs["req"]
|
||||||
|
|
||||||
|
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
||||||
|
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
||||||
|
|
||||||
|
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
if handler_task in done:
|
||||||
|
return handler_task.result()
|
||||||
|
return None
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|||||||
@@ -158,6 +158,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# "Enable FP8 calibration on HPU"
|
# "Enable FP8 calibration on HPU"
|
||||||
"FD_HPU_MEASUREMENT_MODE": lambda: os.getenv("FD_HPU_MEASUREMENT_MODE", "0"),
|
"FD_HPU_MEASUREMENT_MODE": lambda: os.getenv("FD_HPU_MEASUREMENT_MODE", "0"),
|
||||||
"FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")),
|
"FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")),
|
||||||
|
"FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE": lambda: int(
|
||||||
|
os.getenv("FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", "1")
|
||||||
|
),
|
||||||
# Whether to collect user information
|
# Whether to collect user information
|
||||||
"DO_NOT_TRACK": lambda: (os.getenv("DO_NOT_TRACK", "0")) == "1",
|
"DO_NOT_TRACK": lambda: (os.getenv("DO_NOT_TRACK", "0")) == "1",
|
||||||
# Usage stats server url
|
# Usage stats server url
|
||||||
|
|||||||
@@ -248,9 +248,24 @@ class TokenProcessor:
|
|||||||
|
|
||||||
task: Request = self.resource_manager.tasks_list[i]
|
task: Request = self.resource_manager.tasks_list[i]
|
||||||
task_id = task.request_id
|
task_id = task.request_id
|
||||||
|
|
||||||
token_ids = stream_data.tokens # numpy.array
|
token_ids = stream_data.tokens # numpy.array
|
||||||
if token_ids is not None and token_ids[-1] <= 0:
|
if token_ids is not None and token_ids[-1] <= 0:
|
||||||
|
if task_id in self.resource_manager.abort_req_ids_set:
|
||||||
|
if (
|
||||||
|
envs.ENABLE_V1_KVCACHE_SCHEDULER and token_ids[-1] == PREEMPTED_TOKEN_ID
|
||||||
|
) or not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
|
llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.")
|
||||||
|
self.resource_manager.abort_req_ids_set.remove(task_id)
|
||||||
|
self._recycle_resources(task_id, i, task)
|
||||||
|
llm_logger.info(f"{task_id} received negative token. Recycle end.")
|
||||||
|
abort_res = RequestOutput(
|
||||||
|
request_id=task_id,
|
||||||
|
finished=True,
|
||||||
|
error_code=499,
|
||||||
|
error_msg=f"Your request with request_id:{task_id} is aborted.",
|
||||||
|
)
|
||||||
|
batch_result.append(abort_res)
|
||||||
|
continue
|
||||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
if (
|
if (
|
||||||
task_id in self.resource_manager.to_be_rescheduled_request_id_set
|
task_id in self.resource_manager.to_be_rescheduled_request_id_set
|
||||||
@@ -711,6 +726,19 @@ class TokenProcessor:
|
|||||||
if accept_num[i] == PREEMPTED_TOKEN_ID: # in MTP, meas preemption has happend in worker
|
if accept_num[i] == PREEMPTED_TOKEN_ID: # in MTP, meas preemption has happend in worker
|
||||||
llm_logger.info(f"sync preemption for request_id {task_id} done.")
|
llm_logger.info(f"sync preemption for request_id {task_id} done.")
|
||||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
|
if task_id in self.resource_manager.abort_req_ids_set:
|
||||||
|
llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.")
|
||||||
|
self.resource_manager.abort_req_ids_set.remove(task_id)
|
||||||
|
self._recycle_resources(task_id, i, task)
|
||||||
|
llm_logger.info(f"{task_id} received negative token. Recycle end.")
|
||||||
|
abort_res = RequestOutput(
|
||||||
|
request_id=task_id,
|
||||||
|
finished=True,
|
||||||
|
error_code=499,
|
||||||
|
error_msg=f"Your request with request_id:{task_id} is aborted.",
|
||||||
|
)
|
||||||
|
batch_result.append(abort_res)
|
||||||
|
continue
|
||||||
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
|
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
|
||||||
self.resource_manager.reschedule_preempt_task(task_id)
|
self.resource_manager.reschedule_preempt_task(task_id)
|
||||||
continue
|
continue
|
||||||
@@ -737,6 +765,22 @@ class TokenProcessor:
|
|||||||
if recovery_stop:
|
if recovery_stop:
|
||||||
llm_logger.info(f"recovery stop signal found at task {task_id}")
|
llm_logger.info(f"recovery stop signal found at task {task_id}")
|
||||||
if not recovery_stop and token_id < 0:
|
if not recovery_stop and token_id < 0:
|
||||||
|
if task_id in self.resource_manager.abort_req_ids_set:
|
||||||
|
if (
|
||||||
|
envs.ENABLE_V1_KVCACHE_SCHEDULER and token_id == PREEMPTED_TOKEN_ID
|
||||||
|
) or not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
|
llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.")
|
||||||
|
self.resource_manager.abort_req_ids_set.remove(task_id)
|
||||||
|
self._recycle_resources(task_id, i, task)
|
||||||
|
llm_logger.info(f"{task_id} received negative token. Recycle end.")
|
||||||
|
abort_res = RequestOutput(
|
||||||
|
request_id=task_id,
|
||||||
|
finished=True,
|
||||||
|
error_code=499,
|
||||||
|
error_msg=f"Your request with request_id:{task_id} is aborted.",
|
||||||
|
)
|
||||||
|
batch_result.append(abort_res)
|
||||||
|
continue
|
||||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
if (
|
if (
|
||||||
task_id in self.resource_manager.to_be_rescheduled_request_id_set
|
task_id in self.resource_manager.to_be_rescheduled_request_id_set
|
||||||
@@ -760,6 +804,7 @@ class TokenProcessor:
|
|||||||
task.metrics.record_recv_first_token()
|
task.metrics.record_recv_first_token()
|
||||||
task.metrics.cal_cost_time()
|
task.metrics.cal_cost_time()
|
||||||
metrics = copy.copy(task.metrics)
|
metrics = copy.copy(task.metrics)
|
||||||
|
llm_logger.info(f"task:{task.request_id} start recode first token")
|
||||||
self._record_first_token_metrics(task, current_time)
|
self._record_first_token_metrics(task, current_time)
|
||||||
|
|
||||||
tracing.trace_report_span(
|
tracing.trace_report_span(
|
||||||
@@ -845,7 +890,6 @@ class TokenProcessor:
|
|||||||
result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids])
|
result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids])
|
||||||
result.outputs.top_logprobs.logprobs.extend([topk_logprobs])
|
result.outputs.top_logprobs.logprobs.extend([topk_logprobs])
|
||||||
result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank])
|
result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank])
|
||||||
|
|
||||||
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
||||||
result.finished = True
|
result.finished = True
|
||||||
trace_carrier = tracing.trace_get_proc_propagate_context(rid=rid)
|
trace_carrier = tracing.trace_get_proc_propagate_context(rid=rid)
|
||||||
@@ -872,7 +916,9 @@ class TokenProcessor:
|
|||||||
self._compute_speculative_status(result)
|
self._compute_speculative_status(result)
|
||||||
if not is_prefill:
|
if not is_prefill:
|
||||||
self._record_completion_metrics(task, current_time)
|
self._record_completion_metrics(task, current_time)
|
||||||
|
llm_logger.info(f"task {task_id} received eos token. Recycling.")
|
||||||
self._recycle_resources(task_id, i, task, result, is_prefill)
|
self._recycle_resources(task_id, i, task, result, is_prefill)
|
||||||
|
llm_logger.info(f"eos token {task_id} Recycle end.")
|
||||||
break
|
break
|
||||||
|
|
||||||
llm_logger.debug(f"get response from infer: {result}")
|
llm_logger.debug(f"get response from infer: {result}")
|
||||||
|
|||||||
@@ -220,6 +220,7 @@ class ErrorCode(str, Enum):
|
|||||||
CONNECTION_ERROR = "connection_error"
|
CONNECTION_ERROR = "connection_error"
|
||||||
MISSING_REQUIRED_PARAMETER = "missing_required_parameter"
|
MISSING_REQUIRED_PARAMETER = "missing_required_parameter"
|
||||||
INTERNAL_ERROR = "internal_error"
|
INTERNAL_ERROR = "internal_error"
|
||||||
|
CLIENT_ABORTED = "client_aborted"
|
||||||
|
|
||||||
|
|
||||||
class ColoredFormatter(logging.Formatter):
|
class ColoredFormatter(logging.Formatter):
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ aiozmq
|
|||||||
openai
|
openai
|
||||||
tqdm
|
tqdm
|
||||||
pynvml
|
pynvml
|
||||||
uvicorn==0.29.0
|
uvicorn>=0.38.0
|
||||||
fastapi
|
fastapi
|
||||||
paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl
|
paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl
|
||||||
redis
|
redis
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ aiozmq
|
|||||||
openai>=1.93.0
|
openai>=1.93.0
|
||||||
tqdm
|
tqdm
|
||||||
pynvml
|
pynvml
|
||||||
uvicorn==0.29.0
|
uvicorn>=0.38.0
|
||||||
fastapi
|
fastapi
|
||||||
paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl
|
paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl
|
||||||
redis
|
redis
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ aiozmq
|
|||||||
openai>=1.93.0
|
openai>=1.93.0
|
||||||
tqdm
|
tqdm
|
||||||
pynvml
|
pynvml
|
||||||
uvicorn==0.29.0
|
uvicorn>=0.38.0
|
||||||
fastapi
|
fastapi
|
||||||
# if paddleformers version > 0.3.2, metax triton will be replaced by the newest triton.
|
# if paddleformers version > 0.3.2, metax triton will be replaced by the newest triton.
|
||||||
paddleformers==0.3.2
|
paddleformers==0.3.2
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import shutil
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -48,7 +47,7 @@ def setup_and_run_server():
|
|||||||
- Tears down server after all tests finish
|
- Tears down server after all tests finish
|
||||||
"""
|
"""
|
||||||
print("Pre-test port cleanup...")
|
print("Pre-test port cleanup...")
|
||||||
FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8633))
|
FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8333))
|
||||||
clean_ports([FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT, FD_CONTROLLER_PORT])
|
clean_ports([FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT, FD_CONTROLLER_PORT])
|
||||||
|
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
@@ -173,9 +172,12 @@ def parse_prometheus_to_dict(metrics_text: str):
|
|||||||
value = float(line.split("}")[1].strip())
|
value = float(line.split("}")[1].strip())
|
||||||
|
|
||||||
# 解析 labels
|
# 解析 labels
|
||||||
# 用正则取出所有 key 和 value(去掉外层引号)
|
labels = {}
|
||||||
pairs = re.findall(r'(\w+)="([^"]*)"', labels_str)
|
for kv in labels_str.split(","):
|
||||||
labels = {k: v for k, v in pairs}
|
if "=" not in kv:
|
||||||
|
continue
|
||||||
|
k, v = kv.split("=")
|
||||||
|
labels[k] = v.strip('"')
|
||||||
|
|
||||||
# 存储
|
# 存储
|
||||||
if metric_name not in result:
|
if metric_name not in result:
|
||||||
@@ -212,7 +214,6 @@ def test_metrics_with_clear_and_reset():
|
|||||||
"""
|
"""
|
||||||
Test the metrics monitoring endpoint.
|
Test the metrics monitoring endpoint.
|
||||||
"""
|
"""
|
||||||
FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8633))
|
|
||||||
metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
|
metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
|
||||||
|
|
||||||
async_concurrency(n=10)
|
async_concurrency(n=10)
|
||||||
@@ -229,23 +230,13 @@ def test_metrics_with_clear_and_reset():
|
|||||||
running = metrics["fastdeploy:num_requests_running"]
|
running = metrics["fastdeploy:num_requests_running"]
|
||||||
waiting = metrics["fastdeploy:num_requests_waiting"]
|
waiting = metrics["fastdeploy:num_requests_waiting"]
|
||||||
|
|
||||||
print("ASSERT clear_load_weight后非0 running:", running, "waiting:", waiting)
|
print(
|
||||||
assert running != 0 or waiting != 0, "Expected running/waiting to be non-zero"
|
"ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):",
|
||||||
|
running,
|
||||||
# ===== reset_scheduler =====
|
"waiting:",
|
||||||
reset_url = f"http://0.0.0.0:{FD_CONTROLLER_PORT}/controller/reset_scheduler"
|
waiting,
|
||||||
print("Calling reset_scheduler...")
|
)
|
||||||
r = requests.post(reset_url, json={"reset": True}, timeout=30)
|
assert running == 0 and waiting == 0, "Expected both running and waiting to be 0 after clear_load_weight"
|
||||||
assert r.status_code == 200, f"reset_scheduler failed: {r.status_code}"
|
|
||||||
|
|
||||||
metrics = get_metrics_dict(metrics_url)
|
|
||||||
running = metrics["fastdeploy:num_requests_running"]
|
|
||||||
waiting = metrics["fastdeploy:num_requests_waiting"]
|
|
||||||
|
|
||||||
print("ASSERT reset_scheduler后为0 running:", running, "waiting:", waiting)
|
|
||||||
# Temporarily disable this assertion. The running/waiting states are not strictly
|
|
||||||
# guaranteed to reach zero in the current workflow, so we skip this check for now.
|
|
||||||
# assert running == 0 and waiting == 0, "Expected running/waiting to be zero"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ def _reload_api_server(args):
|
|||||||
fake_envs_mod.TRACES_EXPORTER = "console"
|
fake_envs_mod.TRACES_EXPORTER = "console"
|
||||||
fake_envs_mod.EXPORTER_OTLP_ENDPOINT = ""
|
fake_envs_mod.EXPORTER_OTLP_ENDPOINT = ""
|
||||||
fake_envs_mod.EXPORTER_OTLP_HEADERS = ""
|
fake_envs_mod.EXPORTER_OTLP_HEADERS = ""
|
||||||
|
fake_envs_mod.FD_SUPPORT_MAX_CONNECTIONS = 1024
|
||||||
fake_envs_mod.environment_variables = _FakeEnvVars()
|
fake_envs_mod.environment_variables = _FakeEnvVars()
|
||||||
|
|
||||||
# Save original sys.argv and replace with minimal valid args to avoid parse errors
|
# Save original sys.argv and replace with minimal valid args to avoid parse errors
|
||||||
@@ -397,26 +398,36 @@ async def test_chat_and_completion_routes():
|
|||||||
chat_handler = MagicMock()
|
chat_handler = MagicMock()
|
||||||
chat_handler.create_chat_completion = AsyncMock(return_value=error_resp)
|
chat_handler.create_chat_completion = AsyncMock(return_value=error_resp)
|
||||||
api_server.app.state.chat_handler = chat_handler
|
api_server.app.state.chat_handler = chat_handler
|
||||||
assert (await api_server.create_chat_completion(body, fake_req)).status_code == 500
|
|
||||||
|
with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()):
|
||||||
|
assert (await api_server.create_chat_completion(body, fake_req)).status_code == 500
|
||||||
|
|
||||||
success_resp = ChatCompletionResponse(id="1", model="m", choices=[], usage=UsageInfo())
|
success_resp = ChatCompletionResponse(id="1", model="m", choices=[], usage=UsageInfo())
|
||||||
api_server.app.state.chat_handler.create_chat_completion = AsyncMock(return_value=success_resp)
|
api_server.app.state.chat_handler.create_chat_completion = AsyncMock(return_value=success_resp)
|
||||||
assert (await api_server.create_chat_completion(body, fake_req)).status_code == 200
|
|
||||||
|
with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()):
|
||||||
|
assert (await api_server.create_chat_completion(body, fake_req)).status_code == 200
|
||||||
|
|
||||||
async def stream_gen():
|
async def stream_gen():
|
||||||
yield "data"
|
yield "data"
|
||||||
|
|
||||||
api_server.app.state.chat_handler.create_chat_completion = AsyncMock(return_value=stream_gen())
|
api_server.app.state.chat_handler.create_chat_completion = AsyncMock(return_value=stream_gen())
|
||||||
assert isinstance(await api_server.create_chat_completion(body, fake_req), api_server.StreamingResponse)
|
|
||||||
|
with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()):
|
||||||
|
assert isinstance(await api_server.create_chat_completion(body, fake_req), api_server.StreamingResponse)
|
||||||
|
|
||||||
# Completion handler
|
# Completion handler
|
||||||
completion_handler = MagicMock()
|
completion_handler = MagicMock()
|
||||||
completion_handler.create_completion = AsyncMock(return_value=error_resp)
|
completion_handler.create_completion = AsyncMock(return_value=error_resp)
|
||||||
api_server.app.state.completion_handler = completion_handler
|
api_server.app.state.completion_handler = completion_handler
|
||||||
assert (await api_server.create_completion(body, fake_req)).status_code == 500
|
|
||||||
|
with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()):
|
||||||
|
assert (await api_server.create_completion(body, fake_req)).status_code == 500
|
||||||
|
|
||||||
api_server.app.state.completion_handler.create_completion = AsyncMock(return_value=success_resp)
|
api_server.app.state.completion_handler.create_completion = AsyncMock(return_value=success_resp)
|
||||||
assert (await api_server.create_completion(body, fake_req)).status_code == 200
|
|
||||||
|
with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()):
|
||||||
|
assert (await api_server.create_completion(body, fake_req)).status_code == 200
|
||||||
|
|
||||||
# HTTPException handling
|
# HTTPException handling
|
||||||
class RaiseHTTP:
|
class RaiseHTTP:
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
@@ -1079,6 +1080,127 @@ class TestOpenAIServingCompletion(unittest.IsolatedAsyncioTestCase):
|
|||||||
# logprobs should be None when not requested
|
# logprobs should be None when not requested
|
||||||
self.assertIsNone(choice.logprobs)
|
self.assertIsNone(choice.logprobs)
|
||||||
|
|
||||||
|
async def test_create_chat_completion_cancelled_error(self):
|
||||||
|
"""Test asyncio.CancelledError handling in create_chat_completion method"""
|
||||||
|
# Create mock request
|
||||||
|
request = ChatCompletionRequest(messages=[{"role": "user", "content": "Hello"}], stream=False)
|
||||||
|
|
||||||
|
# Mock the semaphore
|
||||||
|
self.chat_completion_handler.engine_client.semaphore = MagicMock()
|
||||||
|
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
|
||||||
|
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
|
||||||
|
|
||||||
|
# Mock the model weight status check
|
||||||
|
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
|
||||||
|
|
||||||
|
# Mock format_and_add_data to raise CancelledError
|
||||||
|
self.chat_completion_handler.engine_client.format_and_add_data = AsyncMock(
|
||||||
|
side_effect=asyncio.CancelledError("Test cancellation during data formatting")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the abort method that should be called when CancelledError occurs
|
||||||
|
self.chat_completion_handler.engine_client.abort = AsyncMock()
|
||||||
|
|
||||||
|
# Execute and verify that CancelledError is handled properly
|
||||||
|
# The CancelledError should be caught and handled, not re-raised
|
||||||
|
try:
|
||||||
|
await self.chat_completion_handler.create_chat_completion(request)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# This should not happen as CancelledError should be caught and handled
|
||||||
|
self.fail("CancelledError should be caught and handled, not re-raised")
|
||||||
|
|
||||||
|
# Verify abort was called despite the cancellation
|
||||||
|
self.chat_completion_handler.engine_client.abort.assert_called_once()
|
||||||
|
|
||||||
|
async def test_chat_completion_stream_generator_cancelled_error(self):
|
||||||
|
"""Test asyncio.CancelledError handling in chat_completion_stream_generator method"""
|
||||||
|
# Create mock request
|
||||||
|
request = ChatCompletionRequest(messages=[{"role": "user", "content": "Hello"}], stream=True)
|
||||||
|
|
||||||
|
request_id = "test_cancel_request"
|
||||||
|
model_name = "test_model"
|
||||||
|
prompt_token_ids = [1, 2, 3]
|
||||||
|
prompt_tokens = "Hello world"
|
||||||
|
|
||||||
|
# Mock the connection manager
|
||||||
|
mock_dealer = MagicMock()
|
||||||
|
mock_response_queue = AsyncMock()
|
||||||
|
|
||||||
|
# Mock get_connection to return normally
|
||||||
|
self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock(
|
||||||
|
return_value=(mock_dealer, mock_response_queue)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the semaphore
|
||||||
|
self.chat_completion_handler.engine_client.semaphore = MagicMock()
|
||||||
|
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
|
||||||
|
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
|
||||||
|
|
||||||
|
# Mock the model weight status check
|
||||||
|
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
|
||||||
|
|
||||||
|
# Mock the response processor to raise CancelledError during processing
|
||||||
|
mock_response_processor = MagicMock()
|
||||||
|
mock_response_processor.enable_multimodal_content.return_value = False
|
||||||
|
|
||||||
|
async def mock_async_generator_with_cancel():
|
||||||
|
# Simulate some normal response first
|
||||||
|
yield {
|
||||||
|
"request_id": f"{request_id}_0",
|
||||||
|
"error_code": 200,
|
||||||
|
"metrics": {
|
||||||
|
"first_token_time": 1234567890,
|
||||||
|
"inference_start_time": 1234567880,
|
||||||
|
"arrival_time": 1234567890,
|
||||||
|
"request_start_time": 1234567870,
|
||||||
|
},
|
||||||
|
"prompt_logprobs": None,
|
||||||
|
"outputs": {
|
||||||
|
"token_ids": [5],
|
||||||
|
"text": "Hi",
|
||||||
|
"top_logprobs": None,
|
||||||
|
"draft_top_logprobs": None,
|
||||||
|
"multipart": [{"type": "text", "text": "Hi"}],
|
||||||
|
},
|
||||||
|
"finished": False,
|
||||||
|
"num_cached_tokens": 0,
|
||||||
|
"num_input_image_tokens": 0,
|
||||||
|
"num_input_video_tokens": 0,
|
||||||
|
}
|
||||||
|
# Then raise CancelledError
|
||||||
|
raise asyncio.CancelledError("Test cancellation during streaming")
|
||||||
|
|
||||||
|
mock_response_processor.process_response_chat.return_value = mock_async_generator_with_cancel()
|
||||||
|
|
||||||
|
# Mock the cleanup method
|
||||||
|
self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock()
|
||||||
|
|
||||||
|
# Mock the abort method that should be called when CancelledError occurs
|
||||||
|
self.chat_completion_handler.engine_client.abort = AsyncMock()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor
|
||||||
|
):
|
||||||
|
# Execute the generator and verify CancelledError handling
|
||||||
|
# The CancelledError should be caught and handled, not re-raised
|
||||||
|
chunks = []
|
||||||
|
try:
|
||||||
|
async for chunk in self.chat_completion_handler.chat_completion_stream_generator(
|
||||||
|
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100
|
||||||
|
):
|
||||||
|
chunks.append(chunk)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# This should not happen as CancelledError should be caught and handled
|
||||||
|
self.fail("CancelledError should be caught and handled, not re-raised")
|
||||||
|
|
||||||
|
# Should have received at least one chunk before cancellation
|
||||||
|
self.assertGreaterEqual(len(chunks), 1)
|
||||||
|
self.assertIsNotNone(chunks[0])
|
||||||
|
|
||||||
|
# Verify cleanup and abort were called despite the cancellation
|
||||||
|
self.chat_completion_handler.engine_client.connection_manager.cleanup_request.assert_called_once()
|
||||||
|
self.chat_completion_handler.engine_client.abort.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import unittest
|
import unittest
|
||||||
from typing import List
|
from typing import List
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
@@ -520,7 +521,7 @@ class TestOpenAIServingCompletion(unittest.IsolatedAsyncioTestCase):
|
|||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
# Verify results
|
# Verify results
|
||||||
self.assertTrue(len(results) > 0)
|
self.assertGreater(len(results), 0)
|
||||||
# Check that the first response contains prompt_logprobs
|
# Check that the first response contains prompt_logprobs
|
||||||
self.assertIn("prompt_logprobs", results[0])
|
self.assertIn("prompt_logprobs", results[0])
|
||||||
|
|
||||||
@@ -1216,6 +1217,118 @@ class TestOpenAIServingCompletion(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.assertIsNone(result.choices[0].prompt_logprobs)
|
self.assertIsNone(result.choices[0].prompt_logprobs)
|
||||||
self.assertIsNone(result.choices[0].logprobs)
|
self.assertIsNone(result.choices[0].logprobs)
|
||||||
|
|
||||||
|
async def test_create_completion_cancelled_error(self):
|
||||||
|
"""Test create_completion with asyncio.CancelledError"""
|
||||||
|
# Mock the engine client and its dependencies
|
||||||
|
mock_engine_client = Mock()
|
||||||
|
mock_engine_client.is_master = True
|
||||||
|
mock_engine_client.semaphore = Mock()
|
||||||
|
mock_engine_client.semaphore.acquire = AsyncMock()
|
||||||
|
mock_engine_client.semaphore.release = Mock()
|
||||||
|
mock_engine_client.format_and_add_data = AsyncMock()
|
||||||
|
mock_engine_client.format_and_add_data.return_value = [1, 2, 3]
|
||||||
|
mock_engine_client.abort = AsyncMock()
|
||||||
|
|
||||||
|
# Create serving completion instance
|
||||||
|
serving_completion = OpenAIServingCompletion(mock_engine_client, None, "pid", None, 360)
|
||||||
|
|
||||||
|
# Create mock request
|
||||||
|
mock_request = Mock()
|
||||||
|
mock_request.prompt = "Hello, world!"
|
||||||
|
mock_request.prompt_token_ids = None
|
||||||
|
mock_request.stream = False
|
||||||
|
mock_request.n = 1
|
||||||
|
mock_request.request_id = None
|
||||||
|
mock_request.user = None
|
||||||
|
|
||||||
|
# Mock to_dict_for_infer method to return a proper dict
|
||||||
|
def mock_to_dict_for_infer(request_id_idx, prompt):
|
||||||
|
return {"prompt": prompt, "request_id": request_id_idx, "prompt_tokens": 10, "max_tokens": 100}
|
||||||
|
|
||||||
|
mock_request.to_dict_for_infer = mock_to_dict_for_infer
|
||||||
|
|
||||||
|
# Mock format_and_add_data to raise CancelledError
|
||||||
|
mock_engine_client.format_and_add_data.side_effect = asyncio.CancelledError("Test cancellation")
|
||||||
|
|
||||||
|
# Call method and expect CancelledError to be handled
|
||||||
|
result = await serving_completion.create_completion(mock_request)
|
||||||
|
|
||||||
|
# Verify that error was handled properly
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
self.assertTrue(hasattr(result, "error"))
|
||||||
|
self.assertEqual(result.error.code, "client_aborted")
|
||||||
|
self.assertIn("client disconnected", result.error.message)
|
||||||
|
|
||||||
|
# Verify that abort was called
|
||||||
|
mock_engine_client.abort.assert_called_once()
|
||||||
|
|
||||||
|
async def test_completion_stream_generator_cancelled_error(self):
|
||||||
|
"""Test completion_stream_generator with asyncio.CancelledError"""
|
||||||
|
# Mock the engine client and its dependencies
|
||||||
|
mock_engine_client = Mock()
|
||||||
|
mock_engine_client.semaphore = Mock()
|
||||||
|
mock_engine_client.semaphore.acquire = AsyncMock()
|
||||||
|
mock_engine_client.semaphore.release = Mock()
|
||||||
|
mock_engine_client.connection_manager = AsyncMock()
|
||||||
|
mock_engine_client.data_processor = Mock()
|
||||||
|
mock_engine_client.ori_vocab_size = 1000
|
||||||
|
mock_engine_client.check_model_weight_status.return_value = False
|
||||||
|
mock_engine_client.check_health.return_value = (True, "Healthy")
|
||||||
|
mock_engine_client.abort = AsyncMock()
|
||||||
|
|
||||||
|
# Mock the data_processor methods
|
||||||
|
mock_engine_client.data_processor.process_response_dict = Mock()
|
||||||
|
|
||||||
|
# Mock connection manager get_connection method
|
||||||
|
mock_dealer = Mock()
|
||||||
|
mock_dealer.write = Mock()
|
||||||
|
mock_response_queue = AsyncMock()
|
||||||
|
|
||||||
|
# Make response_queue.get raise CancelledError
|
||||||
|
mock_response_queue.get.side_effect = asyncio.CancelledError("Test cancellation")
|
||||||
|
mock_engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue)
|
||||||
|
|
||||||
|
# Create serving completion instance
|
||||||
|
serving_completion = OpenAIServingCompletion(mock_engine_client, None, "pid", None, 360)
|
||||||
|
|
||||||
|
# Create mock request
|
||||||
|
mock_request = Mock()
|
||||||
|
mock_request.prompt_logprobs = None
|
||||||
|
mock_request.logprobs = None
|
||||||
|
mock_request.include_draft_logprobs = False
|
||||||
|
mock_request.return_token_ids = True
|
||||||
|
mock_request.include_stop_str_in_output = False
|
||||||
|
mock_request.max_streaming_response_tokens = 1
|
||||||
|
mock_request.max_tokens = None
|
||||||
|
mock_request.stream_options = Mock()
|
||||||
|
mock_request.stream_options.include_usage = False
|
||||||
|
mock_request.n = 1
|
||||||
|
mock_request.echo = False
|
||||||
|
|
||||||
|
# Call the method and collect results
|
||||||
|
result_generator = serving_completion.completion_stream_generator(
|
||||||
|
request=mock_request,
|
||||||
|
num_choices=1,
|
||||||
|
request_id="test_request",
|
||||||
|
created_time=1234567890,
|
||||||
|
model_name="test_model",
|
||||||
|
prompt_batched_token_ids=[[1, 2, 3]],
|
||||||
|
prompt_tokens_list=["hello", "world"],
|
||||||
|
max_tokens_list=[100],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect results - should handle CancelledError gracefully
|
||||||
|
results = []
|
||||||
|
async for result in result_generator:
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Verify that abort was called
|
||||||
|
mock_engine_client.abort.assert_called_once_with("test_request_0", 1)
|
||||||
|
|
||||||
|
# Verify that generator ends gracefully (should have [DONE] message)
|
||||||
|
self.assertTrue(len(results) > 0)
|
||||||
|
self.assertTrue(any("[DONE]" in result for result in results))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -0,0 +1,356 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
from fastdeploy.entrypoints.openai.utils import with_cancellation
|
||||||
|
|
||||||
|
|
||||||
|
class TestWithCancellation(unittest.TestCase):
|
||||||
|
"""Test cases for with_cancellation decorator"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures"""
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up test fixtures"""
|
||||||
|
self.loop.close()
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect")
|
||||||
|
def test_normal_execution(self, mock_listen_disconnect):
|
||||||
|
"""Test that handler executes normally when no disconnect occurs"""
|
||||||
|
|
||||||
|
# Setup mock request
|
||||||
|
mock_request = MagicMock(spec=Request)
|
||||||
|
|
||||||
|
# Create a mock that returns a coroutine that never completes
|
||||||
|
async def never_disconnect(request):
|
||||||
|
await asyncio.Future() # This will never complete
|
||||||
|
|
||||||
|
mock_listen_disconnect.side_effect = never_disconnect
|
||||||
|
|
||||||
|
# Setup handler function
|
||||||
|
@with_cancellation
|
||||||
|
async def test_handler(self, raw_request):
|
||||||
|
await asyncio.sleep(0.01) # Simulate some work
|
||||||
|
return "test_result"
|
||||||
|
|
||||||
|
# Run the test - need to pass self as first arg
|
||||||
|
result = self.loop.run_until_complete(test_handler(None, mock_request))
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
self.assertEqual(result, "test_result")
|
||||||
|
mock_listen_disconnect.assert_called_once_with(mock_request)
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect")
|
||||||
|
def test_client_disconnect(self, mock_listen_disconnect):
|
||||||
|
"""Test that handler is cancelled when client disconnects"""
|
||||||
|
|
||||||
|
# Setup mock request
|
||||||
|
mock_request = MagicMock(spec=Request)
|
||||||
|
|
||||||
|
# Create a future that will complete (disconnect) after a short delay
|
||||||
|
disconnect_future = asyncio.Future()
|
||||||
|
self.loop.call_later(0.05, disconnect_future.set_result, None)
|
||||||
|
mock_listen_disconnect.return_value = disconnect_future
|
||||||
|
|
||||||
|
# Setup handler function that takes longer than disconnect
|
||||||
|
handler_called = False
|
||||||
|
|
||||||
|
@with_cancellation
|
||||||
|
async def test_handler(self, raw_request):
|
||||||
|
nonlocal handler_called
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(0.1) # Simulate work
|
||||||
|
handler_called = True
|
||||||
|
return "should_not_reach_here"
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Handler should be cancelled
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Run the test - need to pass self as first arg
|
||||||
|
result = self.loop.run_until_complete(test_handler(None, mock_request))
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
self.assertIsNone(result) # Should return None when cancelled
|
||||||
|
self.assertFalse(handler_called) # Handler should not complete
|
||||||
|
mock_listen_disconnect.assert_called_once_with(mock_request)
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect")
|
||||||
|
def test_handler_with_args_kwargs(self, mock_listen_disconnect):
|
||||||
|
"""Test that decorator properly handles handler with args and kwargs"""
|
||||||
|
|
||||||
|
# Setup mock request
|
||||||
|
mock_request = MagicMock(spec=Request)
|
||||||
|
|
||||||
|
# Create a mock that returns a coroutine that never completes
|
||||||
|
async def never_disconnect(request):
|
||||||
|
await asyncio.Future() # This will never complete
|
||||||
|
|
||||||
|
mock_listen_disconnect.side_effect = never_disconnect
|
||||||
|
|
||||||
|
# Setup handler function with multiple arguments
|
||||||
|
@with_cancellation
|
||||||
|
async def test_handler(arg1, arg2, raw_request, kwarg1=None, kwarg2=None):
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return {"arg1": arg1, "arg2": arg2, "kwarg1": kwarg1, "kwarg2": kwarg2, "request": raw_request}
|
||||||
|
|
||||||
|
# Run the test with both positional and keyword arguments
|
||||||
|
result = self.loop.run_until_complete(
|
||||||
|
test_handler("value1", "value2", mock_request, kwarg1="kwvalue1", kwarg2="kwvalue2")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
expected = {
|
||||||
|
"arg1": "value1",
|
||||||
|
"arg2": "value2",
|
||||||
|
"kwarg1": "kwvalue1",
|
||||||
|
"kwarg2": "kwvalue2",
|
||||||
|
"request": mock_request,
|
||||||
|
}
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect")
|
||||||
|
def test_handler_returns_streaming_response(self, mock_listen_disconnect):
|
||||||
|
"""Test that decorator handles StreamingResponse correctly"""
|
||||||
|
|
||||||
|
# Setup mock request
|
||||||
|
mock_request = MagicMock(spec=Request)
|
||||||
|
|
||||||
|
# Create a mock that returns a coroutine that never completes
|
||||||
|
async def never_disconnect(request):
|
||||||
|
await asyncio.Future() # This will never complete
|
||||||
|
|
||||||
|
mock_listen_disconnect.side_effect = never_disconnect
|
||||||
|
|
||||||
|
# Setup handler that returns StreamingResponse
|
||||||
|
@with_cancellation
|
||||||
|
async def test_handler(self, raw_request):
|
||||||
|
async def generate():
|
||||||
|
yield "chunk1"
|
||||||
|
yield "chunk2"
|
||||||
|
|
||||||
|
return StreamingResponse(generate())
|
||||||
|
|
||||||
|
# Run the test - need to pass self as first arg
|
||||||
|
result = self.loop.run_until_complete(test_handler(None, mock_request))
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
self.assertIsInstance(result, StreamingResponse)
|
||||||
|
mock_listen_disconnect.assert_called_once_with(mock_request)
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect")
|
||||||
|
def test_handler_exception_propagation(self, mock_listen_disconnect):
|
||||||
|
"""Test that exceptions from handler are properly propagated"""
|
||||||
|
|
||||||
|
# Setup mock request
|
||||||
|
mock_request = MagicMock(spec=Request)
|
||||||
|
|
||||||
|
# Create a mock that returns a coroutine that never completes
|
||||||
|
async def never_disconnect(request):
|
||||||
|
await asyncio.Future() # This will never complete
|
||||||
|
|
||||||
|
mock_listen_disconnect.side_effect = never_disconnect
|
||||||
|
|
||||||
|
# Setup handler that raises an exception
|
||||||
|
@with_cancellation
|
||||||
|
async def test_handler(self, raw_request):
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
raise ValueError("Test exception")
|
||||||
|
|
||||||
|
# Run the test and expect exception - need to pass self as first arg
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
self.loop.run_until_complete(test_handler(None, mock_request))
|
||||||
|
|
||||||
|
self.assertEqual(str(context.exception), "Test exception")
|
||||||
|
mock_listen_disconnect.assert_called_once_with(mock_request)
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect")
|
||||||
|
def test_concurrent_cancellation_and_completion(self, mock_listen_disconnect):
|
||||||
|
"""Test edge case where cancellation and completion happen simultaneously"""
|
||||||
|
|
||||||
|
# Setup mock request
|
||||||
|
mock_request = MagicMock(spec=Request)
|
||||||
|
|
||||||
|
# Create futures that complete at roughly the same time
|
||||||
|
disconnect_future = asyncio.Future()
|
||||||
|
handler_future = asyncio.Future()
|
||||||
|
|
||||||
|
# Set both futures to complete almost simultaneously
|
||||||
|
self.loop.call_later(0.05, lambda: disconnect_future.set_result(None))
|
||||||
|
self.loop.call_later(0.05, lambda: handler_future.set_result("completed"))
|
||||||
|
|
||||||
|
mock_listen_disconnect.return_value = disconnect_future
|
||||||
|
|
||||||
|
@with_cancellation
|
||||||
|
async def test_handler(self, raw_request):
|
||||||
|
return await handler_future
|
||||||
|
|
||||||
|
# Run the test - need to pass self as first arg
|
||||||
|
result = self.loop.run_until_complete(test_handler(None, mock_request))
|
||||||
|
|
||||||
|
# The result depends on which task completes first
|
||||||
|
# This test ensures the decorator handles this edge case gracefully
|
||||||
|
self.assertIn(result, [None, "completed"])
|
||||||
|
|
||||||
|
def test_wrapper_preserves_function_metadata(self):
|
||||||
|
"""Test that the wrapper preserves the original function's metadata"""
|
||||||
|
|
||||||
|
def original_handler(raw_request):
|
||||||
|
"""Original handler docstring"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Apply decorator
|
||||||
|
decorated_handler = with_cancellation(original_handler)
|
||||||
|
|
||||||
|
# Verify metadata is preserved
|
||||||
|
self.assertEqual(decorated_handler.__name__, "original_handler")
|
||||||
|
self.assertEqual(decorated_handler.__doc__, "Original handler docstring")
|
||||||
|
self.assertTrue(hasattr(decorated_handler, "__wrapped__"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestListenForDisconnect(unittest.TestCase):
|
||||||
|
"""Test cases for listen_for_disconnect function"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures"""
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up test fixtures"""
|
||||||
|
self.loop.close()
|
||||||
|
|
||||||
|
def test_listen_for_disconnect_normal_flow(self):
|
||||||
|
"""Test that listen_for_disconnect waits for disconnect message"""
|
||||||
|
from fastdeploy.entrypoints.openai.utils import listen_for_disconnect
|
||||||
|
|
||||||
|
# Setup mock request
|
||||||
|
mock_request = MagicMock()
|
||||||
|
receive_call_count = 0
|
||||||
|
|
||||||
|
# Create mock messages - normal messages followed by disconnect
|
||||||
|
messages = [
|
||||||
|
{"type": "http.request", "body": b"some data"},
|
||||||
|
{"type": "http.request", "body": b"more data"},
|
||||||
|
{"type": "http.disconnect"}, # This should break the loop
|
||||||
|
]
|
||||||
|
|
||||||
|
# Setup receive method to return messages sequentially
|
||||||
|
receive_iter = iter(messages)
|
||||||
|
|
||||||
|
async def mock_receive():
|
||||||
|
nonlocal receive_call_count
|
||||||
|
receive_call_count += 1
|
||||||
|
try:
|
||||||
|
return next(receive_iter)
|
||||||
|
except StopIteration:
|
||||||
|
# After all messages, return a message that keeps it waiting
|
||||||
|
await asyncio.Future() # Never completes
|
||||||
|
|
||||||
|
mock_request.receive = mock_receive
|
||||||
|
|
||||||
|
# Run the function in the event loop
|
||||||
|
self.loop.run_until_complete(listen_for_disconnect(mock_request))
|
||||||
|
|
||||||
|
# Verify that receive was called multiple times
|
||||||
|
self.assertGreaterEqual(receive_call_count, 3)
|
||||||
|
|
||||||
|
def test_listen_for_disconnect_immediate_disconnect(self):
|
||||||
|
"""Test that listen_for_disconnect returns immediately on disconnect"""
|
||||||
|
from fastdeploy.entrypoints.openai.utils import listen_for_disconnect
|
||||||
|
|
||||||
|
# Setup mock request
|
||||||
|
mock_request = MagicMock()
|
||||||
|
receive_called = False
|
||||||
|
|
||||||
|
# Setup receive to return disconnect immediately (as a coroutine)
|
||||||
|
async def mock_receive():
|
||||||
|
nonlocal receive_called
|
||||||
|
receive_called = True
|
||||||
|
return {"type": "http.disconnect"}
|
||||||
|
|
||||||
|
mock_request.receive = mock_receive
|
||||||
|
|
||||||
|
# Run the function in the event loop
|
||||||
|
self.loop.run_until_complete(listen_for_disconnect(mock_request))
|
||||||
|
|
||||||
|
# Verify that receive was called exactly once
|
||||||
|
self.assertTrue(receive_called)
|
||||||
|
|
||||||
|
def test_listen_for_disconnect_timeout(self):
|
||||||
|
"""Test that listen_for_disconnect can be cancelled with timeout"""
|
||||||
|
from fastdeploy.entrypoints.openai.utils import listen_for_disconnect
|
||||||
|
|
||||||
|
# Setup mock request
|
||||||
|
mock_request = MagicMock()
|
||||||
|
|
||||||
|
# Setup receive to never return disconnect
|
||||||
|
async def mock_receive():
|
||||||
|
await asyncio.Future() # Never completes
|
||||||
|
|
||||||
|
mock_request.receive = mock_receive
|
||||||
|
|
||||||
|
# Run with timeout to test cancellation
|
||||||
|
with self.assertRaises(asyncio.TimeoutError):
|
||||||
|
self.loop.run_until_complete(asyncio.wait_for(listen_for_disconnect(mock_request), timeout=0.01))
|
||||||
|
|
||||||
|
def test_listen_for_disconnect_various_message_types(self):
|
||||||
|
"""Test that listen_for_disconnect ignores non-disconnect messages"""
|
||||||
|
from fastdeploy.entrypoints.openai.utils import listen_for_disconnect
|
||||||
|
|
||||||
|
# Setup mock request
|
||||||
|
mock_request = MagicMock()
|
||||||
|
receive_call_count = 0
|
||||||
|
|
||||||
|
# Create various non-disconnect messages
|
||||||
|
messages = [
|
||||||
|
{"type": "http.request", "body": b"data"},
|
||||||
|
{"type": "http.response", "status": 200},
|
||||||
|
{"type": "websocket.connect"},
|
||||||
|
{"type": "http.request", "body": b"more data"},
|
||||||
|
{"type": "http.disconnect"}, # Final disconnect
|
||||||
|
]
|
||||||
|
|
||||||
|
# Setup receive method
|
||||||
|
receive_iter = iter(messages)
|
||||||
|
|
||||||
|
async def mock_receive():
|
||||||
|
nonlocal receive_call_count
|
||||||
|
receive_call_count += 1
|
||||||
|
try:
|
||||||
|
return next(receive_iter)
|
||||||
|
except StopIteration:
|
||||||
|
await asyncio.Future()
|
||||||
|
|
||||||
|
mock_request.receive = mock_receive
|
||||||
|
|
||||||
|
# Run the function in the event loop
|
||||||
|
self.loop.run_until_complete(listen_for_disconnect(mock_request))
|
||||||
|
|
||||||
|
# Verify all messages were processed
|
||||||
|
self.assertEqual(receive_call_count, 5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -0,0 +1,257 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from fastdeploy.engine.request import RequestStatus
|
||||||
|
from fastdeploy.entrypoints.engine_client import EngineClient
|
||||||
|
|
||||||
|
|
||||||
|
class TestEngineClientAbort(unittest.TestCase):
|
||||||
|
"""Test cases for EngineClient.abort method"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures"""
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
|
||||||
|
# Create a mock FDConfig
|
||||||
|
self.mock_fd_config = MagicMock()
|
||||||
|
self.mock_fd_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
self.mock_fd_config.model_config.enable_mm = False
|
||||||
|
self.mock_fd_config.model_config.max_model_len = 2048
|
||||||
|
self.mock_fd_config.model_config.enable_logprob = True
|
||||||
|
self.mock_fd_config.cache_config.enable_prefix_caching = False
|
||||||
|
self.mock_fd_config.scheduler_config.splitwise_role = "mixed"
|
||||||
|
self.mock_fd_config.limit_mm_per_prompt = 5
|
||||||
|
self.mock_fd_config.eplb_config.enable_eplb = False
|
||||||
|
self.mock_fd_config.structured_outputs_config.reasoning_parser = None
|
||||||
|
self.mock_fd_config.mm_processor_kwargs = {}
|
||||||
|
self.mock_fd_config.tool_parser = None
|
||||||
|
self.mock_fd_config.cache_config.max_processor_cache = 0
|
||||||
|
|
||||||
|
# Create EngineClient instance
|
||||||
|
with patch("fastdeploy.entrypoints.engine_client.InputPreprocessor"):
|
||||||
|
with patch("fastdeploy.entrypoints.engine_client.IPCSignal"):
|
||||||
|
with patch("fastdeploy.entrypoints.engine_client.StatefulSemaphore"):
|
||||||
|
with patch("fastdeploy.entrypoints.engine_client.DealerConnectionManager"):
|
||||||
|
with patch("fastdeploy.entrypoints.engine_client.FileLock"):
|
||||||
|
self.engine_client = EngineClient(
|
||||||
|
pid=12345, port=8000, fd_config=self.mock_fd_config, workers=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up test fixtures"""
|
||||||
|
self.loop.close()
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
|
||||||
|
@patch.object(EngineClient, "_send_task")
|
||||||
|
def test_abort_single_request(self, mock_send_task):
|
||||||
|
"""Test aborting a single request"""
|
||||||
|
request_id = "test_request"
|
||||||
|
|
||||||
|
# Run the abort method
|
||||||
|
self.loop.run_until_complete(self.engine_client.abort(request_id, n=1))
|
||||||
|
|
||||||
|
# Verify _send_task was called with correct data
|
||||||
|
expected_data = {
|
||||||
|
"request_id": "test_request_0",
|
||||||
|
"status": RequestStatus.ABORT.value,
|
||||||
|
}
|
||||||
|
mock_send_task.assert_called_once_with(expected_data)
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
|
||||||
|
@patch.object(EngineClient, "_send_task")
|
||||||
|
def test_abort_multiple_requests(self, mock_send_task):
|
||||||
|
"""Test aborting multiple requests"""
|
||||||
|
request_id = "test_request"
|
||||||
|
n = 3
|
||||||
|
|
||||||
|
# Run the abort method
|
||||||
|
self.loop.run_until_complete(self.engine_client.abort(request_id, n=n))
|
||||||
|
|
||||||
|
# Verify _send_task was called correct number of times
|
||||||
|
self.assertEqual(mock_send_task.call_count, n)
|
||||||
|
|
||||||
|
# Verify each call had correct request_id
|
||||||
|
expected_calls = [
|
||||||
|
({"request_id": "test_request_0", "status": RequestStatus.ABORT.value},),
|
||||||
|
({"request_id": "test_request_1", "status": RequestStatus.ABORT.value},),
|
||||||
|
({"request_id": "test_request_2", "status": RequestStatus.ABORT.value},),
|
||||||
|
]
|
||||||
|
|
||||||
|
actual_calls = [call.args for call in mock_send_task.call_args_list]
|
||||||
|
self.assertEqual(actual_calls, expected_calls)
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
|
||||||
|
@patch.object(EngineClient, "_send_task")
|
||||||
|
def test_abort_with_existing_suffix(self, mock_send_task):
|
||||||
|
"""Test aborting request that already has _number suffix"""
|
||||||
|
request_id = "test_request_123_2"
|
||||||
|
n = 2
|
||||||
|
|
||||||
|
# Run the abort method
|
||||||
|
self.loop.run_until_complete(self.engine_client.abort(request_id, n=n))
|
||||||
|
|
||||||
|
# Verify _send_task was called correct number of times
|
||||||
|
self.assertEqual(mock_send_task.call_count, n)
|
||||||
|
|
||||||
|
# Verify each call had correct request_id (should use prefix before existing suffix)
|
||||||
|
expected_calls = [
|
||||||
|
({"request_id": "test_request_123_0", "status": RequestStatus.ABORT.value},),
|
||||||
|
({"request_id": "test_request_123_1", "status": RequestStatus.ABORT.value},),
|
||||||
|
]
|
||||||
|
|
||||||
|
actual_calls = [call.args for call in mock_send_task.call_args_list]
|
||||||
|
self.assertEqual(actual_calls, expected_calls)
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
|
||||||
|
@patch.object(EngineClient, "_send_task")
|
||||||
|
def test_abort_with_no_suffix(self, mock_send_task):
|
||||||
|
"""Test aborting request without _number suffix"""
|
||||||
|
request_id = "test_request_without_suffix"
|
||||||
|
n = 2
|
||||||
|
|
||||||
|
# Run the abort method
|
||||||
|
self.loop.run_until_complete(self.engine_client.abort(request_id, n=n))
|
||||||
|
|
||||||
|
# Verify _send_task was called correct number of times
|
||||||
|
self.assertEqual(mock_send_task.call_count, n)
|
||||||
|
|
||||||
|
# Verify each call had correct request_id (should use full request_id as prefix)
|
||||||
|
expected_calls = [
|
||||||
|
({"request_id": "test_request_without_suffix_0", "status": RequestStatus.ABORT.value},),
|
||||||
|
({"request_id": "test_request_without_suffix_1", "status": RequestStatus.ABORT.value},),
|
||||||
|
]
|
||||||
|
|
||||||
|
actual_calls = [call.args for call in mock_send_task.call_args_list]
|
||||||
|
self.assertEqual(actual_calls, expected_calls)
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
|
||||||
|
@patch.object(EngineClient, "_send_task")
|
||||||
|
def test_abort_with_zero_n(self, mock_send_task):
|
||||||
|
"""Test aborting with n=0 should not send any requests"""
|
||||||
|
request_id = "test_request_123"
|
||||||
|
|
||||||
|
# Run the abort method
|
||||||
|
self.loop.run_until_complete(self.engine_client.abort(request_id, n=0))
|
||||||
|
|
||||||
|
# Verify _send_task was not called
|
||||||
|
mock_send_task.assert_not_called()
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
|
||||||
|
@patch.object(EngineClient, "_send_task")
|
||||||
|
def test_abort_with_negative_n(self, mock_send_task):
|
||||||
|
"""Test aborting with negative n should not send any requests"""
|
||||||
|
request_id = "test_request_123"
|
||||||
|
|
||||||
|
# Run the abort method
|
||||||
|
self.loop.run_until_complete(self.engine_client.abort(request_id, n=-1))
|
||||||
|
|
||||||
|
# Verify _send_task was not called
|
||||||
|
mock_send_task.assert_not_called()
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", False)
|
||||||
|
@patch.object(EngineClient, "_send_task")
|
||||||
|
def test_abort_when_feature_disabled(self, mock_send_task):
|
||||||
|
"""Test abort when FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE is False"""
|
||||||
|
request_id = "test_request_123"
|
||||||
|
|
||||||
|
# Run the abort method
|
||||||
|
self.loop.run_until_complete(self.engine_client.abort(request_id, n=1))
|
||||||
|
|
||||||
|
# Verify _send_task was not called
|
||||||
|
mock_send_task.assert_not_called()
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
|
||||||
|
@patch.object(EngineClient, "_send_task")
|
||||||
|
def test_abort_request_id_regex_parsing(self, mock_send_task):
|
||||||
|
"""Test that request_id regex parsing works correctly for various formats"""
|
||||||
|
test_cases = [
|
||||||
|
("simple_request", "simple_request"),
|
||||||
|
("request_with_underscores", "request_with_underscores"),
|
||||||
|
("request_123", "request"),
|
||||||
|
("request_123_456", "request_123"),
|
||||||
|
("request_0", "request"),
|
||||||
|
("complex_name_123_456_789", "complex_name_123_456"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for input_request_id, expected_prefix in test_cases:
|
||||||
|
with self.subTest(input_request_id=input_request_id):
|
||||||
|
mock_send_task.reset_mock()
|
||||||
|
|
||||||
|
# Run the abort method
|
||||||
|
self.loop.run_until_complete(self.engine_client.abort(input_request_id, n=1))
|
||||||
|
|
||||||
|
# Verify _send_task was called with correct prefix
|
||||||
|
expected_data = {
|
||||||
|
"request_id": f"{expected_prefix}_0",
|
||||||
|
"status": RequestStatus.ABORT.value,
|
||||||
|
}
|
||||||
|
mock_send_task.assert_called_once_with(expected_data)
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
|
||||||
|
@patch("fastdeploy.entrypoints.engine_client.api_server_logger")
|
||||||
|
@patch.object(EngineClient, "_send_task")
|
||||||
|
def test_abort_logging(self, mock_send_task, mock_logger):
|
||||||
|
"""Test that abort method logs correctly"""
|
||||||
|
request_id = "test_request"
|
||||||
|
n = 2
|
||||||
|
|
||||||
|
# Run the abort method
|
||||||
|
self.loop.run_until_complete(self.engine_client.abort(request_id, n=n))
|
||||||
|
|
||||||
|
# Verify info log was called twice
|
||||||
|
self.assertEqual(mock_logger.info.call_count, 2)
|
||||||
|
|
||||||
|
# Verify the first log message (abort start)
|
||||||
|
first_call = mock_logger.info.call_args_list[0]
|
||||||
|
self.assertEqual(first_call[0][0], "abort request_id:test_request")
|
||||||
|
|
||||||
|
# Verify the second log message (abort completion with request IDs)
|
||||||
|
second_call = mock_logger.info.call_args_list[1]
|
||||||
|
expected_log_message = "Aborted request(s) %s."
|
||||||
|
self.assertEqual(second_call[0][0], expected_log_message)
|
||||||
|
self.assertEqual(second_call[0][1], "test_request_0,test_request_1")
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
|
||||||
|
@patch("fastdeploy.entrypoints.engine_client.api_server_logger")
|
||||||
|
@patch.object(EngineClient, "_send_task")
|
||||||
|
def test_abort_warning_logging_for_invalid_format(self, mock_send_task, mock_logger):
|
||||||
|
"""Test that abort method logs warning for invalid request_id format"""
|
||||||
|
request_id = "invalid_format_no_suffix" # This should actually not trigger warning
|
||||||
|
n = 1
|
||||||
|
|
||||||
|
# Run the abort method
|
||||||
|
self.loop.run_until_complete(self.engine_client.abort(request_id, n=n))
|
||||||
|
|
||||||
|
# Verify warning log was called (this case might not actually trigger warning)
|
||||||
|
# The warning is only triggered when regex doesn't match, but our test case has valid format
|
||||||
|
# Let's test with a case that should trigger warning
|
||||||
|
mock_logger.reset_mock()
|
||||||
|
|
||||||
|
# This should trigger warning because it doesn't end with _number
|
||||||
|
request_id_no_suffix = "just_a_string"
|
||||||
|
self.loop.run_until_complete(self.engine_client.abort(request_id_no_suffix, n=1))
|
||||||
|
|
||||||
|
# Should have logged warning about format error
|
||||||
|
mock_logger.warning.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -17,7 +17,7 @@
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import Mock
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
@@ -34,6 +34,19 @@ class MockConfig:
|
|||||||
|
|
||||||
class SpeculativeConfig:
|
class SpeculativeConfig:
|
||||||
method = None
|
method = None
|
||||||
|
num_speculative_tokens = 1
|
||||||
|
num_model_steps = 1
|
||||||
|
max_candidate_len = 5
|
||||||
|
verify_window = 2
|
||||||
|
max_ngram_size = 5
|
||||||
|
min_ngram_size = 2
|
||||||
|
model = None
|
||||||
|
quantization = None
|
||||||
|
num_gpu_block_expand_ratio = 1
|
||||||
|
model_type = "main"
|
||||||
|
benchmark_mode = False
|
||||||
|
num_extra_cache_layer = 0
|
||||||
|
mtp_strategy = "default"
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
enable_logprob = False
|
enable_logprob = False
|
||||||
@@ -91,6 +104,8 @@ class MockResourceManager:
|
|||||||
self.stop_flags = [False]
|
self.stop_flags = [False]
|
||||||
self.tasks_list = [MockTask()]
|
self.tasks_list = [MockTask()]
|
||||||
self.to_be_rescheduled_request_id_set = set()
|
self.to_be_rescheduled_request_id_set = set()
|
||||||
|
self.abort_req_ids_set = set()
|
||||||
|
self.req_dict = {}
|
||||||
|
|
||||||
def info(self):
|
def info(self):
|
||||||
return "Mock resource manager info"
|
return "Mock resource manager info"
|
||||||
@@ -241,6 +256,147 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase):
|
|||||||
assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1
|
assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1
|
||||||
assert len(request_output.outputs.draft_top_logprobs[2]) == accept_num[i]
|
assert len(request_output.outputs.draft_top_logprobs[2]) == accept_num[i]
|
||||||
|
|
||||||
|
def test_process_batch_output_aborted_task_negative_token_speculative_decoding(self):
|
||||||
|
"""Test aborted task receiving negative token triggers recycling in speculative decoding mode"""
|
||||||
|
processor = self.setup_token_processor(speculative_decoding=True, use_logprobs=True)
|
||||||
|
|
||||||
|
# Set up task as aborted
|
||||||
|
task_id = "test_aborted_request"
|
||||||
|
task = processor.resource_manager.tasks_list[0]
|
||||||
|
task.request_id = task_id
|
||||||
|
processor.resource_manager.abort_req_ids_set = {task_id}
|
||||||
|
|
||||||
|
# Add the task to req_dict to prevent _recycle_aborted_task from processing it early
|
||||||
|
# Use a larger batch to avoid the early recycling condition
|
||||||
|
processor.resource_manager.req_dict[task_id] = 0 # batch_id = 0
|
||||||
|
|
||||||
|
# Mock _recycle_resources to track if it's called
|
||||||
|
processor._recycle_resources = MagicMock()
|
||||||
|
|
||||||
|
# Set up output tokens with negative token
|
||||||
|
# stop_flag
|
||||||
|
processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2))
|
||||||
|
# mtype target = 3
|
||||||
|
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3))
|
||||||
|
# batch = 2 (so batch_id=0 is < batch_size-1=1)
|
||||||
|
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2))
|
||||||
|
# Set accept_num = PREEMPTED_TOKEN_ID (-9) for first task to trigger abort logic
|
||||||
|
processor.output_tokens[3, 0].set_tensor(paddle.to_tensor(-9))
|
||||||
|
processor.output_tokens[4, 0].set_tensor(paddle.to_tensor(1))
|
||||||
|
|
||||||
|
# Add second task to tasks_list
|
||||||
|
task2 = MockTask()
|
||||||
|
task2.request_id = "test_request_2"
|
||||||
|
processor.resource_manager.tasks_list = [task, task2]
|
||||||
|
processor.resource_manager.stop_flags = [False, False]
|
||||||
|
# Update tokens_counter to include both tasks
|
||||||
|
processor.tokens_counter[task_id] = 0
|
||||||
|
processor.tokens_counter[task2.request_id] = 0
|
||||||
|
|
||||||
|
# Mock llm_logger to capture the log message and envs.ENABLE_V1_KVCACHE_SCHEDULER
|
||||||
|
with (
|
||||||
|
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
|
||||||
|
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0),
|
||||||
|
):
|
||||||
|
# Call the method
|
||||||
|
processor._process_batch_output()
|
||||||
|
|
||||||
|
# In speculative decoding mode, when accept_num[i] == PREEMPTED_TOKEN_ID,
|
||||||
|
# the code logs "sync preemption" and continues without triggering abort recycling
|
||||||
|
# This is the expected behavior for speculative decoding mode
|
||||||
|
mock_logger.info.assert_any_call(f"sync preemption for request_id {task_id} done.")
|
||||||
|
# Verify that _recycle_resources was NOT called for the aborted task
|
||||||
|
# (it may be called for other tasks like test_request_2 if they receive EOS tokens)
|
||||||
|
for call in processor._recycle_resources.call_args_list:
|
||||||
|
self.assertNotEqual(
|
||||||
|
call[0][0], task_id, f"_recycle_resources should not be called for aborted task {task_id}"
|
||||||
|
)
|
||||||
|
# Verify that the task is still in abort_req_ids_set
|
||||||
|
self.assertIn(task_id, processor.resource_manager.abort_req_ids_set)
|
||||||
|
|
||||||
|
def test_process_batch_output_aborted_task_negative_token_normal_mode(self):
|
||||||
|
"""Test aborted task receiving negative token triggers recycling in normal mode"""
|
||||||
|
processor = self.setup_token_processor(speculative_decoding=False, use_logprobs=False)
|
||||||
|
|
||||||
|
# Set up task as aborted
|
||||||
|
task_id = "test_aborted_request"
|
||||||
|
task = processor.resource_manager.tasks_list[0]
|
||||||
|
task.request_id = task_id
|
||||||
|
processor.resource_manager.abort_req_ids_set = {task_id}
|
||||||
|
|
||||||
|
# Add the task to req_dict to prevent _recycle_aborted_task from processing it early
|
||||||
|
# batch_id should be < batch_size - 1 to avoid early recycling
|
||||||
|
processor.resource_manager.req_dict[task_id] = (
|
||||||
|
0 # batch_id = 0, batch_size = 1, so 0 < 0 is false, but 0 >= 0 is true
|
||||||
|
)
|
||||||
|
# Actually, let's use a larger batch to avoid the early recycling condition
|
||||||
|
processor.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64")
|
||||||
|
|
||||||
|
# Mock _recycle_resources to track if it's called
|
||||||
|
processor._recycle_resources = MagicMock()
|
||||||
|
|
||||||
|
# Set up output tokens with negative token
|
||||||
|
# batch = 2 (so batch_id=0 is < batch_size-1=1)
|
||||||
|
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(2))
|
||||||
|
# Set negative token for first task (batch_id=0)
|
||||||
|
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(-1))
|
||||||
|
# Set positive token for second task (batch_id=1)
|
||||||
|
processor.output_tokens[3, 0].set_tensor(paddle.to_tensor(100))
|
||||||
|
|
||||||
|
# Add second task to tasks_list
|
||||||
|
task2 = MockTask()
|
||||||
|
task2.request_id = "test_request_2"
|
||||||
|
processor.resource_manager.tasks_list = [task, task2]
|
||||||
|
processor.resource_manager.stop_flags = [False, False]
|
||||||
|
# Update tokens_counter to include both tasks
|
||||||
|
processor.tokens_counter[task_id] = 0
|
||||||
|
processor.tokens_counter[task2.request_id] = 0
|
||||||
|
|
||||||
|
# Mock llm_logger to capture the log message and envs.ENABLE_V1_KVCACHE_SCHEDULER
|
||||||
|
with (
|
||||||
|
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
|
||||||
|
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0),
|
||||||
|
):
|
||||||
|
# Call the method
|
||||||
|
processor._process_batch_output()
|
||||||
|
|
||||||
|
# Verify the recycling logic was triggered
|
||||||
|
mock_logger.info.assert_any_call(f"Aborted task {task_id} received negative token. Recycling.")
|
||||||
|
processor._recycle_resources.assert_called_once_with(task_id, 0, task)
|
||||||
|
self.assertNotIn(task_id, processor.resource_manager.abort_req_ids_set)
|
||||||
|
|
||||||
|
def test_process_batch_output_non_aborted_task_negative_token(self):
|
||||||
|
"""Test non-aborted task receiving negative token does not trigger recycling"""
|
||||||
|
processor = self.setup_token_processor(speculative_decoding=False, use_logprobs=False)
|
||||||
|
|
||||||
|
# Set up task as not aborted
|
||||||
|
task_id = "test_normal_request"
|
||||||
|
task = processor.resource_manager.tasks_list[0]
|
||||||
|
task.request_id = task_id
|
||||||
|
processor.resource_manager.abort_req_ids_set = set() # Empty set
|
||||||
|
|
||||||
|
# Mock _recycle_resources to track if it's called
|
||||||
|
processor._recycle_resources = MagicMock()
|
||||||
|
|
||||||
|
# Set up output tokens with negative token
|
||||||
|
# batch = 1
|
||||||
|
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(1))
|
||||||
|
# Set negative token
|
||||||
|
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(-1))
|
||||||
|
|
||||||
|
# Mock llm_logger to capture the log message and envs.ENABLE_V1_KVCACHE_SCHEDULER
|
||||||
|
with (
|
||||||
|
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
|
||||||
|
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0),
|
||||||
|
):
|
||||||
|
# Call the method
|
||||||
|
processor._process_batch_output()
|
||||||
|
print(mock_logger)
|
||||||
|
# Verify the recycling logic was NOT triggered
|
||||||
|
# When a non-aborted task receives a negative token, the code just continues
|
||||||
|
# without logging or recycling
|
||||||
|
processor._recycle_resources.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(verbosity=2, buffer=False)
|
unittest.main(verbosity=2, buffer=False)
|
||||||
|
|||||||
@@ -153,6 +153,60 @@ class TestTokenProcessorLogprobs(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
||||||
|
def test_process_batch_output_use_zmq_aborted_task_negative_token(self):
|
||||||
|
"""Test aborted task receiving negative token triggers recycling logic"""
|
||||||
|
# Set up task as aborted
|
||||||
|
task_id = "test_aborted_request"
|
||||||
|
self.task_mock.request_id = task_id
|
||||||
|
self.processor.resource_manager.abort_req_ids_set = {task_id}
|
||||||
|
|
||||||
|
# Create stream data with negative token
|
||||||
|
stream_data = MagicMock()
|
||||||
|
stream_data.tokens = np.array([1, 2, -1]) # Last token is negative
|
||||||
|
stream_data.batch_id = 0
|
||||||
|
|
||||||
|
# Mock _recycle_resources to track if it's called
|
||||||
|
self.processor._recycle_resources = MagicMock()
|
||||||
|
|
||||||
|
# Mock the llm_logger module and envs.ENABLE_V1_KVCACHE_SCHEDULER
|
||||||
|
with (
|
||||||
|
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
|
||||||
|
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0),
|
||||||
|
):
|
||||||
|
# Call the method
|
||||||
|
result = self.processor._process_batch_output_use_zmq([stream_data])
|
||||||
|
|
||||||
|
# Verify the recycling logic was triggered
|
||||||
|
mock_logger.info.assert_any_call(f"Aborted task {task_id} received negative token. Recycling.")
|
||||||
|
self.processor._recycle_resources.assert_called_once_with(task_id, 0, self.task_mock)
|
||||||
|
self.assertNotIn(task_id, self.processor.resource_manager.abort_req_ids_set)
|
||||||
|
self.assertEqual(len(result), 1) # Should return abort result
|
||||||
|
self.assertEqual(result[0].finished, True)
|
||||||
|
self.assertEqual(result[0].error_code, 499)
|
||||||
|
self.assertIn("aborted", result[0].error_msg.lower())
|
||||||
|
|
||||||
|
def test_process_batch_output_use_zmq_non_aborted_task_negative_token(self):
|
||||||
|
"""Test non-aborted task receiving negative token does not trigger recycling"""
|
||||||
|
# Set up task as not aborted
|
||||||
|
task_id = "test_normal_request"
|
||||||
|
self.task_mock.request_id = task_id
|
||||||
|
self.processor.resource_manager.abort_req_ids_set = set() # Empty set
|
||||||
|
|
||||||
|
# Create stream data with negative token
|
||||||
|
stream_data = MagicMock()
|
||||||
|
stream_data.tokens = np.array([1, 2, -1]) # Last token is negative
|
||||||
|
stream_data.batch_id = 0
|
||||||
|
|
||||||
|
# Mock _recycle_resources to track if it's called
|
||||||
|
self.processor._recycle_resources = MagicMock()
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
self.processor._process_batch_output_use_zmq([stream_data])
|
||||||
|
|
||||||
|
# Verify recycling logic was NOT triggered
|
||||||
|
self.processor._recycle_resources.assert_not_called()
|
||||||
|
self.processor.llm_logger.info.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ class _DummyResourceManager:
|
|||||||
self.req_dict = {}
|
self.req_dict = {}
|
||||||
self.requests = {}
|
self.requests = {}
|
||||||
self.to_be_rescheduled_request_id_set = set()
|
self.to_be_rescheduled_request_id_set = set()
|
||||||
|
self.abort_req_ids_set = set()
|
||||||
self.recycled = []
|
self.recycled = []
|
||||||
self.cached_tasks = []
|
self.cached_tasks = []
|
||||||
self.cleared = False
|
self.cleared = False
|
||||||
|
|||||||
Reference in New Issue
Block a user