[PD Disaggregation] support different tp_size for prefill and decode (#5296)

* up

* up

* up

* fix
This commit is contained in:
Juncai
2025-12-01 17:50:20 +08:00
committed by GitHub
parent 54119cf07e
commit 0925d44f18
13 changed files with 584 additions and 36 deletions
@@ -41,7 +41,7 @@
* @param local_key_cache Vector of local key cache pointers
* @param local_value_cache Vector of local value cache pointers
* @param block_number Number of blocks in cache
* @param block_bytes Size of each block in bytes
* @param block_bytes Bytes of each block in each tp rank
*
* @throws std::runtime_error If initialization fails
*/
@@ -51,7 +51,9 @@ RDMACommunicator::RDMACommunicator(std::string& role,
std::vector<int64_t> local_key_cache,
std::vector<int64_t> local_value_cache,
int block_number,
int block_bytes)
int block_bytes,
int prefill_tp_size,
int prefill_tp_idx)
: splitwise_role(role),
gpu_idx(gpu_idx),
port(port),
@@ -59,6 +61,8 @@ RDMACommunicator::RDMACommunicator(std::string& role,
local_cache_value_ptr_layer_head_(std::move(local_value_cache)),
block_number(block_number),
block_size_byte(block_bytes),
prefill_tp_size(prefill_tp_size),
prefill_tp_idx(prefill_tp_idx),
RDMACommunicator_status(0),
rdma_event_channel_epoll_fd(-1) {
try {
@@ -480,11 +484,14 @@ std::string RDMACommunicator::fetch_local_ip() {
*
* @param dst_ip Destination IP address
* @param dst_port Destination port
* @param dest_tp_size Default 0: assumes dest has same tp_size as source;
* otherwise specifies decode tp_size
* @return ConnStatus::kConnected ConnStatus::kError;
*/
int RDMACommunicator::connect(const std::string& dst_ip,
const std::string& dst_port) {
const std::string& dst_port,
int dest_tp_size = 0) {
std::string url = dst_ip + ":" + dst_port;
// Initialize IB devices if not already done
@@ -515,6 +522,10 @@ int RDMACommunicator::connect(const std::string& dst_ip,
ctx->conn.layer_number = layer_number;
ctx->conn.block_number = block_number;
ctx->conn.block_byte_size = block_size_byte;
if (dest_tp_size > 0)
ctx->conn.decode_tp_size = dest_tp_size;
else
ctx->conn.decode_tp_size = prefill_tp_size;
// Get port information for the connection
if (get_port_info(ctx->context, ib_dev->port, &ctx->portinfo)) {
@@ -537,9 +548,6 @@ int RDMACommunicator::connect(const std::string& dst_ip,
ERR("Couldn't getexchange port infodestinations");
return static_cast<int>(ConnStatus::kError);
} else {
std::lock_guard<std::mutex> lock(mutex_);
ctx->conn.connected = 1;
conn_map[url] = ctx;
client_exchange_mr(ctx);
}
@@ -589,6 +597,10 @@ int RDMACommunicator::connect(const std::string& dst_ip,
}
}
std::lock_guard<std::mutex> lock(mutex_);
ctx->conn.connected = 1;
conn_map[url] = ctx;
WARN("connect end ....");
return static_cast<int>(ConnStatus::kConnected);
}
@@ -649,6 +661,7 @@ int RDMACommunicator::client_listener() {
bool RDMACommunicator::is_connected(const std::string& dst_ip,
const std::string& dst_port) {
std::lock_guard<std::mutex> lock(mutex_);
std::string url = dst_ip + ":" + dst_port;
return conn_map.find(url) != conn_map.end();
}
@@ -889,17 +902,25 @@ int RDMACommunicator::write_cache(const std::string& ip,
uint32_t cache_value_rkey =
ctx->conn.write_cache_value_remote_rkey_list[layer_idx];
uint32_t crc_cache_key_rkey, crc_cache_value_rkey;
bool pd_tp_size_is_same = prefill_tp_size == ctx->conn.decode_tp_size;
uint64_t offset_in_block =
pd_tp_size_is_same ? 0 : block_size_byte * prefill_tp_idx;
uint64_t total_block_size_byte =
pd_tp_size_is_same ? block_size_byte : block_size_byte * prefill_tp_size;
for (size_t block_index = 0; block_index < block_num; ++block_index) {
char* char_ptr = static_cast<char*>(
ctx->conn.write_cache_key_remote_ptr_list[layer_idx]);
cache_key_remote_addr[block_index] =
(uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte));
cache_key_remote_addr[block_index] = (uint64_t(
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
offset_in_block));
char_ptr = static_cast<char*>(
ctx->conn.write_cache_value_remote_ptr_list[layer_idx]);
cache_value_remote_addr[block_index] =
(uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte));
cache_value_remote_addr[block_index] = (uint64_t(
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
offset_in_block));
}
ctx->conn.wc_target_count = 0;
for (int i = 0; i < 2; ++i) {
bool is_key = (i == 0);
@@ -14,10 +14,39 @@ PYBIND11_MODULE(rdma_comm, m) {
std::vector<int64_t>,
std::vector<int64_t>,
int,
int>())
.def("connect", &RDMACommunicator::connect)
.def("is_connected", &RDMACommunicator::is_connected)
.def("write_cache", &RDMACommunicator::write_cache);
int,
int,
int>(),
py::arg("splitwise_role"),
py::arg("gpu_idx"),
py::arg("port"),
py::arg("key_cache_ptrs"),
py::arg("value_cache_ptrs"),
py::arg("block_number"),
py::arg("block_bytes"),
py::arg("prefill_tp_size") = 1,
py::arg("prefill_tp_idx") = 0)
.def("connect",
&RDMACommunicator::connect,
py::arg("dst_ip"),
py::arg("dst_port"),
py::arg("dst_tp_size") =
0, // Default 0: assumes dest has same tp_size as source;
// otherwise specifies decode tp_size
py::call_guard<py::gil_scoped_release>())
.def("is_connected",
&RDMACommunicator::is_connected,
py::arg("dst_ip"),
py::arg("dst_port"),
py::call_guard<py::gil_scoped_release>())
.def("write_cache",
&RDMACommunicator::write_cache,
py::arg("dst_ip"),
py::arg("dst_port"),
py::arg("local_block_ids"),
py::arg("remote_block_ids"),
py::arg("layer_idx"),
py::call_guard<py::gil_scoped_release>());
#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;