[PD Disaggregation] support different tp_size for prefill and decode (#5296)

* up

* up

* up

* fix
This commit is contained in:
Juncai
2025-12-01 17:50:20 +08:00
committed by GitHub
parent 54119cf07e
commit 0925d44f18
13 changed files with 584 additions and 36 deletions
@@ -34,6 +34,8 @@ class RDMACommManager:
max_block_num,
block_bytes,
rdma_port,
prefill_tp_size,
prefill_tp_idx,
):
try:
import rdma_comm
@@ -51,12 +53,16 @@ class RDMACommManager:
cache_v_ptr_list,
max_block_num,
block_bytes,
prefill_tp_size,
prefill_tp_idx,
)
self.splitwise_role = splitwise_role
self.connected_rdma = set()
logger.info(f"init rdma messager {gpu_id} {rdma_port}")
logger.info(
f"init rdma messager {gpu_id} {rdma_port}, prefill_tp_size: {prefill_tp_size}, prefill_tp_idx: {prefill_tp_idx}"
)
def connect(self, ip, port):
def connect(self, ip, port, tp_size):
"""
Connect to remote gpu and write cache.
"""
@@ -65,7 +71,7 @@ class RDMACommManager:
if ret:
return True
ret = self.messager.connect(ip, str(port))
ret = self.messager.connect(ip, str(port), tp_size)
logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}")
return ret == 0