[BugFix] fix dp sheduler bug in ep4tp1 when start by using multi_api_server (#6598)

* [BugFix] fix dp sheduler bug in ep4tp1 when start by using multi_api_server

* [BugFix] modify request_queue and result_queue of dp scheduler
This commit is contained in:
ddchenhao66
2026-03-05 10:04:12 +08:00
committed by GitHub
parent 56ceeda80c
commit fa4815b93a
6 changed files with 26 additions and 48 deletions
+15 -16
View File
@@ -16,7 +16,7 @@ import sys
import time
import unittest
from multiprocessing import Queue
from unittest.mock import Mock, patch
from unittest.mock import Mock, call, patch
# Mock all external dependencies before importing anything
mock_logger = Mock()
@@ -577,12 +577,12 @@ class TestDPScheduler(unittest.TestCase):
@patch("threading.Thread")
def test_put_requests_success(self, mock_thread):
"""Test successful put_requests with dp_rank."""
# Create request queues - use Mock instead of real Queue to avoid threading issues
request_queues = [Mock(), Mock(), Mock()]
result_queue = Mock()
# Start the scheduler - this will create mocked threads
self.dp_scheduler.start(0, request_queues, result_queue)
self.dp_scheduler.start(0)
# Create request queues - use Mock instead of real Queue to avoid threading issues
self.dp_scheduler.request_queues = Mock()
self.dp_scheduler.result_queues = Mock()
# Create requests with dp_rank
mock_request1 = MockRequest("test_req1")
@@ -601,18 +601,17 @@ class TestDPScheduler(unittest.TestCase):
self.assertEqual(results[1], ("test_req2", None))
# Verify requests were put to the correct queues
request_queues[0].put.assert_called_once_with(mock_request1)
request_queues[1].put.assert_called_once_with(mock_request2)
expected_calls = [call(mock_request1), call(mock_request2)]
self.dp_scheduler.request_queues.put.assert_has_calls(expected_calls, any_order=False)
@patch("threading.Thread")
def test_start_creates_threads(self, mock_thread):
"""Test that start creates and starts threads."""
mock_thread.return_value = Mock()
request_queues = [Queue(), Queue()]
result_queue = Queue()
self.dp_scheduler.start(0, request_queues, result_queue)
self.dp_scheduler.start(0)
self.dp_scheduler.request_queues = Queue()
self.dp_scheduler.result_queues = Queue()
# Should create 2 threads
self.assertEqual(mock_thread.call_count, 2)
@@ -640,8 +639,8 @@ class TestDPIntegration(unittest.TestCase):
with patch.object(dp_scheduler, "start") as mock_start:
# Set up test data directly
dp_scheduler.dp_rank = 0
dp_scheduler.request_queues = [Mock(), Mock()]
dp_scheduler.result_queue = Mock()
dp_scheduler.request_queues = Mock()
dp_scheduler.result_queues = Mock()
dp_scheduler.scheduler_logger = mock_logger
dp_scheduler._scheduler.scheduler_logger = mock_logger
@@ -650,7 +649,7 @@ class TestDPIntegration(unittest.TestCase):
mock_request.dp_rank = 0
# Mock the request_queues to avoid real Queue operations
dp_scheduler.request_queues[0].put = Mock()
dp_scheduler.request_queues.put = Mock()
# Test put_requests functionality
results = dp_scheduler.put_requests([mock_request])
@@ -658,7 +657,7 @@ class TestDPIntegration(unittest.TestCase):
self.assertEqual(results[0], ("integration_req", None))
# Verify the request was put to the correct queue
dp_scheduler.request_queues[0].put.assert_called_once_with(mock_request)
dp_scheduler.request_queues.put.assert_called_once_with(mock_request)
# Verify start method was not called (to avoid threads)
mock_start.assert_not_called()