mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[PD Disaggregation] support different tp_size for prefill and decode (#5296)
* up * up * up * fix
This commit is contained in:
@@ -34,6 +34,8 @@ class RDMACommManager:
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
rdma_port,
|
||||
prefill_tp_size,
|
||||
prefill_tp_idx,
|
||||
):
|
||||
try:
|
||||
import rdma_comm
|
||||
@@ -51,12 +53,16 @@ class RDMACommManager:
|
||||
cache_v_ptr_list,
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
prefill_tp_size,
|
||||
prefill_tp_idx,
|
||||
)
|
||||
self.splitwise_role = splitwise_role
|
||||
self.connected_rdma = set()
|
||||
logger.info(f"init rdma messager {gpu_id} {rdma_port}")
|
||||
logger.info(
|
||||
f"init rdma messager {gpu_id} {rdma_port}, prefill_tp_size: {prefill_tp_size}, prefill_tp_idx: {prefill_tp_idx}"
|
||||
)
|
||||
|
||||
def connect(self, ip, port):
|
||||
def connect(self, ip, port, tp_size):
|
||||
"""
|
||||
Connect to remote gpu and write cache.
|
||||
"""
|
||||
@@ -65,7 +71,7 @@ class RDMACommManager:
|
||||
if ret:
|
||||
return True
|
||||
|
||||
ret = self.messager.connect(ip, str(port))
|
||||
ret = self.messager.connect(ip, str(port), tp_size)
|
||||
logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}")
|
||||
return ret == 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user