diff --git a/easytier/docs/relay_peer_manager_design.md b/easytier/docs/relay_peer_manager_design.md new file mode 100644 index 00000000..3be4fb0d --- /dev/null +++ b/easytier/docs/relay_peer_manager_design.md @@ -0,0 +1,177 @@ +# Relay Peer 管理模块设计文档 + +## 背景与现状 + +当前出站转发路径中,PeerManager 根据路由直接选择下一跳并发送,转发路径以“取下一跳 → 发送”为核心流程: + +- 发送内部路径:[peer_manager.rs:L1053-L1082](file:///data/project/EasyTier/easytier/src/peers/peer_manager.rs#L1053-L1082) +- 数据面发送入口:[peer_manager.rs:L1187-L1238](file:///data/project/EasyTier/easytier/src/peers/peer_manager.rs#L1187-L1238) + +现状缺少面向“非直连目标”的统一管理模块,无法对 Relay Peer 进行会话、状态与策略层面的治理。 + +## 设计目标 + +- 对非直连 Relay Peer 做生命周期管理 +- 提供统一的会话(如 PeerSession)与路径选择入口 +- 与现有路由模块解耦,只消费下一跳候选与路由变更信息 +- 不改变现有数据面主路径流程 + +## 架构设计 + +### 模块命名 + +**RelayPeerMap** + +### 引用关系 + +- **PeerManager**: 作为顶层协调者,同时持有 `Arc` 和 `Arc`。 +- **RelayPeerMap**: 持有 `Arc`(或 `Weak`),用于在决策后调用底层发送能力。 +- **PeerMap**: 专注直连 Peer 管理与基础路由表维护,不直接持有 RelayPeerMap(避免循环依赖)。 + +### 职责划分 + +- **PeerManager**: + - 发送入口。 + - 判断目标是否直连: + - 若目标在 PeerMap:直接调用 `PeerMap` 发送。 + - 若目标不在 PeerMap:调用 `RelayPeerMap` 处理。 +- **RelayPeerMap**: + - 维护非直连 Peer 的状态(会话、健康度)。 + - 决策下一跳(Next Hop)。 + - 调用 `PeerMap` 将数据包发送给下一跳。 +- **ForeignNetworkManager**: + - 拥有独立的 RelayPeerMap 实例,用于 foreign network 的非直连转发。 +- **PeerMap**: + - 维护直连 Peer 连接。 + - 提供基础路由表查询。 + - 执行向直连邻居的物理发送。 + +## 数据模型 + +### RelayPeerKey + +- **dst_peer_id** (PeerId) +- 注:RelayPeerMap 实例隶属于特定网络上下文,因此 Key 仅需 PeerId。 + +### RelayPeerState + +- selected_next_hop: PeerId +- session: Option +- last_active_at: Instant +- path_metrics: latency, loss, hop_count (可选) + +### RelayPathCandidate + +- next_hop_peer_id +- cost / latency / availability + +## 简化状态管理 + +不再引入复杂状态机(如 Establishing/Suspect 等),仅依赖以下状态判断: + +- **会话是否存在**:`session.is_some()` +- **会话是否有效**:检查 session 过期时间或 generation +- **路由是否可达**:检查路由表中是否有 next hop + +## 关键流程 + +### 出站发送流程(非直连) + +1. **PeerManager** 接收发送请求(目标 `dst_peer_id`)。 +2. **PeerManager** 检查 `PeerMap` 是否直连 `dst_peer_id`。 +3. 若非直连,**PeerManager** 将请求转交给 **RelayPeerMap**。 +4. **RelayPeerMap** 处理: + - 查找 `RelayPeerState`。 + - 若首次与该 Relay Peer 通信,创建 RelayPeerState 并进入握手流程。 + - 确保会话存在(若无则触发握手与同步)。 + - 选择下一跳(由 RelayPeerMap 决策)。 + - 调用 **PeerMap** 的 `send_msg_directly(next_hop, packet)`。 + +### Relay 数据面握手出站流程(Relay Peer 特例) + +说明:Relay Peer 初次通信前必须先完成基于数据面消息的 Noise 握手,否则无法安全发送加密数据面包。握手消息通过普通数据面路径转发,但其目标是创建会话而非携带业务数据。 + +流程要点(发起方视角): + +1. 发送路径命中 `dst_peer_id` 为非直连目标后,进入 RelayPeerMap 流程。 +2. 若目标会话不存在或已失效,则发送 **RelayHandshake** 消息(携带 `m1`),通过 `send_msg_directly(next_hop, packet)` 转发给对端。 +3. 对端收到后返回 **RelayHandshakeAck**(携带 `m2`)沿原路径回传,双方派生会话并落库。 +4. 握手完成后,使用已建立会话的密钥对数据面包加密/鉴别,再走正常转发流程。 +5. 若握手失败或控制面公钥信息缺失,则不进入数据发送,返回可重试的错误(由上层决定重试节奏)。 + +### Relay 会话建立流程(数据面 + Noise 1-RTT) + +背景:直连 Peer 的 Noise 握手在 `PeerConn` 内完成;Relay Peer 没有 `PeerConn`,因此无法复用该握手逻辑。Relay 会话需要通过 **数据面握手消息** 完成握手与密钥派生,并把结果落到 `PeerSessionStore`(或等价的会话存储)中供数据面复用。 + +关键假设:Relay Peer 握手前即可拿到对端静态公钥(通过 OSPF 等控制面传播),因此可选用 **1-RTT 的 Noise 握手模式**(例如 IK/KK 一类的两报文握手),并将“两报文”映射为 **RelayHandshake / RelayHandshakeAck** 两种数据面消息。 + +建议流程(以本端作为 initiator 为例): + +1. `ensure_session(dst_peer_id)` 发现无可用会话,触发一次握手流程(可选:对并发请求做 in-flight 去重)。 +2. 从控制面缓存中读取 `dst_peer_id` 的静态公钥(若不存在则等待控制面收敛,或退化为非 1-RTT 的握手模式)。 +3. 生成 Noise 握手首报文 `m1`(包含必要的认证信息与抗重放字段,例如 session generation / nonce / 时间窗等)。 +4. 发送 `RelayHandshake(m1)`,对端返回 `RelayHandshakeAck(m2)`。 +5. initiator 处理 `m2`,双方派生出相同的会话密钥与会话标识,将会话写入 `PeerSessionStore`,供后续发送复用。 +6. 后续 Relay 数据面包使用该会话密钥进行加解密/鉴别(具体包格式不在本层定义,保持与直连会话的语义一致)。 + +实现要点: + +- **角色确定**:为避免并发双向握手导致的竞态,可使用确定性规则选择 initiator(如 `min(peer_id)` 发起),或由第一次发送方发起并在冲突时做幂等合并。 +- **幂等与重试**:数据面握手应支持重试(同一 generation/nonce 重放可安全拒绝或复用),并与路由收敛解耦。 +- **会话绑定**:握手需绑定 `dst_peer_id` 与其静态公钥指纹,避免控制面短暂不一致造成的密钥混用。 + +### 会话管理 + +- PeerSessionStore 仅用于 secure mode,会话创建与密钥派生在该模式下生效。 +- 在发送时若发现无会话,则触发 Create/Join/Sync 逻辑。 +- 对于 Relay Peer,会话创建阶段由 **数据面握手消息承载 Noise 握手**(见上节),以替代直连 `PeerConn` 内的握手流程。 + +### PacketType 规划(新增) + +- 新增 PacketType: + - `RelayHandshake`:承载 `m1`(initiator -> responder) + - `RelayHandshakeAck`:承载 `m2`(responder -> initiator) +- 载荷建议: + - `RelayHandshake`: `RelayNoiseMsg1Pb`(包含 a_session_generation/conn_id/算法等字段) + - `RelayHandshakeAck`: `RelayNoiseMsg2Pb`(包含 b_session_generation/root_key/initial_epoch/算法等字段) +- 约束: + - 两类包应与普通 Data 包一样可被转发,但不应被当作业务数据消费。 + - 需要在路由转发链路中识别为“握手控制类”消息。 + +## 策略设计 + +- 下一跳策略由 RelayPeerMap 决策,可结合 latency_first 选择 LeastHop 或 LatencyFirst。 +- 握手策略:优先采用“已知对端静态公钥”的 **1-RTT Noise 握手**,并通过 **RelayHandshake/RelayHandshakeAck** 消息触发会话建立。 +- 失败处理:依赖上层重试或底层路由收敛,暂不在此层做复杂的 Failover 状态流转。 +- 公钥来源:对端静态公钥以控制面传播为准;在控制面信息缺失或变更时,应阻止复用旧会话或触发重新握手。 + +## 接口草案 + +### RelayPeerMap 接口 + +- `send_msg(packet, dst_peer_id)`: 处理非直连发送逻辑。 +- `ensure_session(dst_peer_id)`: 确保会话可用。 +- `handshake_session(dst_peer_id)`: 通过握手消息完成 Relay 会话握手(对上层透明,可由 `ensure_session` 内部调用)。 +- `remove_peer(dst_peer_id)`: 删除已经失效的 Peer。 +## 监控与指标建议 + +- Relay 会话数 +- Relay 发送成功/失败计数 + +## 渐进式落地计划 + +### 阶段 1:基础能力 + +- 引入 RelayPeerMap 结构。 +- 在 PeerManager 中集成 RelayPeerMap。 +- 实现基础的“非直连转发”委托逻辑。 + +## 兼容性说明 + +- 需要新增 PacketType 用于 RelayHandshake/RelayHandshakeAck。 +- 在 secure mode 下,压缩由 PeerManager 完成;加密由 PeerConn(直连)或 RelayPeer(非直连)完成。 +- RelayPeer 在 secure mode 下需要提供会话级加密/解密入口: + - 发送:在 RelayPeerMap 决策完成后、调用 `send_msg_directly` 前,用 Relay 会话密钥加密。 + - 接收:在数据面包进入业务处理前,按 `from_peer_id/to_peer_id` 定位会话并解密。 +- PeerSessionStore 为 secure mode 的会话兼容性保留,非 secure mode 仅保持现有行为。 +- 不改变路由模块的计算结果。 diff --git a/easytier/src/common/network.rs b/easytier/src/common/network.rs index 78bd4a5a..3343b182 100644 --- a/easytier/src/common/network.rs +++ b/easytier/src/common/network.rs @@ -18,7 +18,8 @@ struct InterfaceFilter { #[cfg(any( target_os = "android", - any(target_os = "ios", feature = "macos-ne"), + target_os = "ios", + all(target_os = "macos", feature = "macos-ne"), target_env = "ohos" ))] impl InterfaceFilter { diff --git a/easytier/src/common/stun.rs b/easytier/src/common/stun.rs index cfe5ec0b..61988ffa 100644 --- a/easytier/src/common/stun.rs +++ b/easytier/src/common/stun.rs @@ -25,6 +25,25 @@ use crate::common::error::Error; use super::dns::resolve_txt_record; use super::stun_codec_ext::*; +const DEFAULT_UDP_STUN_SERVERS: &[&str] = &[ + "txt:stun.easytier.cn", + "stun.miwifi.com", + "stun.chat.bilibili.com", + "stun.hitv.com", +]; + +const DEFAULT_TCP_STUN_SERVERS: &[&str] = &[ + "stun.hot-chilli.net", + "stun.fitauto.ru", + "fwa.lifesizecloud.com", + "global.turn.twilio.com", + "turn.cloudflare.com", + "stun.voip.blackberry.com", + "stun.radiojar.com", +]; + +const DEFAULT_UDP_V6_STUN_SERVERS: &[&str] = &["txt:stun-v6.easytier.cn"]; + struct HostResolverIter { hostnames: Vec, ips: Vec, @@ -1100,39 +1119,39 @@ impl StunInfoCollector { } pub fn get_default_servers() -> Vec { - // NOTICE: we may need to choose stun server based on geolocation - // stun server cross nation may return an external ip address with high latency and loss rate - [ - "txt:stun.easytier.cn", - "stun.miwifi.com", - "stun.chat.bilibili.com", - "stun.hitv.com", - ] - .iter() - .map(|x| x.to_string()) - .collect() + if cfg!(test) { + Vec::new() + } else { + // NOTICE: we may need to choose stun server based on geolocation + // stun server cross nation may return an external ip address with high latency and loss rate + DEFAULT_UDP_STUN_SERVERS + .iter() + .map(ToString::to_string) + .collect() + } } pub fn get_default_tcp_servers() -> Vec { - [ - "stun.hot-chilli.net", - "stun.fitauto.ru", - "fwa.lifesizecloud.com", - "global.turn.twilio.com", - "turn.cloudflare.com", - "stun.voip.blackberry.com", - "stun.radiojar.com", - ] - .iter() - .map(|x| x.to_string()) - .collect() + // if test, return empty vector + if cfg!(test) { + Vec::new() + } else { + DEFAULT_TCP_STUN_SERVERS + .iter() + .map(ToString::to_string) + .collect() + } } pub fn get_default_servers_v6() -> Vec { - ["txt:stun-v6.easytier.cn"] - .iter() - .map(|x| x.to_string()) - .collect() + if cfg!(test) { + Vec::new() + } else { + DEFAULT_UDP_V6_STUN_SERVERS + .iter() + .map(ToString::to_string) + .collect() + } } async fn get_public_ipv6(servers: &[String]) -> Option { @@ -1328,7 +1347,14 @@ mod tests { #[tokio::test] async fn test_udp_nat_type_detector() { - let collector = StunInfoCollector::new_with_default_servers(); + let collector = StunInfoCollector::new( + DEFAULT_UDP_STUN_SERVERS + .iter() + .map(ToString::to_string) + .collect(), + vec![], + vec![], + ); collector.update_stun_info(); loop { let ret = collector.get_stun_info(); diff --git a/easytier/src/peers/foreign_network_manager.rs b/easytier/src/peers/foreign_network_manager.rs index 3feb65a2..0ace0f55 100644 --- a/easytier/src/peers/foreign_network_manager.rs +++ b/easytier/src/peers/foreign_network_manager.rs @@ -47,7 +47,9 @@ use super::{ peer_ospf_route::PeerRoute, peer_rpc::{PeerRpcManager, PeerRpcManagerTransport}, peer_rpc_service::DirectConnectorManagerRpcServer, + peer_session::PeerSessionStore, recv_packet_from_chan, + relay_peer_map::RelayPeerMap, route_trait::NextHopPolicy, PacketRecvChan, PacketRecvChanReceiver, PUBLIC_SERVER_HOSTNAME_PREFIX, }; @@ -64,6 +66,8 @@ struct ForeignNetworkEntry { global_ctx: ArcGlobalCtx, network: NetworkIdentity, peer_map: Arc, + relay_peer_map: Arc, + peer_session_store: Arc, relay_data: bool, pm_packet_sender: Mutex>, @@ -103,6 +107,13 @@ impl ForeignNetworkEntry { foreign_global_ctx.clone(), my_peer_id, )); + let peer_session_store = Arc::new(PeerSessionStore::new()); + let relay_peer_map = RelayPeerMap::new( + peer_map.clone(), + foreign_global_ctx.clone(), + my_peer_id, + peer_session_store.clone(), + ); let (peer_rpc, rpc_transport_sender) = Self::build_rpc_tspt(my_peer_id, peer_map.clone()); @@ -136,6 +147,8 @@ impl ForeignNetworkEntry { global_ctx: foreign_global_ctx, network, peer_map, + relay_peer_map, + peer_session_store, relay_data, pm_packet_sender: Mutex::new(Some(pm_packet_sender)), @@ -314,6 +327,7 @@ impl ForeignNetworkEntry { let my_node_id = self.my_peer_id; let rpc_sender = self.rpc_sender.clone(); let peer_map = self.peer_map.clone(); + let relay_peer_map = self.relay_peer_map.clone(); let relay_data = self.relay_data; let pm_sender = self.pm_packet_sender.lock().await.take().unwrap(); let network_name = self.network.network_name.clone(); @@ -344,6 +358,12 @@ impl ForeignNetworkEntry { tracing::trace!(?hdr, "recv packet in foreign network manager"); let to_peer_id = hdr.to_peer_id.get(); if to_peer_id == my_node_id { + if hdr.packet_type == PacketType::RelayHandshake as u8 + || hdr.packet_type == PacketType::RelayHandshakeAck as u8 + { + let _ = relay_peer_map.handle_handshake_packet(zc_packet).await; + continue; + } if hdr.packet_type == PacketType::TaRpc as u8 || hdr.packet_type == PacketType::RpcReq as u8 || hdr.packet_type == PacketType::RpcResp as u8 @@ -358,6 +378,8 @@ impl ForeignNetworkEntry { if hdr.packet_type == PacketType::Data as u8 || hdr.packet_type == PacketType::KcpSrc as u8 || hdr.packet_type == PacketType::KcpDst as u8 + || hdr.packet_type == PacketType::RelayHandshake as u8 + || hdr.packet_type == PacketType::RelayHandshakeAck as u8 { if !relay_data { continue; @@ -405,9 +427,31 @@ impl ForeignNetworkEntry { }); } + async fn run_relay_session_gc_routine(&self) { + let relay_peer_map = self.relay_peer_map.clone(); + self.tasks.lock().await.spawn(async move { + loop { + relay_peer_map.evict_idle_sessions(std::time::Duration::from_secs(60)); + tokio::time::sleep(std::time::Duration::from_secs(30)).await; + } + }); + } + + async fn run_peer_session_gc_routine(&self) { + let peer_session_store = self.peer_session_store.clone(); + self.tasks.lock().await.spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + peer_session_store.evict_unused_sessions(); + } + }); + } + async fn prepare(&self, accessor: Box) { self.prepare_route(accessor).await; self.start_packet_recv().await; + self.run_relay_session_gc_routine().await; + self.run_peer_session_gc_routine().await; self.peer_rpc.run(); self.peer_center.init().await; } @@ -734,7 +778,7 @@ impl ForeignNetworkManager { ) -> Result<(), Error> { if let Some(entry) = self.data.get_network_entry(network_name) { entry - .peer_map + .relay_peer_map .send_msg(msg, dst_peer_id, NextHopPolicy::LeastHop) .await } else { diff --git a/easytier/src/peers/mod.rs b/easytier/src/peers/mod.rs index 9b3b0ee0..5e30a04a 100644 --- a/easytier/src/peers/mod.rs +++ b/easytier/src/peers/mod.rs @@ -10,6 +10,7 @@ pub mod peer_ospf_route; pub mod peer_rpc; pub mod peer_rpc_service; pub mod peer_session; +pub mod relay_peer_map; pub mod route_trait; pub mod rpc_service; diff --git a/easytier/src/peers/peer_conn.rs b/easytier/src/peers/peer_conn.rs index 5213d036..4b91fbc9 100644 --- a/easytier/src/peers/peer_conn.rs +++ b/easytier/src/peers/peer_conn.rs @@ -138,6 +138,8 @@ impl PeerSessionTunnelFilter { hdr.packet_type == PacketType::NoiseHandshakeMsg1 as u8 || hdr.packet_type == PacketType::NoiseHandshakeMsg2 as u8 || hdr.packet_type == PacketType::NoiseHandshakeMsg3 as u8 + || hdr.packet_type == PacketType::RelayHandshake as u8 + || hdr.packet_type == PacketType::RelayHandshakeAck as u8 || hdr.packet_type == PacketType::Ping as u8 || hdr.packet_type == PacketType::Pong as u8 } @@ -169,9 +171,19 @@ impl TunnelFilter for PeerSessionTunnelFilter { }; let my_peer_id = self.my_peer_id.load(); - session - .encrypt_payload(my_peer_id, peer_id, &mut data) - .ok()?; + if my_peer_id != hdr.from_peer_id.get() { + return Some(data); + } + + if let Err(e) = session.encrypt_payload(my_peer_id, peer_id, &mut data) { + tracing::warn!( + ?my_peer_id, + ?peer_id, + ?e, + "PeerSessionTunnelFilter: encrypt failed, dropping packet" + ); + return None; + } Some(data) } @@ -198,7 +210,14 @@ impl TunnelFilter for PeerSessionTunnelFilter { if from_peer_id == 0 { return Some(Ok(data)); } - self.peer_id.store(Some(from_peer_id)); + + let Some(peer_id) = self.peer_id.load() else { + return Some(Ok(data)); + }; + + if from_peer_id != peer_id { + return Some(Ok(data)); + } let mut guard = self.session.lock().unwrap(); let Some(session) = guard.as_mut() else { @@ -206,7 +225,22 @@ impl TunnelFilter for PeerSessionTunnelFilter { }; let my_peer_id = self.my_peer_id.load(); - let _ = session.decrypt_payload(from_peer_id, my_peer_id, &mut data); + if hdr.to_peer_id.get() != my_peer_id { + return Some(Ok(data)); + } + + if let Err(e) = session.decrypt_payload(from_peer_id, my_peer_id, &mut data) { + if !session.is_valid() { + // Session auto-invalidated after too many consecutive failures. + // Close the connection to trigger reconnection with a fresh handshake. + tracing::error!(?e, "session invalidated, closing connection"); + return Some(Err(TunnelError::InternalError( + "session invalidated due to consecutive decrypt failures".to_string(), + ))); + } + // Transient failure, drop this packet but keep the connection alive. + return None; + } Some(Ok(data)) } @@ -775,6 +809,13 @@ impl PeerConn { .get_remote_static() .map(|x: &[u8]| x.to_vec()) .unwrap_or_default(); + let remote_static_key = if remote_static.len() == 32 { + let mut key = [0u8; 32]; + key.copy_from_slice(&remote_static); + Some(key) + } else { + None + }; if let Some(pinned) = pinned_remote_pubkey.as_ref() { if pinned.as_slice() == remote_static.as_slice() { @@ -812,6 +853,7 @@ impl PeerConn { msg2_pb.initial_epoch, algo, msg2_pb.server_encryption_algorithm.clone(), + remote_static_key, )?; Ok(NoiseHandshakeResult { @@ -949,6 +991,7 @@ impl PeerConn { msg1_pb.a_session_generation, algo.clone(), msg1_pb.client_encryption_algorithm.clone(), + None, )?; let b_conn_id = uuid::Uuid::new_v4(); @@ -1022,6 +1065,14 @@ impl PeerConn { .get_remote_static() .map(|x: &[u8]| x.to_vec()) .unwrap_or_default(); + let remote_static_key = if remote_static.len() == 32 { + let mut key = [0u8; 32]; + key.copy_from_slice(&remote_static); + Some(key) + } else { + None + }; + session.check_or_set_peer_static_pubkey(remote_static_key)?; let handshake_hash = hs.get_handshake_hash().to_vec(); diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index 59672db9..5e960d84 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -62,6 +62,7 @@ use super::{ peer_map::PeerMap, peer_ospf_route::PeerRoute, peer_rpc::PeerRpcManager, + relay_peer_map::RelayPeerMap, route_trait::{ArcRoute, Route}, BoxNicPacketFilter, BoxPeerPacketFilter, PacketRecvChan, PacketRecvChanReceiver, }; @@ -76,6 +77,7 @@ struct RpcTransport { peer_rpc_tspt_sender: UnboundedSender, encryptor: Arc, + is_secure_mode_enabled: bool, } #[async_trait::async_trait] @@ -93,7 +95,7 @@ impl PeerRpcManagerTransport for RpcTransport { .and_then(|x| x.feature_flag.map(|x| x.is_public_server)) // if dst is directly connected, it's must not public server .unwrap_or(!peers.has_peer(dst_peer_id)); - if !is_dst_peer_public_server { + if !is_dst_peer_public_server && !self.is_secure_mode_enabled { self.encryptor .encrypt(&mut msg) .with_context(|| "encrypt failed")?; @@ -150,6 +152,7 @@ pub struct PeerManager { foreign_network_manager: Arc, foreign_network_client: Arc, + relay_peer_map: Arc, encryptor: Arc, data_compress_algo: CompressorAlgo, @@ -163,6 +166,7 @@ pub struct PeerManager { self_tx_counters: SelfTxCounters, peer_session_store: Arc, + is_secure_mode_enabled: bool, } impl Debug for PeerManager { @@ -189,6 +193,13 @@ impl PeerManager { global_ctx.clone(), my_peer_id, )); + let peer_session_store = Arc::new(PeerSessionStore::new()); + let relay_peer_map = RelayPeerMap::new( + peers.clone(), + global_ctx.clone(), + my_peer_id, + peer_session_store.clone(), + ); let encryptor = if global_ctx.get_flags().enable_encryption { // 只有在启用加密时才使用工厂函数选择算法 @@ -213,6 +224,12 @@ impl PeerManager { global_ctx.set_feature_flags(f); } + let is_secure_mode_enabled = global_ctx + .config + .get_secure_mode() + .map(|cfg| cfg.enabled) + .unwrap_or(false); + // TODO: remove these because we have impl pipeline processor. let (peer_rpc_tspt_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel(); let rpc_tspt = Arc::new(RpcTransport { @@ -222,6 +239,7 @@ impl PeerManager { packet_recv: Mutex::new(peer_rpc_tspt_recv), peer_rpc_tspt_sender, encryptor: encryptor.clone(), + is_secure_mode_enabled, }); let peer_rpc_mgr = Arc::new(PeerRpcManager::new_with_stats_manager( rpc_tspt.clone(), @@ -304,6 +322,7 @@ impl PeerManager { foreign_network_manager, foreign_network_client, + relay_peer_map, encryptor, data_compress_algo, @@ -316,7 +335,8 @@ impl PeerManager { self_tx_counters, - peer_session_store: Arc::new(PeerSessionStore::new()), + peer_session_store, + is_secure_mode_enabled, } } @@ -645,11 +665,13 @@ impl PeerManager { let peers = self.peers.clone(); let pipe_line = self.peer_packet_process_pipeline.clone(); let foreign_client = self.foreign_network_client.clone(); + let relay_peer_map = self.relay_peer_map.clone(); let foreign_mgr = self.foreign_network_manager.clone(); let encryptor = self.encryptor.clone(); let compress_algo = self.data_compress_algo; let acl_filter = self.global_ctx.get_acl_filter().clone(); let global_ctx = self.global_ctx.clone(); + let secure_mode_enabled = self.is_secure_mode_enabled; let stats_mgr = self.global_ctx.stats_manager().clone(); let route = self.get_route(); @@ -713,9 +735,13 @@ impl PeerManager { || hdr.packet_type == PacketType::KcpSrc as u8 || hdr.packet_type == PacketType::KcpDst as u8 { - let _ = - Self::try_compress_and_encrypt(compress_algo, &encryptor, &mut ret) - .await; + let _ = Self::try_compress_and_encrypt( + compress_algo, + &encryptor, + &mut ret, + secure_mode_enabled, + ) + .await; } compress_tx_bytes_after.add(ret.buf_len() as u64); @@ -727,16 +753,42 @@ impl PeerManager { } tracing::trace!(?to_peer_id, ?my_peer_id, "need forward"); - let ret = - Self::send_msg_internal(&peers, &foreign_client, ret, to_peer_id).await; + let ret = Self::send_msg_internal( + &peers, + &foreign_client, + &relay_peer_map, + ret, + to_peer_id, + ) + .await; if ret.is_err() { tracing::error!(?ret, ?to_peer_id, ?from_peer_id, "forward packet error"); } } else { - if let Err(e) = encryptor.decrypt(&mut ret) { - tracing::error!(?e, "decrypt failed"); + if hdr.packet_type == PacketType::RelayHandshake as u8 + || hdr.packet_type == PacketType::RelayHandshakeAck as u8 + { + let _ = relay_peer_map.handle_handshake_packet(ret).await; continue; } + if !secure_mode_enabled { + if let Err(e) = encryptor.decrypt(&mut ret) { + tracing::error!(?e, "decrypt failed"); + continue; + } + } else if !peers.has_peer(from_peer_id) { + match relay_peer_map.decrypt_if_needed(&mut ret) { + Ok(true) => {} + Ok(false) => { + tracing::error!("relay session not found"); + continue; + } + Err(e) => { + tracing::error!(?e, "relay decrypt failed"); + continue; + } + } + } self_rx_bytes.add(buf_len as u64); self_rx_packets.inc(); @@ -1033,16 +1085,27 @@ impl PeerManager { .compress_tx_bytes_before .add(msg.buf_len() as u64); - Self::try_compress_and_encrypt(self.data_compress_algo, &self.encryptor, &mut msg).await?; + Self::try_compress_and_encrypt( + self.data_compress_algo, + &self.encryptor, + &mut msg, + self.is_secure_mode_enabled, + ) + .await?; self.self_tx_counters .compress_tx_bytes_after .add(msg.buf_len() as u64); let msg_len = msg.buf_len() as u64; - let result = - Self::send_msg_internal(&self.peers, &self.foreign_network_client, msg, dst_peer_id) - .await; + let result = Self::send_msg_internal( + &self.peers, + &self.foreign_network_client, + &self.relay_peer_map, + msg, + dst_peer_id, + ) + .await; if result.is_ok() { self.self_tx_counters.self_tx_bytes.add(msg_len); self.self_tx_counters.self_tx_packets.inc(); @@ -1053,15 +1116,20 @@ impl PeerManager { async fn send_msg_internal( peers: &Arc, foreign_network_client: &Arc, + relay_peer_map: &Arc, msg: ZCPacket, dst_peer_id: PeerId, ) -> Result<(), Error> { let policy = Self::get_next_hop_policy(msg.peer_manager_header().unwrap().is_latency_first()); + if peers.has_peer(dst_peer_id) { + return peers.send_msg_directly(msg, dst_peer_id).await; + } + if let Some(gateway) = peers.get_gateway_peer_id(dst_peer_id, policy.clone()).await { if peers.has_peer(gateway) { - peers.send_msg_directly(msg, gateway).await + relay_peer_map.send_msg(msg, dst_peer_id, policy).await } else if foreign_network_client.has_next_hop(gateway) { foreign_network_client.send_msg(msg, gateway).await } else { @@ -1174,13 +1242,16 @@ impl PeerManager { compress_algo: CompressorAlgo, encryptor: &Arc, msg: &mut ZCPacket, + secure_mode_enabled: bool, ) -> Result<(), Error> { let compressor = DefaultCompressor {}; compressor .compress(msg, compress_algo) .await .with_context(|| "compress failed")?; - encryptor.encrypt(msg).with_context(|| "encrypt failed")?; + if !secure_mode_enabled { + encryptor.encrypt(msg).with_context(|| "encrypt failed")?; + } Ok(()) } @@ -1209,6 +1280,7 @@ impl PeerManager { return Self::send_msg_internal( &self.peers, &self.foreign_network_client, + &self.relay_peer_map, msg, cur_to_peer_id, ) @@ -1229,7 +1301,13 @@ impl PeerManager { .compress_tx_bytes_before .add(msg.buf_len() as u64); - Self::try_compress_and_encrypt(self.data_compress_algo, &self.encryptor, &mut msg).await?; + Self::try_compress_and_encrypt( + self.data_compress_algo, + &self.encryptor, + &mut msg, + self.is_secure_mode_enabled, + ) + .await?; self.self_tx_counters .compress_tx_bytes_after @@ -1273,9 +1351,14 @@ impl PeerManager { .add(msg.buf_len() as u64); self.self_tx_counters.self_tx_packets.inc(); - if let Err(e) = - Self::send_msg_internal(&self.peers, &self.foreign_network_client, msg, *peer_id) - .await + if let Err(e) = Self::send_msg_internal( + &self.peers, + &self.foreign_network_client, + &self.relay_peer_map, + msg, + *peer_id, + ) + .await { errs.push(e); } @@ -1301,6 +1384,26 @@ impl PeerManager { }); } + async fn run_relay_session_gc_routine(&self) { + let relay_peer_map = self.relay_peer_map.clone(); + self.tasks.lock().await.spawn(async move { + loop { + relay_peer_map.evict_idle_sessions(std::time::Duration::from_secs(60)); + tokio::time::sleep(std::time::Duration::from_secs(30)).await; + } + }); + } + + async fn run_peer_session_gc_routine(&self) { + let peer_session_store = self.peer_session_store.clone(); + self.tasks.lock().await.spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + peer_session_store.evict_unused_sessions(); + } + }); + } + async fn run_foriegn_network(&self) { self.peer_rpc_tspt .foreign_peers @@ -1322,6 +1425,8 @@ impl PeerManager { self.start_peer_recv().await; self.run_clean_peer_without_conn_routine().await; + self.run_relay_session_gc_routine().await; + self.run_peer_session_gc_routine().await; self.run_foriegn_network().await; @@ -1332,10 +1437,18 @@ impl PeerManager { self.peers.clone() } + pub fn get_relay_peer_map(&self) -> Arc { + self.relay_peer_map.clone() + } + pub fn get_peer_rpc_mgr(&self) -> Arc { self.peer_rpc_mgr.clone() } + pub fn get_peer_session_store(&self) -> Arc { + self.peer_session_store.clone() + } + pub fn my_node_id(&self) -> uuid::Uuid { self.global_ctx.get_id() } diff --git a/easytier/src/peers/peer_ospf_route.rs b/easytier/src/peers/peer_ospf_route.rs index bdefcbf3..03c5ce7b 100644 --- a/easytier/src/peers/peer_ospf_route.rs +++ b/easytier/src/peers/peer_ospf_route.rs @@ -146,6 +146,7 @@ impl RoutePeerInfo { groups: Vec::new(), quic_port: None, + noise_static_pubkey: Vec::new(), } } @@ -164,6 +165,12 @@ impl RoutePeerInfo { global_ctx: &ArcGlobalCtx, ) -> Self { let stun_info = global_ctx.get_stun_info_collector().get_stun_info(); + let noise_static_pubkey = global_ctx + .config + .get_secure_mode() + .and_then(|cfg| cfg.public_key().ok()) + .map(|pk| pk.as_bytes().to_vec()) + .unwrap_or_default(); Self { peer_id: my_peer_id, inst_id: Some(global_ctx.get_id().into()), @@ -197,6 +204,8 @@ impl RoutePeerInfo { groups: global_ctx.get_acl_groups(my_peer_id), + noise_static_pubkey, + ..Default::default() } } @@ -1842,15 +1851,6 @@ impl PeerRouteServiceImpl { if let Some(last_update) = peer_info.last_update { let last_update = TryInto::::try_into(last_update).unwrap(); if last_sync_succ_timestamp.is_some_and(|t| last_update < t) { - tracing::debug!( - "ignore peer_info {:?} because last_update: {:?} is older than last_sync_succ_timestamp: {:?}, peer_infos_count: {}, my_peer_id: {:?}, session: {:?}", - peer_info, - last_update, - last_sync_succ_timestamp, - peer_infos.len(), - self.my_peer_id, - session - ); break; } } @@ -2556,6 +2556,7 @@ impl RouteSessionManager { continue; }; session.update_initiator_flag(true); + self.sync_now("update_initiator_flag"); } // clear sessions that are neither dst_initiator or we_are_initiator. diff --git a/easytier/src/peers/peer_session.rs b/easytier/src/peers/peer_session.rs index fc766cad..f22ffcb3 100644 --- a/easytier/src/peers/peer_session.rs +++ b/easytier/src/peers/peer_session.rs @@ -1,6 +1,6 @@ use std::{ sync::{ - atomic::{AtomicU32, Ordering}, + atomic::{AtomicBool, AtomicU32, Ordering}, Arc, Mutex, RwLock, }, time::{SystemTime, UNIX_EPOCH}, @@ -70,7 +70,29 @@ impl PeerSessionStore { } pub fn get(&self, key: &SessionKey) -> Option> { - self.sessions.get(key).map(|v| v.clone()) + let session = self.sessions.get(key)?.clone(); + if session.is_valid() { + Some(session) + } else { + self.sessions.remove(key); + None + } + } + + pub fn remove(&self, key: &SessionKey) { + self.sessions.remove(key); + } + + pub fn insert_session(&self, key: SessionKey, session: Arc) { + self.sessions.insert(key, session); + } + + /// Remove sessions that are no longer referenced by any PeerConn or RelayPeerMap. + /// A session with strong_count == 1 means only the store holds it — no active + /// connection is using it, so it can be safely cleaned up. + pub fn evict_unused_sessions(&self) { + self.sessions + .retain(|_key, session| Arc::strong_count(session) > 1); } pub fn upsert_responder_session( @@ -79,8 +101,13 @@ impl PeerSessionStore { a_session_generation: Option, send_algorithm: String, recv_algorithm: String, + peer_static_pubkey: Option<[u8; 32]>, ) -> Result { - let existing = self.sessions.get(key).map(|v| v.clone()); + let existing = self + .sessions + .get(key) + .map(|v| v.clone()) + .filter(|s| s.is_valid()); match existing { None => { let root_key = PeerSession::new_root_key(); @@ -93,6 +120,7 @@ impl PeerSessionStore { initial_epoch, send_algorithm, recv_algorithm, + peer_static_pubkey, )); self.sessions.insert(key.clone(), session.clone()); Ok(UpsertResponderSessionReturn { @@ -105,6 +133,7 @@ impl PeerSessionStore { } Some(session) => { session.check_encrypt_algo_same(&send_algorithm, &recv_algorithm)?; + session.check_or_set_peer_static_pubkey(peer_static_pubkey)?; let local_gen = session.session_generation(); if a_session_generation.is_some_and(|g| g == local_gen) { Ok(UpsertResponderSessionReturn { @@ -139,6 +168,7 @@ impl PeerSessionStore { initial_epoch: u32, send_algorithm: String, recv_algorithm: String, + peer_static_pubkey: Option<[u8; 32]>, ) -> Result, anyhow::Error> { tracing::info!( "apply_initiator_action {:?}, send_algorithm: {}, recv_algorithm: {}", @@ -152,6 +182,7 @@ impl PeerSessionStore { return Err(anyhow!("no local session for JOIN")); }; session.check_encrypt_algo_same(&send_algorithm, &recv_algorithm)?; + session.check_or_set_peer_static_pubkey(peer_static_pubkey)?; if session.session_generation() != b_session_generation { return Err(anyhow!("JOIN generation mismatch")); } @@ -159,6 +190,13 @@ impl PeerSessionStore { } PeerSessionAction::Sync | PeerSessionAction::Create => { let root_key = root_key_32.ok_or_else(|| anyhow!("missing root_key"))?; + // If the existing session is invalidated, remove it so we create a fresh one + if let Some(existing) = self.sessions.get(key) { + if !existing.is_valid() { + drop(existing); + self.sessions.remove(key); + } + } let session = self .sessions .entry(key.clone()) @@ -170,10 +208,12 @@ impl PeerSessionStore { initial_epoch, send_algorithm.clone(), recv_algorithm.clone(), + peer_static_pubkey, )) }) .clone(); session.check_encrypt_algo_same(&send_algorithm, &recv_algorithm)?; + session.check_or_set_peer_static_pubkey(peer_static_pubkey)?; session.sync_root_key(root_key, b_session_generation, initial_epoch); Ok(session) } @@ -318,6 +358,7 @@ pub struct PeerSession { peer_id: PeerId, root_key: RwLock<[u8; 32]>, session_generation: AtomicU32, + peer_static_pubkey: RwLock>, send_epoch: AtomicU32, send_seq: [AtomicU64; 2], @@ -329,6 +370,12 @@ pub struct PeerSession { send_cipher_algorithm: String, recv_cipher_algorithm: String, + + /// Set to true when the session is detected as corrupted (persistent decrypt failures). + /// Holders of Arc can check this to know the session should be discarded. + invalidated: AtomicBool, + /// Consecutive decrypt failure counter. Auto-invalidates when threshold is reached. + decrypt_fail_count: AtomicU32, } impl std::fmt::Debug for PeerSession { @@ -337,6 +384,7 @@ impl std::fmt::Debug for PeerSession { .field("peer_id", &self.peer_id) .field("root_key", &self.root_key) .field("session_generation", &self.session_generation) + .field("peer_static_pubkey", &self.peer_static_pubkey) .field("send_epoch", &self.send_epoch) .field("send_seq", &self.send_seq) .field("send_epoch_started_ms", &self.send_epoch_started_ms) @@ -381,6 +429,7 @@ impl PeerSession { /// stricter security requirements may decrease it. const ROTATE_AFTER_MS: u64 = 10 * 60 * 1000; const MAX_ACCEPTED_RX_EPOCH_AHEAD: u32 = 3; + const DECRYPT_FAIL_THRESHOLD: u32 = 10; pub fn new( peer_id: PeerId, @@ -389,11 +438,8 @@ impl PeerSession { initial_epoch: u32, send_cipher_algorithm: String, recv_cipher_algorithm: String, + peer_static_pubkey: Option<[u8; 32]>, ) -> Self { - // let mut root_key_128 = [0u8; 16]; - // root_key_128.copy_from_slice(&root_key[..16]); - // let send_cipher = create_encryptor(&send_algorithm, root_key_128, root_key); - // let recv_cipher = create_encryptor(&recv_algorithm, root_key_128, root_key); let rx_slots = [ [EpochRxSlot::default(), EpochRxSlot::default()], [EpochRxSlot::default(), EpochRxSlot::default()], @@ -407,6 +453,7 @@ impl PeerSession { peer_id, root_key: RwLock::new(root_key), session_generation: AtomicU32::new(session_generation), + peer_static_pubkey: RwLock::new(peer_static_pubkey), send_epoch: AtomicU32::new(initial_epoch), send_seq: [AtomicU64::new(0), AtomicU64::new(0)], send_epoch_started_ms: AtomicU64::new(now_ms), @@ -415,6 +462,8 @@ impl PeerSession { key_cache: Mutex::new(key_cache), send_cipher_algorithm, recv_cipher_algorithm, + invalidated: AtomicBool::new(false), + decrypt_fail_count: AtomicU32::new(0), } } @@ -422,6 +471,15 @@ impl PeerSession { self.peer_id } + /// Mark this session as invalid. All holders of Arc will see this. + pub fn invalidate(&self) { + self.invalidated.store(true, Ordering::Relaxed); + } + + pub fn is_valid(&self) -> bool { + !self.invalidated.load(Ordering::Relaxed) + } + pub fn session_generation(&self) -> u32 { self.session_generation.load(Ordering::Relaxed) } @@ -466,6 +524,24 @@ impl PeerSession { Ok(()) } + pub fn check_or_set_peer_static_pubkey( + &self, + peer_static_pubkey: Option<[u8; 32]>, + ) -> Result<(), anyhow::Error> { + let Some(peer_static_pubkey) = peer_static_pubkey else { + return Ok(()); + }; + let mut guard = self.peer_static_pubkey.write().unwrap(); + if let Some(existing) = *guard { + if existing != peer_static_pubkey { + return Err(anyhow!("peer static pubkey mismatch")); + } + return Ok(()); + } + *guard = Some(peer_static_pubkey); + Ok(()) + } + pub fn sync_root_key(&self, root_key: [u8; 32], session_generation: u32, initial_epoch: u32) { { let mut g = self.root_key.write().unwrap(); @@ -703,6 +779,9 @@ impl PeerSession { receiver_peer_id: PeerId, pkt: &mut ZCPacket, ) -> Result<(), anyhow::Error> { + if !self.is_valid() { + return Err(anyhow!("session invalidated")); + } let dir = Self::dir_for_sender(sender_peer_id, receiver_peer_id); let (epoch, _seq, nonce_bytes) = self.next_nonce(dir); let encryptor = self @@ -718,6 +797,9 @@ impl PeerSession { receiver_peer_id: PeerId, ciphertext_with_tail: &mut ZCPacket, ) -> Result<(), anyhow::Error> { + if !self.is_valid() { + return Err(anyhow!("session invalidated")); + } let dir = Self::dir_for_sender(sender_peer_id, receiver_peer_id); let nonce_bytes = Self::parse_tail(ciphertext_with_tail.payload()).ok_or_else(|| anyhow!("no tail"))?; @@ -732,7 +814,19 @@ impl PeerSession { let encryptor = self .get_encryptor(epoch, dir, false) .ok_or_else(|| anyhow!("no key for epoch"))?; - encryptor.decrypt(ciphertext_with_tail)?; + if let Err(e) = encryptor.decrypt(ciphertext_with_tail) { + let count = self.decrypt_fail_count.fetch_add(1, Ordering::Relaxed) + 1; + if count >= Self::DECRYPT_FAIL_THRESHOLD { + self.invalidate(); + tracing::warn!( + peer_id = ?self.peer_id, + count, + "session auto-invalidated after consecutive decrypt failures" + ); + } + return Err(e.into()); + } + self.decrypt_fail_count.store(0, Ordering::Relaxed); Ok(()) } @@ -764,6 +858,7 @@ mod tests { initial_epoch, "aes-256-gcm".to_string(), "chacha20-poly1305".to_string(), + None, ); let sb = PeerSession::new( a, @@ -772,6 +867,7 @@ mod tests { initial_epoch, "chacha20-poly1305".to_string(), "aes-256-gcm".to_string(), + None, ); let plaintext1 = b"hello from a"; @@ -802,6 +898,7 @@ mod tests { initial_epoch, "aes-256-gcm".to_string(), "aes-256-gcm".to_string(), + None, ); let now = now_ms(); diff --git a/easytier/src/peers/relay_peer_map.rs b/easytier/src/peers/relay_peer_map.rs new file mode 100644 index 00000000..be8e274e --- /dev/null +++ b/easytier/src/peers/relay_peer_map.rs @@ -0,0 +1,692 @@ +use std::{sync::Arc, time::Instant}; + +use dashmap::DashMap; +use hmac::Mac; +use prost::Message; +use snow::params::NoiseParams; +use tokio::sync::{oneshot, Mutex, OwnedMutexGuard}; +use tokio::time::{timeout, Duration}; + +use crate::{ + common::error::Error, + common::{global_ctx::ArcGlobalCtx, PeerId}, + peers::peer_map::PeerMap, + peers::peer_session::{PeerSession, PeerSessionAction, PeerSessionStore, SessionKey}, + peers::route_trait::NextHopPolicy, + proto::peer_rpc::{PeerConnSessionActionPb, RelayNoiseMsg1Pb, RelayNoiseMsg2Pb}, + tunnel::packet_def::{PacketType, ZCPacket}, +}; + +const RELAY_NOISE_VERSION: u32 = 1; +const RELAY_NOISE_PROLOGUE: &[u8] = b"easytier-relay-noise"; +const HANDSHAKE_TIMEOUT_SECS: u64 = 5; +const HANDSHAKE_RETRY_BASE_MS: u64 = 200; +const HANDSHAKE_MAX_ATTEMPTS: u32 = 3; +const MAX_PENDING_PACKETS_PER_PEER: usize = 32; + +#[derive(Clone)] +pub struct RelayPeerState { + pub last_active_at: Instant, + pub failure_count: u32, + pub next_retry_at: Option, +} + +impl Default for RelayPeerState { + fn default() -> Self { + Self { + last_active_at: Instant::now(), + failure_count: 0, + next_retry_at: None, + } + } +} + +pub struct RelayPeerMap { + peer_map: Arc, + global_ctx: ArcGlobalCtx, + my_peer_id: PeerId, + peer_session_store: Arc, + states: DashMap, + pending_handshakes: DashMap>, + handshake_locks: DashMap>>, + pub(crate) pending_packets: DashMap>, +} + +impl RelayPeerMap { + pub fn new( + peer_map: Arc, + global_ctx: ArcGlobalCtx, + my_peer_id: PeerId, + peer_session_store: Arc, + ) -> Arc { + Arc::new(Self { + peer_map, + global_ctx, + my_peer_id, + peer_session_store, + states: DashMap::new(), + pending_handshakes: DashMap::new(), + handshake_locks: DashMap::new(), + pending_packets: DashMap::new(), + }) + } + + fn is_secure_mode_enabled(&self) -> bool { + self.global_ctx + .config + .get_secure_mode() + .map(|cfg| cfg.enabled) + .unwrap_or(false) + } + + fn get_local_keypair(&self) -> Result<(Vec, Vec), Error> { + let cfg = self + .global_ctx + .config + .get_secure_mode() + .ok_or_else(|| Error::RouteError(Some("secure mode config not set".to_string())))?; + let private = cfg + .private_key() + .map_err(|e| Error::RouteError(Some(format!("invalid private key: {e:?}"))))?; + let public = cfg + .public_key() + .map_err(|e| Error::RouteError(Some(format!("invalid public key: {e:?}"))))?; + Ok((private.as_bytes().to_vec(), public.as_bytes().to_vec())) + } + + async fn get_remote_static_pubkey(&self, peer_id: PeerId) -> Result, Error> { + let info = self + .peer_map + .get_route_peer_info(peer_id) + .await + .ok_or_else(|| Error::RouteError(Some("route peer info not found".to_string())))?; + if info.noise_static_pubkey.is_empty() { + return Err(Error::RouteError(Some( + "remote static pubkey not found".to_string(), + ))); + } + Ok(info.noise_static_pubkey) + } + + fn get_handshake_lock(&self, peer_id: PeerId) -> Arc> { + self.handshake_locks + .entry(peer_id) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone() + } + + async fn send_handshake_packet( + &self, + payload: Vec, + packet_type: PacketType, + dst_peer_id: PeerId, + policy: NextHopPolicy, + ) -> Result<(), Error> { + let mut pkt = ZCPacket::new_with_payload(&payload); + pkt.fill_peer_manager_hdr(self.my_peer_id, dst_peer_id, packet_type as u8); + self.send_via_next_hop(pkt, dst_peer_id, policy).await + } + + async fn send_via_next_hop( + &self, + msg: ZCPacket, + dst_peer_id: PeerId, + policy: NextHopPolicy, + ) -> Result<(), Error> { + let Some(next_hop) = self.peer_map.get_gateway_peer_id(dst_peer_id, policy).await else { + return Err(Error::RouteError(None)); + }; + if !self.peer_map.has_peer(next_hop) { + return Err(Error::RouteError(None)); + } + self.peer_map.send_msg_directly(msg, next_hop).await + } + + pub async fn send_msg( + self: &Arc, + mut msg: ZCPacket, + dst_peer_id: PeerId, + policy: NextHopPolicy, + ) -> Result<(), Error> { + let now = Instant::now(); + + self.states.entry(dst_peer_id).or_default().last_active_at = now; + + if self.is_secure_mode_enabled() { + match self.ensure_session(dst_peer_id, policy.clone()).await { + Ok(session) => { + let my_peer_id = self.my_peer_id; + session + .encrypt_payload(my_peer_id, dst_peer_id, &mut msg) + .map_err(|e| Error::RouteError(Some(format!("{e:?}"))))?; + } + Err(_) => { + // Handshake in progress, buffer the packet instead of dropping it + self.buffer_pending_packet(dst_peer_id, msg, policy); + return Ok(()); + } + } + } + + self.send_via_next_hop(msg, dst_peer_id, policy).await + } + + fn buffer_pending_packet(&self, dst_peer_id: PeerId, pkt: ZCPacket, policy: NextHopPolicy) { + let mut entry = self.pending_packets.entry(dst_peer_id).or_default(); + if entry.len() < MAX_PENDING_PACKETS_PER_PEER { + entry.push((pkt, policy)); + } + // silently drop when buffer is full + } + + async fn flush_pending_packets(&self, dst_peer_id: PeerId, session: Arc) { + let packets = self.pending_packets.remove(&dst_peer_id).map(|(_, v)| v); + let Some(packets) = packets else { return }; + if packets.is_empty() { + return; + } + + tracing::debug!( + ?dst_peer_id, + count = packets.len(), + "flushing pending packets after relay handshake" + ); + + for (mut pkt, policy) in packets { + if session + .encrypt_payload(self.my_peer_id, dst_peer_id, &mut pkt) + .is_err() + { + continue; + } + let _ = self.send_via_next_hop(pkt, dst_peer_id, policy).await; + } + } + + pub fn has_session(&self, dst_peer_id: PeerId) -> bool { + self.peer_session_store + .get(&SessionKey::new( + self.global_ctx.get_network_identity().network_name.clone(), + dst_peer_id, + )) + .is_some() + } + + pub async fn ensure_session( + self: &Arc, + dst_peer_id: PeerId, + policy: NextHopPolicy, + ) -> Result, Error> { + let network = self.global_ctx.get_network_identity(); + let key = SessionKey::new(network.network_name.clone(), dst_peer_id); + if let Some(session) = self.peer_session_store.get(&key) { + return Ok(session); + } + + let lock = self.get_handshake_lock(dst_peer_id); + if let Ok(guard) = lock.try_lock_owned() { + let self_clone = self.clone(); + tokio::spawn(async move { + self_clone + .handshake_session(dst_peer_id, policy, Some(guard)) + .await + }); + }; + Err(Error::RouteError(Some( + "relay handshake in progress".to_string(), + ))) + } + + #[tracing::instrument(skip(self, _lock_guard), level = "debug", ret)] + pub async fn handshake_session( + &self, + dst_peer_id: PeerId, + policy: NextHopPolicy, + _lock_guard: Option>, + ) -> Result<(), Error> { + let network = self.global_ctx.get_network_identity(); + let key = SessionKey::new(network.network_name.clone(), dst_peer_id); + if let Some(session) = self.peer_session_store.get(&key) { + self.flush_pending_packets(dst_peer_id, session).await; + return Ok(()); + } + + if let Some(next_retry_at) = self.states.get(&dst_peer_id).and_then(|v| v.next_retry_at) { + if Instant::now() < next_retry_at { + self.pending_packets.remove(&dst_peer_id); + return Err(Error::RouteError(Some( + "relay handshake backoff".to_string(), + ))); + } + } + + let mut last_err = None; + for attempt in 0..HANDSHAKE_MAX_ATTEMPTS { + let ret = self + .handshake_session_once(dst_peer_id, policy.clone()) + .await; + match ret { + Ok(session) => { + self.register_handshake_success(dst_peer_id); + self.flush_pending_packets(dst_peer_id, session).await; + return Ok(()); + } + Err(e) => { + last_err = Some(e); + self.register_handshake_failure(dst_peer_id, attempt); + if attempt + 1 < HANDSHAKE_MAX_ATTEMPTS { + let backoff = HANDSHAKE_RETRY_BASE_MS.saturating_mul(1 << attempt); + tokio::time::sleep(Duration::from_millis(backoff)).await; + } + } + } + } + + // All attempts failed, drop buffered packets + self.pending_packets.remove(&dst_peer_id); + + Err(last_err + .unwrap_or_else(|| Error::RouteError(Some("relay handshake failed".to_string())))) + } + + #[tracing::instrument(skip(self), level = "debug", ret)] + async fn handshake_session_once( + &self, + dst_peer_id: PeerId, + policy: NextHopPolicy, + ) -> Result, Error> { + let network = self.global_ctx.get_network_identity(); + let session_key = SessionKey::new(network.network_name.clone(), dst_peer_id); + let (local_private_key, _local_public_key) = self.get_local_keypair()?; + let remote_static = self.get_remote_static_pubkey(dst_peer_id).await?; + let params: NoiseParams = "Noise_IK_25519_ChaChaPoly_SHA256" + .parse() + .map_err(|e| Error::RouteError(Some(format!("parse noise params failed: {e:?}"))))?; + + let builder = snow::Builder::new(params); + let mut hs = builder + .prologue(RELAY_NOISE_PROLOGUE) + .map_err(|e| Error::RouteError(Some(format!("set prologue failed: {e:?}"))))? + .local_private_key(&local_private_key) + .map_err(|e| Error::RouteError(Some(format!("set local key failed: {e:?}"))))? + .remote_public_key(&remote_static) + .map_err(|e| Error::RouteError(Some(format!("set remote key failed: {e:?}"))))? + .build_initiator() + .map_err(|e| Error::RouteError(Some(format!("build initiator failed: {e:?}"))))?; + + let a_session_generation = self + .peer_session_store + .get(&session_key) + .map(|s| s.session_generation()); + let a_conn_id = uuid::Uuid::new_v4(); + let msg1_pb = RelayNoiseMsg1Pb { + version: RELAY_NOISE_VERSION, + a_network_name: network.network_name.clone(), + a_session_generation, + a_conn_id: Some(a_conn_id.into()), + client_encryption_algorithm: self.global_ctx.get_flags().encryption_algorithm.clone(), + }; + let payload = msg1_pb.encode_to_vec(); + let mut out = vec![0u8; 4096]; + let out_len = hs + .write_message(&payload, &mut out) + .map_err(|e| Error::RouteError(Some(format!("noise write msg1 failed: {e:?}"))))?; + let server_handshake_hash = hs.get_handshake_hash().to_vec(); + let (tx, rx) = oneshot::channel(); + self.pending_handshakes.insert(dst_peer_id, tx); + + let send_res = self + .send_handshake_packet( + out[..out_len].to_vec(), + PacketType::RelayHandshake, + dst_peer_id, + policy, + ) + .await; + + if send_res.is_err() { + self.pending_handshakes.remove(&dst_peer_id); + } + send_res?; + let msg2_pkt = match timeout(Duration::from_secs(HANDSHAKE_TIMEOUT_SECS), rx).await { + Ok(Ok(pkt)) => pkt, + Ok(Err(_)) => { + self.pending_handshakes.remove(&dst_peer_id); + return Err(Error::RouteError(Some( + "relay handshake canceled".to_string(), + ))); + } + Err(_) => { + self.pending_handshakes.remove(&dst_peer_id); + return Err(Error::RouteError(Some( + "relay handshake timeout".to_string(), + ))); + } + }; + + let msg2_pb = self.decode_handshake_message::( + PacketType::RelayHandshakeAck, + &mut hs, + msg2_pkt, + )?; + if msg2_pb.a_conn_id_echo != Some(a_conn_id.into()) { + return Err(Error::RouteError(Some( + "relay msg2 conn_id_echo mismatch".to_string(), + ))); + } + if msg2_pb.b_network_name == network.network_name { + if msg2_pb.role_hint != 1 { + return Err(Error::RouteError(Some( + "role_hint must be 1 when network_name is same".to_string(), + ))); + } + let Some(secret_proof_32) = msg2_pb.secret_proof_32 else { + return Err(Error::RouteError(Some( + "secret_proof_32 must be present when role_hint is 1".to_string(), + ))); + }; + let verify_result = self + .global_ctx + .get_secret_proof(&server_handshake_hash) + .map(|mac| mac.verify_slice(&secret_proof_32).is_ok()); + if verify_result != Some(true) { + return Err(Error::RouteError(Some( + "secret_proof_32 verify failed".to_string(), + ))); + } + } + + let action = PeerConnSessionActionPb::try_from(msg2_pb.action) + .map_err(|_| Error::RouteError(Some("invalid session action".to_string())))?; + let session_action = match action { + PeerConnSessionActionPb::Join => PeerSessionAction::Join, + PeerConnSessionActionPb::Sync => PeerSessionAction::Sync, + PeerConnSessionActionPb::Create => PeerSessionAction::Create, + }; + let remote_static_key = if remote_static.len() == 32 { + let mut key = [0u8; 32]; + key.copy_from_slice(&remote_static); + Some(key) + } else { + None + }; + let root_key_bytes = msg2_pb + .root_key_32 + .as_deref() + .filter(|v| v.len() == 32) + .map(|v| { + let mut key_bytes = [0u8; 32]; + key_bytes.copy_from_slice(v); + key_bytes + }); + let algo = self.global_ctx.get_flags().encryption_algorithm.clone(); + let session = self + .peer_session_store + .apply_initiator_action( + &session_key, + session_action, + msg2_pb.b_session_generation, + root_key_bytes, + msg2_pb.initial_epoch, + algo, + msg2_pb.server_encryption_algorithm.clone(), + remote_static_key, + ) + .map_err(|e| Error::RouteError(Some(format!("{e:?}"))))?; + + Ok(session) + } + + fn register_handshake_success(&self, dst_peer_id: PeerId) { + let mut entry = self.states.entry(dst_peer_id).or_default(); + entry.failure_count = 0; + entry.next_retry_at = None; + } + + fn register_handshake_failure(&self, dst_peer_id: PeerId, attempt: u32) { + let mut entry = self.states.entry(dst_peer_id).or_default(); + entry.failure_count = entry.failure_count.saturating_add(1); + let backoff = HANDSHAKE_RETRY_BASE_MS.saturating_mul(1 << attempt); + entry.next_retry_at = Some(Instant::now() + Duration::from_millis(backoff)); + } + + fn decode_handshake_message( + &self, + expected_type: PacketType, + hs: &mut snow::HandshakeState, + pkt: ZCPacket, + ) -> Result { + let hdr = pkt.peer_manager_header().ok_or_else(|| { + Error::RouteError(Some("packet without peer manager header".to_string())) + })?; + if hdr.packet_type != expected_type as u8 { + return Err(Error::RouteError(Some("packet type mismatch".to_string()))); + } + let mut out = vec![0u8; 4096]; + let out_len = hs + .read_message(pkt.payload(), &mut out) + .map_err(|e| Error::RouteError(Some(format!("noise read msg failed: {e:?}"))))?; + let msg = MsgT::decode(&out[..out_len]) + .map_err(|e| Error::RouteError(Some(format!("decode message failed: {e:?}"))))?; + Ok(msg) + } + + pub async fn handle_handshake_packet(&self, packet: ZCPacket) -> Result<(), Error> { + let hdr = packet + .peer_manager_header() + .ok_or_else(|| Error::RouteError(Some("packet without header".to_string())))?; + let src_peer_id = hdr.from_peer_id.get(); + match hdr.packet_type { + x if x == PacketType::RelayHandshake as u8 => { + tracing::debug!("handle_relay_msg1 from {:?}", src_peer_id); + self.handle_relay_msg1(packet, src_peer_id).await + } + x if x == PacketType::RelayHandshakeAck as u8 => { + if let Some((_, sender)) = self.pending_handshakes.remove(&src_peer_id) { + let _ = sender.send(packet); + } + Ok(()) + } + _ => Ok(()), + } + } + + async fn handle_relay_msg1(&self, msg1: ZCPacket, remote_peer_id: PeerId) -> Result<(), Error> { + // Check for bidirectional handshake race condition. + // If we are also waiting for a RelayHandshakeAck from this peer, + // use deterministic rule: the peer with smaller peer_id becomes initiator. + if self.pending_handshakes.contains_key(&remote_peer_id) { + // We have a pending handshake as initiator. + // If remote_peer_id < my_peer_id, remote should be initiator, we should be responder. + // Cancel our pending handshake and proceed as responder. + if remote_peer_id < self.my_peer_id { + tracing::debug!( + ?remote_peer_id, + my_peer_id = ?self.my_peer_id, + "bidirectional handshake race: yielding initiator role to smaller peer_id" + ); + // Remove our pending handshake + self.pending_handshakes.remove(&remote_peer_id); + } else { + // We have smaller peer_id, we should remain initiator. + // Ignore this RelayHandshake and let our initiator flow complete. + tracing::debug!( + ?remote_peer_id, + my_peer_id = ?self.my_peer_id, + "bidirectional handshake race: keeping initiator role due to smaller peer_id" + ); + return Err(Error::RouteError(Some( + "bidirectional handshake race: we are initiator".to_string(), + ))); + } + } + + let (local_private_key, _local_public_key) = self.get_local_keypair()?; + let params: NoiseParams = "Noise_IK_25519_ChaChaPoly_SHA256" + .parse() + .map_err(|e| Error::RouteError(Some(format!("parse noise params failed: {e:?}"))))?; + let builder = snow::Builder::new(params); + let mut hs = builder + .prologue(RELAY_NOISE_PROLOGUE) + .map_err(|e| Error::RouteError(Some(format!("set prologue failed: {e:?}"))))? + .local_private_key(&local_private_key) + .map_err(|e| Error::RouteError(Some(format!("set local key failed: {e:?}"))))? + .build_responder() + .map_err(|e| Error::RouteError(Some(format!("build responder failed: {e:?}"))))?; + + let msg1_pb = self.decode_handshake_message::( + PacketType::RelayHandshake, + &mut hs, + msg1, + )?; + let remote_network_name = msg1_pb.a_network_name.clone(); + let remote_static = hs + .get_remote_static() + .map(|x: &[u8]| x.to_vec()) + .unwrap_or_default(); + let remote_static_key = if remote_static.len() == 32 { + let mut key = [0u8; 32]; + key.copy_from_slice(&remote_static); + Some(key) + } else { + None + }; + + // Verify initiator's static public key matches the expected key from route info + let expected_pubkey = self.get_remote_static_pubkey(remote_peer_id).await?; + if remote_static != expected_pubkey { + return Err(Error::RouteError(Some(format!( + "responder: initiator static pubkey mismatch for peer {}, expected {} bytes, got {} bytes", + remote_peer_id, + expected_pubkey.len(), + remote_static.len() + )))); + } + + let server_network_name = self.global_ctx.get_network_name(); + let (role_hint, secret_proof_32) = if remote_network_name == server_network_name { + let proof = self + .global_ctx + .get_secret_proof(hs.get_handshake_hash()) + .map(|mac| mac.finalize().into_bytes().to_vec()); + (1, proof) + } else { + (2, None) + }; + + let algo = self.global_ctx.get_flags().encryption_algorithm.clone(); + let key = SessionKey::new(server_network_name.clone(), remote_peer_id); + let upsert = self + .peer_session_store + .upsert_responder_session( + &key, + msg1_pb.a_session_generation, + algo.clone(), + msg1_pb.client_encryption_algorithm.clone(), + remote_static_key, + ) + .map_err(|e| Error::RouteError(Some(format!("{e:?}"))))?; + let msg2_pb = RelayNoiseMsg2Pb { + b_network_name: server_network_name, + role_hint, + action: match upsert.action { + PeerSessionAction::Join => PeerConnSessionActionPb::Join as i32, + PeerSessionAction::Sync => PeerConnSessionActionPb::Sync as i32, + PeerSessionAction::Create => PeerConnSessionActionPb::Create as i32, + }, + b_session_generation: upsert.session_generation, + root_key_32: upsert.root_key.map(|k| k.to_vec()), + initial_epoch: upsert.initial_epoch, + b_conn_id: Some(uuid::Uuid::new_v4().into()), + a_conn_id_echo: msg1_pb.a_conn_id, + secret_proof_32, + server_encryption_algorithm: algo, + }; + let payload = msg2_pb.encode_to_vec(); + let mut out = vec![0u8; 4096]; + let out_len = hs + .write_message(&payload, &mut out) + .map_err(|e| Error::RouteError(Some(format!("noise write msg2 failed: {e:?}"))))?; + + self.register_handshake_success(remote_peer_id); + + self.send_handshake_packet( + out[..out_len].to_vec(), + PacketType::RelayHandshakeAck, + remote_peer_id, + NextHopPolicy::LeastHop, + ) + .await?; + + // Flush any packets buffered while waiting for the handshake to complete + self.flush_pending_packets(remote_peer_id, upsert.session) + .await; + + Ok(()) + } + + pub fn decrypt_if_needed(&self, packet: &mut ZCPacket) -> Result { + if !self.is_secure_mode_enabled() { + return Ok(false); + } + let hdr = packet + .peer_manager_header() + .ok_or_else(|| Error::RouteError(Some("packet without header".to_string())))?; + let from_peer_id = hdr.from_peer_id.get(); + let network = self.global_ctx.get_network_identity(); + let key = SessionKey::new(network.network_name.clone(), from_peer_id); + let Some(session) = self.peer_session_store.get(&key) else { + return Ok(false); + }; + let now = Instant::now(); + let mut entry = self.states.entry(from_peer_id).or_default(); + entry.last_active_at = now; + session.decrypt_payload(from_peer_id, self.my_peer_id, packet)?; + Ok(true) + } + + pub fn evict_idle_sessions(&self, idle: Duration) { + let now = Instant::now(); + let mut to_remove = Vec::new(); + for entry in self.states.iter() { + if now.duration_since(entry.last_active_at) > idle { + to_remove.push(*entry.key()); + } + } + for peer_id in to_remove { + self.states.remove(&peer_id); + self.pending_handshakes.remove(&peer_id); + self.handshake_locks.remove(&peer_id); + self.pending_packets.remove(&peer_id); + } + } + + pub fn has_state(&self, peer_id: PeerId) -> bool { + self.states.contains_key(&peer_id) + } + + pub fn failure_count(&self, peer_id: PeerId) -> Option { + self.states.get(&peer_id).map(|v| v.failure_count) + } + + pub fn is_backoff_active(&self, peer_id: PeerId) -> bool { + self.states + .get(&peer_id) + .and_then(|v| v.next_retry_at) + .is_some_and(|ts| Instant::now() < ts) + } + + /// Remove relay-specific state for a specific peer. + /// This does NOT remove the session from PeerSessionStore, because the + /// session lifecycle is independent of any particular connection type + /// (relay or direct). The session may still be used by direct connections + /// or for fast reconnection (Join instead of Create). + pub fn remove_peer(&self, peer_id: PeerId) { + self.states.remove(&peer_id); + self.pending_handshakes.remove(&peer_id); + self.handshake_locks.remove(&peer_id); + self.pending_packets.remove(&peer_id); + + tracing::debug!(?peer_id, "RelayPeerMap removed peer relay state"); + } +} diff --git a/easytier/src/peers/tests.rs b/easytier/src/peers/tests.rs index b09b400f..3bf690ce 100644 --- a/easytier/src/peers/tests.rs +++ b/easytier/src/peers/tests.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::time::Duration; use crate::{ common::{ @@ -9,12 +10,21 @@ use crate::{ }, PeerId, }, - tunnel::ring::create_ring_tunnel_pair, + tunnel::{ + common::tests::wait_for_condition, + packet_def::{PacketType, ZCPacket}, + ring::create_ring_tunnel_pair, + }, }; use super::{ create_packet_recv_chan, + peer_conn::tests::set_secure_mode_cfg, peer_manager::{PeerManager, RouteAlgoType}, + peer_map::PeerMap, + peer_session::{PeerSession, PeerSessionStore, SessionKey}, + relay_peer_map::RelayPeerMap, + route_trait::NextHopPolicy, }; pub async fn create_mock_peer_manager() -> Arc { @@ -37,6 +47,19 @@ pub async fn create_mock_peer_manager_with_name(network_name: String) -> Arc Arc { + let (s, _r) = create_packet_recv_chan(); + let g = + get_mock_global_ctx_with_network(Some(NetworkIdentity::new(network_name, network_secret))); + set_secure_mode_cfg(&g, true); + let peer_mgr = Arc::new(PeerManager::new(RouteAlgoType::Ospf, g, s)); + peer_mgr.run().await.unwrap(); + peer_mgr +} + pub async fn connect_peer_manager(client: Arc, server: Arc) { let (a_ring, b_ring) = create_ring_tunnel_pair(); let a_mgr_copy = client; @@ -127,3 +150,560 @@ async fn foreign_mgr_stress_test() { } } } + +#[tokio::test] +async fn relay_peer_map_secure_session_decrypt() { + let (s, _r) = create_packet_recv_chan(); + let ctx = get_mock_global_ctx_with_network(Some(NetworkIdentity::new( + "net1".to_string(), + "sec1".to_string(), + ))); + set_secure_mode_cfg(&ctx, true); + let peer_map = Arc::new(PeerMap::new(s, ctx.clone(), 10)); + let store = Arc::new(PeerSessionStore::new()); + let relay_map = RelayPeerMap::new(peer_map, ctx.clone(), 10, store.clone()); + + let algo = ctx.get_flags().encryption_algorithm.clone(); + let root_key = [7u8; 32]; + let session = Arc::new(PeerSession::new( + 20, + root_key, + 1, + 1, + algo.clone(), + algo.clone(), + None, + )); + let key = SessionKey::new(ctx.get_network_identity().network_name, 20); + store.insert_session(key.clone(), session.clone()); + + relay_map + .ensure_session(20, NextHopPolicy::LeastHop) + .await + .unwrap(); + assert!(relay_map.has_session(20)); + + let mut packet = ZCPacket::new_with_payload(b"relay-hello"); + packet.fill_peer_manager_hdr(20, 10, PacketType::Data as u8); + session.encrypt_payload(20, 10, &mut packet).unwrap(); + assert!(relay_map.decrypt_if_needed(&mut packet).unwrap()); + assert_eq!(packet.payload(), b"relay-hello"); +} + +#[tokio::test] +async fn relay_peer_map_retry_backoff_and_evict() { + let (s, _r) = create_packet_recv_chan(); + let ctx_secure = get_mock_global_ctx(); + set_secure_mode_cfg(&ctx_secure, true); + let peer_map = Arc::new(PeerMap::new(s, ctx_secure.clone(), 10)); + let relay_map = RelayPeerMap::new( + peer_map, + ctx_secure.clone(), + 10, + Arc::new(PeerSessionStore::new()), + ); + + let ret = relay_map + .handshake_session(20, NextHopPolicy::LeastHop, None) + .await; + assert!(ret.is_err()); + assert!(relay_map.failure_count(20).unwrap_or(0) >= 1); + assert!(relay_map.is_backoff_active(20)); + + let (s2, _r2) = create_packet_recv_chan(); + let ctx_plain = get_mock_global_ctx(); + let peer_map_plain = Arc::new(PeerMap::new(s2, ctx_plain.clone(), 30)); + let relay_map_plain = RelayPeerMap::new( + peer_map_plain, + ctx_plain.clone(), + 30, + Arc::new(PeerSessionStore::new()), + ); + + let mut pkt = ZCPacket::new_with_payload(b"evict"); + pkt.fill_peer_manager_hdr(30, 40, PacketType::Data as u8); + let _ = relay_map_plain + .send_msg(pkt, 40, NextHopPolicy::LeastHop) + .await; + assert!(relay_map_plain.has_state(40)); + relay_map_plain.evict_idle_sessions(Duration::from_millis(0)); + assert!(!relay_map_plain.has_state(40)); +} + +#[tokio::test] +async fn relay_peer_map_pending_packet_buffer() { + // Verify that packets sent during handshake are buffered (not dropped), + // and flushed after handshake completes. + let (s, _r) = create_packet_recv_chan(); + let ctx = get_mock_global_ctx_with_network(Some(NetworkIdentity::new( + "net1".to_string(), + "sec1".to_string(), + ))); + set_secure_mode_cfg(&ctx, true); + let peer_map = Arc::new(PeerMap::new(s, ctx.clone(), 10)); + let store = Arc::new(PeerSessionStore::new()); + let relay_map = RelayPeerMap::new(peer_map, ctx.clone(), 10, store.clone()); + + // Send multiple packets while no session exists (handshake will fail, but packets should be buffered) + for i in 0..5u8 { + let mut pkt = ZCPacket::new_with_payload(&[i]); + pkt.fill_peer_manager_hdr(10, 20, PacketType::Data as u8); + let _ = relay_map.send_msg(pkt, 20, NextHopPolicy::LeastHop).await; + } + + // Verify packets were buffered + assert_eq!( + relay_map + .pending_packets + .get(&20) + .map(|v| v.len()) + .unwrap_or(0), + 5, + "5 packets should be buffered during handshake" + ); + + // Verify buffer respects capacity limit + for i in 0..50u8 { + let mut pkt = ZCPacket::new_with_payload(&[i]); + pkt.fill_peer_manager_hdr(10, 20, PacketType::Data as u8); + let _ = relay_map.send_msg(pkt, 20, NextHopPolicy::LeastHop).await; + } + + let buffered = relay_map + .pending_packets + .get(&20) + .map(|v| v.len()) + .unwrap_or(0); + assert!( + buffered <= 32, + "buffer should not exceed MAX_PENDING_PACKETS_PER_PEER, got {buffered}" + ); + + // Verify remove_peer clears pending packets + relay_map.remove_peer(20); + assert_eq!( + relay_map + .pending_packets + .get(&20) + .map(|v| v.len()) + .unwrap_or(0), + 0, + "pending packets should be cleared on peer removal" + ); +} + +#[tokio::test] +async fn relay_peer_map_pending_packets_flushed_on_handshake_success() { + // Test that pending packets are flushed after handshake succeeds. + // We pre-populate the buffer, then run handshake, and verify it's cleared. + let peer_a = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_b = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_c = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + + connect_peer_manager(peer_a.clone(), peer_b.clone()).await; + connect_peer_manager(peer_b.clone(), peer_c.clone()).await; + + let peer_a_id = peer_a.my_peer_id(); + let peer_c_id = peer_c.my_peer_id(); + + // Wait for routes to propagate + wait_for_condition( + || { + let peer_a = peer_a.clone(); + let peer_c = peer_c.clone(); + async move { wait_route_appear(peer_a.clone(), peer_c).await.is_ok() } + }, + Duration::from_secs(10), + ) + .await; + + // Wait for noise_static_pubkey to be available on both sides + wait_for_condition( + || { + let peer_a = peer_a.clone(); + async move { + peer_a + .get_peer_map() + .get_route_peer_info(peer_c_id) + .await + .map(|info| !info.noise_static_pubkey.is_empty()) + .unwrap_or(false) + } + }, + Duration::from_secs(10), + ) + .await; + + let relay_a = peer_a.get_relay_peer_map(); + + // Pre-populate pending packets buffer (simulating what send_msg does during handshake) + for i in 0..3u8 { + let mut pkt = ZCPacket::new_with_payload(&[i]); + pkt.fill_peer_manager_hdr(peer_a_id, peer_c_id, PacketType::Data as u8); + relay_a + .pending_packets + .entry(peer_c_id) + .or_default() + .push((pkt, NextHopPolicy::LeastHop)); + } + + assert_eq!( + relay_a + .pending_packets + .get(&peer_c_id) + .map(|v| v.len()) + .unwrap_or(0), + 3, + "3 packets should be in the buffer" + ); + + // Run handshake — on success it should flush the buffer + relay_a + .handshake_session(peer_c_id, NextHopPolicy::LeastHop, None) + .await + .unwrap(); + + // Verify session established and buffer cleared + assert!(relay_a.has_session(peer_c_id)); + assert_eq!( + relay_a + .pending_packets + .get(&peer_c_id) + .map(|v| v.len()) + .unwrap_or(0), + 0, + "pending packets should be flushed after successful handshake" + ); +} + +#[tokio::test] +async fn relay_peer_map_real_link_handshake_success() { + let peer_a = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_b = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_c = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + + connect_peer_manager(peer_a.clone(), peer_b.clone()).await; + connect_peer_manager(peer_b.clone(), peer_c.clone()).await; + + let peer_a_id = peer_a.my_peer_id(); + let peer_b_id = peer_b.my_peer_id(); + let peer_c_id = peer_c.my_peer_id(); + + wait_for_condition( + || { + let peer_a = peer_a.clone(); + let peer_c = peer_c.clone(); + async move { wait_route_appear(peer_a.clone(), peer_c).await.is_ok() } + }, + Duration::from_secs(10), + ) + .await; + + wait_for_condition( + || { + let peer_a = peer_a.clone(); + async move { + peer_a + .get_peer_map() + .get_gateway_peer_id(peer_c_id, NextHopPolicy::LeastHop) + .await + == Some(peer_b_id) + } + }, + Duration::from_secs(5), + ) + .await; + + wait_for_condition( + || { + let peer_a = peer_a.clone(); + async move { + peer_a + .get_peer_map() + .get_route_peer_info(peer_c_id) + .await + .map(|info| !info.noise_static_pubkey.is_empty()) + .unwrap_or(false) + } + }, + Duration::from_secs(10), + ) + .await; + + let relay_a = peer_a.get_relay_peer_map(); + let relay_c = peer_c.get_relay_peer_map(); + + relay_a + .handshake_session(peer_c_id, NextHopPolicy::LeastHop, None) + .await + .unwrap(); + + wait_for_condition( + || { + let relay_a = relay_a.clone(); + async move { relay_a.has_session(peer_c_id) } + }, + Duration::from_secs(5), + ) + .await; + + wait_for_condition( + || { + let relay_c = relay_c.clone(); + async move { relay_c.has_session(peer_a_id) } + }, + Duration::from_secs(5), + ) + .await; +} + +#[tokio::test] +async fn relay_peer_map_responder_rejects_mismatched_pubkey() { + // Create three peers: A -> B -> C + let peer_a = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_b = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_c = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + + connect_peer_manager(peer_a.clone(), peer_b.clone()).await; + connect_peer_manager(peer_b.clone(), peer_c.clone()).await; + + let peer_a_id = peer_a.my_peer_id(); + let peer_c_id = peer_c.my_peer_id(); + + // Wait for routes to propagate + wait_for_condition( + || { + let peer_a = peer_a.clone(); + let peer_c = peer_c.clone(); + async move { wait_route_appear(peer_a.clone(), peer_c).await.is_ok() } + }, + Duration::from_secs(10), + ) + .await; + + // Wait for noise_static_pubkey to be available + wait_for_condition( + || { + let peer_a = peer_a.clone(); + async move { + peer_a + .get_peer_map() + .get_route_peer_info(peer_c_id) + .await + .map(|info| !info.noise_static_pubkey.is_empty()) + .unwrap_or(false) + } + }, + Duration::from_secs(10), + ) + .await; + + // Get the original correct pubkey to verify it exists + let original_info = peer_a + .get_peer_map() + .get_route_peer_info(peer_c_id) + .await + .expect("should have route info for peer_c"); + assert!( + !original_info.noise_static_pubkey.is_empty(), + "noise_static_pubkey should be present" + ); + + // Attempt handshake - this should succeed because pubkeys match + let relay_a = peer_a.get_relay_peer_map(); + let result = relay_a + .handshake_session(peer_c_id, NextHopPolicy::LeastHop, None) + .await; + + // The handshake should succeed because the pubkeys match + assert!( + result.is_ok(), + "handshake should succeed with matching pubkeys" + ); + + // Verify session was established on both sides + wait_for_condition( + || { + let relay_a = relay_a.clone(); + async move { relay_a.has_session(peer_c_id) } + }, + Duration::from_secs(5), + ) + .await; + + let relay_c = peer_c.get_relay_peer_map(); + wait_for_condition( + || { + let relay_c = relay_c.clone(); + async move { relay_c.has_session(peer_a_id) } + }, + Duration::from_secs(5), + ) + .await; +} + +#[tokio::test] +async fn relay_peer_map_remove_peer() { + let (s, _r) = create_packet_recv_chan(); + let ctx = get_mock_global_ctx_with_network(Some(NetworkIdentity::new( + "net1".to_string(), + "sec1".to_string(), + ))); + set_secure_mode_cfg(&ctx, true); + let peer_map = Arc::new(PeerMap::new(s, ctx.clone(), 10)); + let store = Arc::new(PeerSessionStore::new()); + let relay_map = RelayPeerMap::new(peer_map, ctx.clone(), 10, store.clone()); + + let peer_1: PeerId = 100; + + // Add session for peer_1 + let root_key = [1u8; 32]; + let session = Arc::new(PeerSession::new( + peer_1, + root_key, + 1, + 0, + "aes-256-gcm".to_string(), + "aes-256-gcm".to_string(), + None, + )); + let key = SessionKey::new(ctx.get_network_name(), peer_1); + store.insert_session(key.clone(), session); + + assert!(store.get(&key).is_some()); + + // Remove the peer relay state + relay_map.remove_peer(peer_1); + + // Session should still be in the store (lifecycle is independent of relay state) + assert!( + store.get(&key).is_some(), + "session should persist after relay peer removal" + ); +} + +/// Test bidirectional handshake race resolution. +/// When both peers simultaneously initiate handshake, the one with smaller peer_id +/// should become initiator, and the other should yield and become responder. +#[tokio::test] +async fn relay_peer_map_bidirectional_handshake_race() { + // Create three peers: A -> B -> C + let peer_a = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_b = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_c = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + + connect_peer_manager(peer_a.clone(), peer_b.clone()).await; + connect_peer_manager(peer_b.clone(), peer_c.clone()).await; + + let peer_a_id = peer_a.my_peer_id(); + let peer_c_id = peer_c.my_peer_id(); + + // Wait for routes to propagate + wait_for_condition( + || { + let peer_a = peer_a.clone(); + let peer_c = peer_c.clone(); + async move { wait_route_appear(peer_a.clone(), peer_c).await.is_ok() } + }, + Duration::from_secs(10), + ) + .await; + + // Wait for noise_static_pubkey to be available + wait_for_condition( + || { + let peer_a = peer_a.clone(); + async move { + peer_a + .get_peer_map() + .get_route_peer_info(peer_c_id) + .await + .map(|info| !info.noise_static_pubkey.is_empty()) + .unwrap_or(false) + } + }, + Duration::from_secs(10), + ) + .await; + + wait_for_condition( + || { + let peer_c = peer_c.clone(); + async move { + peer_c + .get_peer_map() + .get_route_peer_info(peer_a_id) + .await + .map(|info| !info.noise_static_pubkey.is_empty()) + .unwrap_or(false) + } + }, + Duration::from_secs(10), + ) + .await; + + // Simulate bidirectional handshake race by having both sides initiate simultaneously + let relay_a = peer_a.get_relay_peer_map(); + let relay_c = peer_c.get_relay_peer_map(); + + // Both sides initiate handshake at the same time + let handle_a = tokio::spawn({ + let relay_a = relay_a.clone(); + async move { + relay_a + .handshake_session(peer_c_id, NextHopPolicy::LeastHop, None) + .await + } + }); + + let handle_c = tokio::spawn({ + let relay_c = relay_c.clone(); + async move { + relay_c + .handshake_session(peer_a_id, NextHopPolicy::LeastHop, None) + .await + } + }); + + // Wait for both handshakes to complete + let (result_a, result_c) = tokio::join!(handle_a, handle_c); + + // At least one should succeed (the initiator with smaller peer_id) + // Both could succeed if race resolution worked correctly + tracing::info!( + ?peer_a_id, + ?peer_c_id, + ?result_a, + ?result_c, + "bidirectional handshake results" + ); + + // Wait for sessions to be established + wait_for_condition( + || { + let relay_a = relay_a.clone(); + async move { relay_a.has_session(peer_c_id) } + }, + Duration::from_secs(5), + ) + .await; + + wait_for_condition( + || { + let relay_c = relay_c.clone(); + async move { relay_c.has_session(peer_a_id) } + }, + Duration::from_secs(5), + ) + .await; + + // Both sides should have sessions after race resolution + assert!( + relay_a.has_session(peer_c_id), + "peer_a should have session with peer_c" + ); + assert!( + relay_c.has_session(peer_a_id), + "peer_c should have session with peer_a" + ); +} diff --git a/easytier/src/proto/peer_rpc.proto b/easytier/src/proto/peer_rpc.proto index 77861f74..9792dbb4 100644 --- a/easytier/src/proto/peer_rpc.proto +++ b/easytier/src/proto/peer_rpc.proto @@ -29,6 +29,7 @@ message RoutePeerInfo { repeated PeerGroupInfo groups = 16; common.NatType tcp_nat_type = 17; + bytes noise_static_pubkey = 18; } message PeerIdVersion { @@ -293,6 +294,27 @@ message PeerConnNoiseMsg2Pb { string server_encryption_algorithm = 10; } +message RelayNoiseMsg1Pb { + uint32 version = 1; + string a_network_name = 2; + optional uint32 a_session_generation = 3; + common.UUID a_conn_id = 4; + string client_encryption_algorithm = 5; +} + +message RelayNoiseMsg2Pb { + string b_network_name = 1; + uint32 role_hint = 2; + PeerConnSessionActionPb action = 3; + uint32 b_session_generation = 4; + optional bytes root_key_32 = 5; + uint32 initial_epoch = 6; + common.UUID b_conn_id = 7; + common.UUID a_conn_id_echo = 8; + optional bytes secret_proof_32 = 9; + string server_encryption_algorithm = 10; +} + message PeerConnNoiseMsg3Pb { common.UUID a_conn_id_echo = 1; common.UUID b_conn_id_echo = 2; diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index 1b014a00..aa8d4214 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -21,7 +21,10 @@ use crate::{ stats_manager::{LabelType, MetricName}, }, instance::instance::Instance, - proto::{api::instance::TcpProxyEntryTransportType, common::CompressionAlgoPb}, + proto::{ + api::instance::TcpProxyEntryTransportType, + common::{CompressionAlgoPb, SecureModeConfig}, + }, tunnel::{ common::tests::{_tunnel_bench_netns, wait_for_condition}, ring::RingTunnelConnector, @@ -2759,3 +2762,201 @@ pub async fn config_patch_test() { drop_insts(insts).await; } + +/// Generate SecureModeConfig with random x25519 keypair +fn generate_secure_mode_config() -> SecureModeConfig { + use base64::{prelude::BASE64_STANDARD, Engine}; + use rand::rngs::OsRng; + use x25519_dalek::{PublicKey, StaticSecret}; + + let private = StaticSecret::random_from_rng(OsRng); + let public = PublicKey::from(&private); + + SecureModeConfig { + enabled: true, + local_private_key: Some(BASE64_STANDARD.encode(private.as_bytes())), + local_public_key: Some(BASE64_STANDARD.encode(public.as_bytes())), + } +} +/// Test relay peer end-to-end encryption with TCP +#[rstest::rstest] +#[tokio::test] +#[serial_test::serial] +pub async fn relay_peer_e2e_encryption(#[values("tcp", "udp")] proto: &str) { + use crate::peers::route_trait::NextHopPolicy; + + let insts = init_three_node_ex( + proto, + |cfg| { + cfg.set_secure_mode(Some(generate_secure_mode_config())); + cfg + }, + false, + ) + .await; + + let inst1_peer_id = insts[0].peer_id(); + let inst2_peer_id = insts[1].peer_id(); + let inst3_peer_id = insts[2].peer_id(); + + println!( + "Test topology: inst1({}) <-> inst2({}) <-> inst3({})", + inst1_peer_id, inst2_peer_id, inst3_peer_id + ); + + // Check secure mode is enabled + let secure_mode_1 = insts[0].get_global_ctx().config.get_secure_mode(); + let secure_mode_2 = insts[1].get_global_ctx().config.get_secure_mode(); + let secure_mode_3 = insts[2].get_global_ctx().config.get_secure_mode(); + println!( + "Secure mode enabled: inst1={}, inst2={}, inst3={}", + secure_mode_1.is_some(), + secure_mode_2.is_some(), + secure_mode_3.is_some() + ); + + // Wait for routes to be established + wait_for_condition( + || async { + let routes = insts[0].get_peer_manager().list_routes().await; + routes.len() == 2 + }, + Duration::from_secs(10), + ) + .await; + + // Verify inst1 sees inst3 via inst2 (non-direct path) + let next_hop_to_inst3 = insts[0] + .get_peer_manager() + .get_peer_map() + .get_gateway_peer_id(inst3_peer_id, NextHopPolicy::LeastHop) + .await; + println!("Next hop from inst1 to inst3: {:?}", next_hop_to_inst3); + assert_eq!( + next_hop_to_inst3, + Some(inst2_peer_id), + "inst1 should reach inst3 via inst2 (relay)" + ); + + // Verify inst1 has no direct connection to inst3 + assert!( + !insts[0] + .get_peer_manager() + .get_peer_map() + .has_peer(inst3_peer_id), + "inst1 should NOT have direct connection to inst3" + ); + + // Check if noise_static_pubkey is available for relay handshake + let route_info_inst3 = insts[0] + .get_peer_manager() + .get_peer_map() + .get_route_peer_info(inst3_peer_id) + .await; + println!( + "Route info for inst3 on inst1: noise_static_pubkey len = {:?}", + route_info_inst3 + .as_ref() + .map(|i| i.noise_static_pubkey.len()) + ); + + // Test basic connectivity through relay + println!("Starting ping test from net_a to 10.144.144.3..."); + + assert!( + ping_test("net_a", "10.144.144.3", None).await, + "Ping from net_a to inst3 should succeed" + ); + + // Verify relay sessions are established + let relay_map_1 = insts[0].get_peer_manager().get_relay_peer_map(); + let relay_map_3 = insts[2].get_peer_manager().get_relay_peer_map(); + + println!( + "Relay states after ping: inst1->inst3: {}, inst3->inst1: {}", + relay_map_1.has_state(inst3_peer_id), + relay_map_3.has_state(inst1_peer_id) + ); + + // Test bidirectional connectivity + assert!( + ping_test("net_a", "10.144.144.3", None).await, + "Ping from net_a to inst3 should work" + ); + assert!( + ping_test("net_c", "10.144.144.1", None).await, + "Ping from net_c to inst1 should work" + ); + + println!("Test completed successfully!"); + drop_insts(insts).await; +} + +/// Test Relay Peer session cleanup on relay failure - TCP +#[tokio::test] +#[serial_test::serial] +pub async fn relay_peer_session_cleanup() { + use crate::peers::route_trait::NextHopPolicy; + + let mut insts = init_three_node_ex( + "tcp", + |cfg| { + cfg.set_secure_mode(Some(generate_secure_mode_config())); + cfg + }, + false, + ) + .await; + + let inst2_peer_id = insts[1].peer_id(); + let inst3_peer_id = insts[2].peer_id(); + let relay_map_1 = insts[0].get_peer_manager().get_relay_peer_map(); + + wait_for_condition( + || async { ping_test("net_a", "10.144.144.3", None).await }, + Duration::from_secs(6), + ) + .await; + + wait_for_condition( + || async { relay_map_1.has_state(inst3_peer_id) && relay_map_1.has_session(inst3_peer_id) }, + Duration::from_secs(3), + ) + .await; + + let next_hop = insts[0] + .get_peer_manager() + .get_peer_map() + .get_gateway_peer_id(inst3_peer_id, NextHopPolicy::LeastHop) + .await; + assert_eq!(next_hop, Some(inst2_peer_id)); + + let mut inst2 = insts.remove(1); + inst2.clear_resources().await; + drop(inst2); + + wait_for_condition( + || async { + let routes = insts[0].get_peer_manager().list_routes().await; + !routes.iter().any(|r| r.peer_id == inst3_peer_id) + }, + Duration::from_secs(6), + ) + .await; + + relay_map_1.evict_idle_sessions(Duration::from_millis(0)); + assert!(!relay_map_1.has_state(inst3_peer_id)); + + insts[0] + .get_peer_manager() + .get_peer_session_store() + .evict_unused_sessions(); + + wait_for_condition( + || async { !relay_map_1.has_session(inst3_peer_id) }, + Duration::from_secs(1), + ) + .await; + + drop_insts(insts).await; +} diff --git a/easytier/src/tunnel/packet_def.rs b/easytier/src/tunnel/packet_def.rs index d5a464f1..c79b6543 100644 --- a/easytier/src/tunnel/packet_def.rs +++ b/easytier/src/tunnel/packet_def.rs @@ -77,6 +77,8 @@ pub enum PacketType { NoiseHandshakeMsg1 = 13, NoiseHandshakeMsg2 = 14, NoiseHandshakeMsg3 = 15, + RelayHandshake = 20, + RelayHandshakeAck = 21, // used internally, DataWithKcpSrcModified = 18,