[Optimization]Unified data processing for online and offline (#6891)

* remove process_request

* fix chat

* fix unit test

* remove process response

* fix unit test

* fix offline decode

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix sampling_params

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
luukunn
2026-03-19 21:56:09 +08:00
committed by GitHub
parent c3d8db85c4
commit f4a79d4c00
19 changed files with 160 additions and 619 deletions
+20 -40
View File
@@ -261,38 +261,25 @@ class TestErnie4_5Processor(unittest.TestCase):
self.assertEqual(len(stop_seqs2), 2)
self.assertEqual(len(stop_lens2), 2)
def test_process_request_chat_template_kwargs(self):
"""Test chat_template_kwargs application inside process_request."""
def test_process_request_dict_with_chat_template_kwargs(self):
"""Test chat_template_kwargs application inside process_request_dict."""
proc = self._make_processor()
class ReqObj(dict):
"""Mock request object supporting attributes, set(), and to_dict()."""
request = {
"messages": [{"role": "user", "content": "hello"}],
"temperature": 0.5,
"top_p": 0.5,
"chat_template_kwargs": {"extra": "VALUE"},
}
def set(self, k, v):
self[k] = v
processed = proc.process_request_dict(request, max_model_len=20)
def __getattr__(self, item):
return self.get(item, None)
def to_dict(self):
return dict(self)
request = ReqObj(
{
"messages": [{"role": "user", "content": "hello"}],
"temperature": 0.5,
"top_p": 0.5,
}
)
processed = proc.process_request(request, max_model_len=20, chat_template_kwargs={"extra": "VALUE"})
self.assertEqual(processed.eos_token_ids, [proc.tokenizer.eos_token_id])
self.assertEqual(processed["eos_token_ids"], [proc.tokenizer.eos_token_id])
expected_ids = proc.tokenizer.convert_tokens_to_ids(proc.tokenizer.tokenize("hello"))
self.assertIsNotNone(processed.prompt_token_ids)
self.assertEqual(processed.prompt_token_ids, expected_ids)
self.assertIsNotNone(processed["prompt_token_ids"])
self.assertEqual(processed["prompt_token_ids"], expected_ids)
self.assertIn("max_tokens", processed)
self.assertEqual(processed["max_tokens"], max(1, 20 - len(expected_ids)))
@@ -322,21 +309,14 @@ class TestErnie4_5Processor(unittest.TestCase):
def test_process_response_with_tool_parser(self):
"""Verify tool_call extraction in process_response."""
proc = self._make_processor(tool=True)
class RespObj:
"""Mock response carrying token_ids and index for testing."""
def __init__(self):
self.request_id = "reqx"
self.outputs = MagicMock()
self.outputs.token_ids = [9, proc.tokenizer.eos_token_id]
self.outputs.index = 0
resp = RespObj()
result = proc.process_response(resp)
self.assertTrue(hasattr(result.outputs, "tool_calls"))
self.assertEqual(result.outputs.tool_calls[0]["name"], "fake_tool")
resp = {
"request_id": "reqx",
"outputs": {"token_ids": [9, proc.tokenizer.eos_token_id], "index": 0},
"finished": True,
}
result = proc.process_response_dict(resp, False)
assert "tool_calls" in result["outputs"]
self.assertEqual(result["outputs"]["tool_calls"][0]["name"], "fake_tool")
def test_process_response_dict_normal_with_tool(self):
"""Verify tool_call extraction in normal (non-streaming) response mode."""
-27
View File
@@ -264,33 +264,6 @@ class TestErnie4_5VLProcessorProcessResponseDictStreaming(unittest.TestCase):
processor._check_mm_limits(mm_data)
self.assertIn("Too many image items", str(context.exception))
def test_process_request(self):
"""Test process_request method"""
from fastdeploy.engine.request import Request
# Mock the process_request_dict method
self.processor.process_request_dict = MagicMock()
# Create a mock Request object
mock_request = MagicMock(spec=Request)
mock_request.to_dict.return_value = {"messages": [{"role": "user", "content": "Hello"}]}
# Mock Request.from_dict to return a mock request
with patch.object(Request, "from_dict") as mock_from_dict:
mock_result_request = MagicMock(spec=Request)
mock_from_dict.return_value = mock_result_request
self.processor.process_request(mock_request, max_model_len=100, chat_template_kwargs={"key": "value"})
# Verify to_dict was called
mock_request.to_dict.assert_called_once()
# Verify process_request_dict was called
self.processor.process_request_dict.assert_called_once()
# Verify from_dict was called
mock_from_dict.assert_called_once()
def test_get_pad_id(self):
"""Test get_pad_id method"""
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None):
@@ -965,40 +965,6 @@ class TestPaddleOCRVLProcessor(unittest.TestCase):
with self.assertRaises(ValueError):
self.processor._check_mm_limits(item_exceeded)
def test_process_request_wrapper(self):
"""测试 process_request 封装方法"""
# 1. 模拟输入 Request 对象
request_obj = MagicMock()
request_dict = {
"prompt": "test prompt",
"multimodal_data": {"image": ["image1"]},
"metadata": {"generated_token_ids": []},
"request_id": "test-request",
}
request_obj.to_dict.return_value = request_dict
# 2. patch 'Request'
patch_target = "fastdeploy.input.paddleocr_vl_processor.paddleocr_vl_processor.Request"
with patch(patch_target) as MockRequestCls:
# 3. 模拟 Request.from_dict 返回一个 mock 对象
final_mock_request = MagicMock()
MockRequestCls.from_dict.return_value = final_mock_request
# 4. Call function
result_request = self.processor.process_request(request_obj, max_model_len=512)
# 5. 检查 *传递给* Request.from_dict 的字典
self.assertTrue(MockRequestCls.from_dict.called)
# 获取传递给 from_dict 的第一个位置参数
processed_task_dict = MockRequestCls.from_dict.call_args[0][0]
# 这个断言现在应该能通过了
self.assertEqual(processed_task_dict["prompt_token_ids"], [1, 2, 3])
# 6. 检查返回的是否是最终的 Request 对象
self.assertIs(result_request, final_mock_request)
def test_parse_processor_kwargs_invalid_type(self):
"""测试 _parse_processor_kwargs 传入非字典类型"""
invalid_input = ["video_max_frames", 10]
+20 -22
View File
@@ -21,7 +21,6 @@ from unittest.mock import MagicMock, patch
import numpy as np
from PIL import Image
from fastdeploy.engine.request import Request
from fastdeploy.input.qwen3_vl_processor import Qwen3VLProcessor
from fastdeploy.input.qwen3_vl_processor.process import sample_frames
@@ -127,9 +126,9 @@ class TestQwen3VLProcessor(unittest.TestCase):
self.patcher_parse_image.stop()
self.patcher_parse_video.stop()
def test_process_request(self):
def test_process_request_dict_with_multimodal(self):
"""
Test processing of Request object with multimodal input
Test processing of request dict with multimodal input
Validates:
1. Token ID lengths match position_ids and token_type_ids shapes
@@ -151,17 +150,16 @@ class TestQwen3VLProcessor(unittest.TestCase):
],
}
request = Request.from_dict(message)
result = self.processor.process_request(request, 1024 * 100)
result = self.processor.process_request_dict(message, 1024 * 100)
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["position_ids"].shape[0])
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["token_type_ids"].shape[0])
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["position_ids"].shape[0])
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["token_type_ids"].shape[0])
self.assertEqual(
result.multimodal_inputs["images"].shape[0],
sum(map(lambda x: x.prod(), result.multimodal_inputs["grid_thw"])),
result["multimodal_inputs"]["images"].shape[0],
sum(map(lambda x: x.prod(), result["multimodal_inputs"]["grid_thw"])),
)
self.assertEqual(
result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum()
result["multimodal_inputs"]["image_type_ids"].shape[0], result["multimodal_inputs"]["grid_thw"][:, 0].sum()
)
def test_process_request_dict(self):
@@ -224,17 +222,16 @@ class TestQwen3VLProcessor(unittest.TestCase):
},
}
request = Request.from_dict(prompt)
result = self.processor.process_request(request, 1024 * 100)
result = self.processor.process_request_dict(prompt, 1024 * 100)
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["position_ids"].shape[0])
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["token_type_ids"].shape[0])
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["position_ids"].shape[0])
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["token_type_ids"].shape[0])
self.assertEqual(
result.multimodal_inputs["images"].shape[0],
sum(map(lambda x: x.prod(), result.multimodal_inputs["grid_thw"])),
result["multimodal_inputs"]["images"].shape[0],
sum(map(lambda x: x.prod(), result["multimodal_inputs"]["grid_thw"])),
)
self.assertEqual(
result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum()
result["multimodal_inputs"]["image_type_ids"].shape[0], result["multimodal_inputs"]["grid_thw"][:, 0].sum()
)
def test_message_and_prompt(self):
@@ -276,14 +273,15 @@ class TestQwen3VLProcessor(unittest.TestCase):
"video": [{"video": b"123"}],
},
}
request2 = Request.from_dict(prompt)
result2 = self.processor.process_request(request2, 1024 * 100)
result2 = self.processor.process_request_dict(prompt, 1024 * 100)
# Verify both processing methods produce identical results
self.assertEqual(result["prompt_token_ids"], result2.prompt_token_ids)
self.assertTrue(np.equal(result["multimodal_inputs"]["grid_thw"], result2.multimodal_inputs["grid_thw"]).all())
self.assertEqual(result["prompt_token_ids"], result2["prompt_token_ids"])
self.assertTrue(
np.equal(result["multimodal_inputs"]["position_ids"], result2.multimodal_inputs["position_ids"]).all()
np.equal(result["multimodal_inputs"]["grid_thw"], result2["multimodal_inputs"]["grid_thw"]).all()
)
self.assertTrue(
np.equal(result["multimodal_inputs"]["position_ids"], result2["multimodal_inputs"]["position_ids"]).all()
)
def test_apply_chat_template(self):
+20 -22
View File
@@ -20,7 +20,6 @@ from unittest.mock import MagicMock, patch
import numpy as np
from PIL import Image
from fastdeploy.engine.request import Request
from fastdeploy.input.qwen_vl_processor import QwenVLProcessor
from fastdeploy.input.qwen_vl_processor.process_video import sample_frames
@@ -129,9 +128,9 @@ class TestQwenVLProcessor(unittest.TestCase):
self.patcher_parse_image.stop()
self.patcher_parse_video.stop()
def test_process_request(self):
def test_process_request_dict_with_multimodal(self):
"""
Test processing of Request object with multimodal input
Test processing of request dict with multimodal input
Validates:
1. Token ID lengths match position_ids and token_type_ids shapes
@@ -153,17 +152,16 @@ class TestQwenVLProcessor(unittest.TestCase):
],
}
request = Request.from_dict(message)
result = self.processor.process_request(request, 1024 * 100)
result = self.processor.process_request_dict(message, 1024 * 100)
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["position_ids"].shape[0])
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["token_type_ids"].shape[0])
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["position_ids"].shape[0])
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["token_type_ids"].shape[0])
self.assertEqual(
result.multimodal_inputs["images"].shape[0],
sum(map(lambda x: x.prod(), result.multimodal_inputs["grid_thw"])),
result["multimodal_inputs"]["images"].shape[0],
sum(map(lambda x: x.prod(), result["multimodal_inputs"]["grid_thw"])),
)
self.assertEqual(
result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum()
result["multimodal_inputs"]["image_type_ids"].shape[0], result["multimodal_inputs"]["grid_thw"][:, 0].sum()
)
def test_process_request_dict(self):
@@ -246,17 +244,16 @@ class TestQwenVLProcessor(unittest.TestCase):
},
}
request = Request.from_dict(prompt)
result = self.processor.process_request(request, 1024 * 100)
result = self.processor.process_request_dict(prompt, 1024 * 100)
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["position_ids"].shape[0])
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["token_type_ids"].shape[0])
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["position_ids"].shape[0])
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["token_type_ids"].shape[0])
self.assertEqual(
result.multimodal_inputs["images"].shape[0],
sum(map(lambda x: x.prod(), result.multimodal_inputs["grid_thw"])),
result["multimodal_inputs"]["images"].shape[0],
sum(map(lambda x: x.prod(), result["multimodal_inputs"]["grid_thw"])),
)
self.assertEqual(
result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum()
result["multimodal_inputs"]["image_type_ids"].shape[0], result["multimodal_inputs"]["grid_thw"][:, 0].sum()
)
def test_message_and_prompt(self):
@@ -298,14 +295,15 @@ class TestQwenVLProcessor(unittest.TestCase):
"video": [{"video": b"123"}],
},
}
request2 = Request.from_dict(prompt)
result2 = self.processor.process_request(request2, 1024 * 100)
result2 = self.processor.process_request_dict(prompt, 1024 * 100)
# Verify both processing methods produce identical results
self.assertEqual(result["prompt_token_ids"], result2.prompt_token_ids)
self.assertTrue(np.equal(result["multimodal_inputs"]["grid_thw"], result2.multimodal_inputs["grid_thw"]).all())
self.assertEqual(result["prompt_token_ids"], result2["prompt_token_ids"])
self.assertTrue(
np.equal(result["multimodal_inputs"]["position_ids"], result2.multimodal_inputs["position_ids"]).all()
np.equal(result["multimodal_inputs"]["grid_thw"], result2["multimodal_inputs"]["grid_thw"]).all()
)
self.assertTrue(
np.equal(result["multimodal_inputs"]["position_ids"], result2["multimodal_inputs"]["position_ids"]).all()
)
def test_apply_chat_template(self):
+28 -37
View File
@@ -347,19 +347,9 @@ class DataProcessorTestCase(unittest.TestCase):
def _load_tokenizer(self):
return DummyTokenizer()
def process_request(self, request, **kwargs):
return super().process_request(request, **kwargs)
def process_response(self, response_dict):
return super().process_response(response_dict)
processor = MinimalProcessor()
defaults = processor._apply_default_parameters({})
self.assertAlmostEqual(defaults["top_p"], 0.5)
with self.assertRaises(NotImplementedError):
processor.process_request({}, max_model_len=None)
with self.assertRaises(NotImplementedError):
processor.process_response({})
with self.assertRaises(NotImplementedError):
processor.text2ids("text")
with self.assertRaises(NotImplementedError):
@@ -392,28 +382,28 @@ class DataProcessorTestCase(unittest.TestCase):
self.assertTrue(processed["enable_thinking"])
self.assertEqual(processed["prompt_tokens"], "system prompt hello")
def test_process_request_object_handles_sequences(self):
request = DummyRequest(
prompt=[1, 2, 3, 4, 5, 6],
stop=["stop"],
bad_words=["zz"],
temperature=0,
top_p=0,
)
processed = self.processor.process_request(request, max_model_len=5)
def test_process_request_dict_handles_sequences(self):
request = {
"prompt": [1, 2, 3, 4, 5, 6],
"stop": ["stop"],
"bad_words": ["zz"],
"temperature": 0,
"top_p": 0,
}
processed = self.processor.process_request_dict(request, max_model_len=5)
self.assertEqual(processed.prompt_token_ids, [1, 2, 3, 4])
self.assertEqual(processed.sampling_params.max_tokens, 1)
self.assertEqual(processed.sampling_params.stop_token_ids, [[4]])
self.assertEqual(set(processed.sampling_params.bad_words_token_ids), {2, 3})
self.assertEqual(processed.sampling_params.temperature, 1)
self.assertEqual(processed.sampling_params.top_k, 1)
self.assertAlmostEqual(processed.sampling_params.top_p, 1e-5)
self.assertEqual(processed["prompt_token_ids"], [1, 2, 3, 4])
self.assertEqual(processed["max_tokens"], 1)
self.assertEqual(processed["stop_token_ids"], [[4]])
self.assertEqual(set(processed["bad_words_token_ids"]), {2, 3})
self.assertEqual(processed["temperature"], 1)
self.assertEqual(processed["top_k"], 1)
self.assertAlmostEqual(processed["top_p"], 1e-5)
def test_process_request_requires_prompt_or_messages(self):
request = DummyRequest(prompt=None, messages=None, prompt_token_ids=None)
with self.assertRaisesRegex(ValueError, "should have `input_ids`, `text` or `messages`"):
self.processor.process_request(request, max_model_len=5)
def test_process_request_dict_requires_prompt_or_messages(self):
request = {"prompt": None, "messages": None, "prompt_token_ids": None}
with self.assertRaisesRegex(ValueError, "Request must contain"):
self.processor.process_request_dict(request, max_model_len=5)
def test_process_request_dict_rejects_bad_kwargs(self):
request = {
@@ -458,14 +448,15 @@ class DataProcessorTestCase(unittest.TestCase):
processor.reasoning_parser = self.create_dummy_reasoning(processor.tokenizer)
processor.tool_parser_obj = self.create_dummy_tool_parser(processor.tokenizer, content="tool-only")
response = SimpleNamespace(
request_id="resp",
outputs=SimpleNamespace(token_ids=[1, processor.tokenizer.eos_token_id]),
)
response = {
"request_id": "resp",
"finished": True,
"outputs": {"token_ids": [1, processor.tokenizer.eos_token_id]},
}
processed = processor.process_response(response)
self.assertEqual(processed.outputs.reasoning_content, "think")
self.assertEqual(processed.outputs.tool_calls, ["tool"])
processed = processor.process_response_dict(response, stream=False)
self.assertEqual(processed["outputs"]["reasoning_content"], "think")
self.assertEqual(processed["outputs"]["tool_calls"], ["tool"])
def test_process_response_streaming_clears_state(self):
processor = self.processor