mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 09:44:10 +08:00
[PD Disaggregation] support different tp_size for prefill and decode (#5296)
* up * up * up * fix
This commit is contained in:
@@ -41,7 +41,7 @@
|
||||
* @param local_key_cache Vector of local key cache pointers
|
||||
* @param local_value_cache Vector of local value cache pointers
|
||||
* @param block_number Number of blocks in cache
|
||||
* @param block_bytes Size of each block in bytes
|
||||
* @param block_bytes Bytes of each block in each tp rank
|
||||
*
|
||||
* @throws std::runtime_error If initialization fails
|
||||
*/
|
||||
@@ -51,7 +51,9 @@ RDMACommunicator::RDMACommunicator(std::string& role,
|
||||
std::vector<int64_t> local_key_cache,
|
||||
std::vector<int64_t> local_value_cache,
|
||||
int block_number,
|
||||
int block_bytes)
|
||||
int block_bytes,
|
||||
int prefill_tp_size,
|
||||
int prefill_tp_idx)
|
||||
: splitwise_role(role),
|
||||
gpu_idx(gpu_idx),
|
||||
port(port),
|
||||
@@ -59,6 +61,8 @@ RDMACommunicator::RDMACommunicator(std::string& role,
|
||||
local_cache_value_ptr_layer_head_(std::move(local_value_cache)),
|
||||
block_number(block_number),
|
||||
block_size_byte(block_bytes),
|
||||
prefill_tp_size(prefill_tp_size),
|
||||
prefill_tp_idx(prefill_tp_idx),
|
||||
RDMACommunicator_status(0),
|
||||
rdma_event_channel_epoll_fd(-1) {
|
||||
try {
|
||||
@@ -480,11 +484,14 @@ std::string RDMACommunicator::fetch_local_ip() {
|
||||
*
|
||||
* @param dst_ip Destination IP address
|
||||
* @param dst_port Destination port
|
||||
* @param dest_tp_size Default 0: assumes dest has same tp_size as source;
|
||||
* otherwise specifies decode tp_size
|
||||
* @return ConnStatus::kConnected ConnStatus::kError;
|
||||
*/
|
||||
|
||||
int RDMACommunicator::connect(const std::string& dst_ip,
|
||||
const std::string& dst_port) {
|
||||
const std::string& dst_port,
|
||||
int dest_tp_size = 0) {
|
||||
std::string url = dst_ip + ":" + dst_port;
|
||||
|
||||
// Initialize IB devices if not already done
|
||||
@@ -515,6 +522,10 @@ int RDMACommunicator::connect(const std::string& dst_ip,
|
||||
ctx->conn.layer_number = layer_number;
|
||||
ctx->conn.block_number = block_number;
|
||||
ctx->conn.block_byte_size = block_size_byte;
|
||||
if (dest_tp_size > 0)
|
||||
ctx->conn.decode_tp_size = dest_tp_size;
|
||||
else
|
||||
ctx->conn.decode_tp_size = prefill_tp_size;
|
||||
|
||||
// Get port information for the connection
|
||||
if (get_port_info(ctx->context, ib_dev->port, &ctx->portinfo)) {
|
||||
@@ -537,9 +548,6 @@ int RDMACommunicator::connect(const std::string& dst_ip,
|
||||
ERR("Couldn't getexchange port infodestinations");
|
||||
return static_cast<int>(ConnStatus::kError);
|
||||
} else {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
ctx->conn.connected = 1;
|
||||
conn_map[url] = ctx;
|
||||
client_exchange_mr(ctx);
|
||||
}
|
||||
|
||||
@@ -589,6 +597,10 @@ int RDMACommunicator::connect(const std::string& dst_ip,
|
||||
}
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
ctx->conn.connected = 1;
|
||||
conn_map[url] = ctx;
|
||||
|
||||
WARN("connect end ....");
|
||||
return static_cast<int>(ConnStatus::kConnected);
|
||||
}
|
||||
@@ -649,6 +661,7 @@ int RDMACommunicator::client_listener() {
|
||||
|
||||
bool RDMACommunicator::is_connected(const std::string& dst_ip,
|
||||
const std::string& dst_port) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::string url = dst_ip + ":" + dst_port;
|
||||
return conn_map.find(url) != conn_map.end();
|
||||
}
|
||||
@@ -889,17 +902,25 @@ int RDMACommunicator::write_cache(const std::string& ip,
|
||||
uint32_t cache_value_rkey =
|
||||
ctx->conn.write_cache_value_remote_rkey_list[layer_idx];
|
||||
uint32_t crc_cache_key_rkey, crc_cache_value_rkey;
|
||||
bool pd_tp_size_is_same = prefill_tp_size == ctx->conn.decode_tp_size;
|
||||
uint64_t offset_in_block =
|
||||
pd_tp_size_is_same ? 0 : block_size_byte * prefill_tp_idx;
|
||||
uint64_t total_block_size_byte =
|
||||
pd_tp_size_is_same ? block_size_byte : block_size_byte * prefill_tp_size;
|
||||
|
||||
for (size_t block_index = 0; block_index < block_num; ++block_index) {
|
||||
char* char_ptr = static_cast<char*>(
|
||||
ctx->conn.write_cache_key_remote_ptr_list[layer_idx]);
|
||||
cache_key_remote_addr[block_index] =
|
||||
(uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte));
|
||||
cache_key_remote_addr[block_index] = (uint64_t(
|
||||
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
|
||||
offset_in_block));
|
||||
char_ptr = static_cast<char*>(
|
||||
ctx->conn.write_cache_value_remote_ptr_list[layer_idx]);
|
||||
cache_value_remote_addr[block_index] =
|
||||
(uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte));
|
||||
cache_value_remote_addr[block_index] = (uint64_t(
|
||||
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
|
||||
offset_in_block));
|
||||
}
|
||||
|
||||
ctx->conn.wc_target_count = 0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
bool is_key = (i == 0);
|
||||
|
||||
@@ -14,10 +14,39 @@ PYBIND11_MODULE(rdma_comm, m) {
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
int,
|
||||
int>())
|
||||
.def("connect", &RDMACommunicator::connect)
|
||||
.def("is_connected", &RDMACommunicator::is_connected)
|
||||
.def("write_cache", &RDMACommunicator::write_cache);
|
||||
int,
|
||||
int,
|
||||
int>(),
|
||||
py::arg("splitwise_role"),
|
||||
py::arg("gpu_idx"),
|
||||
py::arg("port"),
|
||||
py::arg("key_cache_ptrs"),
|
||||
py::arg("value_cache_ptrs"),
|
||||
py::arg("block_number"),
|
||||
py::arg("block_bytes"),
|
||||
py::arg("prefill_tp_size") = 1,
|
||||
py::arg("prefill_tp_idx") = 0)
|
||||
.def("connect",
|
||||
&RDMACommunicator::connect,
|
||||
py::arg("dst_ip"),
|
||||
py::arg("dst_port"),
|
||||
py::arg("dst_tp_size") =
|
||||
0, // Default 0: assumes dest has same tp_size as source;
|
||||
// otherwise specifies decode tp_size
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("is_connected",
|
||||
&RDMACommunicator::is_connected,
|
||||
py::arg("dst_ip"),
|
||||
py::arg("dst_port"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("write_cache",
|
||||
&RDMACommunicator::write_cache,
|
||||
py::arg("dst_ip"),
|
||||
py::arg("dst_port"),
|
||||
py::arg("local_block_ids"),
|
||||
py::arg("remote_block_ids"),
|
||||
py::arg("layer_idx"),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
#ifdef VERSION_INFO
|
||||
m.attr("__version__") = VERSION_INFO;
|
||||
|
||||
Reference in New Issue
Block a user