mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-25 01:55:45 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
+27
-27
@@ -3,13 +3,13 @@
|
||||
* @brief RDMA connection implementation for key-value cache
|
||||
* @version 1.0.0
|
||||
* @copyright Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
@@ -32,7 +32,7 @@ std::vector<IbDeviceInfo> g_ib_all_devs;
|
||||
static int64_t get_ib_busid(const char *dev_name) {
|
||||
char dev_path[PATH_MAX];
|
||||
snprintf(dev_path, PATH_MAX, "/sys/class/infiniband/%s/device", dev_name);
|
||||
|
||||
|
||||
char *p = realpath(dev_path, NULL);
|
||||
if (p == NULL) {
|
||||
WARN("Failed to get realpath for device %s: %s", dev_name, strerror(errno));
|
||||
@@ -63,7 +63,7 @@ static int64_t get_ib_busid(const char *dev_name) {
|
||||
/**
|
||||
* @brief Parse and cache IB device information
|
||||
* @return Number of IB devices found, negative on error
|
||||
*
|
||||
*
|
||||
* @note This function is thread-safe and will only parse once
|
||||
*/
|
||||
int parse_port_ib_info() {
|
||||
@@ -448,7 +448,7 @@ bool poll_cq_with_timeout(struct RdmaContext *ctx, int timeout_seconds, int cqe_
|
||||
if ((current_time.tv_sec - start_time.tv_sec) >= timeout_seconds) {
|
||||
ERR("Timeout occurred after %d seconds", timeout_seconds);
|
||||
free(wc_array);
|
||||
return false;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
@@ -468,7 +468,7 @@ bool clear_qp_info(struct RdmaContext* ctx) {
|
||||
success = false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (ctx->cq) {
|
||||
if (ibv_destroy_cq(ctx->cq)) {
|
||||
ERR("Failed to deallocate cq Domain.");
|
||||
@@ -565,7 +565,7 @@ struct RdmaContext* create_qp(struct IbDeviceInfo* ib_dev, struct ibv_pd** g_pd)
|
||||
return NULL;
|
||||
}
|
||||
|
||||
INFO("Successfully created QP 0x%x on device %s",
|
||||
INFO("Successfully created QP 0x%x on device %s",
|
||||
ctx->qp->qp_num, ib_dev->devName);
|
||||
|
||||
return ctx;
|
||||
@@ -601,10 +601,10 @@ bool client_exchange_destinations(
|
||||
ERR("Failed to get port info for port %d", ib_port);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
my_dest.lid = ctx->portinfo.lid;
|
||||
my_dest.mtu = ctx->portinfo.active_mtu;
|
||||
|
||||
|
||||
// Validate LID for InfiniBand
|
||||
if (ctx->portinfo.link_layer != IBV_LINK_LAYER_ETHERNET && !my_dest.lid) {
|
||||
ERR("Invalid LID 0x%04x for non-Ethernet link layer", my_dest.lid);
|
||||
@@ -722,24 +722,24 @@ bool server_exchange_mr(struct RdmaContext *ctx) {
|
||||
auto layer_num = ctx->conn.layer_number;
|
||||
auto& key_mrs = ctx->conn.write_cache_key_server_mr_list;
|
||||
auto& val_mrs = ctx->conn.write_cache_value_server_mr_list;
|
||||
|
||||
|
||||
// Verify that server memory regions are properly initialized
|
||||
if (key_mrs.size() != layer_num || val_mrs.size() != layer_num) {
|
||||
ERR("server write cache memory region size error");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// Prepare memory region information to send
|
||||
std::vector<uint64_t> send_key_ptrs;
|
||||
std::vector<uint32_t> send_key_rkeys;
|
||||
std::vector<uint64_t> send_val_ptrs;
|
||||
std::vector<uint32_t> send_val_rkeys;
|
||||
|
||||
|
||||
send_key_ptrs.reserve(layer_num);
|
||||
send_key_rkeys.reserve(layer_num);
|
||||
send_val_ptrs.reserve(layer_num);
|
||||
send_val_rkeys.reserve(layer_num);
|
||||
|
||||
|
||||
// Collect memory region information from local MRs
|
||||
for (int i = 0; i < layer_num; ++i) {
|
||||
send_key_ptrs.push_back(reinterpret_cast<uint64_t>(key_mrs[i]->addr));
|
||||
@@ -753,13 +753,13 @@ bool server_exchange_mr(struct RdmaContext *ctx) {
|
||||
if (!exchange_mr_vector(ctx, send_key_rkeys, false)) return false;
|
||||
if (!exchange_mr_vector(ctx, send_val_ptrs, false)) return false;
|
||||
if (!exchange_mr_vector(ctx, send_val_rkeys, false)) return false;
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Send memory region information from server to client
|
||||
*
|
||||
*
|
||||
* @param ctx The RDMA context
|
||||
* @param local_mr Pointer to the local memory region to be sent
|
||||
* @param byte_num Size of the memory region in bytes
|
||||
@@ -796,16 +796,16 @@ bool server_send_memory_region(struct RdmaContext *ctx, void *local_mr, int byte
|
||||
ibv_dereg_mr(ctx->conn.send_mr);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// Wait for completion
|
||||
struct ibv_wc wc;
|
||||
ctx->conn.wc_count = 0;
|
||||
ctx->conn.wc_target_count = 0;
|
||||
|
||||
|
||||
if (!poll_cq_with_timeout(ctx, RDMA_POLL_CQE_TIMEOUT, 1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// Deregister the memory region
|
||||
ibv_dereg_mr(ctx->conn.send_mr);
|
||||
return true;
|
||||
@@ -813,7 +813,7 @@ bool server_send_memory_region(struct RdmaContext *ctx, void *local_mr, int byte
|
||||
|
||||
/**
|
||||
* Receive memory region information on the client side
|
||||
*
|
||||
*
|
||||
* @param ctx The RDMA context
|
||||
* @param remote_mr Pointer to the buffer where remote memory region info will be stored
|
||||
* @param byte_num Size of the memory region in bytes
|
||||
@@ -863,17 +863,17 @@ bool client_receive_memory_region(struct RdmaContext *ctx, void *remote_mr, int
|
||||
|
||||
/**
|
||||
* Sets up a listening socket on the specified port
|
||||
*
|
||||
*
|
||||
* @param port The port number to listen on
|
||||
* @return The socket file descriptor on success, -1 on failure
|
||||
*/
|
||||
int setup_listening_socket(int port) {
|
||||
int sockfd = -1;
|
||||
struct addrinfo hints = {0};
|
||||
|
||||
|
||||
// Set up hints for getaddrinfo
|
||||
hints.ai_flags = AI_PASSIVE;
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
struct addrinfo *res = nullptr;
|
||||
@@ -881,14 +881,14 @@ int setup_listening_socket(int port) {
|
||||
// Convert port to string for getaddrinfo
|
||||
std::ostringstream service;
|
||||
service << port;
|
||||
|
||||
|
||||
// Get address info for the specified port
|
||||
int n = getaddrinfo(nullptr, service.str().c_str(), &hints, &res);
|
||||
if (n != 0) {
|
||||
ERR("getaddrinfo failed for port %d: %s", port, gai_strerror(n));
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
// Check if a specific network interface is specified
|
||||
const char *ifname = KVCacheConfig::getInstance().get_socket_interface();
|
||||
// Try each address until we successfully bind to one
|
||||
@@ -913,7 +913,7 @@ int setup_listening_socket(int port) {
|
||||
// Enable address reuse
|
||||
n = 1;
|
||||
setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &n, sizeof(n));
|
||||
|
||||
|
||||
// Attempt to bind to the address
|
||||
if (bind(sockfd, t->ai_addr, t->ai_addrlen) == 0) {
|
||||
break; // Successful bind
|
||||
@@ -948,7 +948,7 @@ int setup_listening_socket(int port) {
|
||||
close(sockfd);
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
// Enable TCP keep-alive
|
||||
int enable = 1;
|
||||
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &enable, sizeof(enable)) < 0) {
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
* @brief RDMA-based Key-Value Cache Communication Implementation
|
||||
* @version 1.0.0
|
||||
* @copyright Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
@@ -34,15 +34,15 @@
|
||||
|
||||
/**
|
||||
* @brief Construct a new RDMACommunicator object
|
||||
*
|
||||
*
|
||||
* @param role Role in distributed system ("decode" or "prefill")
|
||||
* @param gpu_idx GPU device index to use
|
||||
* @param port Communication port number
|
||||
* @param local_key_cache Vector of local key cache pointers
|
||||
* @param local_value_cache Vector of local value cache pointers
|
||||
* @param local_value_cache Vector of local value cache pointers
|
||||
* @param block_number Number of blocks in cache
|
||||
* @param block_bytes Size of each block in bytes
|
||||
*
|
||||
*
|
||||
* @throws std::runtime_error If initialization fails
|
||||
*/
|
||||
RDMACommunicator::RDMACommunicator(std::string &role, int gpu_idx,
|
||||
@@ -50,16 +50,16 @@ RDMACommunicator::RDMACommunicator(std::string &role, int gpu_idx,
|
||||
std::vector<int64_t> local_key_cache,
|
||||
std::vector<int64_t> local_value_cache,
|
||||
int block_number, int block_bytes)
|
||||
: splitwise_role(role),
|
||||
gpu_idx(gpu_idx),
|
||||
: splitwise_role(role),
|
||||
gpu_idx(gpu_idx),
|
||||
port(port),
|
||||
local_cache_key_ptr_layer_head_(std::move(local_key_cache)),
|
||||
local_cache_value_ptr_layer_head_(std::move(local_value_cache)),
|
||||
block_number(block_number),
|
||||
block_number(block_number),
|
||||
block_size_byte(block_bytes),
|
||||
RDMACommunicator_status(0),
|
||||
rdma_event_channel_epoll_fd(-1) {
|
||||
|
||||
|
||||
try {
|
||||
WARN("Initializing RDMA communicator for role: %s", role.c_str());
|
||||
|
||||
@@ -80,7 +80,7 @@ RDMACommunicator::RDMACommunicator(std::string &role, int gpu_idx,
|
||||
// Step 3:Initialize the event channel
|
||||
rdma_event_channel_epoll_fd = epoll_create1(EPOLL_CLOEXEC);
|
||||
if (rdma_event_channel_epoll_fd < 0) {
|
||||
throw std::runtime_error("Failed to create epoll fd: " +
|
||||
throw std::runtime_error("Failed to create epoll fd: " +
|
||||
std::string(strerror(errno)));
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ void RDMACommunicator::resize_vectors() {
|
||||
if (layer_number <= 0) {
|
||||
throw std::runtime_error("Invalid layer number");
|
||||
}
|
||||
|
||||
|
||||
local_cache_key_ptr_per_layer.resize(layer_number);
|
||||
local_cache_value_ptr_per_layer.resize(layer_number);
|
||||
}
|
||||
@@ -126,9 +126,9 @@ void RDMACommunicator::assign_pointers() {
|
||||
// Assign pointers for each layer and block
|
||||
for (int layer_idx = 0; layer_idx < layer_number; ++layer_idx) {
|
||||
// Validate layer head pointers
|
||||
if (local_cache_key_ptr_layer_head_[layer_idx] == 0 ||
|
||||
if (local_cache_key_ptr_layer_head_[layer_idx] == 0 ||
|
||||
local_cache_value_ptr_layer_head_[layer_idx] == 0) {
|
||||
throw std::runtime_error("Invalid cache pointer for layer " +
|
||||
throw std::runtime_error("Invalid cache pointer for layer " +
|
||||
std::to_string(layer_idx));
|
||||
}
|
||||
|
||||
@@ -140,12 +140,12 @@ void RDMACommunicator::assign_pointers() {
|
||||
for (int block_idx = 0; block_idx < block_number; ++block_idx) {
|
||||
local_cache_key_ptr_per_layer[layer_idx][block_idx] =
|
||||
reinterpret_cast<void*>(
|
||||
local_cache_key_ptr_layer_head_[layer_idx] +
|
||||
local_cache_key_ptr_layer_head_[layer_idx] +
|
||||
block_idx * block_size_byte);
|
||||
|
||||
|
||||
local_cache_value_ptr_per_layer[layer_idx][block_idx] =
|
||||
reinterpret_cast<void*>(
|
||||
local_cache_value_ptr_layer_head_[layer_idx] +
|
||||
local_cache_value_ptr_layer_head_[layer_idx] +
|
||||
block_idx * block_size_byte);
|
||||
}
|
||||
}
|
||||
@@ -214,7 +214,7 @@ RDMACommunicator::~RDMACommunicator() {
|
||||
|
||||
int RDMACommunicator::start_server(int sport, int sgid_idx, int gpu_index) {
|
||||
WARN("verbs server starting …");
|
||||
|
||||
|
||||
int sockfd = setup_listening_socket(sport);
|
||||
if (sockfd < 0) {
|
||||
ERR("Failed to set up listening socket");
|
||||
@@ -244,7 +244,7 @@ int RDMACommunicator::start_server(int sport, int sgid_idx, int gpu_index) {
|
||||
struct RdmaContext* contexts[RDMA_TCP_CONNECT_SIZE] = {nullptr};
|
||||
|
||||
while (RDMACommunicator_status == 1) {
|
||||
int nfds = epoll_wait(epollfd, events, 10, -1);
|
||||
int nfds = epoll_wait(epollfd, events, 10, -1);
|
||||
if (nfds < 0) {
|
||||
if (errno == EINTR) continue;
|
||||
ERR("epoll_wait failed: %s", strerror(errno));
|
||||
@@ -292,7 +292,7 @@ int RDMACommunicator::start_server(int sport, int sgid_idx, int gpu_index) {
|
||||
ctx->conn.block_byte_size = block_size_byte;
|
||||
ctx->conn.local_cache_key_ptr_per_layer = local_cache_key_ptr_per_layer;
|
||||
ctx->conn.local_cache_value_ptr_per_layer = local_cache_value_ptr_per_layer;
|
||||
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if(!server_mr_register_per_layer(ctx)){
|
||||
ERR("server_mr_register_per_layer failed");
|
||||
@@ -394,7 +394,7 @@ void RDMACommunicator::close_client_connection(int fd, struct RdmaContext* ctx,
|
||||
}
|
||||
|
||||
conn_map.erase(ctx->conn.url);
|
||||
|
||||
|
||||
for (size_t i = 0; i < ctx->conn.read_bufs.size(); ++i) {
|
||||
if (ctx->conn.read_mrs[i]) ibv_dereg_mr(ctx->conn.read_mrs[i]);
|
||||
if (ctx->conn.read_bufs[i]) free(ctx->conn.read_bufs[i]);
|
||||
@@ -402,7 +402,7 @@ void RDMACommunicator::close_client_connection(int fd, struct RdmaContext* ctx,
|
||||
ctx->conn.read_bufs.clear();
|
||||
ctx->conn.read_mrs.clear();
|
||||
|
||||
|
||||
|
||||
ctx->conn.connected = 0;
|
||||
if (!clear_qp_info(ctx)) {
|
||||
LOGD("Failed to clear memory regions for Connection fd %d", fd);
|
||||
@@ -465,7 +465,7 @@ std::string RDMACommunicator::fetch_local_ip() {
|
||||
* Connect to a remote RDMA endpoint
|
||||
*
|
||||
* Establishes an RDMA connection with the specified destination IP and port.
|
||||
*
|
||||
*
|
||||
* @param dst_ip Destination IP address
|
||||
* @param dst_port Destination port
|
||||
* @return ConnStatus::kConnected ConnStatus::kError;
|
||||
@@ -503,7 +503,7 @@ int RDMACommunicator::connect(const std::string &dst_ip,
|
||||
ctx->conn.layer_number = layer_number;
|
||||
ctx->conn.block_number = block_number;
|
||||
ctx->conn.block_byte_size = block_size_byte;
|
||||
|
||||
|
||||
// Get port information for the connection
|
||||
if (get_port_info(ctx->context, ib_dev->port, &ctx->portinfo)) {
|
||||
ERR("Couldn't get port info");
|
||||
@@ -516,7 +516,7 @@ int RDMACommunicator::connect(const std::string &dst_ip,
|
||||
}
|
||||
|
||||
// Exchange connection information with remote peer
|
||||
if (!client_exchange_destinations(ctx, ib_dev->port, KVCacheConfig::getInstance().resolve_rdma_dest_port(dst_port),
|
||||
if (!client_exchange_destinations(ctx, ib_dev->port, KVCacheConfig::getInstance().resolve_rdma_dest_port(dst_port),
|
||||
KVCacheConfig::getInstance().get_rdma_gid_index(), dst_ip)) {
|
||||
ERR("Couldn't getexchange port infodestinations");
|
||||
return static_cast<int>(ConnStatus::kError);
|
||||
@@ -641,7 +641,7 @@ void RDMACommunicator::remove_conn(const std::string& url) {
|
||||
}
|
||||
|
||||
struct RdmaContext *RDMACommunicator::get_conn(const std::string &ip,
|
||||
const std::string &port) {
|
||||
const std::string &port) {
|
||||
std::string url = ip + ":" + port;
|
||||
if (conn_map.find(url) == conn_map.end()) {
|
||||
return NULL;
|
||||
@@ -660,9 +660,9 @@ struct RdmaContext *RDMACommunicator::get_conn(const std::string &ip,
|
||||
* @throws std::runtime_error Throws an exception if registration fails
|
||||
*/
|
||||
struct ibv_mr* RDMACommunicator::register_memory_region(
|
||||
ibv_pd* pd, void* addr, size_t size,
|
||||
ibv_pd* pd, void* addr, size_t size,
|
||||
const std::string& desc, uint32_t access_flags) {
|
||||
|
||||
|
||||
if (!pd || !addr || size == 0) {
|
||||
throw std::invalid_argument("Invalid memory region parameters");
|
||||
}
|
||||
@@ -675,11 +675,11 @@ struct ibv_mr* RDMACommunicator::register_memory_region(
|
||||
|
||||
struct ibv_mr* mr = ibv_reg_mr(pd, addr, size, access_flags);
|
||||
if (!mr) {
|
||||
throw std::runtime_error("Failed to register memory region " + desc +
|
||||
throw std::runtime_error("Failed to register memory region " + desc +
|
||||
": " + strerror(errno));
|
||||
}
|
||||
|
||||
LOGD("Registered %s MR: addr=%p, size=%zu, flags=0x%x, lkey=0x%x",
|
||||
LOGD("Registered %s MR: addr=%p, size=%zu, flags=0x%x, lkey=0x%x",
|
||||
desc.c_str(), addr, size, access_flags, mr->lkey);
|
||||
return mr;
|
||||
}
|
||||
@@ -744,7 +744,7 @@ fail:
|
||||
/**
|
||||
* @brief Register server-side memory regions for RDMA operations
|
||||
* @param ctx RDMA context containing protection domain and other resources
|
||||
*
|
||||
*
|
||||
* @details This method registers memory regions for both keys and values
|
||||
* for each layer, enabling remote read/write access.
|
||||
*/
|
||||
@@ -850,7 +850,7 @@ int RDMACommunicator::write_cache(const std::string &ip,
|
||||
|
||||
for (size_t block_index = 0; block_index < block_num; ++block_index) {
|
||||
char* char_ptr = static_cast<char*>(ctx->conn.write_cache_key_remote_ptr_list[layer_idx]);
|
||||
cache_key_remote_addr[block_index] =
|
||||
cache_key_remote_addr[block_index] =
|
||||
(uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte));
|
||||
char_ptr = static_cast<char*>(ctx->conn.write_cache_value_remote_ptr_list[layer_idx]);
|
||||
cache_value_remote_addr[block_index] =
|
||||
@@ -869,28 +869,28 @@ int RDMACommunicator::write_cache(const std::string &ip,
|
||||
if (KVCacheConfig::getInstance().is_debug_mode_enabled()) {
|
||||
auto duration_us = std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::steady_clock::now() - start_time).count();
|
||||
|
||||
|
||||
DEBUG("Write cache completed - IP: %s, Port: %s, Layer: %d, BlockSize: %d, Blocks: %lu, Duration: %ld us",
|
||||
ip.c_str(), port.c_str(), layer_idx, block_size_byte, block_num, duration_us);
|
||||
}
|
||||
return 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool RDMACommunicator::post_block_send(struct RdmaContext* ctx, int layer_idx,
|
||||
const std::vector<int64_t>& local_block_ids,
|
||||
bool is_key, std::vector<uint64_t>& remote_addr,
|
||||
uint32_t rkey, const std::string &ip,
|
||||
bool RDMACommunicator::post_block_send(struct RdmaContext* ctx, int layer_idx,
|
||||
const std::vector<int64_t>& local_block_ids,
|
||||
bool is_key, std::vector<uint64_t>& remote_addr,
|
||||
uint32_t rkey, const std::string &ip,
|
||||
const std::string &port) {
|
||||
auto block_num = local_block_ids.size();
|
||||
assert(block_num > 0 && "block_num must be > 0");
|
||||
|
||||
bool success = execute_rdma_writes(ctx, layer_idx, local_block_ids,
|
||||
bool success = execute_rdma_writes(ctx, layer_idx, local_block_ids,
|
||||
is_key, remote_addr, rkey);
|
||||
|
||||
|
||||
if (success) {
|
||||
if (KVCacheConfig::getInstance().is_gdrcopy_flush_enabled()) {
|
||||
const size_t last_idx = block_num - 1;
|
||||
success = execute_read_verification(ctx, last_idx, remote_addr[last_idx],
|
||||
success = execute_read_verification(ctx, last_idx, remote_addr[last_idx],
|
||||
rkey, layer_idx, ip, port);
|
||||
}
|
||||
}
|
||||
@@ -905,22 +905,22 @@ bool RDMACommunicator::execute_rdma_writes(struct RdmaContext* ctx, int layer_id
|
||||
auto block_num = local_block_ids.size();
|
||||
struct ibv_sge* sge_list = new ibv_sge[block_num];
|
||||
struct ibv_send_wr* send_wr_list = new ibv_send_wr[block_num];
|
||||
|
||||
prepare_write_requests(sge_list, send_wr_list, layer_idx,
|
||||
|
||||
prepare_write_requests(sge_list, send_wr_list, layer_idx,
|
||||
local_block_ids, is_key, remote_addr, rkey);
|
||||
|
||||
|
||||
bool success = true;
|
||||
size_t inflight_wr = 0;
|
||||
|
||||
|
||||
for (size_t scnt = 0; scnt < block_num; ++scnt) {
|
||||
size_t idx = scnt % RDMA_WR_LIST_MAX_SIZE;
|
||||
inflight_wr++;
|
||||
|
||||
|
||||
bool is_batch_end = (idx == RDMA_WR_LIST_MAX_SIZE - 1 || scnt == block_num - 1);
|
||||
bool need_poll = (inflight_wr >= RDMA_SQ_MAX_SIZE || scnt == block_num - 1);
|
||||
|
||||
|
||||
if (is_batch_end) {
|
||||
if (!post_send_with_retry(ctx, &send_wr_list[scnt - idx],
|
||||
if (!post_send_with_retry(ctx, &send_wr_list[scnt - idx],
|
||||
inflight_wr, need_poll)) {
|
||||
success = false;
|
||||
break;
|
||||
@@ -930,7 +930,7 @@ bool RDMACommunicator::execute_rdma_writes(struct RdmaContext* ctx, int layer_id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
delete[] sge_list;
|
||||
delete[] send_wr_list;
|
||||
return success;
|
||||
@@ -944,19 +944,19 @@ void RDMACommunicator::prepare_write_requests(struct ibv_sge* sge_list,
|
||||
std::vector<uint64_t>& remote_addr,
|
||||
uint32_t rkey) {
|
||||
auto block_num = local_block_ids.size();
|
||||
|
||||
|
||||
for (size_t i = 0; i < block_num; ++i) {
|
||||
sge_list[i].addr = (uintptr_t)(is_key ?
|
||||
local_cache_key_ptr_per_layer[layer_idx][local_block_ids[i]] :
|
||||
sge_list[i].addr = (uintptr_t)(is_key ?
|
||||
local_cache_key_ptr_per_layer[layer_idx][local_block_ids[i]] :
|
||||
local_cache_value_ptr_per_layer[layer_idx][local_block_ids[i]]);
|
||||
sge_list[i].length = block_size_byte;
|
||||
sge_list[i].lkey = (is_key ?
|
||||
write_mr_key_list[layer_idx]->lkey :
|
||||
sge_list[i].lkey = (is_key ?
|
||||
write_mr_key_list[layer_idx]->lkey :
|
||||
write_mr_value_list[layer_idx]->lkey);
|
||||
|
||||
|
||||
size_t idx = i % RDMA_WR_LIST_MAX_SIZE;
|
||||
send_wr_list[i].wr_id = i;
|
||||
send_wr_list[i].next = (idx == RDMA_WR_LIST_MAX_SIZE - 1 || i == block_num - 1) ?
|
||||
send_wr_list[i].next = (idx == RDMA_WR_LIST_MAX_SIZE - 1 || i == block_num - 1) ?
|
||||
nullptr : &send_wr_list[i + 1];
|
||||
send_wr_list[i].sg_list = &sge_list[i];
|
||||
send_wr_list[i].num_sge = 1;
|
||||
@@ -975,7 +975,7 @@ bool RDMACommunicator::post_send_with_retry(struct RdmaContext* ctx,
|
||||
int retries = 0;
|
||||
int ret = 0;
|
||||
struct ibv_send_wr* bad_wr = nullptr;
|
||||
|
||||
|
||||
if (inflight_wr >= RDMA_SQ_MAX_SIZE && wr_list) {
|
||||
struct ibv_send_wr* last_wr = wr_list;
|
||||
while (last_wr->next) {
|
||||
@@ -983,7 +983,7 @@ bool RDMACommunicator::post_send_with_retry(struct RdmaContext* ctx,
|
||||
}
|
||||
last_wr->send_flags |= IBV_SEND_SIGNALED;
|
||||
}
|
||||
|
||||
|
||||
do {
|
||||
ret = ibv_post_send(ctx->qp, wr_list, &bad_wr);
|
||||
if (ret == 0) {
|
||||
@@ -997,14 +997,14 @@ bool RDMACommunicator::post_send_with_retry(struct RdmaContext* ctx,
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
ERR("ibv_post_send failed: %s (errno: %d), retry %d/%d",
|
||||
ERR("ibv_post_send failed: %s (errno: %d), retry %d/%d",
|
||||
strerror(errno), errno, retries + 1, max_retries);
|
||||
usleep(1000);
|
||||
retries++;
|
||||
}
|
||||
} while (retries < max_retries);
|
||||
|
||||
ERR("ibv_post_send failed after %d retries: %s (errno: %d)",
|
||||
|
||||
ERR("ibv_post_send failed after %d retries: %s (errno: %d)",
|
||||
retries, strerror(errno), errno);
|
||||
return false;
|
||||
}
|
||||
@@ -1053,4 +1053,4 @@ bool RDMACommunicator::execute_read_verification(struct RdmaContext* ctx,
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
* @brief Logging module implementation for key-value cache system
|
||||
* @version 1.0.0
|
||||
* @copyright Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
@@ -134,7 +134,7 @@ void debug_init() {
|
||||
buffer[len++] = '\n';
|
||||
if (global_error_file != NULL) {
|
||||
fwrite(buffer, 1, len, global_error_file);
|
||||
}
|
||||
}
|
||||
}
|
||||
__atomic_store_n(&global_debug_level, tempg_kv_cache_debug_level, __ATOMIC_RELEASE);
|
||||
pthread_mutex_unlock(&global_debug_lock);
|
||||
|
||||
Reference in New Issue
Block a user