Files
FastDeploy/tests/output/test_process_batch_output_use_zmq.py
T
2026-03-17 14:06:40 +08:00

213 lines
8.3 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import time
import unittest
from unittest.mock import MagicMock, patch
import numpy as np
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
from fastdeploy.output.token_processor import TokenProcessor
from fastdeploy.worker.output import LogprobsLists
class TestTokenProcessorLogprobs(unittest.TestCase):
def setUp(self):
self.cfg = MagicMock()
self.cfg.model_config.enable_logprob = True
self.cfg.speculative_config.method = None
self.cfg.parallel_config.local_data_parallel_id = 0
self.cached_generated_tokens = MagicMock()
self.engine_worker_queue = MagicMock()
self.split_connector = MagicMock()
self.processor = TokenProcessor(
self.cfg, self.cached_generated_tokens, self.engine_worker_queue, self.split_connector
)
# Mock resource manager
self.processor.resource_manager = MagicMock()
self.processor.resource_manager.stop_flags = [False]
# Create a proper task mock with time attributes
self.task_mock = MagicMock()
self.task_mock.request_id = "test_request"
self.task_mock.pooling_params = None
self.task_mock.messages = None
self.task_mock.disaggregate_info = None
self.task_mock.eos_token_ids = [2]
self.task_mock.ic_req_data = {}
self.task_mock.prompt_token_ids_len = 0
now = time.time()
self.task_mock.metrics = RequestMetrics(
arrival_time=now,
preprocess_start_time=now - 0.2,
preprocess_end_time=now - 0.1,
scheduler_recv_req_time=now + 0.1,
inference_start_time=now + 0.2,
)
self.processor.resource_manager.tasks_list = [self.task_mock]
# Mock logger
self.processor.llm_logger = MagicMock()
# Mock metrics to avoid prometheus dependency issues
self.processor.main_process_metrics = MagicMock()
self.processor._recycle_resources = MagicMock()
# Mock the _process_per_token method to avoid prometheus issues
self.processor._process_per_token = MagicMock()
self.processor._process_per_token.return_value = RequestOutput(
request_id="test_request",
outputs=CompletionOutput(
index=0,
send_idx=0,
token_ids=[],
draft_token_ids=[],
),
finished=False,
metrics=MagicMock(),
)
def test_process_logprobs_success(self):
"""Test successful logprobs parsing"""
stream_data = MagicMock()
logprobs = MagicMock()
logprobs.tolists.return_value = LogprobsLists(
logprobs=[[0.5]], logprob_token_ids=[[1]], sampled_token_ranks=[0]
)
stream_data.logprobs = logprobs
stream_data.tokens = np.array([1])
stream_data.batch_id = 0
result = self.processor._process_batch_output_use_zmq([stream_data])
self.assertEqual(len(result), 1)
self.processor.llm_logger.warning.assert_not_called()
def test_process_logprobs_failure(self):
"""Test failed logprobs parsing"""
stream_data = MagicMock()
stream_data.logprobs = MagicMock()
stream_data.logprobs.tolists.side_effect = Exception("Test error")
stream_data.tokens = np.array([1])
stream_data.batch_id = 0
with patch.object(self.processor.llm_logger, "warning"):
result = self.processor._process_batch_output_use_zmq([stream_data])
self.assertEqual(len(result), 1)
self.assertIsNone(result[0].outputs.logprob)
def test_process_prompt_logprobs_success(self):
"""Test successful prompt_logprobs parsing"""
stream_data = MagicMock()
stream_data.logprobs = None
stream_data.prompt_logprobs = np.array([0.1, 0.2])
stream_data.tokens = np.array([1])
stream_data.batch_id = 0
result = self.processor._process_batch_output_use_zmq([stream_data])
self.assertEqual(len(result), 1)
self.processor.llm_logger.warning.assert_not_called()
def test_process_prompt_logprobs_failure(self):
"""Test failed prompt_logprobs parsing"""
stream_data = MagicMock()
stream_data.logprobs = None
stream_data.prompt_logprobs = MagicMock()
stream_data.prompt_logprobs.tolist.side_effect = AttributeError("'NoneType' object has no attribute 'tolist'")
stream_data.tokens = np.array([1])
stream_data.batch_id = 0
with patch.object(self.processor.llm_logger, "warning"):
result = self.processor._process_batch_output_use_zmq([stream_data])
self.assertEqual(len(result), 1)
self.assertIsNone(getattr(result[0], "prompt_logprobs", None))
def test_process_batch_with_stop_flag(self):
"""Test processing when stop flag is True"""
self.processor.resource_manager.stop_flags = [True]
stream_data = MagicMock()
stream_data.batch_id = 0
result = self.processor._process_batch_output_use_zmq([stream_data])
self.assertEqual(len(result), 0)
def test_process_batch_output_use_zmq_aborted_task_negative_token(self):
"""Test aborted task receiving negative token triggers recycling logic"""
# Set up task as aborted
task_id = "test_aborted_request"
self.task_mock.request_id = task_id
self.processor.resource_manager.to_be_aborted_req_id_set = {task_id}
self.processor.resource_manager.recycle_abort_task = MagicMock(
side_effect=lambda rid: self.processor.resource_manager.to_be_aborted_req_id_set.discard(rid)
)
# Create stream data with negative token (PREEMPTED_TOKEN_ID = -9)
stream_data = MagicMock()
stream_data.tokens = np.array([1, 2, -9]) # Last token is PREEMPTED_TOKEN_ID
stream_data.batch_id = 0
# Mock _recycle_resources to track if it's called
self.processor._recycle_resources = MagicMock()
# Mock the llm_logger module and envs.ENABLE_V1_KVCACHE_SCHEDULER
with (
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1),
):
# Call the method
result = self.processor._process_batch_output_use_zmq([stream_data])
# Verify the recycling logic was triggered
mock_logger.info.assert_any_call(f"start to recycle abort request_id {task_id}")
self.processor.resource_manager.recycle_abort_task.assert_called_once_with(task_id)
self.assertNotIn(task_id, self.processor.resource_manager.to_be_aborted_req_id_set)
self.assertEqual(len(result), 0) # Aborted task is skipped (continue)
def test_process_batch_output_use_zmq_non_aborted_task_negative_token(self):
"""Test non-aborted task receiving negative token does not trigger recycling"""
# Set up task as not aborted
task_id = "test_normal_request"
self.task_mock.request_id = task_id
self.processor.resource_manager.abort_req_ids_set = set() # Empty set
# Create stream data with negative token
stream_data = MagicMock()
stream_data.tokens = np.array([1, 2, -1]) # Last token is negative
stream_data.batch_id = 0
# Mock _recycle_resources to track if it's called
self.processor._recycle_resources = MagicMock()
# Call the method
self.processor._process_batch_output_use_zmq([stream_data])
# Verify recycling logic was NOT triggered
self.processor._recycle_resources.assert_not_called()
self.processor.llm_logger.info.assert_not_called()
if __name__ == "__main__":
unittest.main()