mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Other] Adjust GPUModelRunner to enhance compatibility (#6851)
This commit is contained in:
@@ -475,55 +475,6 @@ class TestProcessMMFeatures(unittest.TestCase):
|
||||
any(isinstance(t, paddle.Tensor) for t in self.runner.share_inputs["image_features_list"]),
|
||||
)
|
||||
|
||||
def test_process_mm_features_rope_3d_position_ids(self):
|
||||
"""Test 3D position IDs processing"""
|
||||
request_list = [
|
||||
self._create_mock_request(
|
||||
task_type_value=0,
|
||||
idx=0,
|
||||
position_ids=np.array([[1, 2, 3]]),
|
||||
max_tokens=2048,
|
||||
),
|
||||
self._create_mock_request(
|
||||
task_type_value=0,
|
||||
idx=1,
|
||||
position_ids=np.array([[4, 5, 6]]),
|
||||
max_tokens=1024,
|
||||
),
|
||||
]
|
||||
|
||||
# Mock prepare_rope3d to return list of rope embeddings
|
||||
self.runner.prepare_rope3d.return_value = [1, 2]
|
||||
|
||||
self.runner._process_mm_features(request_list)
|
||||
|
||||
# Verify prepare_rope3d was called with correct parameters
|
||||
self.runner.prepare_rope3d.assert_called_once()
|
||||
|
||||
# Verify rope embeddings were set in share_inputs
|
||||
self.assertEqual(self.runner.share_inputs["rope_emb"][0], paddle.Tensor([1]))
|
||||
self.assertEqual(self.runner.share_inputs["rope_emb"][1], paddle.Tensor([2]))
|
||||
|
||||
def test_process_mm_features_pooling_model(self):
|
||||
"""Test processing with pooling model"""
|
||||
self.runner.is_pooling_model = True
|
||||
|
||||
request_list = [
|
||||
self._create_mock_request(
|
||||
task_type_value=0,
|
||||
idx=0,
|
||||
position_ids=np.array([[1, 2, 3]]),
|
||||
),
|
||||
]
|
||||
|
||||
self.runner.prepare_rope3d.return_value = [1]
|
||||
|
||||
self.runner._process_mm_features(request_list)
|
||||
|
||||
# Verify max_tokens_lst contains 0 for pooling model
|
||||
call_args = self.runner.prepare_rope3d.call_args
|
||||
self.assertEqual(call_args[0][2], [0, 1]) # max_tokens_lst
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user