[BugFix] fix dp sheduler bug in ep4tp1 when start by using multi_api_server (#6598)

* [BugFix] fix dp sheduler bug in ep4tp1 when start by using multi_api_server

* [BugFix] modify request_queue and result_queue of dp scheduler
This commit is contained in:
ddchenhao66
2026-03-05 10:04:12 +08:00
committed by GitHub
parent 56ceeda80c
commit fa4815b93a
6 changed files with 26 additions and 48 deletions
-8
View File
@@ -2116,19 +2116,11 @@ class EngineService:
role = self.cfg.scheduler_config.splitwise_role
host_ip = self.cfg.host_ip
request_queues_for_dp_ipc = None
result_queue_for_dp_ipc = None
if self.cfg.scheduler_config.name == "splitwise":
self.scheduler.start(role, host_ip, self.cfg.register_info)
elif self.cfg.scheduler_config.name == "dp":
request_queues_for_dp_ipc = []
result_queue_for_dp_ipc = multiprocessing.Queue()
for i in range(self.cfg.parallel_config.data_parallel_size):
request_queues_for_dp_ipc.append(multiprocessing.Queue())
self.scheduler.start(
self.cfg.node_rank * self.cfg.worker_num_per_node % self.cfg.worker_num_per_node,
request_queues_for_dp_ipc,
result_queue_for_dp_ipc,
)
if not envs.FD_ENABLE_MULTI_API_SERVER:
-11
View File
@@ -775,20 +775,11 @@ class LLMEngine:
role = self.cfg.scheduler_config.splitwise_role
host_ip = self.cfg.host_ip
request_queues_for_dp_ipc = None
result_queues_for_dp_ipc = None
if self.cfg.scheduler_config.name == "splitwise":
self.engine.scheduler.start(role, host_ip, self.cfg.register_info)
elif self.cfg.scheduler_config.name == "dp":
request_queues_for_dp_ipc = []
result_queues_for_dp_ipc = []
for i in range(self.cfg.parallel_config.data_parallel_size):
request_queues_for_dp_ipc.append(multiprocessing.Queue())
result_queues_for_dp_ipc.append(multiprocessing.Queue())
self.engine.scheduler.start(
self.cfg.node_rank * self.cfg.worker_num_per_node % self.cfg.worker_num_per_node,
request_queues_for_dp_ipc,
result_queues_for_dp_ipc,
)
if not envs.FD_ENABLE_MULTI_API_SERVER:
@@ -826,8 +817,6 @@ class LLMEngine:
cfg,
i,
None,
request_queues_for_dp_ipc,
result_queues_for_dp_ipc,
),
)
)
+2 -5
View File
@@ -96,9 +96,7 @@ class ExpertService:
self._finalizer = weakref.finalize(self, self._exit_sub_services)
def start(
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queues_for_dp_ipc=None
):
def start(self, ipc_signal_suffix, local_data_parallel_id):
"""
Initializes the engine and starts its sub-services.
If `api_server_pid` is defined, will launch a thread
@@ -112,8 +110,7 @@ class ExpertService:
self.engine.create_data_processor()
if self.cfg.scheduler_config.name == "dp":
self.cfg.init_cache_info()
assert (request_queues_for_dp_ipc is not None) and (result_queues_for_dp_ipc is not None)
self.engine.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queues_for_dp_ipc)
self.engine.scheduler.start(local_data_parallel_id)
if ipc_signal_suffix is not None:
self.api_server_pid = ipc_signal_suffix
@@ -156,6 +156,7 @@ def load_data_service():
engine_args = EngineArgs.from_cli_args(args)
config = engine_args.create_engine_config()
api_server_logger.info(f"local_data_parallel_id: {config.parallel_config}")
api_server_logger.info(f"local_data_parallel_id: {config.parallel_config.local_data_parallel_id}")
expert_service = ExpertService(config, config.parallel_config.local_data_parallel_id)
if not expert_service.start(args.port, config.parallel_config.local_data_parallel_id):
api_server_logger.error("Failed to initialize FastDeploy LLM expert service, service exit now!")
+8 -8
View File
@@ -15,9 +15,9 @@
"""
import logging
import multiprocessing
import threading
import time
from multiprocessing import Queue
from typing import Dict, List, Optional
from fastdeploy.engine.request import Request, RequestOutput
@@ -207,10 +207,10 @@ class DPScheduler:
splitwise_role,
)
def start(self, dp_rank: int, request_queues: List[Queue], result_queues: Queue):
def start(self, dp_rank: int):
self.dp_rank = dp_rank
self.request_queues = request_queues
self.result_queues = result_queues
self.request_queues = multiprocessing.Queue()
self.result_queues = multiprocessing.Queue()
self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log")
self._scheduler.scheduler_logger = self.scheduler_logger
threading.Thread(target=self._put_requests_to_local).start()
@@ -221,13 +221,13 @@ class DPScheduler:
for request in requests:
if not hasattr(request, "dp_rank"):
raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}")
self.request_queues[request.dp_rank].put(request)
self.request_queues.put(request)
results.append((request.request_id, None))
return results
def _put_requests_to_local(self):
while True:
request = self.request_queues[self.dp_rank].get()
request = self.request_queues.get()
self.scheduler_logger.info(f"Recieve request from puller, request_id: {request.request_id}")
self._scheduler.put_requests([request])
@@ -236,7 +236,7 @@ class DPScheduler:
results = self._scheduler.get_results()
if len(results) == 0:
continue
self.result_queues[self.dp_rank].put(results)
self.result_queues.put(results)
def get_requests(
self,
@@ -257,4 +257,4 @@ class DPScheduler:
self._scheduler.put_results(results)
def get_results(self) -> Dict[str, List[RequestOutput]]:
return self.result_queues[self.dp_rank].get()
return self.result_queues.get()
+15 -16
View File
@@ -16,7 +16,7 @@ import sys
import time
import unittest
from multiprocessing import Queue
from unittest.mock import Mock, patch
from unittest.mock import Mock, call, patch
# Mock all external dependencies before importing anything
mock_logger = Mock()
@@ -577,12 +577,12 @@ class TestDPScheduler(unittest.TestCase):
@patch("threading.Thread")
def test_put_requests_success(self, mock_thread):
"""Test successful put_requests with dp_rank."""
# Create request queues - use Mock instead of real Queue to avoid threading issues
request_queues = [Mock(), Mock(), Mock()]
result_queue = Mock()
# Start the scheduler - this will create mocked threads
self.dp_scheduler.start(0, request_queues, result_queue)
self.dp_scheduler.start(0)
# Create request queues - use Mock instead of real Queue to avoid threading issues
self.dp_scheduler.request_queues = Mock()
self.dp_scheduler.result_queues = Mock()
# Create requests with dp_rank
mock_request1 = MockRequest("test_req1")
@@ -601,18 +601,17 @@ class TestDPScheduler(unittest.TestCase):
self.assertEqual(results[1], ("test_req2", None))
# Verify requests were put to the correct queues
request_queues[0].put.assert_called_once_with(mock_request1)
request_queues[1].put.assert_called_once_with(mock_request2)
expected_calls = [call(mock_request1), call(mock_request2)]
self.dp_scheduler.request_queues.put.assert_has_calls(expected_calls, any_order=False)
@patch("threading.Thread")
def test_start_creates_threads(self, mock_thread):
"""Test that start creates and starts threads."""
mock_thread.return_value = Mock()
request_queues = [Queue(), Queue()]
result_queue = Queue()
self.dp_scheduler.start(0, request_queues, result_queue)
self.dp_scheduler.start(0)
self.dp_scheduler.request_queues = Queue()
self.dp_scheduler.result_queues = Queue()
# Should create 2 threads
self.assertEqual(mock_thread.call_count, 2)
@@ -640,8 +639,8 @@ class TestDPIntegration(unittest.TestCase):
with patch.object(dp_scheduler, "start") as mock_start:
# Set up test data directly
dp_scheduler.dp_rank = 0
dp_scheduler.request_queues = [Mock(), Mock()]
dp_scheduler.result_queue = Mock()
dp_scheduler.request_queues = Mock()
dp_scheduler.result_queues = Mock()
dp_scheduler.scheduler_logger = mock_logger
dp_scheduler._scheduler.scheduler_logger = mock_logger
@@ -650,7 +649,7 @@ class TestDPIntegration(unittest.TestCase):
mock_request.dp_rank = 0
# Mock the request_queues to avoid real Queue operations
dp_scheduler.request_queues[0].put = Mock()
dp_scheduler.request_queues.put = Mock()
# Test put_requests functionality
results = dp_scheduler.put_requests([mock_request])
@@ -658,7 +657,7 @@ class TestDPIntegration(unittest.TestCase):
self.assertEqual(results[0], ("integration_req", None))
# Verify the request was put to the correct queue
dp_scheduler.request_queues[0].put.assert_called_once_with(mock_request)
dp_scheduler.request_queues.put.assert_called_once_with(mock_request)
# Verify start method was not called (to avoid threads)
mock_start.assert_not_called()