[PD Disaggregation] Unify the disaggregation info and the pd communication (#5438)

* Unify the disaggregation info and the pd communication

* up

* up

* fix

* fix conflict

* fix unittest
This commit is contained in:
Juncai
2025-12-09 14:44:59 +08:00
committed by GitHub
parent 8178e3fc6a
commit 83ea9646f9
10 changed files with 146 additions and 233 deletions
+20 -11
View File
@@ -14,7 +14,6 @@
# limitations under the License.
"""
import copy
import hashlib
import math
import pickle
@@ -533,16 +532,26 @@ class APIScheduler:
else:
dnodes.sort()
dnode = self.select_pd(req, dnodes, "decode")
disaggregated = copy.deepcopy(dnode.disaggregated)
transfer_protocol = disaggregated["transfer_protocol"]
if len(transfer_protocol) > 1 and "ipc" in transfer_protocol and "rdma" in transfer_protocol:
if pnode.host == dnode.host:
disaggregated["transfer_protocol"] = "ipc"
else:
disaggregated["transfer_protocol"] = "rdma"
else:
disaggregated["transfer_protocol"] = transfer_protocol[0]
req.disaggregate_info = disaggregated
is_same_node = pnode.disaggregated["host_ip"] == dnode.disaggregated["host_ip"]
is_support_ipc = (
"ipc" in pnode.disaggregated["transfer_protocol"] and "ipc" in dnode.disaggregated["transfer_protocol"]
)
is_same_tp_size = pnode.disaggregated["tp_size"] == dnode.disaggregated["tp_size"]
use_ipc = is_same_node and is_support_ipc and is_same_tp_size
disaggregate_info = {
"prefill_ip": pnode.disaggregated["host_ip"],
"decode_ip": dnode.disaggregated["host_ip"],
"prefill_connector_port": pnode.disaggregated["connector_port"],
"decode_connector_port": dnode.disaggregated["connector_port"],
"decode_device_ids": dnode.disaggregated["device_ids"],
"decode_rdma_ports": dnode.disaggregated["rdma_ports"],
"transfer_protocol": "ipc" if use_ipc else "rdma",
"decode_tp_size": dnode.disaggregated["tp_size"],
}
req.disaggregate_info = disaggregate_info
pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}"
req_dict = req.to_dict()
req_dict["group"] = group