mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
be3be4913a
* [Optimization] refactor(chat_handler,completion_handler): extract base classes and use AsyncLLM * [Optimization] refactor(chat_handler,completion_handler): rename class
409 lines
17 KiB
Python
409 lines
17 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import time
|
|
import unittest
|
|
|
|
from fastdeploy.engine.request import (
|
|
CompletionOutput,
|
|
LogprobsLists,
|
|
RequestMetrics,
|
|
RequestOutput,
|
|
)
|
|
|
|
|
|
class TestRequestOutputInit(unittest.TestCase):
|
|
"""Test case for RequestOutput initialization"""
|
|
|
|
def test_init_default_values(self):
|
|
"""Test initialization with default values"""
|
|
request_id = "test_request_123"
|
|
request_output = RequestOutput(request_id=request_id)
|
|
|
|
self.assertEqual(request_output.request_id, request_id)
|
|
self.assertIsNone(request_output.prompt)
|
|
# prompt_token_ids becomes empty list when None is passed
|
|
self.assertEqual(request_output.prompt_token_ids, [])
|
|
self.assertIsNone(request_output.prompt_logprobs)
|
|
self.assertEqual(request_output.output_type, 3)
|
|
self.assertIsNone(request_output.outputs)
|
|
self.assertFalse(request_output.finished)
|
|
self.assertIsNone(request_output.metrics)
|
|
self.assertEqual(request_output.num_cached_tokens, 0)
|
|
self.assertEqual(request_output.num_input_image_tokens, 0)
|
|
self.assertEqual(request_output.num_input_video_tokens, 0)
|
|
self.assertEqual(request_output.error_code, 200)
|
|
self.assertIsNone(request_output.error_msg)
|
|
self.assertIsNone(request_output.ic_req_data)
|
|
self.assertEqual(request_output.prompt_token_ids_len, 0)
|
|
self.assertIsNone(request_output.accumulate_tool_calls)
|
|
|
|
def test_init_with_numpy_array_prompt_token_ids(self):
|
|
"""Test initialization with numpy array prompt_token_ids"""
|
|
import numpy as np
|
|
|
|
request_id = "test_request_456"
|
|
numpy_array = np.array([1, 2, 3, 4, 5])
|
|
|
|
request_output = RequestOutput(request_id=request_id, prompt_token_ids=numpy_array)
|
|
|
|
self.assertEqual(request_output.prompt_token_ids, [1, 2, 3, 4, 5])
|
|
self.assertIsInstance(request_output.prompt_token_ids, list)
|
|
|
|
def test_init_with_list_prompt_token_ids(self):
|
|
"""Test initialization with list prompt_token_ids"""
|
|
request_id = "test_request_789"
|
|
token_list = [10, 20, 30, 40]
|
|
|
|
request_output = RequestOutput(request_id=request_id, prompt_token_ids=token_list)
|
|
|
|
self.assertEqual(request_output.prompt_token_ids, token_list)
|
|
|
|
def test_init_with_outputs_and_tool_calls(self):
|
|
"""Test initialization with outputs containing tool calls"""
|
|
request_id = "test_request_tool"
|
|
|
|
# Create a CompletionOutput with tool calls
|
|
tool_calls = [{"name": "test_tool", "arguments": {"param": "value"}}]
|
|
outputs = CompletionOutput(index=0, send_idx=0, token_ids=[100, 200, 300], tool_calls=tool_calls)
|
|
|
|
request_output = RequestOutput(request_id=request_id, outputs=outputs)
|
|
|
|
self.assertEqual(request_output.accumulate_tool_calls, [tool_calls])
|
|
self.assertEqual(request_output.outputs, outputs)
|
|
|
|
def test_init_with_outputs_no_tool_calls(self):
|
|
"""Test initialization with outputs but no tool calls"""
|
|
request_id = "test_request_no_tool"
|
|
|
|
# Create a CompletionOutput without tool calls
|
|
outputs = CompletionOutput(index=0, send_idx=0, token_ids=[100, 200, 300])
|
|
|
|
request_output = RequestOutput(request_id=request_id, outputs=outputs)
|
|
|
|
self.assertIsNone(request_output.accumulate_tool_calls)
|
|
self.assertEqual(request_output.outputs, outputs)
|
|
|
|
def test_init_with_all_parameters(self):
|
|
"""Test initialization with all parameters provided"""
|
|
request_id = "test_request_full"
|
|
prompt = "Test prompt"
|
|
prompt_token_ids = [1, 2, 3]
|
|
prompt_token_ids_len = 3
|
|
|
|
outputs = CompletionOutput(
|
|
index=0, send_idx=0, token_ids=[100, 200], text="Generated text", reasoning_content="Reasoning content"
|
|
)
|
|
|
|
metrics = RequestMetrics()
|
|
metrics.arrival_time = time.time()
|
|
|
|
request_output = RequestOutput(
|
|
request_id=request_id,
|
|
prompt=prompt,
|
|
prompt_token_ids=prompt_token_ids,
|
|
prompt_logprobs={"test": "logprobs"},
|
|
output_type=1,
|
|
outputs=outputs,
|
|
finished=True,
|
|
metrics=metrics,
|
|
num_cached_tokens=5,
|
|
num_input_image_tokens=2,
|
|
num_input_video_tokens=1,
|
|
error_code=400,
|
|
error_msg="Test error",
|
|
ic_req_data={"internal": "data"},
|
|
prompt_token_ids_len=prompt_token_ids_len,
|
|
)
|
|
|
|
self.assertEqual(request_output.request_id, request_id)
|
|
self.assertEqual(request_output.prompt, prompt)
|
|
self.assertEqual(request_output.prompt_token_ids, prompt_token_ids)
|
|
self.assertEqual(request_output.prompt_logprobs, {"test": "logprobs"})
|
|
self.assertEqual(request_output.output_type, 1)
|
|
self.assertEqual(request_output.outputs, outputs)
|
|
self.assertTrue(request_output.finished)
|
|
self.assertEqual(request_output.metrics, metrics)
|
|
self.assertEqual(request_output.num_cached_tokens, 5)
|
|
self.assertEqual(request_output.num_input_image_tokens, 2)
|
|
self.assertEqual(request_output.num_input_video_tokens, 1)
|
|
self.assertEqual(request_output.error_code, 400)
|
|
self.assertEqual(request_output.error_msg, "Test error")
|
|
self.assertEqual(request_output.ic_req_data, {"internal": "data"})
|
|
self.assertEqual(request_output.prompt_token_ids_len, prompt_token_ids_len)
|
|
|
|
|
|
class TestRequestOutputAccumulate(unittest.TestCase):
|
|
"""Test case for RequestOutput accumulate method"""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures"""
|
|
self.request_id = "test_request_accumulate"
|
|
self.base_request = RequestOutput(
|
|
request_id=self.request_id,
|
|
outputs=CompletionOutput(
|
|
index=0, send_idx=0, token_ids=[100, 200], text="First ", reasoning_content="Reasoning "
|
|
),
|
|
)
|
|
|
|
def test_accumulate_basic_text(self):
|
|
"""Test basic text accumulation"""
|
|
next_output = RequestOutput(
|
|
request_id=self.request_id,
|
|
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[300], text="second"),
|
|
metrics=RequestMetrics(), # Add metrics to avoid None attribute access
|
|
)
|
|
|
|
self.base_request.accumulate(next_output)
|
|
|
|
self.assertEqual(self.base_request.outputs.text, "First second")
|
|
self.assertEqual(self.base_request.outputs.token_ids, [100, 200, 300])
|
|
self.assertEqual(self.base_request.outputs.index, 0)
|
|
|
|
def test_accumulate_reasoning_content(self):
|
|
"""Test reasoning content accumulation"""
|
|
next_output = RequestOutput(
|
|
request_id=self.request_id,
|
|
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[300], reasoning_content="content"),
|
|
metrics=RequestMetrics(), # Add metrics to avoid None attribute access
|
|
)
|
|
|
|
self.base_request.accumulate(next_output)
|
|
|
|
self.assertEqual(self.base_request.outputs.reasoning_content, "Reasoning content")
|
|
|
|
def test_accumulate_completion_tokens(self):
|
|
"""Test completion tokens accumulation"""
|
|
next_output = RequestOutput(
|
|
request_id=self.request_id,
|
|
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[300], completion_tokens=" tokens"),
|
|
metrics=RequestMetrics(), # Add metrics to avoid None attribute access
|
|
)
|
|
|
|
self.base_request.accumulate(next_output)
|
|
|
|
self.assertEqual(self.base_request.outputs.completion_tokens, " tokens")
|
|
|
|
def test_accumulate_tool_calls(self):
|
|
"""Test tool calls accumulation"""
|
|
tool_calls = [{"name": "tool2", "arguments": {"param": "value2"}}]
|
|
next_output = RequestOutput(
|
|
request_id=self.request_id,
|
|
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[300], tool_calls=tool_calls),
|
|
metrics=RequestMetrics(), # Add metrics to avoid None attribute access
|
|
)
|
|
|
|
self.base_request.accumulate(next_output)
|
|
|
|
self.assertEqual(self.base_request.accumulate_tool_calls, [tool_calls])
|
|
|
|
def test_accumulate_multiple_tool_calls(self):
|
|
"""Test multiple tool calls accumulation"""
|
|
# Add initial tool call through constructor
|
|
initial_tool_calls = [{"name": "tool1", "arguments": {"param": "value1"}}]
|
|
base_request = RequestOutput(
|
|
request_id=self.request_id,
|
|
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[100, 200], tool_calls=initial_tool_calls),
|
|
)
|
|
|
|
second_tool_calls = [{"name": "tool2", "arguments": {"param": "value2"}}]
|
|
next_output = RequestOutput(
|
|
request_id=self.request_id,
|
|
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[300], tool_calls=second_tool_calls),
|
|
metrics=RequestMetrics(), # Add metrics to avoid None attribute access
|
|
)
|
|
|
|
base_request.accumulate(next_output)
|
|
|
|
self.assertEqual(base_request.accumulate_tool_calls, [initial_tool_calls, second_tool_calls])
|
|
|
|
def test_accumulate_with_metrics(self):
|
|
"""Test accumulation including metrics updates"""
|
|
next_output = RequestOutput(
|
|
request_id=self.request_id,
|
|
metrics=RequestMetrics(model_forward_time=1.5, model_execute_time=2.5),
|
|
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[300], text=" text"),
|
|
)
|
|
|
|
# Set up base metrics
|
|
self.base_request.metrics = RequestMetrics()
|
|
|
|
self.base_request.accumulate(next_output)
|
|
|
|
self.assertEqual(self.base_request.metrics.model_forward_time, 1.5)
|
|
self.assertEqual(self.base_request.metrics.model_execute_time, 2.5)
|
|
|
|
def test_accumulate_with_logprobs(self):
|
|
"""Test accumulation with logprobs data"""
|
|
# Create LogprobsLists objects - each list corresponds to a different request/position
|
|
initial_top_logprobs = LogprobsLists(
|
|
logprob_token_ids=[[100, 200]], # First request with 2 token probabilities
|
|
logprobs=[[0.7, 0.6]], # Corresponding log probabilities
|
|
sampled_token_ranks=[0], # Rank for the first request
|
|
)
|
|
|
|
initial_draft_logprobs = LogprobsLists(
|
|
logprobs=[[0.8, 0.7]], # Default draft logprobs
|
|
logprob_token_ids=[[150, 250]], # Default draft token IDs
|
|
sampled_token_ranks=[1], # Default draft ranks
|
|
)
|
|
|
|
# Set up initial logprobs
|
|
self.base_request.outputs.top_logprobs = initial_top_logprobs
|
|
self.base_request.outputs.draft_top_logprobs = initial_draft_logprobs
|
|
|
|
# Create next output with new logprobs (representing a new decoding step)
|
|
new_top_logprobs = LogprobsLists(
|
|
logprob_token_ids=[[300, 400, 500]], # New step with 3 token IDs
|
|
logprobs=[[0.5, 0.4, 0.3]], # Corresponding log probabilities
|
|
sampled_token_ranks=[1], # New rank
|
|
)
|
|
|
|
new_draft_logprobs = LogprobsLists(
|
|
logprob_token_ids=[[350, 450, 550]], # New draft token IDs
|
|
logprobs=[[0.6, 0.5, 0.4]], # New draft log probabilities
|
|
sampled_token_ranks=[2], # New draft rank
|
|
)
|
|
|
|
next_output = RequestOutput(
|
|
request_id=self.request_id,
|
|
outputs=CompletionOutput(
|
|
index=0,
|
|
send_idx=1,
|
|
token_ids=[600], # New token
|
|
text=" text",
|
|
top_logprobs=new_top_logprobs,
|
|
draft_top_logprobs=new_draft_logprobs,
|
|
),
|
|
metrics=RequestMetrics(),
|
|
)
|
|
|
|
self.base_request.accumulate(next_output)
|
|
|
|
# Verify accumulation adds new rows (requests/positions)
|
|
# After accumulation, we should have 2 rows (initial + new)
|
|
self.assertEqual(len(self.base_request.outputs.top_logprobs.logprob_token_ids), 2)
|
|
self.assertEqual(len(self.base_request.outputs.top_logprobs.logprobs), 2)
|
|
self.assertEqual(len(self.base_request.outputs.top_logprobs.sampled_token_ranks), 2)
|
|
|
|
# Check first row remains unchanged
|
|
self.assertEqual(self.base_request.outputs.top_logprobs.logprob_token_ids[0], [100, 200])
|
|
self.assertEqual(self.base_request.outputs.top_logprobs.logprobs[0], [0.7, 0.6])
|
|
self.assertEqual(self.base_request.outputs.top_logprobs.sampled_token_ranks[0], 0)
|
|
|
|
# Check second row contains new data
|
|
self.assertEqual(self.base_request.outputs.top_logprobs.logprob_token_ids[1], [300, 400, 500])
|
|
self.assertEqual(self.base_request.outputs.top_logprobs.logprobs[1], [0.5, 0.4, 0.3])
|
|
self.assertEqual(self.base_request.outputs.top_logprobs.sampled_token_ranks[1], 1)
|
|
|
|
# Same for draft logprobs
|
|
self.assertEqual(len(self.base_request.outputs.draft_top_logprobs.logprob_token_ids), 2)
|
|
self.assertEqual(len(self.base_request.outputs.draft_top_logprobs.logprobs), 2)
|
|
self.assertEqual(len(self.base_request.outputs.draft_top_logprobs.sampled_token_ranks), 2)
|
|
|
|
self.assertEqual(self.base_request.outputs.draft_top_logprobs.logprob_token_ids[0], [150, 250])
|
|
self.assertEqual(self.base_request.outputs.draft_top_logprobs.logprobs[0], [0.8, 0.7])
|
|
self.assertEqual(self.base_request.outputs.draft_top_logprobs.sampled_token_ranks[0], 1)
|
|
|
|
self.assertEqual(self.base_request.outputs.draft_top_logprobs.logprob_token_ids[1], [350, 450, 550])
|
|
self.assertEqual(self.base_request.outputs.draft_top_logprobs.logprobs[1], [0.6, 0.5, 0.4])
|
|
self.assertEqual(self.base_request.outputs.draft_top_logprobs.sampled_token_ranks[1], 2)
|
|
|
|
def test_accumulate_null_text_handling(self):
|
|
"""Test accumulate with null text handling"""
|
|
base_request = RequestOutput(
|
|
request_id=self.request_id, outputs=CompletionOutput(index=0, send_idx=0, token_ids=[100]) # text is None
|
|
)
|
|
|
|
next_output = RequestOutput(
|
|
request_id=self.request_id,
|
|
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[200], text="new text"),
|
|
metrics=RequestMetrics(), # Add metrics to avoid None attribute access
|
|
)
|
|
|
|
base_request.accumulate(next_output)
|
|
|
|
self.assertEqual(base_request.outputs.text, "new text")
|
|
|
|
def test_accumulate_finished_flag(self):
|
|
"""Test that finished flag is OR-ed correctly"""
|
|
next_output = RequestOutput(
|
|
request_id=self.request_id,
|
|
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[300]),
|
|
finished=True,
|
|
metrics=RequestMetrics(), # Add metrics to avoid None attribute access
|
|
)
|
|
|
|
self.base_request.accumulate(next_output)
|
|
|
|
self.assertTrue(self.base_request.finished)
|
|
|
|
def test_accumulate_prompt_updates(self):
|
|
"""Test that prompt and prompt_token_ids are updated from next_output"""
|
|
next_output = RequestOutput(
|
|
request_id=self.request_id,
|
|
prompt="Updated prompt",
|
|
prompt_token_ids=[999],
|
|
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[300]),
|
|
metrics=RequestMetrics(), # Add metrics to avoid None attribute access
|
|
)
|
|
|
|
self.base_request.accumulate(next_output)
|
|
|
|
self.assertEqual(self.base_request.prompt, "Updated prompt")
|
|
self.assertEqual(self.base_request.prompt_token_ids, [999])
|
|
|
|
|
|
class TestRequestOutputFromDict(unittest.TestCase):
|
|
"""Test case for RequestOutput from_dict method"""
|
|
|
|
def test_from_dict_with_outputs_and_metrics(self):
|
|
"""Test from_dict with outputs and metrics dictionaries"""
|
|
test_dict = {
|
|
"request_id": "test_dict_123",
|
|
"prompt": "Dict prompt",
|
|
"prompt_token_ids": [1, 2, 3],
|
|
"finished": True,
|
|
"outputs": {"index": 0, "send_idx": 0, "token_ids": [100, 200], "text": "Dict text"},
|
|
"metrics": {"arrival_time": 1000.0, "model_forward_time": 1.5},
|
|
}
|
|
|
|
request_output = RequestOutput.from_dict(test_dict)
|
|
|
|
self.assertEqual(request_output.request_id, "test_dict_123")
|
|
self.assertEqual(request_output.prompt, "Dict prompt")
|
|
self.assertEqual(request_output.prompt_token_ids, [1, 2, 3])
|
|
self.assertTrue(request_output.finished)
|
|
self.assertIsInstance(request_output.outputs, CompletionOutput)
|
|
self.assertEqual(request_output.outputs.text, "Dict text")
|
|
self.assertIsInstance(request_output.metrics, RequestMetrics)
|
|
self.assertEqual(request_output.metrics.arrival_time, 1000.0)
|
|
|
|
def test_from_dict_without_outputs_and_metrics(self):
|
|
"""Test from_dict without outputs and metrics in dictionary"""
|
|
test_dict = {"request_id": "test_dict_456", "finished": False}
|
|
|
|
request_output = RequestOutput.from_dict(test_dict)
|
|
|
|
self.assertEqual(request_output.request_id, "test_dict_456")
|
|
self.assertFalse(request_output.finished)
|
|
self.assertIsNone(request_output.outputs)
|
|
self.assertIsNone(request_output.metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|