mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] fix dp sheduler bug in ep4tp1 when start by using multi_api_server (#6598)
* [BugFix] fix dp sheduler bug in ep4tp1 when start by using multi_api_server * [BugFix] modify request_queue and result_queue of dp scheduler
This commit is contained in:
@@ -2116,19 +2116,11 @@ class EngineService:
|
||||
|
||||
role = self.cfg.scheduler_config.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
request_queues_for_dp_ipc = None
|
||||
result_queue_for_dp_ipc = None
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.scheduler.start(role, host_ip, self.cfg.register_info)
|
||||
elif self.cfg.scheduler_config.name == "dp":
|
||||
request_queues_for_dp_ipc = []
|
||||
result_queue_for_dp_ipc = multiprocessing.Queue()
|
||||
for i in range(self.cfg.parallel_config.data_parallel_size):
|
||||
request_queues_for_dp_ipc.append(multiprocessing.Queue())
|
||||
self.scheduler.start(
|
||||
self.cfg.node_rank * self.cfg.worker_num_per_node % self.cfg.worker_num_per_node,
|
||||
request_queues_for_dp_ipc,
|
||||
result_queue_for_dp_ipc,
|
||||
)
|
||||
|
||||
if not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
|
||||
@@ -775,20 +775,11 @@ class LLMEngine:
|
||||
|
||||
role = self.cfg.scheduler_config.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
request_queues_for_dp_ipc = None
|
||||
result_queues_for_dp_ipc = None
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.engine.scheduler.start(role, host_ip, self.cfg.register_info)
|
||||
elif self.cfg.scheduler_config.name == "dp":
|
||||
request_queues_for_dp_ipc = []
|
||||
result_queues_for_dp_ipc = []
|
||||
for i in range(self.cfg.parallel_config.data_parallel_size):
|
||||
request_queues_for_dp_ipc.append(multiprocessing.Queue())
|
||||
result_queues_for_dp_ipc.append(multiprocessing.Queue())
|
||||
self.engine.scheduler.start(
|
||||
self.cfg.node_rank * self.cfg.worker_num_per_node % self.cfg.worker_num_per_node,
|
||||
request_queues_for_dp_ipc,
|
||||
result_queues_for_dp_ipc,
|
||||
)
|
||||
|
||||
if not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
@@ -826,8 +817,6 @@ class LLMEngine:
|
||||
cfg,
|
||||
i,
|
||||
None,
|
||||
request_queues_for_dp_ipc,
|
||||
result_queues_for_dp_ipc,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -96,9 +96,7 @@ class ExpertService:
|
||||
|
||||
self._finalizer = weakref.finalize(self, self._exit_sub_services)
|
||||
|
||||
def start(
|
||||
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queues_for_dp_ipc=None
|
||||
):
|
||||
def start(self, ipc_signal_suffix, local_data_parallel_id):
|
||||
"""
|
||||
Initializes the engine and starts its sub-services.
|
||||
If `api_server_pid` is defined, will launch a thread
|
||||
@@ -112,8 +110,7 @@ class ExpertService:
|
||||
self.engine.create_data_processor()
|
||||
if self.cfg.scheduler_config.name == "dp":
|
||||
self.cfg.init_cache_info()
|
||||
assert (request_queues_for_dp_ipc is not None) and (result_queues_for_dp_ipc is not None)
|
||||
self.engine.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queues_for_dp_ipc)
|
||||
self.engine.scheduler.start(local_data_parallel_id)
|
||||
|
||||
if ipc_signal_suffix is not None:
|
||||
self.api_server_pid = ipc_signal_suffix
|
||||
|
||||
@@ -156,6 +156,7 @@ def load_data_service():
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
config = engine_args.create_engine_config()
|
||||
api_server_logger.info(f"local_data_parallel_id: {config.parallel_config}")
|
||||
api_server_logger.info(f"local_data_parallel_id: {config.parallel_config.local_data_parallel_id}")
|
||||
expert_service = ExpertService(config, config.parallel_config.local_data_parallel_id)
|
||||
if not expert_service.start(args.port, config.parallel_config.local_data_parallel_id):
|
||||
api_server_logger.error("Failed to initialize FastDeploy LLM expert service, service exit now!")
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
import multiprocessing
|
||||
import threading
|
||||
import time
|
||||
from multiprocessing import Queue
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
@@ -207,10 +207,10 @@ class DPScheduler:
|
||||
splitwise_role,
|
||||
)
|
||||
|
||||
def start(self, dp_rank: int, request_queues: List[Queue], result_queues: Queue):
|
||||
def start(self, dp_rank: int):
|
||||
self.dp_rank = dp_rank
|
||||
self.request_queues = request_queues
|
||||
self.result_queues = result_queues
|
||||
self.request_queues = multiprocessing.Queue()
|
||||
self.result_queues = multiprocessing.Queue()
|
||||
self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log")
|
||||
self._scheduler.scheduler_logger = self.scheduler_logger
|
||||
threading.Thread(target=self._put_requests_to_local).start()
|
||||
@@ -221,13 +221,13 @@ class DPScheduler:
|
||||
for request in requests:
|
||||
if not hasattr(request, "dp_rank"):
|
||||
raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}")
|
||||
self.request_queues[request.dp_rank].put(request)
|
||||
self.request_queues.put(request)
|
||||
results.append((request.request_id, None))
|
||||
return results
|
||||
|
||||
def _put_requests_to_local(self):
|
||||
while True:
|
||||
request = self.request_queues[self.dp_rank].get()
|
||||
request = self.request_queues.get()
|
||||
self.scheduler_logger.info(f"Recieve request from puller, request_id: {request.request_id}")
|
||||
self._scheduler.put_requests([request])
|
||||
|
||||
@@ -236,7 +236,7 @@ class DPScheduler:
|
||||
results = self._scheduler.get_results()
|
||||
if len(results) == 0:
|
||||
continue
|
||||
self.result_queues[self.dp_rank].put(results)
|
||||
self.result_queues.put(results)
|
||||
|
||||
def get_requests(
|
||||
self,
|
||||
@@ -257,4 +257,4 @@ class DPScheduler:
|
||||
self._scheduler.put_results(results)
|
||||
|
||||
def get_results(self) -> Dict[str, List[RequestOutput]]:
|
||||
return self.result_queues[self.dp_rank].get()
|
||||
return self.result_queues.get()
|
||||
|
||||
@@ -16,7 +16,7 @@ import sys
|
||||
import time
|
||||
import unittest
|
||||
from multiprocessing import Queue
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock, call, patch
|
||||
|
||||
# Mock all external dependencies before importing anything
|
||||
mock_logger = Mock()
|
||||
@@ -577,12 +577,12 @@ class TestDPScheduler(unittest.TestCase):
|
||||
@patch("threading.Thread")
|
||||
def test_put_requests_success(self, mock_thread):
|
||||
"""Test successful put_requests with dp_rank."""
|
||||
# Create request queues - use Mock instead of real Queue to avoid threading issues
|
||||
request_queues = [Mock(), Mock(), Mock()]
|
||||
result_queue = Mock()
|
||||
|
||||
# Start the scheduler - this will create mocked threads
|
||||
self.dp_scheduler.start(0, request_queues, result_queue)
|
||||
self.dp_scheduler.start(0)
|
||||
|
||||
# Create request queues - use Mock instead of real Queue to avoid threading issues
|
||||
self.dp_scheduler.request_queues = Mock()
|
||||
self.dp_scheduler.result_queues = Mock()
|
||||
|
||||
# Create requests with dp_rank
|
||||
mock_request1 = MockRequest("test_req1")
|
||||
@@ -601,18 +601,17 @@ class TestDPScheduler(unittest.TestCase):
|
||||
self.assertEqual(results[1], ("test_req2", None))
|
||||
|
||||
# Verify requests were put to the correct queues
|
||||
request_queues[0].put.assert_called_once_with(mock_request1)
|
||||
request_queues[1].put.assert_called_once_with(mock_request2)
|
||||
expected_calls = [call(mock_request1), call(mock_request2)]
|
||||
self.dp_scheduler.request_queues.put.assert_has_calls(expected_calls, any_order=False)
|
||||
|
||||
@patch("threading.Thread")
|
||||
def test_start_creates_threads(self, mock_thread):
|
||||
"""Test that start creates and starts threads."""
|
||||
mock_thread.return_value = Mock()
|
||||
|
||||
request_queues = [Queue(), Queue()]
|
||||
result_queue = Queue()
|
||||
|
||||
self.dp_scheduler.start(0, request_queues, result_queue)
|
||||
self.dp_scheduler.start(0)
|
||||
self.dp_scheduler.request_queues = Queue()
|
||||
self.dp_scheduler.result_queues = Queue()
|
||||
|
||||
# Should create 2 threads
|
||||
self.assertEqual(mock_thread.call_count, 2)
|
||||
@@ -640,8 +639,8 @@ class TestDPIntegration(unittest.TestCase):
|
||||
with patch.object(dp_scheduler, "start") as mock_start:
|
||||
# Set up test data directly
|
||||
dp_scheduler.dp_rank = 0
|
||||
dp_scheduler.request_queues = [Mock(), Mock()]
|
||||
dp_scheduler.result_queue = Mock()
|
||||
dp_scheduler.request_queues = Mock()
|
||||
dp_scheduler.result_queues = Mock()
|
||||
dp_scheduler.scheduler_logger = mock_logger
|
||||
dp_scheduler._scheduler.scheduler_logger = mock_logger
|
||||
|
||||
@@ -650,7 +649,7 @@ class TestDPIntegration(unittest.TestCase):
|
||||
mock_request.dp_rank = 0
|
||||
|
||||
# Mock the request_queues to avoid real Queue operations
|
||||
dp_scheduler.request_queues[0].put = Mock()
|
||||
dp_scheduler.request_queues.put = Mock()
|
||||
|
||||
# Test put_requests functionality
|
||||
results = dp_scheduler.put_requests([mock_request])
|
||||
@@ -658,7 +657,7 @@ class TestDPIntegration(unittest.TestCase):
|
||||
self.assertEqual(results[0], ("integration_req", None))
|
||||
|
||||
# Verify the request was put to the correct queue
|
||||
dp_scheduler.request_queues[0].put.assert_called_once_with(mock_request)
|
||||
dp_scheduler.request_queues.put.assert_called_once_with(mock_request)
|
||||
|
||||
# Verify start method was not called (to avoid threads)
|
||||
mock_start.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user