mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] fix dp sheduler bug in ep4tp1 when start by using multi_api_server (#6598)
* [BugFix] fix dp sheduler bug in ep4tp1 when start by using multi_api_server * [BugFix] modify request_queue and result_queue of dp scheduler
This commit is contained in:
@@ -16,7 +16,7 @@ import sys
|
||||
import time
|
||||
import unittest
|
||||
from multiprocessing import Queue
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock, call, patch
|
||||
|
||||
# Mock all external dependencies before importing anything
|
||||
mock_logger = Mock()
|
||||
@@ -577,12 +577,12 @@ class TestDPScheduler(unittest.TestCase):
|
||||
@patch("threading.Thread")
|
||||
def test_put_requests_success(self, mock_thread):
|
||||
"""Test successful put_requests with dp_rank."""
|
||||
# Create request queues - use Mock instead of real Queue to avoid threading issues
|
||||
request_queues = [Mock(), Mock(), Mock()]
|
||||
result_queue = Mock()
|
||||
|
||||
# Start the scheduler - this will create mocked threads
|
||||
self.dp_scheduler.start(0, request_queues, result_queue)
|
||||
self.dp_scheduler.start(0)
|
||||
|
||||
# Create request queues - use Mock instead of real Queue to avoid threading issues
|
||||
self.dp_scheduler.request_queues = Mock()
|
||||
self.dp_scheduler.result_queues = Mock()
|
||||
|
||||
# Create requests with dp_rank
|
||||
mock_request1 = MockRequest("test_req1")
|
||||
@@ -601,18 +601,17 @@ class TestDPScheduler(unittest.TestCase):
|
||||
self.assertEqual(results[1], ("test_req2", None))
|
||||
|
||||
# Verify requests were put to the correct queues
|
||||
request_queues[0].put.assert_called_once_with(mock_request1)
|
||||
request_queues[1].put.assert_called_once_with(mock_request2)
|
||||
expected_calls = [call(mock_request1), call(mock_request2)]
|
||||
self.dp_scheduler.request_queues.put.assert_has_calls(expected_calls, any_order=False)
|
||||
|
||||
@patch("threading.Thread")
|
||||
def test_start_creates_threads(self, mock_thread):
|
||||
"""Test that start creates and starts threads."""
|
||||
mock_thread.return_value = Mock()
|
||||
|
||||
request_queues = [Queue(), Queue()]
|
||||
result_queue = Queue()
|
||||
|
||||
self.dp_scheduler.start(0, request_queues, result_queue)
|
||||
self.dp_scheduler.start(0)
|
||||
self.dp_scheduler.request_queues = Queue()
|
||||
self.dp_scheduler.result_queues = Queue()
|
||||
|
||||
# Should create 2 threads
|
||||
self.assertEqual(mock_thread.call_count, 2)
|
||||
@@ -640,8 +639,8 @@ class TestDPIntegration(unittest.TestCase):
|
||||
with patch.object(dp_scheduler, "start") as mock_start:
|
||||
# Set up test data directly
|
||||
dp_scheduler.dp_rank = 0
|
||||
dp_scheduler.request_queues = [Mock(), Mock()]
|
||||
dp_scheduler.result_queue = Mock()
|
||||
dp_scheduler.request_queues = Mock()
|
||||
dp_scheduler.result_queues = Mock()
|
||||
dp_scheduler.scheduler_logger = mock_logger
|
||||
dp_scheduler._scheduler.scheduler_logger = mock_logger
|
||||
|
||||
@@ -650,7 +649,7 @@ class TestDPIntegration(unittest.TestCase):
|
||||
mock_request.dp_rank = 0
|
||||
|
||||
# Mock the request_queues to avoid real Queue operations
|
||||
dp_scheduler.request_queues[0].put = Mock()
|
||||
dp_scheduler.request_queues.put = Mock()
|
||||
|
||||
# Test put_requests functionality
|
||||
results = dp_scheduler.put_requests([mock_request])
|
||||
@@ -658,7 +657,7 @@ class TestDPIntegration(unittest.TestCase):
|
||||
self.assertEqual(results[0], ("integration_req", None))
|
||||
|
||||
# Verify the request was put to the correct queue
|
||||
dp_scheduler.request_queues[0].put.assert_called_once_with(mock_request)
|
||||
dp_scheduler.request_queues.put.assert_called_once_with(mock_request)
|
||||
|
||||
# Verify start method was not called (to avoid threads)
|
||||
mock_start.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user