Files
FastDeploy/tests/entrypoints/openai/test_serving_chat.py
T
luukunn 3651113ee5 [DataProcessor]Remove ENABLE_V1_DATA_PROCESSOR (#7052)
* remove ENABLE_V1_DATA_PROCESSOR

* fix unit test

* fix unit test
2026-04-01 09:53:41 +08:00

1672 lines
71 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
import json
import unittest
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import paddle
import fastdeploy.envs as envs
from fastdeploy.engine.request import RequestOutput
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, StreamOptions
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
from fastdeploy.utils import ErrorCode, ParameterError
from fastdeploy.worker.output import Logprob, LogprobsLists, LogprobsTensors
class TestOpenAIServingCompletion(unittest.IsolatedAsyncioTestCase):
def setUp(self):
"""
Set up the test environment by creating an instance of the OpenAIServingChat class using Mock.
"""
self.mock_engine = MagicMock()
self.chat_completion_handler = OpenAIServingChat(
self.mock_engine,
models=None,
pid=123,
ips=None,
max_waiting_time=10,
chat_template=None,
)
def test_build_prompt_logprobs_basic(self):
"""Test basic functionality of _build_prompt_logprobs"""
# Create mock data
num_prompt_tokens = 2
num_logprobs = 3
# Create tensors
token_ids = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype=paddle.int64)
logprobs = paddle.to_tensor([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]], dtype=paddle.float32)
ranks = paddle.to_tensor([1, 2], dtype=paddle.int64)
prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks)
# Mock the data processor
with patch.object(
self.chat_completion_handler.engine_client.data_processor, "process_logprob_response"
) as mock_decode:
mock_decode.side_effect = ["token1", "token2", "token3", "token4", "token5", "token6"]
result = self.chat_completion_handler._build_prompt_logprobs(prompt_logprobs_tensors, num_logprobs, True)
# Verify result structure (first element is None, then actual results)
self.assertEqual(len(result), num_prompt_tokens + 1)
self.assertIsNone(result[0])
# Check first position (index 1 since index 0 is None)
first_pos_result = result[1]
self.assertEqual(len(first_pos_result), num_logprobs)
# Check token IDs and logprobs for first position
expected_tokens = [1, 2, 3]
expected_logprobs = [float(logprobs[0][i]) for i in range(num_logprobs)]
expected_ranks = [1, 1, 2] # First token uses rank from ranks tensor, then topk ranks start from 1
for i, token_id in enumerate(expected_tokens):
self.assertIn(token_id, first_pos_result)
self.assertIsInstance(first_pos_result[token_id], Logprob)
self.assertEqual(first_pos_result[token_id].logprob, expected_logprobs[i])
self.assertEqual(first_pos_result[token_id].rank, expected_ranks[i])
self.assertEqual(first_pos_result[token_id].decoded_token, f"token{i+1}")
def test_build_prompt_logprobs_with_all_logprobs(self):
"""Test _build_prompt_logprobs with num_prompt_logprobs=-1 (all logprobs)"""
num_prompt_tokens = 1
num_logprobs = 2
token_ids = paddle.to_tensor([[10, 20]], dtype=paddle.int64)
logprobs = paddle.to_tensor([[-1.0, -2.0]], dtype=paddle.float32)
ranks = paddle.to_tensor([0], dtype=paddle.int64)
prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks)
with patch.object(
self.chat_completion_handler.engine_client.data_processor, "process_logprob_response"
) as mock_decode:
mock_decode.side_effect = ["hello", "world"]
result = self.chat_completion_handler._build_prompt_logprobs(prompt_logprobs_tensors, -1, True)
self.assertEqual(len(result), num_prompt_tokens + 1)
self.assertIsNone(result[0])
first_pos_result = result[1]
self.assertEqual(len(first_pos_result), num_logprobs)
# Verify all logprobs are included when num_prompt_logprobs=-1
for token_id in first_pos_result:
self.assertIn(token_id, [10, 20])
def test_build_prompt_logprobs_single_token(self):
"""Test _build_prompt_logprobs with single prompt token"""
num_prompt_tokens = 1
num_logprobs = 1
token_ids = paddle.to_tensor([[100]], dtype=paddle.int64)
logprobs = paddle.to_tensor([[-0.5]], dtype=paddle.float32)
ranks = paddle.to_tensor([1], dtype=paddle.int64)
prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks)
with patch.object(
self.chat_completion_handler.engine_client.data_processor, "process_logprob_response"
) as mock_decode:
mock_decode.return_value = "single_token"
result = self.chat_completion_handler._build_prompt_logprobs(prompt_logprobs_tensors, num_logprobs, True)
self.assertEqual(len(result), num_prompt_tokens + 1)
self.assertIsNone(result[0])
first_pos_result = result[1]
self.assertEqual(len(first_pos_result), num_logprobs)
# Check the single token
self.assertIn(100, first_pos_result)
self.assertEqual(first_pos_result[100].logprob, -0.5)
self.assertEqual(first_pos_result[100].rank, 1)
self.assertEqual(first_pos_result[100].decoded_token, "single_token")
def test_build_prompt_logprobs_multiple_positions(self):
"""Test _build_prompt_logprobs with multiple prompt positions"""
num_prompt_tokens = 3
num_logprobs = 2
token_ids = paddle.to_tensor([[1, 2], [3, 4], [5, 6]], dtype=paddle.int64)
logprobs = paddle.to_tensor([[-0.1, -0.2], [-0.3, -0.4], [-0.5, -0.6]], dtype=paddle.float32)
ranks = paddle.to_tensor([1, 2, 3], dtype=paddle.int64)
prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks)
with patch.object(
self.chat_completion_handler.engine_client.data_processor, "process_logprob_response"
) as mock_decode:
mock_decode.side_effect = ["t1", "t2", "t3", "t4", "t5", "t6"]
result = self.chat_completion_handler._build_prompt_logprobs(prompt_logprobs_tensors, num_logprobs, True)
self.assertEqual(len(result), num_prompt_tokens + 1)
self.assertIsNone(result[0])
# Check each position (index + 1 since index 0 is None)
for pos in range(num_prompt_tokens):
pos_result = result[pos + 1]
self.assertEqual(len(pos_result), num_logprobs)
# Verify token IDs and their properties
expected_tokens = [int(token_ids[pos][0]), int(token_ids[pos][1])]
expected_ranks = [
ranks[pos],
1,
] # First token uses rank from ranks tensor, second token uses topk rank 1
for i, token_id in enumerate(expected_tokens):
self.assertIn(token_id, pos_result)
self.assertEqual(pos_result[token_id].logprob, float(logprobs[pos][i]))
self.assertEqual(pos_result[token_id].rank, expected_ranks[i])
self.assertEqual(pos_result[token_id].decoded_token, f"t{pos*2 + i + 1}")
def test_build_prompt_logprobs_empty_tensors(self):
"""Test _build_prompt_logprobs with empty tensors"""
num_prompt_tokens = 0
num_logprobs = 0
token_ids = paddle.to_tensor([], dtype=paddle.int64).reshape([0, 0])
logprobs = paddle.to_tensor([], dtype=paddle.float32).reshape([0, 0])
ranks = paddle.to_tensor([], dtype=paddle.int64)
prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks)
result = self.chat_completion_handler._build_prompt_logprobs(prompt_logprobs_tensors, num_logprobs, True)
self.assertEqual(len(result), num_prompt_tokens + 1)
self.assertIsNone(result[0])
def test_build_prompt_logprobs_no_decode(self):
"""Test _build_prompt_logprobs without decoding tokens."""
token_ids = paddle.to_tensor([[7, 8]], dtype=paddle.int64)
logprobs = paddle.to_tensor([[-0.1, -0.2]], dtype=paddle.float32)
ranks = paddle.to_tensor([1], dtype=paddle.int64)
prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks)
result = self.chat_completion_handler._build_prompt_logprobs(prompt_logprobs_tensors, 2, False)
self.assertIsNone(result[1][7].decoded_token)
def test_build_logprobs_response_decode_and_error(self):
"""Test _build_logprobs_response decode flag and error handling."""
top_logprobs = LogprobsLists(
logprob_token_ids=[[1, 2]],
logprobs=[[-0.1, -0.2]],
sampled_token_ranks=[1],
)
self.assertIsNone(self.chat_completion_handler._create_chat_logprobs([], True, 1, True))
self.assertIsNone(self.chat_completion_handler._build_logprobs_response(True, top_logprobs, -1, True))
res = self.chat_completion_handler._build_logprobs_response(True, top_logprobs, 0, False)
self.assertEqual(res.content[0].token, "")
self.assertEqual(res.content[0].bytes, [])
with patch.object(
self.chat_completion_handler.engine_client.data_processor,
"process_logprob_response",
side_effect=ValueError,
):
self.assertIsNone(self.chat_completion_handler._build_logprobs_response(True, top_logprobs, 0, True))
with patch.object(
self.chat_completion_handler.engine_client.data_processor,
"process_logprob_response",
side_effect=["\ufffd", "ok"],
):
multi_steps = [
[[1], [2]],
[[-0.1], [-0.2]],
[1, 1],
]
res = self.chat_completion_handler._create_chat_logprobs(multi_steps, True, 0, True)
self.assertEqual(len(res.content), 2)
self.assertTrue(res.content[0].token.startswith("bytes:"))
def test_init_master_ip_list_and_string(self):
"""Test master ip selection for list and string inputs."""
with patch("fastdeploy.entrypoints.openai.serving_chat.get_host_ip", return_value="1.2.3.4"):
handler_list = OpenAIServingChat(
self.mock_engine,
models=None,
pid=123,
ips=["1.2.3.4", "2.2.2.2"],
max_waiting_time=10,
chat_template=None,
)
self.assertEqual(handler_list.master_ip, "1.2.3.4")
self.assertTrue(handler_list.is_master_ip)
handler_str = OpenAIServingChat(
self.mock_engine,
models=None,
pid=123,
ips="5.5.5.5,6.6.6.6",
max_waiting_time=10,
chat_template=None,
)
self.assertEqual(handler_str.master_ip, "5.5.5.5")
self.assertFalse(handler_str.is_master_ip)
async def test_create_chat_completion_master_and_model_errors(self):
"""Test create_chat_completion master check and unsupported model."""
request = ChatCompletionRequest(messages=[{"role": "user", "content": "Hello"}], stream=False)
self.chat_completion_handler.engine_client.is_master = False
self.chat_completion_handler.is_master_ip = False
resp = await self.chat_completion_handler.create_chat_completion(request)
self.assertIn("Only master node", resp.error.message)
models = MagicMock()
models.is_supported_model.return_value = (False, "bad")
models.model_paths = [SimpleNamespace(name="model_a")]
handler = OpenAIServingChat(
self.mock_engine,
models=models,
pid=123,
ips=None,
max_waiting_time=10,
chat_template=None,
)
handler.engine_client.is_master = True
resp = await handler.create_chat_completion(
ChatCompletionRequest(messages=[{"role": "user", "content": "Hello"}], model="bad", stream=False)
)
self.assertEqual(resp.error.code, ErrorCode.MODEL_NOT_SUPPORT)
async def test_create_chat_completion_request_id_and_v1_stream(self):
"""Test request_id prefix and v1 data processor path."""
self.chat_completion_handler.engine_client.is_master = True
self.chat_completion_handler.max_waiting_time = -1
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
self.chat_completion_handler.engine_client.semaphore.status = Mock(return_value="ok")
self.chat_completion_handler.engine_client.format_and_add_data = AsyncMock(
side_effect=ParameterError("param", "bad")
)
with patch("fastdeploy.entrypoints.openai.serving_chat.tracing.trace_req_start") as mock_trace:
resp = await self.chat_completion_handler.create_chat_completion(
ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello"}],
request_id="abc",
stream=False,
)
)
self.assertEqual(resp.error.param, "param")
self.assertIn("bad", resp.error.message)
self.assertEqual(mock_trace.call_args.kwargs["rid"], "chatcmpl-abc")
self.chat_completion_handler.engine_client.format_and_add_data = AsyncMock(side_effect=RuntimeError("boom"))
with patch("fastdeploy.entrypoints.openai.serving_chat.tracing.trace_req_start"):
resp = await self.chat_completion_handler.create_chat_completion(
ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello"}],
request_id="err",
stream=False,
)
)
self.assertIn("generator error", resp.error.message)
async def test_create_chat_completion_full_and_waiting_errors(self):
"""Test full generator error and waiting error handling."""
self.chat_completion_handler.engine_client.is_master = True
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
self.chat_completion_handler.engine_client.semaphore.status = Mock(return_value="ok")
self.chat_completion_handler.engine_client.format_and_add_data = AsyncMock(return_value=[1, 2])
with patch.object(
self.chat_completion_handler,
"chat_completion_full_generator",
AsyncMock(side_effect=RuntimeError("boom")),
):
resp = await self.chat_completion_handler.create_chat_completion(
ChatCompletionRequest(messages=[{"role": "user", "content": "Hello"}], stream=False)
)
self.assertIn("full generator error", resp.error.message)
with patch(
"fastdeploy.entrypoints.openai.serving_chat.tracing.trace_req_start", side_effect=RuntimeError("boom")
):
resp = await self.chat_completion_handler.create_chat_completion(
ChatCompletionRequest(messages=[{"role": "user", "content": "Hello"}], request_id="rid", stream=False)
)
self.assertEqual(resp.error.code, ErrorCode.TIMEOUT)
async def test_create_chat_completion_choice_audio_recover(self):
"""Test _create_chat_completion_choice audio content and recover finish_reason."""
response_processor = MagicMock()
response_processor.enable_multimodal_content.return_value = True
data = {
"request_id": "req_0",
"metrics": {"request_start_time": 1.0},
"error_msg": "Recover by flag",
"num_cached_tokens": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
"outputs": {
"text": "hi",
"metrics": {"request_start_time": 1.0},
"reasoning_content": "",
"tool_calls": None,
"completion_tokens": "1",
"audio_content": "sound",
"multipart": [{"type": "text", "text": "hi"}],
},
}
choice = await self.chat_completion_handler._create_chat_completion_choice(
data=data,
request=ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello"}],
return_token_ids=True,
),
prompt_token_ids=[1, 2],
prompt_tokens="hello",
completion_token_ids=[3],
previous_num_tokens=1,
num_cached_tokens=[0],
num_input_image_tokens=[0],
num_input_video_tokens=[0],
num_image_tokens=[0],
logprob_contents=[[]],
draft_logprob_contents=[[]],
prompt_logprobs_res_list=[[]],
response_processor=response_processor,
max_tokens=2,
speculate_metrics=None,
)
self.assertEqual(choice.finish_reason, "recover_stop")
self.assertEqual(choice.message.audio_content, "sound")
data_length = {**data, "error_msg": None}
choice_length = await self.chat_completion_handler._create_chat_completion_choice(
data=data_length,
request=ChatCompletionRequest(messages=[{"role": "user", "content": "Hello"}]),
prompt_token_ids=[1, 2],
prompt_tokens="hello",
completion_token_ids=[3],
previous_num_tokens=2,
num_cached_tokens=[0],
num_input_image_tokens=[0],
num_input_video_tokens=[0],
num_image_tokens=[0],
logprob_contents=[[]],
draft_logprob_contents=[[]],
prompt_logprobs_res_list=[[]],
response_processor=response_processor,
max_tokens=2,
speculate_metrics=None,
)
self.assertEqual(choice_length.finish_reason, "length")
def test_make_logprob_dict(self):
"""Test the static method _make_logprob_dict"""
logprobs = [-0.1, -0.2, -0.3]
logprob_token_ids = [1, 2, 3]
decoded_tokens = ["token1", "token2", "token3"]
rank = 1
num_logprobs = 3
result = OpenAIServingChat._make_logprob_dict(logprobs, logprob_token_ids, decoded_tokens, rank, num_logprobs)
self.assertEqual(len(result), num_logprobs)
# Check first token (sampled token)
self.assertIn(1, result)
self.assertEqual(result[1].logprob, -0.1)
self.assertEqual(result[1].rank, rank) # rank of sampled token
self.assertEqual(result[1].decoded_token, "token1")
# Check other tokens - topk ranks start from 1
expected_ranks = [rank, 1, 2] # First token uses rank, then topk ranks
for i, token_id in enumerate(logprob_token_ids):
self.assertIn(token_id, result)
self.assertEqual(result[token_id].logprob, logprobs[i])
self.assertEqual(result[token_id].rank, expected_ranks[i])
self.assertEqual(result[token_id].decoded_token, decoded_tokens[i])
def test_make_logprob_dict_with_negative_num_logprobs(self):
"""Test _make_logprob_dict with num_logprobs=-1"""
logprobs = [-0.1, -0.2]
logprob_token_ids = [1, 2]
decoded_tokens = ["token1", "token2"]
rank = 1
num_logprobs = -1
result = OpenAIServingChat._make_logprob_dict(logprobs, logprob_token_ids, decoded_tokens, rank, num_logprobs)
# Should include all logprobs when num_logprobs=-1
self.assertEqual(len(result), len(logprobs))
# Expected ranks: first token uses rank, second token uses topk rank 1
expected_ranks = [rank, 1]
for i, token_id in enumerate(logprob_token_ids):
self.assertIn(token_id, result)
self.assertEqual(result[token_id].logprob, logprobs[i])
self.assertEqual(result[token_id].rank, expected_ranks[i])
self.assertEqual(result[token_id].decoded_token, decoded_tokens[i])
def test_make_logprob_dict_partial_logprobs(self):
"""Test _make_logprob_dict with fewer logprobs than available"""
logprobs = [-0.1, -0.2, -0.3, -0.4]
logprob_token_ids = [1, 2, 3, 4]
decoded_tokens = ["token1", "token2", "token3", "token4"]
rank = 2
num_logprobs = 2
result = OpenAIServingChat._make_logprob_dict(logprobs, logprob_token_ids, decoded_tokens, rank, num_logprobs)
self.assertEqual(len(result), 3)
# Check sampled token (first token)
self.assertIn(1, result)
self.assertEqual(result[1].logprob, -0.1)
self.assertEqual(result[1].rank, rank)
self.assertEqual(result[1].decoded_token, "token1")
# Check top-k token (second token)
self.assertIn(2, result)
self.assertEqual(result[2].logprob, -0.2)
self.assertEqual(result[2].rank, 1) # topk rank starts from 1
self.assertEqual(result[2].decoded_token, "token2")
async def test_chat_completion_stream_generator_with_prompt_logprobs(self):
"""Test chat_completion_stream_generator with prompt_logprobs enabled"""
# Create mock request with prompt_logprobs enabled
request = ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello"}], prompt_logprobs=3, logprobs=False, stream=True
)
request_id = "test_request_123"
model_name = "test_model"
prompt_token_ids = [1, 2, 3]
prompt_tokens = "Hello world"
# Mock the connection manager and response queue
mock_dealer = MagicMock()
mock_response_queue = AsyncMock()
# Create mock response with prompt_logprobs data
mock_response = {
"request_id": f"{request_id}_0",
"error_code": 200,
"metrics": {
"first_token_time": 1234567890,
"inference_start_time": 1234567880,
"arrival_time": 1234567890,
"request_start_time": 1234567870,
},
"prompt_logprobs": LogprobsTensors(
logprob_token_ids=paddle.to_tensor([[1, 2, 3, 4]], dtype=paddle.int64),
logprobs=paddle.to_tensor([[-0.1, -0.2, -0.3, -0.4]], dtype=paddle.float32),
selected_token_ranks=paddle.to_tensor([1], dtype=paddle.int64),
),
"outputs": {
"token_ids": [5],
"text": "Hi",
"top_logprobs": None,
"draft_top_logprobs": None,
"multipart": [{"type": "text", "text": "Hi"}],
},
"finished": True,
"num_cached_tokens": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
}
mock_response = RequestOutput.from_dict(mock_response)
mock_response_queue.get.return_value = mock_response
# Mock the connection manager
self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock(
return_value=(mock_dealer, mock_response_queue)
)
# Mock the semaphore
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
# Mock the model weight status check
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
# Mock the response processor
mock_response_processor = MagicMock()
mock_response_processor.enable_multimodal_content.return_value = False
async def mock_async_generator():
yield mock_response
mock_response_processor.process_response_chat.return_value = mock_async_generator()
# Mock the cleanup method
self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock()
with patch(
"fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor
):
with patch.object(
self.chat_completion_handler.engine_client.data_processor, "process_logprob_response"
) as mock_decode:
mock_decode.side_effect = ["Hello", "world", "test", "token"]
# Execute the generator
results = []
async for chunk in self.chat_completion_handler.chat_completion_stream_generator(
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=3
):
results.append(chunk)
# Verify that prompt_logprobs are included in the response
self.assertGreater(len(results), 0)
# Check that the first chunk contains prompt_logprobs
first_chunk_data = json.loads(results[0].replace("data: ", "").strip())
self.assertIn("choices", first_chunk_data)
self.assertEqual(len(first_chunk_data["choices"]), 1)
choice = first_chunk_data["choices"][0]
self.assertIn("prompt_logprobs", choice)
self.assertIsNotNone(choice["prompt_logprobs"])
# Verify prompt_logprobs structure
prompt_logprobs = choice["prompt_logprobs"]
self.assertIsInstance(prompt_logprobs, list)
self.assertGreater(len(prompt_logprobs), 0)
async def test_chat_completion_stream_generator_with_logprobs(self):
"""Test chat_completion_stream_generator with logprobs enabled"""
# Create mock request with logprobs enabled
request = ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello"}],
prompt_logprobs=None,
logprobs=True,
top_logprobs=2,
stream=True,
)
request_id = "test_request_456"
model_name = "test_model"
prompt_token_ids = [1, 2, 3]
prompt_tokens = "Hello world"
# Mock the connection manager and response queue
mock_dealer = MagicMock()
mock_response_queue = AsyncMock()
# Create mock response with logprobs data
mock_response = {
"request_id": f"{request_id}_0",
"error_code": 200,
"metrics": {
"first_token_time": 1234567890,
"inference_start_time": 1234567880,
"arrival_time": 1234567890,
"request_start_time": 1234567870,
},
"prompt_logprobs": None,
"outputs": {
"token_ids": [5],
"text": "Hi",
"top_logprobs": [
[[5, 6]], # logprob_token_ids
[[-0.1, -0.2]], # logprobs
[1], # sampled_token_ranks
],
"draft_top_logprobs": None,
"multipart": [{"type": "text", "text": "Hi"}],
},
"finished": True,
"num_cached_tokens": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
}
mock_response = RequestOutput.from_dict(mock_response)
mock_response_queue.get.return_value = mock_response
# Mock the connection manager
self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock(
return_value=(mock_dealer, mock_response_queue)
)
# Mock the semaphore
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
# Mock the model weight status check
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
# Mock the response processor
mock_response_processor = MagicMock()
mock_response_processor.enable_multimodal_content.return_value = False
async def mock_async_generator():
yield mock_response
mock_response_processor.process_response_chat.return_value = mock_async_generator()
# Mock the cleanup method
self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock()
# Mock the data processor for logprob response
with patch(
"fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor
):
with patch.object(
self.chat_completion_handler.engine_client.data_processor, "process_logprob_response"
) as mock_decode:
mock_decode.return_value = "Hi"
# Execute the generator
results = []
async for chunk in self.chat_completion_handler.chat_completion_stream_generator(
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100
):
results.append(chunk)
# Verify that logprobs are included in the response
self.assertGreater(len(results), 0)
# Find chunks that contain logprobs
logprobs_chunks = []
for result in results:
if "logprobs" in result:
logprobs_chunks.append(result)
self.assertGreater(len(logprobs_chunks), 0)
# Check logprobs structure in response
for chunk in logprobs_chunks:
chunk_data = json.loads(chunk.replace("data: ", "").strip())
if "choices" in chunk_data and len(chunk_data["choices"]) > 0:
choice = chunk_data["choices"][0]
if "logprobs" in choice:
self.assertIsNotNone(choice["logprobs"])
async def test_chat_completion_stream_generator_with_both_logprobs(self):
"""Test chat_completion_stream_generator with both prompt_logprobs and logprobs enabled"""
# Create mock request with both logprobs enabled
request = ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello"}],
prompt_logprobs=2,
logprobs=True,
top_logprobs=2,
stream=True,
include_draft_logprobs=True,
return_token_ids=True,
collect_metrics=True,
stream_options=StreamOptions(include_usage=True, continuous_usage_stats=True),
)
request_id = "test_request_789"
model_name = "test_model"
prompt_token_ids = [1, 2, 3]
prompt_tokens = "Hello world"
# Mock the connection manager and response queue
mock_dealer = MagicMock()
mock_response_queue = AsyncMock()
# Create mock response with both logprobs data
base_outputs = {
"token_ids": [5],
"text": "Hi",
"top_logprobs": [
[[5, 6]], # logprob_token_ids
[[-0.1, -0.2]], # logprobs
[1], # sampled_token_ranks
],
"draft_top_logprobs": [
[[5, 6]],
[[-0.2, -0.3]],
[1],
],
"multipart": [{"type": "text", "text": "Hi"}],
"reasoning_content": "",
"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "tool", "arguments": "{}"}}],
"audio_content": "aud",
"num_image_tokens": 2,
"completion_tokens": "1",
"skipped": True,
}
mock_response = {
"request_id": f"{request_id}_0",
"error_code": 200,
"metrics": {
"first_token_time": 1234567890,
"inference_start_time": 1234567880,
"arrival_time": 1234567890,
"request_start_time": 1234567870,
"engine_recv_latest_token_time": 1234567891,
},
"prompt_logprobs": LogprobsTensors(
logprob_token_ids=paddle.to_tensor([[1, 2, 3]], dtype=paddle.int64),
logprobs=paddle.to_tensor([[-0.1, -0.2, -0.3]], dtype=paddle.float32),
selected_token_ranks=paddle.to_tensor([1], dtype=paddle.int64),
),
"outputs": base_outputs,
"finished": False,
"num_cached_tokens": 0,
"num_input_image_tokens": 1,
"num_input_video_tokens": 0,
}
mock_response_final = {
**mock_response,
"outputs": {**base_outputs, "skipped": False},
"finished": True,
"error_msg": "Recover by flag",
"trace_carrier": "trace",
}
mock_response_queue.get.return_value = mock_response
# Mock the connection manager
self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock(
return_value=(mock_dealer, mock_response_queue)
)
# Mock the semaphore
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
# Mock the model weight status check
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
# Mock the response processor
mock_response_processor = MagicMock()
mock_response_processor.enable_multimodal_content.return_value = True
async def mock_async_generator():
yield mock_response
yield mock_response_final
mock_response_processor.process_response_chat.return_value = mock_async_generator()
# Mock the cleanup method
self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock()
with patch(
"fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor
):
with patch.object(
self.chat_completion_handler.engine_client.data_processor, "process_logprob_response"
) as mock_decode:
mock_decode.return_value = "Hi"
# Execute the generator
results = []
async for chunk in self.chat_completion_handler.chat_completion_stream_generator(
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100
):
results.append(chunk)
# Verify that both types of logprobs are included
self.assertGreater(len(results), 0)
# Check for prompt_logprobs
first_chunk_data = json.loads(results[0].replace("data: ", "").strip())
self.assertIn("choices", first_chunk_data)
choice = first_chunk_data["choices"][0]
self.assertIn("prompt_logprobs", choice)
self.assertIsNotNone(choice["prompt_logprobs"])
# Check for logprobs in subsequent chunks
logprobs_found = False
for result in results:
# Skip [DONE] message
if result.strip() == "data: [DONE]":
continue
chunk_data = json.loads(result.replace("data: ", "").strip())
if "choices" in chunk_data and len(chunk_data["choices"]) > 0:
choice = chunk_data["choices"][0]
if "logprobs" in choice and choice["logprobs"] is not None:
logprobs_found = True
break
self.assertTrue(logprobs_found, "logprobs should be found in response chunks")
async def test_chat_completion_stream_generator_without_logprobs(self):
"""Test chat_completion_stream_generator without logprobs enabled"""
# Create mock request without logprobs
request = ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello"}],
prompt_logprobs=None,
logprobs=False,
stream=True,
return_token_ids=True,
)
request_id = "test_request_no_logprobs"
model_name = "test_model"
prompt_token_ids = [1, 2, 3]
prompt_tokens = "Hello world"
# Mock the connection manager and response queue
mock_dealer = MagicMock()
mock_response_queue = AsyncMock()
# Create mock response without logprobs data
mock_response = {
"request_id": f"{request_id}_0",
"error_code": 200,
"metrics": {
"first_token_time": 1234567890,
"inference_start_time": 1234567880,
"arrival_time": 1234567890,
"request_start_time": 1234567870,
},
"prompt_logprobs": None,
"outputs": {
"token_ids": [5],
"text": "Hi",
"top_logprobs": None,
"draft_top_logprobs": None,
"multipart": [{"type": "text", "text": "Hi"}],
"skipped": False,
},
"finished": True,
"num_cached_tokens": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
}
mock_response_queue.get.return_value = mock_response
# Mock the connection manager
self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock(
return_value=(mock_dealer, mock_response_queue)
)
# Mock the semaphore
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
# Mock the model weight status check
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
# Mock the response processor
mock_response_processor = MagicMock()
mock_response_processor.enable_multimodal_content.return_value = False
async def mock_async_generator():
yield mock_response
mock_response_processor.process_response_chat.return_value = mock_async_generator()
# Mock the cleanup method
self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock()
with patch(
"fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor
):
# Execute the generator
results = []
async for chunk in self.chat_completion_handler.chat_completion_stream_generator(
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100
):
results.append(chunk)
# Verify that logprobs are not included in the response
self.assertGreater(len(results), 0)
for result in results:
# Skip [DONE] message
if result.strip() == "data: [DONE]":
continue
chunk_data = json.loads(result.replace("data: ", "").strip())
if "choices" in chunk_data and len(chunk_data["choices"]) > 0:
choice = chunk_data["choices"][0]
# prompt_logprobs should be None when not requested
self.assertIsNone(choice.get("prompt_logprobs"))
# logprobs should be None when not requested
self.assertIsNone(choice.get("logprobs"))
async def test_chat_completion_full_generator_with_prompt_logprobs(self):
"""Test chat_completion_full_generator with prompt_logprobs enabled"""
# Create mock request with prompt_logprobs enabled
request = ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello"}], prompt_logprobs=3, logprobs=False, stream=False
)
request_id = "test_request_full_123"
model_name = "test_model"
prompt_token_ids = [1, 2, 3]
prompt_tokens = "Hello world"
# Mock the connection manager and response queue
mock_dealer = MagicMock()
mock_response_queue = AsyncMock()
# Create mock response with prompt_logprobs data
mock_response = {
"request_id": f"{request_id}_0",
"error_code": 200,
"metrics": {
"first_token_time": 1234567890,
"inference_start_time": 1234567880,
"arrival_time": 1234567890,
"request_start_time": 1234567870,
},
"prompt_logprobs": LogprobsTensors(
logprob_token_ids=paddle.to_tensor([[1, 2, 3, 4]], dtype=paddle.int64),
logprobs=paddle.to_tensor([[-0.1, -0.2, -0.3, -0.4]], dtype=paddle.float32),
selected_token_ranks=paddle.to_tensor([1], dtype=paddle.int64),
),
"outputs": {
"token_ids": [5],
"text": "Hi",
"top_logprobs": None,
"draft_top_logprobs": None,
"multipart": [{"type": "text", "text": "Hi"}],
},
"finished": True,
"num_cached_tokens": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
}
mock_response = RequestOutput.from_dict(mock_response)
mock_response_queue.get.return_value = mock_response
# Mock the connection manager
self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock(
return_value=(mock_dealer, mock_response_queue)
)
# Mock the semaphore
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
# Mock the model weight status check
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
# Mock the response processor
mock_response_processor = MagicMock()
mock_response_processor.enable_multimodal_content.return_value = False
async def mock_async_generator():
yield mock_response
mock_response_processor.process_response_chat.return_value = mock_async_generator()
# Mock the cleanup method
self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock()
with patch(
"fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor
):
with patch.object(
self.chat_completion_handler.engine_client.data_processor, "process_logprob_response"
) as mock_decode:
mock_decode.side_effect = ["Hello", "world", "test", "token"]
# Execute the generator
result = await self.chat_completion_handler.chat_completion_full_generator(
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100
)
# Verify that prompt_logprobs are included in the response
self.assertIsNotNone(result)
self.assertIn("choices", result.model_dump())
self.assertGreater(len(result.choices), 0)
choice = result.choices[0]
self.assertIn("prompt_logprobs", choice.model_dump())
self.assertIsNotNone(choice.prompt_logprobs)
# Verify prompt_logprobs structure
prompt_logprobs = choice.prompt_logprobs
self.assertIsInstance(prompt_logprobs, list)
self.assertGreater(len(prompt_logprobs), 0)
async def test_chat_completion_full_generator_with_logprobs(self):
"""Test chat_completion_full_generator with logprobs enabled"""
# Create mock request with logprobs enabled
request = ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello"}],
prompt_logprobs=None,
logprobs=True,
top_logprobs=2,
stream=False,
)
request_id = "test_request_full_456"
model_name = "test_model"
prompt_token_ids = [1, 2, 3]
prompt_tokens = "Hello world"
# Mock the connection manager and response queue
mock_dealer = MagicMock()
mock_response_queue = AsyncMock()
# Create mock response with logprobs data
mock_response = {
"request_id": f"{request_id}_0",
"error_code": 200,
"metrics": {
"first_token_time": 1234567890,
"inference_start_time": 1234567880,
"arrival_time": 1234567890,
"request_start_time": 1234567870,
},
"prompt_logprobs": None,
"outputs": {
"token_ids": [5],
"text": "Hi",
"top_logprobs": [
[[5, 6]], # logprob_token_ids
[[-0.1, -0.2]], # logprobs
[1], # sampled_token_ranks
],
"draft_top_logprobs": None,
"multipart": [{"type": "text", "text": "Hi"}],
},
"finished": True,
"num_cached_tokens": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
}
mock_response = RequestOutput.from_dict(mock_response)
mock_response_queue.get.return_value = mock_response
# Mock the connection manager
self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock(
return_value=(mock_dealer, mock_response_queue)
)
# Mock the semaphore
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
# Mock the model weight status check
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
# Mock the response processor
mock_response_processor = MagicMock()
mock_response_processor.enable_multimodal_content.return_value = False
async def mock_async_generator():
yield mock_response
mock_response_processor.process_response_chat.return_value = mock_async_generator()
# Mock the cleanup method
self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock()
# Mock the data processor for logprob response
with patch(
"fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor
):
with patch.object(
self.chat_completion_handler.engine_client.data_processor, "process_logprob_response"
) as mock_decode:
mock_decode.return_value = "Hi"
# Execute the generator
result = await self.chat_completion_handler.chat_completion_full_generator(
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100
)
# Verify that logprobs are included in the response
self.assertIsNotNone(result)
self.assertIn("choices", result.model_dump())
self.assertGreater(len(result.choices), 0)
choice = result.choices[0]
self.assertIn("logprobs", choice.model_dump())
self.assertIsNotNone(choice.logprobs)
async def test_chat_completion_full_generator_with_both_logprobs(self):
"""Test chat_completion_full_generator with both prompt_logprobs and logprobs enabled"""
# Create mock request with both logprobs enabled
request = ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello"}],
prompt_logprobs=2,
logprobs=True,
top_logprobs=2,
stream=False,
include_draft_logprobs=True,
)
request_id = "test_request_full_789"
model_name = "test_model"
prompt_token_ids = [1, 2, 3]
prompt_tokens = "Hello world"
# Mock the connection manager and response queue
mock_dealer = MagicMock()
mock_response_queue = AsyncMock()
# Create mock response with both logprobs data
mock_response = {
"request_id": f"{request_id}_0",
"error_code": 200,
"metrics": {
"first_token_time": 1234567890,
"inference_start_time": 1234567880,
"arrival_time": 1234567890,
"request_start_time": 1234567870,
"engine_recv_latest_token_time": 1234567891,
},
"prompt_logprobs": LogprobsTensors(
logprob_token_ids=paddle.to_tensor([[1, 2, 3]], dtype=paddle.int64),
logprobs=paddle.to_tensor([[-0.1, -0.2, -0.3]], dtype=paddle.float32),
selected_token_ranks=paddle.to_tensor([1], dtype=paddle.int64),
),
"outputs": {
"token_ids": [5],
"text": "Hi",
"top_logprobs": [
[[5, 6]], # logprob_token_ids
[[-0.1, -0.2]], # logprobs
[1], # sampled_token_ranks
],
"draft_top_logprobs": [
[[5, 6]],
[[-0.2, -0.3]],
[1],
],
"multipart": [{"type": "text", "text": "Hi"}],
"image_token_num": 2,
},
"trace_carrier": "trace",
"finished": True,
"num_cached_tokens": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
}
mock_response = RequestOutput.from_dict(mock_response)
mock_response_queue.get.return_value = mock_response
# Mock the connection manager
self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock(
return_value=(mock_dealer, mock_response_queue)
)
# Mock the semaphore
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
# Mock the model weight status check
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
# Mock the response processor
mock_response_processor = MagicMock()
mock_response_processor.enable_multimodal_content.return_value = False
async def mock_async_generator():
yield mock_response
mock_response_processor.process_response_chat.return_value = mock_async_generator()
# Mock the cleanup method
self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock()
with patch(
"fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor
):
with patch.object(
self.chat_completion_handler.engine_client.data_processor, "process_logprob_response"
) as mock_decode:
mock_decode.return_value = "Hi"
# Execute the generator
result = await self.chat_completion_handler.chat_completion_full_generator(
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100
)
# Verify that both types of logprobs are included
self.assertIsNotNone(result)
self.assertIn("choices", result.model_dump())
self.assertGreater(len(result.choices), 0)
choice = result.choices[0]
# Check for prompt_logprobs
self.assertIn("prompt_logprobs", choice.model_dump())
self.assertIsNotNone(choice.prompt_logprobs)
# Check for logprobs
self.assertIn("logprobs", choice.model_dump())
self.assertIsNotNone(choice.logprobs)
async def test_chat_completion_full_generator_without_logprobs(self):
"""Test chat_completion_full_generator without logprobs enabled"""
# Create mock request without logprobs
request = ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello"}], prompt_logprobs=None, logprobs=False, stream=False
)
request_id = "test_request_full_no_logprobs"
model_name = "test_model"
prompt_token_ids = [1, 2, 3]
prompt_tokens = "Hello world"
# Mock the connection manager and response queue
mock_dealer = MagicMock()
mock_response_queue = AsyncMock()
# Create mock response without logprobs data
mock_response = {
"request_id": f"{request_id}_0",
"error_code": 200,
"metrics": {
"first_token_time": 1234567890,
"inference_start_time": 1234567880,
"arrival_time": 1234567890,
"request_start_time": 1234567870,
},
"prompt_logprobs": None,
"outputs": {
"token_ids": [5],
"text": "Hi",
"top_logprobs": None,
"draft_top_logprobs": None,
"multipart": [{"type": "text", "text": "Hi"}],
},
"finished": True,
"num_cached_tokens": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
}
mock_response = RequestOutput.from_dict(mock_response)
mock_response_queue.get.return_value = mock_response
# Mock the connection manager
self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock(
return_value=(mock_dealer, mock_response_queue)
)
# Mock the semaphore
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
# Mock the model weight status check
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
# Mock the response processor
mock_response_processor = MagicMock()
mock_response_processor.enable_multimodal_content.return_value = False
async def mock_async_generator():
yield mock_response
mock_response_processor.process_response_chat.return_value = mock_async_generator()
# Mock the cleanup method
self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock()
with patch(
"fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor
):
# Execute the generator
result = await self.chat_completion_handler.chat_completion_full_generator(
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100
)
# Verify that logprobs are not included in the response
self.assertIsNotNone(result)
self.assertIn("choices", result.model_dump())
self.assertGreater(len(result.choices), 0)
choice = result.choices[0]
# prompt_logprobs should be None when not requested
self.assertIsNone(choice.prompt_logprobs)
# logprobs should be None when not requested
self.assertIsNone(choice.logprobs)
async def test_create_chat_completion_cancelled_error(self):
"""Test asyncio.CancelledError handling in create_chat_completion method"""
# Create mock request
request = ChatCompletionRequest(messages=[{"role": "user", "content": "Hello"}], stream=False)
# Mock the semaphore
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
# Mock the model weight status check
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
# Mock format_and_add_data to raise CancelledError
self.chat_completion_handler.engine_client.format_and_add_data = AsyncMock(
side_effect=asyncio.CancelledError("Test cancellation during data formatting")
)
# Mock the abort method that should be called when CancelledError occurs
self.chat_completion_handler.engine_client.abort = AsyncMock()
# Execute and verify that CancelledError is handled properly
# The CancelledError should be caught and handled, not re-raised
try:
await self.chat_completion_handler.create_chat_completion(request)
except asyncio.CancelledError:
# This should not happen as CancelledError should be caught and handled
self.fail("CancelledError should be caught and handled, not re-raised")
# Verify abort was called despite the cancellation
self.chat_completion_handler.engine_client.abort.assert_called_once()
async def test_chat_completion_stream_generator_cancelled_error(self):
"""Test asyncio.CancelledError handling in chat_completion_stream_generator method"""
# Create mock request
request = ChatCompletionRequest(messages=[{"role": "user", "content": "Hello"}], stream=True)
request_id = "test_cancel_request"
model_name = "test_model"
prompt_token_ids = [1, 2, 3]
prompt_tokens = "Hello world"
# Mock the connection manager
mock_dealer = MagicMock()
mock_response_queue = AsyncMock()
# Mock get_connection to return normally
self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock(
return_value=(mock_dealer, mock_response_queue)
)
# Mock the semaphore
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
# Mock the model weight status check
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
# Mock the response processor to raise CancelledError during processing
mock_response_processor = MagicMock()
mock_response_processor.enable_multimodal_content.return_value = False
async def mock_async_generator_with_cancel():
# Simulate some normal response first
yield {
"request_id": f"{request_id}_0",
"error_code": 200,
"metrics": {
"first_token_time": 1234567890,
"inference_start_time": 1234567880,
"arrival_time": 1234567890,
"request_start_time": 1234567870,
},
"prompt_logprobs": None,
"outputs": {
"token_ids": [5],
"text": "Hi",
"top_logprobs": None,
"draft_top_logprobs": None,
"multipart": [{"type": "text", "text": "Hi"}],
"enable_parser": False,
"reasoning_content": "",
"tool_calls": None,
"skipped": False,
},
"finished": False,
"num_cached_tokens": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
}
# Then raise CancelledError
raise asyncio.CancelledError("Test cancellation during streaming")
mock_response_processor.process_response_chat.return_value = mock_async_generator_with_cancel()
# Mock the cleanup method
self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock()
# Mock the abort method that should be called when CancelledError occurs
self.chat_completion_handler.engine_client.abort = AsyncMock()
with patch(
"fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor
):
# Execute the generator and verify CancelledError handling
# The CancelledError should be caught and handled, not re-raised
chunks = []
try:
async for chunk in self.chat_completion_handler.chat_completion_stream_generator(
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100
):
chunks.append(chunk)
except asyncio.CancelledError:
# This should not happen as CancelledError should be caught and handled
self.fail("CancelledError should be caught and handled, not re-raised")
# Should have received at least one chunk before cancellation
self.assertGreaterEqual(len(chunks), 1)
self.assertIsNotNone(chunks[0])
# Verify cleanup and abort were called despite the cancellation
self.chat_completion_handler.engine_client.connection_manager.cleanup_request.assert_called_once()
self.chat_completion_handler.engine_client.abort.assert_called_once()
async def test_chat_completion_stream_generator_dealer_mode_writes(self):
"""Cover lines 261-264: dealer.write in non-batch mode for stream generator."""
request = ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello"}],
stream=True,
)
request_id = "test_dealer_stream"
model_name = "test_model"
prompt_token_ids = [1, 2, 3]
prompt_tokens = "Hello world"
mock_dealer = MagicMock()
mock_response_queue = AsyncMock()
mock_response = {
"request_id": f"{request_id}_0",
"error_code": 200,
"metrics": {
"first_token_time": 1234567890,
"inference_start_time": 1234567880,
"arrival_time": 1234567890,
"request_start_time": 1234567870,
},
"prompt_logprobs": None,
"outputs": {
"token_ids": [5],
"text": "Hi",
"top_logprobs": None,
"draft_top_logprobs": None,
"multipart": [{"type": "text", "text": "Hi"}],
"skipped": False,
},
"finished": True,
"num_cached_tokens": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
}
mock_response_queue.get.return_value = mock_response
self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock(
return_value=(mock_dealer, mock_response_queue)
)
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock()
mock_response_processor = MagicMock()
mock_response_processor.enable_multimodal_content.return_value = False
async def mock_async_generator():
yield mock_response
mock_response_processor.process_response_chat.return_value = mock_async_generator()
with (
patch(
"fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor",
return_value=mock_response_processor,
),
patch.object(envs, "ZMQ_SEND_BATCH_DATA", False),
):
results = []
async for chunk in self.chat_completion_handler.chat_completion_stream_generator(
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100
):
results.append(chunk)
# Lines 262-264: dealer.write should be called for each request_id
mock_dealer.write.assert_called_once_with([b"", f"{request_id}_0".encode("utf-8")])
self.assertGreater(len(results), 0)
async def test_chat_completion_stream_generator_cancelled_error_in_wait(self):
"""Cover lines 278, 280: CancelledError during response_queue.get propagates."""
request = ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello"}],
stream=True,
)
request_id = "test_cancel_wait"
model_name = "test_model"
prompt_token_ids = [1, 2, 3]
prompt_tokens = "Hello world"
mock_dealer = MagicMock()
mock_response_queue = AsyncMock()
# Simulate CancelledError during wait_for(response_queue.get())
mock_response_queue.get.side_effect = asyncio.CancelledError()
self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock(
return_value=(mock_dealer, mock_response_queue)
)
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock()
self.chat_completion_handler.engine_client.abort = AsyncMock()
with patch.object(envs, "ZMQ_SEND_BATCH_DATA", True):
chunks = []
try:
async for chunk in self.chat_completion_handler.chat_completion_stream_generator(
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100
):
chunks.append(chunk)
except asyncio.CancelledError:
pass
# Cleanup should still be called
self.chat_completion_handler.engine_client.connection_manager.cleanup_request.assert_called_once()
async def test_chat_completion_full_generator_dealer_mode_writes(self):
"""Cover lines 558-561: dealer.write in non-batch mode for full generator."""
request = ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello"}],
stream=False,
)
request_id = "test_dealer_full"
model_name = "test_model"
prompt_token_ids = [1, 2, 3]
prompt_tokens = "Hello world"
mock_dealer = MagicMock()
mock_response_queue = AsyncMock()
mock_response = {
"request_id": f"{request_id}_0",
"error_code": 200,
"metrics": {
"first_token_time": 1234567890,
"inference_start_time": 1234567880,
"arrival_time": 1234567890,
"request_start_time": 1234567870,
},
"prompt_logprobs": None,
"outputs": {
"text": "Hello there",
"metrics": {"request_start_time": 1.0},
"reasoning_content": "",
"tool_calls": None,
"completion_tokens": "2",
"token_ids": [5, 6],
"top_logprobs": None,
"draft_top_logprobs": None,
"multipart": [{"type": "text", "text": "Hello there"}],
},
"finished": True,
"num_cached_tokens": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
}
mock_response_queue.get.return_value = mock_response
self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock(
return_value=(mock_dealer, mock_response_queue)
)
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock()
mock_response_processor = MagicMock()
mock_response_processor.enable_multimodal_content.return_value = False
async def mock_async_generator():
yield mock_response
mock_response_processor.process_response_chat.return_value = mock_async_generator()
with (
patch(
"fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor",
return_value=mock_response_processor,
),
patch.object(envs, "ZMQ_SEND_BATCH_DATA", False),
):
await self.chat_completion_handler.chat_completion_full_generator(
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100
)
# Lines 559-561: dealer.write should be called for each request_id
mock_dealer.write.assert_called_once_with([b"", f"{request_id}_0".encode("utf-8")])
if __name__ == "__main__":
unittest.main()