[Feature] Support stopping the inference for the corresponding request in the online service after a disconnection request. (#5320)

* request disconnect

* request disconnect

* fix bug

* fix bug--amend

---------

Co-authored-by: root <root@yq01-sys-rpm26xc1knu.yq01.baidu.com>
This commit is contained in:
qwes5s5
2026-01-16 11:46:13 +08:00
committed by GitHub
parent 8f035101ad
commit b2a2e11551
25 changed files with 1339 additions and 63 deletions
+157 -1
View File
@@ -17,7 +17,7 @@
import random
import time
import unittest
from unittest.mock import Mock
from unittest.mock import MagicMock, Mock, patch
import paddle
@@ -34,6 +34,19 @@ class MockConfig:
class SpeculativeConfig:
method = None
num_speculative_tokens = 1
num_model_steps = 1
max_candidate_len = 5
verify_window = 2
max_ngram_size = 5
min_ngram_size = 2
model = None
quantization = None
num_gpu_block_expand_ratio = 1
model_type = "main"
benchmark_mode = False
num_extra_cache_layer = 0
mtp_strategy = "default"
class ModelConfig:
enable_logprob = False
@@ -91,6 +104,8 @@ class MockResourceManager:
self.stop_flags = [False]
self.tasks_list = [MockTask()]
self.to_be_rescheduled_request_id_set = set()
self.abort_req_ids_set = set()
self.req_dict = {}
def info(self):
return "Mock resource manager info"
@@ -241,6 +256,147 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase):
assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1
assert len(request_output.outputs.draft_top_logprobs[2]) == accept_num[i]
def test_process_batch_output_aborted_task_negative_token_speculative_decoding(self):
"""Test aborted task receiving negative token triggers recycling in speculative decoding mode"""
processor = self.setup_token_processor(speculative_decoding=True, use_logprobs=True)
# Set up task as aborted
task_id = "test_aborted_request"
task = processor.resource_manager.tasks_list[0]
task.request_id = task_id
processor.resource_manager.abort_req_ids_set = {task_id}
# Add the task to req_dict to prevent _recycle_aborted_task from processing it early
# Use a larger batch to avoid the early recycling condition
processor.resource_manager.req_dict[task_id] = 0 # batch_id = 0
# Mock _recycle_resources to track if it's called
processor._recycle_resources = MagicMock()
# Set up output tokens with negative token
# stop_flag
processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2))
# mtype target = 3
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3))
# batch = 2 (so batch_id=0 is < batch_size-1=1)
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2))
# Set accept_num = PREEMPTED_TOKEN_ID (-9) for first task to trigger abort logic
processor.output_tokens[3, 0].set_tensor(paddle.to_tensor(-9))
processor.output_tokens[4, 0].set_tensor(paddle.to_tensor(1))
# Add second task to tasks_list
task2 = MockTask()
task2.request_id = "test_request_2"
processor.resource_manager.tasks_list = [task, task2]
processor.resource_manager.stop_flags = [False, False]
# Update tokens_counter to include both tasks
processor.tokens_counter[task_id] = 0
processor.tokens_counter[task2.request_id] = 0
# Mock llm_logger to capture the log message and envs.ENABLE_V1_KVCACHE_SCHEDULER
with (
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0),
):
# Call the method
processor._process_batch_output()
# In speculative decoding mode, when accept_num[i] == PREEMPTED_TOKEN_ID,
# the code logs "sync preemption" and continues without triggering abort recycling
# This is the expected behavior for speculative decoding mode
mock_logger.info.assert_any_call(f"sync preemption for request_id {task_id} done.")
# Verify that _recycle_resources was NOT called for the aborted task
# (it may be called for other tasks like test_request_2 if they receive EOS tokens)
for call in processor._recycle_resources.call_args_list:
self.assertNotEqual(
call[0][0], task_id, f"_recycle_resources should not be called for aborted task {task_id}"
)
# Verify that the task is still in abort_req_ids_set
self.assertIn(task_id, processor.resource_manager.abort_req_ids_set)
def test_process_batch_output_aborted_task_negative_token_normal_mode(self):
"""Test aborted task receiving negative token triggers recycling in normal mode"""
processor = self.setup_token_processor(speculative_decoding=False, use_logprobs=False)
# Set up task as aborted
task_id = "test_aborted_request"
task = processor.resource_manager.tasks_list[0]
task.request_id = task_id
processor.resource_manager.abort_req_ids_set = {task_id}
# Add the task to req_dict to prevent _recycle_aborted_task from processing it early
# batch_id should be < batch_size - 1 to avoid early recycling
processor.resource_manager.req_dict[task_id] = (
0 # batch_id = 0, batch_size = 1, so 0 < 0 is false, but 0 >= 0 is true
)
# Actually, let's use a larger batch to avoid the early recycling condition
processor.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64")
# Mock _recycle_resources to track if it's called
processor._recycle_resources = MagicMock()
# Set up output tokens with negative token
# batch = 2 (so batch_id=0 is < batch_size-1=1)
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(2))
# Set negative token for first task (batch_id=0)
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(-1))
# Set positive token for second task (batch_id=1)
processor.output_tokens[3, 0].set_tensor(paddle.to_tensor(100))
# Add second task to tasks_list
task2 = MockTask()
task2.request_id = "test_request_2"
processor.resource_manager.tasks_list = [task, task2]
processor.resource_manager.stop_flags = [False, False]
# Update tokens_counter to include both tasks
processor.tokens_counter[task_id] = 0
processor.tokens_counter[task2.request_id] = 0
# Mock llm_logger to capture the log message and envs.ENABLE_V1_KVCACHE_SCHEDULER
with (
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0),
):
# Call the method
processor._process_batch_output()
# Verify the recycling logic was triggered
mock_logger.info.assert_any_call(f"Aborted task {task_id} received negative token. Recycling.")
processor._recycle_resources.assert_called_once_with(task_id, 0, task)
self.assertNotIn(task_id, processor.resource_manager.abort_req_ids_set)
def test_process_batch_output_non_aborted_task_negative_token(self):
"""Test non-aborted task receiving negative token does not trigger recycling"""
processor = self.setup_token_processor(speculative_decoding=False, use_logprobs=False)
# Set up task as not aborted
task_id = "test_normal_request"
task = processor.resource_manager.tasks_list[0]
task.request_id = task_id
processor.resource_manager.abort_req_ids_set = set() # Empty set
# Mock _recycle_resources to track if it's called
processor._recycle_resources = MagicMock()
# Set up output tokens with negative token
# batch = 1
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(1))
# Set negative token
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(-1))
# Mock llm_logger to capture the log message and envs.ENABLE_V1_KVCACHE_SCHEDULER
with (
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0),
):
# Call the method
processor._process_batch_output()
print(mock_logger)
# Verify the recycling logic was NOT triggered
# When a non-aborted task receives a negative token, the code just continues
# without logging or recycling
processor._recycle_resources.assert_not_called()
if __name__ == "__main__":
unittest.main(verbosity=2, buffer=False)