fix: refresh ACL groups and enable TCP_NODELAY for WebSocket (#2118)

* fix: refresh ACL groups and enable TCP_NODELAY for WebSocket
* add remove_peers to remove list of peer id in ospf route
* fix secure tunnel for unreliable udp tunnel
* fix(web-client): timeout secure tunnel handshake
* fix(web-server): tolerate delayed secure hello
* fix quic endpoint panic
* fix replay check
This commit is contained in:
KKRainbow
2026-04-19 10:37:39 +08:00
committed by GitHub
parent c49c56612b
commit 2db655bd6d
14 changed files with 7824 additions and 1038 deletions
Generated
+5772 -118
View File
File diff suppressed because it is too large Load Diff
+14 -7
View File
@@ -379,19 +379,26 @@ mod tests {
let req = tokio::time::timeout(Duration::from_secs(12), async {
loop {
let session = mgr
let sessions = mgr
.client_sessions
.iter()
.next()
.map(|item| item.value().clone());
let Some(session) = session else {
.map(|item| item.value().clone())
.collect::<Vec<_>>();
if sessions.is_empty() {
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
let mut waiter = session.data().read().await.heartbeat_waiter();
if let Ok(req) = waiter.recv().await {
}
let mut found_req = None;
for session in sessions {
if let Some(req) = session.data().read().await.req() {
found_req = Some(req);
break;
}
}
if let Some(req) = found_req {
break req;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
})
.await
+1 -1
View File
@@ -386,7 +386,7 @@ impl WebServerService for SessionRpcService {
_: easytier::proto::web::GetFeatureRequest,
) -> rpc_types::error::Result<easytier::proto::web::GetFeatureResponse> {
Ok(easytier::proto::web::GetFeatureResponse {
support_encryption: true,
support_encryption: easytier::web_client::security::web_secure_tunnel_supported(),
})
}
}
+5
View File
@@ -611,6 +611,11 @@ impl QuicStreamReceiver {
}
_ = self.tasks.join_next(), if !self.tasks.is_empty() => {}
else => {
info!("quic stream receiver endpoint closed, exiting");
break;
}
}
}
}
+4
View File
@@ -383,6 +383,10 @@ impl InstanceConfigPatcher {
global_ctx
.get_acl_filter()
.reload_rules(AclRuleBuilder::build(&global_ctx)?.as_ref());
weak_upgrade(&self.peer_manager)?
.get_route()
.refresh_acl_groups()
.await;
Ok(())
}
+1
View File
@@ -20,6 +20,7 @@ pub mod foreign_network_client;
pub mod foreign_network_manager;
pub mod encrypt;
pub(crate) mod secure_datagram;
pub mod peer_task;
+660 -74
View File
@@ -418,6 +418,32 @@ impl Debug for SyncedRouteInfo {
}
impl SyncedRouteInfo {
fn set_peer_groups(&self, peer_id: PeerId, groups: HashMap<String, Vec<u8>>) {
if groups.is_empty() {
self.group_trust_map.remove(&peer_id);
self.group_trust_map_cache.remove(&peer_id);
return;
}
let group_names = groups.keys().cloned().collect();
self.group_trust_map.insert(peer_id, groups);
self.group_trust_map_cache
.insert(peer_id, Arc::new(group_names));
}
fn get_proof_groups(&self, peer_id: PeerId) -> HashMap<String, Vec<u8>> {
self.group_trust_map
.get(&peer_id)
.map(|groups| {
groups
.iter()
.filter(|(_, proof)| !proof.is_empty())
.map(|(group, proof)| (group.clone(), proof.clone()))
.collect()
})
.unwrap_or_default()
}
fn mark_credential_peer(info: &mut RoutePeerInfo, is_credential_peer: bool) {
let mut feature_flag = info.feature_flag.unwrap_or_default();
feature_flag.is_credential_peer = is_credential_peer;
@@ -439,13 +465,38 @@ impl SyncedRouteInfo {
}
fn remove_peer(&self, peer_id: PeerId) {
tracing::warn!(?peer_id, "remove_peer from synced_route_info");
self.peer_infos.write().remove(&peer_id);
self.raw_peer_infos.remove(&peer_id);
self.conn_map.write().remove(&peer_id);
self.foreign_network.retain(|k, _| k.peer_id != peer_id);
self.group_trust_map.remove(&peer_id);
self.group_trust_map_cache.remove(&peer_id);
self.remove_peers([peer_id]);
}
fn remove_peers<I>(&self, peer_ids: I)
where
I: IntoIterator<Item = PeerId>,
{
let peer_ids: HashSet<_> = peer_ids.into_iter().collect();
if peer_ids.is_empty() {
return;
}
for peer_id in &peer_ids {
tracing::warn!(?peer_id, "remove_peer from synced_route_info");
}
{
let mut peer_infos = self.peer_infos.write();
let mut conn_map = self.conn_map.write();
for peer_id in &peer_ids {
peer_infos.remove(peer_id);
conn_map.remove(peer_id);
}
}
for peer_id in &peer_ids {
self.raw_peer_infos.remove(peer_id);
self.group_trust_map.remove(peer_id);
self.group_trust_map_cache.remove(peer_id);
}
self.foreign_network
.retain(|k, _| !peer_ids.contains(&k.peer_id));
shrink_dashmap(&self.raw_peer_infos, None);
shrink_dashmap(&self.foreign_network, None);
@@ -827,31 +878,20 @@ impl SyncedRouteInfo {
&self,
peer_infos: &[RoutePeerInfo],
local_group_declarations: &[GroupIdentity],
trust_admin_groups_without_proof: bool,
) {
let local_group_declarations = local_group_declarations
.iter()
.map(|g| (g.group_name.as_str(), g.group_secret.as_str()))
.collect::<std::collections::HashMap<&str, &str>>();
let verify_groups = |old_trusted_groups: Option<&HashMap<String, Vec<u8>>>,
info: &RoutePeerInfo|
-> HashMap<String, Vec<u8>> {
let verify_groups = |info: &RoutePeerInfo| -> HashMap<String, Vec<u8>> {
let mut trusted_groups_for_peer: HashMap<String, Vec<u8>> = HashMap::new();
for group_proof in &info.groups {
let name = &group_proof.group_name;
let proof_bytes = group_proof.group_proof.clone();
// If we already trusted this group and the proof hasn't changed, reuse it.
if old_trusted_groups
.and_then(|g| g.get(name))
.map(|old| old == &proof_bytes)
.unwrap_or(false)
{
trusted_groups_for_peer.insert(name.clone(), proof_bytes);
continue;
}
if let Some(&local_secret) =
local_group_declarations.get(group_proof.group_name.as_str())
{
@@ -867,34 +907,39 @@ impl SyncedRouteInfo {
}
}
if trust_admin_groups_without_proof && self.is_admin_peer(info) {
for group_proof in &info.groups {
trusted_groups_for_peer
.entry(group_proof.group_name.clone())
.or_default();
}
}
trusted_groups_for_peer
};
for info in peer_infos {
match self.group_trust_map.entry(info.peer_id) {
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
let old_trusted_groups = entry.get().clone();
let trusted_groups_for_peer = verify_groups(Some(&old_trusted_groups), info);
let trusted_groups_for_peer = verify_groups(info);
if trusted_groups_for_peer.is_empty() {
entry.remove();
self.group_trust_map_cache.remove(&info.peer_id);
} else {
self.group_trust_map_cache.insert(
info.peer_id,
Arc::new(trusted_groups_for_peer.keys().cloned().collect()),
);
let group_names = trusted_groups_for_peer.keys().cloned().collect();
self.group_trust_map_cache
.insert(info.peer_id, Arc::new(group_names));
*entry.get_mut() = trusted_groups_for_peer;
}
}
dashmap::mapref::entry::Entry::Vacant(entry) => {
let trusted_groups_for_peer = verify_groups(None, info);
let trusted_groups_for_peer = verify_groups(info);
if !trusted_groups_for_peer.is_empty() {
self.group_trust_map_cache.insert(
info.peer_id,
Arc::new(trusted_groups_for_peer.keys().cloned().collect()),
);
let group_names = trusted_groups_for_peer.keys().cloned().collect();
self.group_trust_map_cache
.insert(info.peer_id, Arc::new(group_names));
entry.insert(trusted_groups_for_peer);
}
}
@@ -904,16 +949,12 @@ impl SyncedRouteInfo {
fn update_my_group_trusts(&self, my_peer_id: PeerId, groups: &[PeerGroupInfo]) {
let mut my_group_map = HashMap::new();
let mut my_group_names = Vec::new();
for group in groups.iter() {
my_group_map.insert(group.group_name.clone(), group.group_proof.clone());
my_group_names.push(group.group_name.clone());
}
self.group_trust_map.insert(my_peer_id, my_group_map);
self.group_trust_map_cache
.insert(my_peer_id, Arc::new(my_group_names));
self.set_peer_groups(my_peer_id, my_group_map);
}
/// Collect trusted credential pubkeys from admin nodes (network_secret holders)
@@ -1004,18 +1045,13 @@ impl SyncedRouteInfo {
continue;
}
if let Some(tc) = all_trusted.get(&info.noise_static_pubkey) {
// This peer is a credential peer, assign groups from credential declaration
if !tc.groups.is_empty() {
let mut group_map = HashMap::new();
let mut group_names = Vec::new();
for g in &tc.groups {
group_map.insert(g.clone(), Vec::new()); // no proof needed, admin-declared
group_names.push(g.clone());
}
self.group_trust_map.insert(info.peer_id, group_map);
self.group_trust_map_cache
.insert(info.peer_id, Arc::new(group_names));
// Start from proof-backed groups so credential-declared groups can coexist
// without leaving stale credential-only entries behind after refreshes.
let mut group_map = self.get_proof_groups(info.peer_id);
for g in &tc.groups {
group_map.entry(g.clone()).or_default();
}
self.set_peer_groups(info.peer_id, group_map);
}
}
@@ -1039,19 +1075,10 @@ impl SyncedRouteInfo {
// Remove untrusted peers from peer_infos so they won't appear in route graph
if !untrusted_peers.is_empty() {
drop(peer_infos); // release read lock before writing
let mut peer_infos_write = self.peer_infos.write();
for peer_id in &untrusted_peers {
tracing::warn!(?peer_id, "removing untrusted peer from route info");
peer_infos_write.remove(peer_id);
self.raw_peer_infos.remove(peer_id);
}
drop(peer_infos_write);
// Also remove from conn_map
let mut conn_map = self.conn_map.write();
for peer_id in &untrusted_peers {
conn_map.remove(peer_id);
}
self.version.inc();
self.remove_peers(untrusted_peers.iter().copied());
}
(untrusted_peers, global_trusted_keys)
@@ -2444,14 +2471,53 @@ impl PeerRouteServiceImpl {
my_peer_info_updated || my_conn_info_updated || my_foreign_network_updated
}
async fn refresh_credential_trusts_and_disconnect(&self) -> bool {
async fn refresh_acl_groups(&self) -> bool {
let my_peer_info_updated = self.update_my_peer_info();
let trust_admin_groups_without_proof = self
.global_ctx
.get_network_identity()
.network_secret
.is_none();
let peer_infos: Vec<_> = self
.synced_route_info
.peer_infos
.read()
.iter()
.map(|(_, info)| info.clone())
.collect();
self.synced_route_info.verify_and_update_group_trusts(
&peer_infos,
&self.global_ctx.get_acl_group_declarations(),
trust_admin_groups_without_proof,
);
let untrusted = self.refresh_credential_trusts();
self.disconnect_untrusted_peers(&untrusted).await;
if my_peer_info_updated || !untrusted.is_empty() {
self.update_route_table_and_cached_local_conn_bitmap();
self.update_foreign_network_owner_map();
}
if my_peer_info_updated {
self.update_peer_info_last_update();
}
my_peer_info_updated || !untrusted.is_empty()
}
fn refresh_credential_trusts(&self) -> Vec<PeerId> {
let network_identity = self.global_ctx.get_network_identity();
let network_secret = network_identity.network_secret.as_deref();
let (untrusted, global_trusted_keys) = self
.synced_route_info
.verify_and_update_credential_trusts(network_secret);
.verify_and_update_credential_trusts(network_identity.network_secret.as_deref());
self.global_ctx
.update_trusted_keys(global_trusted_keys, &network_identity.network_name);
untrusted
}
async fn refresh_credential_trusts_and_disconnect(&self) -> bool {
let untrusted = self.refresh_credential_trusts();
self.disconnect_untrusted_peers(&untrusted).await;
!untrusted.is_empty()
}
@@ -2529,9 +2595,8 @@ impl PeerRouteServiceImpl {
}
}
for p in to_remove.iter() {
self.synced_route_info.remove_peer(*p);
}
self.synced_route_info
.remove_peers(to_remove.iter().copied());
// clear expired foreign network info
let mut to_remove = Vec::new();
@@ -3197,6 +3262,11 @@ impl RouteSessionManager {
(peer_infos, raw_peer_infos.as_ref().unwrap())
};
if !pi.is_empty() {
let trust_admin_groups_without_proof = service_impl
.global_ctx
.get_network_identity()
.network_secret
.is_none();
service_impl.synced_route_info.update_peer_infos(
my_peer_id,
service_impl.my_peer_route_id,
@@ -3209,6 +3279,7 @@ impl RouteSessionManager {
.verify_and_update_group_trusts(
pi,
&service_impl.global_ctx.get_acl_group_declarations(),
trust_admin_groups_without_proof,
);
session.update_dst_saved_peer_info_version(pi, from_peer_id);
need_update_route_table = true;
@@ -3228,15 +3299,7 @@ impl RouteSessionManager {
if need_update_route_table {
// Run credential verification and update route table
let network_identity = service_impl.global_ctx.get_network_identity();
let (untrusted, global_trusted_keys) = service_impl
.synced_route_info
.verify_and_update_credential_trusts(network_identity.network_secret.as_deref());
untrusted_peers = untrusted;
// Sync trusted keys to GlobalCtx for handshake verification
service_impl
.global_ctx
.update_trusted_keys(global_trusted_keys, &network_identity.network_name);
untrusted_peers = service_impl.refresh_credential_trusts();
service_impl.update_route_table_and_cached_local_conn_bitmap();
}
@@ -3644,6 +3707,12 @@ impl Route for PeerRoute {
fn get_peer_groups(&self, peer_id: PeerId) -> Arc<Vec<String>> {
self.service_impl.get_peer_groups(peer_id)
}
async fn refresh_acl_groups(&self) {
if self.service_impl.refresh_acl_groups().await {
self.session_mgr.sync_now("refresh_acl_groups");
}
}
}
impl PeerPacketFilter for Arc<PeerRoute> {}
@@ -3665,11 +3734,14 @@ mod tests {
time::{Duration, SystemTime},
};
use super::{PeerRoute, REMOVE_DEAD_PEER_INFO_AFTER};
use super::{PeerRoute, REMOVE_DEAD_PEER_INFO_AFTER, RouteConnInfo};
use crate::{
common::{
PeerId,
global_ctx::{GlobalCtxEvent, TrustedKeySource, tests::get_mock_global_ctx},
global_ctx::{
GlobalCtxEvent, TrustedKeySource,
tests::{get_mock_global_ctx, get_mock_global_ctx_with_network},
},
},
connector::udp_hole_punch::tests::replace_stun_info_collector,
peers::{
@@ -3680,8 +3752,10 @@ mod tests {
tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
},
proto::{
acl::{Acl, AclV1, GroupIdentity, GroupInfo},
common::{NatType, PeerFeatureFlag},
peer_rpc::{
ForeignNetworkRouteInfoEntry, ForeignNetworkRouteInfoKey, PeerGroupInfo,
PeerIdentityType, RoutePeerInfo, RoutePeerInfos, SyncRouteInfoRequest,
TrustedCredentialPubkey, TrustedCredentialPubkeyProof,
},
@@ -4040,6 +4114,219 @@ mod tests {
);
}
#[tokio::test]
async fn credential_groups_merge_with_proof_groups_and_recompute_cleanly() {
let service_impl = PeerRouteServiceImpl::new(1, get_mock_global_ctx());
let network_secret = "sec1";
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
let credential_peer_id = 31;
let credential_pubkey = vec![7; 32];
let mut credential_info = RoutePeerInfo::new();
credential_info.peer_id = credential_peer_id;
credential_info.version = 1;
credential_info.noise_static_pubkey = credential_pubkey.clone();
credential_info.groups = vec![PeerGroupInfo::generate_with_proof(
"proof-group".to_string(),
"proof-secret".to_string(),
credential_peer_id,
)];
let mut admin_info = RoutePeerInfo::new();
admin_info.peer_id = 32;
admin_info.version = 1;
admin_info.feature_flag = Some(PeerFeatureFlag {
is_credential_peer: false,
..Default::default()
});
admin_info.trusted_credential_pubkeys = vec![TrustedCredentialPubkeyProof::new_signed(
TrustedCredentialPubkey {
pubkey: credential_pubkey.clone(),
groups: vec!["cred-group".to_string()],
expiry_unix: now + 600,
..Default::default()
},
network_secret,
)];
{
let mut guard = service_impl.synced_route_info.peer_infos.write();
guard.insert(admin_info.peer_id, admin_info.clone());
guard.insert(credential_peer_id, credential_info.clone());
}
service_impl
.synced_route_info
.verify_and_update_group_trusts(
&[credential_info],
&[GroupIdentity {
group_name: "proof-group".to_string(),
group_secret: "proof-secret".to_string(),
}],
false,
);
service_impl
.synced_route_info
.verify_and_update_credential_trusts(Some(network_secret));
let groups = service_impl.get_peer_groups(credential_peer_id);
assert!(groups.contains(&"proof-group".to_string()));
assert!(groups.contains(&"cred-group".to_string()));
let guard = service_impl.synced_route_info.peer_infos.write();
let admin_info = guard.get(&32).unwrap().clone();
drop(guard);
let mut updated_admin = admin_info;
updated_admin.trusted_credential_pubkeys = vec![TrustedCredentialPubkeyProof::new_signed(
TrustedCredentialPubkey {
pubkey: credential_pubkey.clone(),
groups: vec!["replacement-group".to_string()],
expiry_unix: now + 600,
..Default::default()
},
network_secret,
)];
service_impl
.synced_route_info
.peer_infos
.write()
.insert(updated_admin.peer_id, updated_admin);
service_impl
.synced_route_info
.verify_and_update_credential_trusts(Some(network_secret));
let groups = service_impl.get_peer_groups(credential_peer_id);
assert!(groups.contains(&"proof-group".to_string()));
assert!(groups.contains(&"replacement-group".to_string()));
assert!(!groups.contains(&"cred-group".to_string()));
}
#[tokio::test]
async fn remove_peers_batches_cleanup_and_version_increment() {
let service_impl = PeerRouteServiceImpl::new(1, get_mock_global_ctx());
let removed_peer_ids = [41, 42];
let retained_peer_id = 43;
{
let mut peer_infos = service_impl.synced_route_info.peer_infos.write();
let mut conn_map = service_impl.synced_route_info.conn_map.write();
for peer_id in removed_peer_ids {
let mut info = RoutePeerInfo::new();
info.peer_id = peer_id;
info.version = 1;
peer_infos.insert(peer_id, info);
conn_map.insert(peer_id, RouteConnInfo::default());
}
let mut retained_info = RoutePeerInfo::new();
retained_info.peer_id = retained_peer_id;
retained_info.version = 1;
peer_infos.insert(retained_peer_id, retained_info);
conn_map.insert(retained_peer_id, RouteConnInfo::default());
}
for peer_id in removed_peer_ids {
service_impl.synced_route_info.raw_peer_infos.insert(
peer_id,
DynamicMessage::new(RoutePeerInfo::default().descriptor()),
);
service_impl.synced_route_info.group_trust_map.insert(
peer_id,
HashMap::from([("guest".to_string(), vec![1, 2, 3])]),
);
service_impl
.synced_route_info
.group_trust_map_cache
.insert(peer_id, Arc::new(vec!["guest".to_string()]));
service_impl.synced_route_info.foreign_network.insert(
ForeignNetworkRouteInfoKey {
peer_id,
..Default::default()
},
ForeignNetworkRouteInfoEntry::default(),
);
}
service_impl.synced_route_info.foreign_network.insert(
ForeignNetworkRouteInfoKey {
peer_id: retained_peer_id,
..Default::default()
},
ForeignNetworkRouteInfoEntry::default(),
);
let initial_version = service_impl.synced_route_info.version.get();
service_impl
.synced_route_info
.remove_peers(removed_peer_ids);
assert_eq!(
service_impl.synced_route_info.version.get(),
initial_version + 1
);
for peer_id in removed_peer_ids {
assert!(
!service_impl
.synced_route_info
.peer_infos
.read()
.contains_key(&peer_id)
);
assert!(
!service_impl
.synced_route_info
.conn_map
.read()
.contains_key(&peer_id)
);
assert!(
!service_impl
.synced_route_info
.raw_peer_infos
.contains_key(&peer_id)
);
assert!(
!service_impl
.synced_route_info
.group_trust_map
.contains_key(&peer_id)
);
assert!(
!service_impl
.synced_route_info
.group_trust_map_cache
.contains_key(&peer_id)
);
assert!(
!service_impl.synced_route_info.foreign_network.contains_key(
&ForeignNetworkRouteInfoKey {
peer_id,
..Default::default()
}
)
);
}
assert!(
service_impl
.synced_route_info
.peer_infos
.read()
.contains_key(&retained_peer_id)
);
assert!(service_impl.synced_route_info.foreign_network.contains_key(
&ForeignNetworkRouteInfoKey {
peer_id: retained_peer_id,
..Default::default()
}
));
}
#[tokio::test]
async fn sync_route_info_marks_credential_sender_and_filters_entries() {
let peer_mgr = create_mock_pmgr().await;
@@ -4208,6 +4495,7 @@ mod tests {
admin_info.trusted_credential_pubkeys = vec![TrustedCredentialPubkeyProof {
credential: Some(TrustedCredentialPubkey {
pubkey: credential_pubkey.clone(),
groups: vec!["guest".to_string()],
expiry_unix: i64::MAX,
..Default::default()
}),
@@ -4237,6 +4525,11 @@ mod tests {
.trusted_credential_pubkeys
.contains_key(&credential_pubkey)
);
assert!(
service_impl
.get_peer_groups(credential_peer_id)
.contains(&"guest".to_string())
);
service_impl.clear_expired_peer().await;
@@ -4260,6 +4553,299 @@ mod tests {
.read()
.contains_key(&credential_peer_id)
);
assert!(
!service_impl
.synced_route_info
.group_trust_map_cache
.contains_key(&credential_peer_id)
);
}
#[tokio::test]
async fn refresh_acl_groups_returns_true_when_untrusted_peers_are_disconnected() {
let service_impl = PeerRouteServiceImpl::new(1, get_mock_global_ctx());
let credential_peer_id: PeerId = 10061;
let credential_pubkey = vec![8u8; 32];
let closed_peers = Arc::new(Mutex::new(Vec::new()));
*service_impl.interface.lock().await = Some(Box::new(TrackingInterface {
my_peer_id: service_impl.my_peer_id,
closed_peers: closed_peers.clone(),
}));
let mut credential_info = RoutePeerInfo::new();
credential_info.peer_id = credential_peer_id;
credential_info.version = 1;
credential_info.noise_static_pubkey = credential_pubkey.clone();
credential_info.feature_flag = Some(PeerFeatureFlag {
is_credential_peer: true,
..Default::default()
});
let self_info = RoutePeerInfo::new_updated_self(
service_impl.my_peer_id,
service_impl.my_peer_route_id,
&service_impl.global_ctx,
);
let mut self_info = self_info;
self_info.version = 1;
self_info.last_update = Some(SystemTime::now().into());
{
let mut guard = service_impl.synced_route_info.peer_infos.write();
guard.insert(service_impl.my_peer_id, self_info);
guard.insert(credential_peer_id, credential_info);
}
service_impl
.synced_route_info
.trusted_credential_pubkeys
.insert(
credential_pubkey.clone(),
TrustedCredentialPubkey {
pubkey: credential_pubkey.clone(),
expiry_unix: i64::MAX,
..Default::default()
},
);
assert!(service_impl.refresh_acl_groups().await);
assert!(closed_peers.lock().contains(&credential_peer_id));
assert!(
!service_impl
.synced_route_info
.peer_infos
.read()
.contains_key(&credential_peer_id)
);
assert!(
!service_impl
.synced_route_info
.trusted_credential_pubkeys
.contains_key(&credential_pubkey)
);
}
#[tokio::test]
async fn refresh_acl_groups_updates_local_membership_immediately() {
let peer_mgr = create_mock_pmgr().await;
let route = create_mock_route(peer_mgr.clone()).await;
let my_peer_id = peer_mgr.my_peer_id();
assert!(route.service_impl.get_peer_groups(my_peer_id).is_empty());
peer_mgr.get_global_ctx().config.set_acl(Some(Acl {
acl_v1: Some(AclV1 {
group: Some(GroupInfo {
declares: vec![GroupIdentity {
group_name: "admin".to_string(),
group_secret: "admin-secret".to_string(),
}],
members: vec!["admin".to_string()],
}),
..Default::default()
}),
}));
route.refresh_acl_groups().await;
let groups = route.service_impl.get_peer_groups(my_peer_id);
assert!(groups.contains(&"admin".to_string()));
assert_eq!(groups.len(), 1);
}
#[tokio::test]
async fn refresh_acl_groups_revalidates_cached_remote_groups() {
let peer_mgr = create_mock_pmgr().await;
let route = create_mock_route(peer_mgr.clone()).await;
let remote_peer_id = 200;
let remote_group = PeerGroupInfo::generate_with_proof(
"ops".to_string(),
"secret-v1".to_string(),
remote_peer_id,
);
peer_mgr.get_global_ctx().config.set_acl(Some(Acl {
acl_v1: Some(AclV1 {
group: Some(GroupInfo {
declares: vec![GroupIdentity {
group_name: "ops".to_string(),
group_secret: "secret-v1".to_string(),
}],
members: vec![],
}),
..Default::default()
}),
}));
let mut remote_info = RoutePeerInfo::new();
remote_info.peer_id = remote_peer_id;
remote_info.version = 1;
remote_info.groups = vec![remote_group];
route
.service_impl
.synced_route_info
.peer_infos
.write()
.insert(remote_peer_id, remote_info.clone());
route
.service_impl
.synced_route_info
.verify_and_update_group_trusts(
&[remote_info],
&peer_mgr.get_global_ctx().get_acl_group_declarations(),
false,
);
assert!(
route
.service_impl
.get_peer_groups(remote_peer_id)
.contains(&"ops".to_string())
);
peer_mgr.get_global_ctx().config.set_acl(Some(Acl {
acl_v1: Some(AclV1 {
group: Some(GroupInfo {
declares: vec![GroupIdentity {
group_name: "ops".to_string(),
group_secret: "secret-v2".to_string(),
}],
members: vec![],
}),
..Default::default()
}),
}));
route.refresh_acl_groups().await;
assert!(
route
.service_impl
.get_peer_groups(remote_peer_id)
.is_empty()
);
}
#[tokio::test]
async fn credential_verifier_trusts_admin_self_groups_from_multiple_admins() {
let service_impl = PeerRouteServiceImpl::new(
1,
get_mock_global_ctx_with_network(Some(
crate::common::config::NetworkIdentity::new_credential("net1".to_string()),
)),
);
let mut admin_a = RoutePeerInfo::new();
admin_a.peer_id = 501;
admin_a.version = 1;
admin_a.groups = vec![
PeerGroupInfo {
group_name: "ops".to_string(),
group_proof: vec![1; 32],
},
PeerGroupInfo {
group_name: "core-admin".to_string(),
group_proof: vec![2; 32],
},
];
let mut admin_b = RoutePeerInfo::new();
admin_b.peer_id = 502;
admin_b.version = 1;
admin_b.groups = vec![PeerGroupInfo {
group_name: "audit".to_string(),
group_proof: vec![3; 32],
}];
service_impl
.synced_route_info
.verify_and_update_group_trusts(&[admin_a.clone(), admin_b.clone()], &[], true);
let admin_a_groups = service_impl.get_peer_groups(admin_a.peer_id);
assert!(admin_a_groups.contains(&"ops".to_string()));
assert!(admin_a_groups.contains(&"core-admin".to_string()));
let admin_b_groups = service_impl.get_peer_groups(admin_b.peer_id);
assert!(admin_b_groups.contains(&"audit".to_string()));
}
#[tokio::test]
async fn credential_verifier_still_checks_credential_self_declared_groups() {
let service_impl = PeerRouteServiceImpl::new(
1,
get_mock_global_ctx_with_network(Some(
crate::common::config::NetworkIdentity::new_credential("net1".to_string()),
)),
);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
let credential_peer_id = 601;
let credential_pubkey = vec![9; 32];
let mut admin_info = RoutePeerInfo::new();
admin_info.peer_id = 600;
admin_info.version = 1;
admin_info.trusted_credential_pubkeys = vec![TrustedCredentialPubkeyProof {
credential: Some(TrustedCredentialPubkey {
pubkey: credential_pubkey.clone(),
groups: vec!["cred-acl".to_string()],
expiry_unix: now + 600,
..Default::default()
}),
credential_hmac: vec![7; 32],
}];
let mut credential_info = RoutePeerInfo::new();
credential_info.peer_id = credential_peer_id;
credential_info.version = 1;
credential_info.noise_static_pubkey = credential_pubkey.clone();
credential_info.feature_flag = Some(PeerFeatureFlag {
is_credential_peer: true,
..Default::default()
});
credential_info.groups = vec![
PeerGroupInfo::generate_with_proof(
"proof-group".to_string(),
"proof-secret".to_string(),
credential_peer_id,
),
PeerGroupInfo::generate_with_proof(
"invalid-group".to_string(),
"wrong-secret".to_string(),
credential_peer_id,
),
];
{
let mut guard = service_impl.synced_route_info.peer_infos.write();
guard.insert(admin_info.peer_id, admin_info.clone());
guard.insert(credential_info.peer_id, credential_info.clone());
}
service_impl
.synced_route_info
.verify_and_update_group_trusts(
&[admin_info, credential_info],
&[
GroupIdentity {
group_name: "proof-group".to_string(),
group_secret: "proof-secret".to_string(),
},
GroupIdentity {
group_name: "invalid-group".to_string(),
group_secret: "actual-secret".to_string(),
},
],
true,
);
service_impl
.synced_route_info
.verify_and_update_credential_trusts(None);
let groups = service_impl.get_peer_groups(credential_peer_id);
assert!(groups.contains(&"proof-group".to_string()));
assert!(groups.contains(&"cred-acl".to_string()));
assert!(!groups.contains(&"invalid-group".to_string()));
}
#[rstest::rstest]
+46 -788
View File
@@ -1,26 +1,14 @@
use std::{
sync::{
Arc, Mutex, RwLock,
atomic::{AtomicBool, AtomicU32, Ordering},
},
time::{SystemTime, UNIX_EPOCH},
use std::sync::{
Arc, RwLock,
atomic::{AtomicBool, Ordering},
};
use atomic_shim::AtomicU64;
use crate::{
common::PeerId,
peers::encrypt::{Encryptor, create_encryptor},
tunnel::packet_def::{StandardAeadTail, ZCPacket},
};
use anyhow::anyhow;
use dashmap::DashMap;
use hmac::{Hmac, Mac as _};
use rand::RngCore as _;
use sha2::Sha256;
use zerocopy::FromBytes;
type HmacSha256 = Hmac<Sha256>;
use super::secure_datagram::{SecureDatagramDirection, SecureDatagramSession};
use crate::{common::PeerId, tunnel::packet_def::ZCPacket};
pub struct UpsertResponderSessionReturn {
pub session: Arc<PeerSession>,
pub action: PeerSessionAction,
@@ -87,9 +75,6 @@ impl PeerSessionStore {
self.sessions.insert(key, session);
}
/// Remove sessions that are no longer referenced by any PeerConn or RelayPeerMap.
/// A session with strong_count == 1 means only the store holds it — no active
/// connection is using it, so it can be safely cleaned up.
pub fn evict_unused_sessions(&self) {
self.sessions
.retain(|_key, session| Arc::strong_count(session) > 1);
@@ -188,7 +173,6 @@ impl PeerSessionStore {
}
PeerSessionAction::Sync | PeerSessionAction::Create => {
let root_key = root_key_32.ok_or_else(|| anyhow!("missing root_key"))?;
// If the existing session is invalidated, remove it so we create a fresh one
if let Some(existing) = self.sessions.get(key)
&& !existing.is_valid()
{
@@ -224,253 +208,25 @@ impl PeerSessionStore {
}
}
#[derive(Clone, Default)]
struct EpochKeySlot {
epoch: u32,
generation: u32,
valid: bool,
send_cipher: Option<Arc<dyn Encryptor>>,
recv_cipher: Option<Arc<dyn Encryptor>>,
}
impl std::fmt::Debug for EpochKeySlot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EpochKeySlot")
.field("epoch", &self.epoch)
.field("generation", &self.generation)
.field("valid", &self.valid)
.finish()
}
}
impl EpochKeySlot {
fn get_encryptor(&self, is_send: bool) -> Arc<dyn Encryptor> {
if is_send {
self.send_cipher.as_ref().unwrap().clone()
} else {
self.recv_cipher.as_ref().unwrap().clone()
}
}
}
#[derive(Debug, Clone, Copy, Default)]
struct ReplayWindow256 {
max_seq: u64,
bitmap: [u8; 32],
valid: bool,
}
impl ReplayWindow256 {
fn clear(&mut self) {
self.max_seq = 0;
self.bitmap.fill(0);
self.valid = false;
}
fn test_bit(&self, idx: usize) -> bool {
let byte = idx / 8;
let bit = idx % 8;
(self.bitmap[byte] >> bit) & 1 == 1
}
fn set_bit(&mut self, idx: usize) {
let byte = idx / 8;
let bit = idx % 8;
self.bitmap[byte] |= 1u8 << bit;
}
fn shift_right(&mut self, shift: usize) {
if shift == 0 {
return;
}
let total_bits = 256usize;
if shift >= total_bits {
self.bitmap.fill(0);
return;
}
let byte_shift = shift / 8;
let bit_shift = shift % 8;
if byte_shift > 0 {
for i in (0..self.bitmap.len()).rev() {
self.bitmap[i] = if i >= byte_shift {
self.bitmap[i - byte_shift]
} else {
0
};
}
}
if bit_shift > 0 {
let mut carry = 0u8;
for b in self.bitmap.iter_mut() {
let new_carry = *b >> (8 - bit_shift);
*b = (*b << bit_shift) | carry;
carry = new_carry;
}
}
}
fn accept(&mut self, seq: u64) -> bool {
if !self.valid {
self.valid = true;
self.max_seq = seq;
self.set_bit(0);
return true;
}
if seq > self.max_seq {
let shift = (seq - self.max_seq) as usize;
self.shift_right(shift);
self.max_seq = seq;
self.set_bit(0);
return true;
}
let delta = (self.max_seq - seq) as usize;
if delta >= 256 {
return false;
}
if self.test_bit(delta) {
return false;
}
self.set_bit(delta);
true
}
}
#[derive(Debug, Clone, Copy, Default)]
struct EpochRxSlot {
epoch: u32,
window: ReplayWindow256,
last_rx_ms: u64,
valid: bool,
}
impl EpochRxSlot {
fn clear(&mut self) {
self.epoch = 0;
self.window.clear();
self.last_rx_ms = 0;
self.valid = false;
}
}
#[derive(Debug, Clone, Copy, Default)]
struct SyncRxGrace {
slots: [[EpochRxSlot; 2]; 2],
expires_at_ms: u64,
valid: bool,
}
impl SyncRxGrace {
fn clear(&mut self) {
self.slots = [[EpochRxSlot::default(), EpochRxSlot::default()]; 2];
self.expires_at_ms = 0;
self.valid = false;
}
fn refresh(&mut self, slots: [[EpochRxSlot; 2]; 2], expires_at_ms: u64) {
self.slots = slots;
self.expires_at_ms = expires_at_ms;
self.valid = true;
}
fn maybe_expire(&mut self, now_ms: u64) {
if self.valid && now_ms >= self.expires_at_ms {
self.clear();
}
}
}
pub struct PeerSession {
peer_id: PeerId,
root_key: RwLock<[u8; 32]>,
session_generation: AtomicU32,
peer_static_pubkey: RwLock<Option<[u8; 32]>>,
send_epoch: AtomicU32,
send_seq: [AtomicU64; 2],
send_epoch_started_ms: AtomicU64,
send_packets_since_epoch: AtomicU64,
rx_slots: Mutex<[[EpochRxSlot; 2]; 2]>,
key_cache: Mutex<[[EpochKeySlot; 2]; 2]>,
sync_rx_grace: Mutex<SyncRxGrace>,
sync_rx_grace_expires_at_ms: AtomicU64,
send_cipher_algorithm: String,
recv_cipher_algorithm: String,
/// Set to true when the session is detected as corrupted (persistent decrypt failures).
/// Holders of Arc<PeerSession> can check this to know the session should be discarded.
datagram: SecureDatagramSession,
invalidated: AtomicBool,
/// Consecutive decrypt failure counter. Auto-invalidates when threshold is reached.
decrypt_fail_count: AtomicU32,
}
impl std::fmt::Debug for PeerSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerSession")
.field("peer_id", &self.peer_id)
.field("root_key", &self.root_key)
.field("session_generation", &self.session_generation)
.field("peer_static_pubkey", &self.peer_static_pubkey)
.field("send_epoch", &self.send_epoch)
.field("send_seq", &self.send_seq)
.field("send_epoch_started_ms", &self.send_epoch_started_ms)
.field("send_packets_since_epoch", &self.send_packets_since_epoch)
.field("rx_slots", &self.rx_slots)
.field("key_cache", &self.key_cache)
.field("sync_rx_grace", &self.sync_rx_grace)
.field(
"sync_rx_grace_expires_at_ms",
&self.sync_rx_grace_expires_at_ms,
)
.field("send_cipher_algorithm", &self.send_cipher_algorithm)
.field("recv_cipher_algorithm", &self.recv_cipher_algorithm)
.field("datagram", &self.datagram)
.finish()
}
}
impl PeerSession {
/// Idle-eviction timeout for receive slots, in milliseconds.
///
/// If no packets are received for this period (~30 seconds), the
/// corresponding RX slot is considered idle and may be cleared/reused.
/// This helps reclaim state for dead peers or paths while still tolerating
/// short network stalls. Environments with very bursty or high-latency
/// traffic may want to increase this value; low-latency or tightly
/// resource-constrained deployments may lower it.
const EVICT_IDLE_AFTER_MS: u64 = 30_000;
/// Keep the pre-sync receive windows alive briefly so in-flight packets
/// from the previous epochs are still accepted after a shared session is
/// synced in place by another connection.
const SYNC_RX_GRACE_AFTER_MS: u64 = 5_000;
/// Maximum number of packets to send in a single epoch before forcing
/// a key/epoch rotation.
///
/// This bounds the amount of traffic protected under a single set of
/// derived keys, which is a common best practice for long-lived secure
/// channels. The current value (~1 million packets) is a conservative
/// default chosen to balance security (more frequent rotation) and
/// performance (avoiding excessive rekeying). Deployments with very high
/// or very low packet rates may tune this threshold accordingly.
const ROTATE_AFTER_PACKETS: u64 = 1_000_000;
/// Maximum wall-clock lifetime of a send epoch, in milliseconds.
///
/// Even if the packet-based limit is not reached, epochs are rotated
/// after this duration (~10 minutes) to avoid long-lived keys and keep
/// replay windows bounded in time. This also limits the impact of a
/// compromised key. Installations that prioritize lower overhead over
/// more aggressive key rotation may increase this value; those with
/// stricter security requirements may decrease it.
const ROTATE_AFTER_MS: u64 = 10 * 60 * 1000;
const MAX_ACCEPTED_RX_EPOCH_AHEAD: u32 = 3;
const DECRYPT_FAIL_THRESHOLD: u32 = 10;
const SYNC_RX_GRACE_AFTER_MS: u64 = SecureDatagramSession::SYNC_RX_GRACE_AFTER_MS;
pub fn new(
peer_id: PeerId,
@@ -481,32 +237,17 @@ impl PeerSession {
recv_cipher_algorithm: String,
peer_static_pubkey: Option<[u8; 32]>,
) -> Self {
let rx_slots = [
[EpochRxSlot::default(), EpochRxSlot::default()],
[EpochRxSlot::default(), EpochRxSlot::default()],
];
let key_cache = [
[EpochKeySlot::default(), EpochKeySlot::default()],
[EpochKeySlot::default(), EpochKeySlot::default()],
];
let now_ms = now_ms();
Self {
peer_id,
root_key: RwLock::new(root_key),
session_generation: AtomicU32::new(session_generation),
peer_static_pubkey: RwLock::new(peer_static_pubkey),
send_epoch: AtomicU32::new(initial_epoch),
send_seq: [AtomicU64::new(0), AtomicU64::new(0)],
send_epoch_started_ms: AtomicU64::new(now_ms),
send_packets_since_epoch: AtomicU64::new(0),
rx_slots: Mutex::new(rx_slots),
key_cache: Mutex::new(key_cache),
sync_rx_grace: Mutex::new(SyncRxGrace::default()),
sync_rx_grace_expires_at_ms: AtomicU64::new(0),
send_cipher_algorithm,
recv_cipher_algorithm,
datagram: SecureDatagramSession::new(
root_key,
session_generation,
initial_epoch,
send_cipher_algorithm,
recv_cipher_algorithm,
),
invalidated: AtomicBool::new(false),
decrypt_fail_count: AtomicU32::new(0),
}
}
@@ -514,44 +255,29 @@ impl PeerSession {
self.peer_id
}
/// Mark this session as invalid. All holders of Arc<PeerSession> will see this.
pub fn invalidate(&self) {
self.invalidated.store(true, Ordering::Relaxed);
self.datagram.invalidate();
}
pub fn is_valid(&self) -> bool {
!self.invalidated.load(Ordering::Relaxed)
!self.invalidated.load(Ordering::Relaxed) && self.datagram.is_valid()
}
pub fn session_generation(&self) -> u32 {
self.session_generation.load(Ordering::Relaxed)
self.datagram.session_generation()
}
pub fn root_key(&self) -> [u8; 32] {
*self.root_key.read().unwrap()
self.datagram.root_key()
}
pub fn new_root_key() -> [u8; 32] {
let mut out = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut out);
out
SecureDatagramSession::new_root_key()
}
pub fn next_sync_epoch(&self) -> u32 {
let send_epoch = self.send_epoch.load(Ordering::Relaxed);
let rx = self.rx_slots.lock().unwrap();
let mut max_epoch = send_epoch;
for dir in 0..2 {
let cur = rx[dir][0];
if cur.valid {
max_epoch = max_epoch.max(cur.epoch);
}
let prev = rx[dir][1];
if prev.valid {
max_epoch = max_epoch.max(prev.epoch);
}
}
max_epoch.wrapping_add(1)
self.datagram.next_sync_epoch()
}
pub fn check_encrypt_algo_same(
@@ -559,12 +285,8 @@ impl PeerSession {
send_algorithm: &str,
recv_algorithm: &str,
) -> Result<(), anyhow::Error> {
if self.send_cipher_algorithm != send_algorithm
|| self.recv_cipher_algorithm != recv_algorithm
{
return Err(anyhow!("encrypt algorithm not same"));
}
Ok(())
self.datagram
.check_encrypt_algo_same(send_algorithm, recv_algorithm)
}
pub fn check_or_set_peer_static_pubkey(
@@ -592,277 +314,25 @@ impl PeerSession {
initial_epoch: u32,
preserve_rx_grace: bool,
) {
let old_root_key = self.root_key();
let can_preserve_rx_grace = preserve_rx_grace && old_root_key == root_key;
{
let mut g = self.root_key.write().unwrap();
*g = root_key;
}
self.session_generation
.store(session_generation, Ordering::Relaxed);
self.send_epoch.store(initial_epoch, Ordering::Relaxed);
self.send_seq[0].store(0, Ordering::Relaxed);
self.send_seq[1].store(0, Ordering::Relaxed);
self.send_epoch_started_ms
.store(now_ms(), Ordering::Relaxed);
self.send_packets_since_epoch.store(0, Ordering::Relaxed);
{
let mut rx = self.rx_slots.lock().unwrap();
let mut sync_rx_grace = self.sync_rx_grace.lock().unwrap();
if can_preserve_rx_grace {
let expires_at_ms = now_ms().saturating_add(Self::SYNC_RX_GRACE_AFTER_MS);
sync_rx_grace.refresh(*rx, expires_at_ms);
self.sync_rx_grace_expires_at_ms
.store(expires_at_ms, Ordering::Relaxed);
} else {
sync_rx_grace.clear();
self.sync_rx_grace_expires_at_ms.store(0, Ordering::Relaxed);
}
for dir in 0..2 {
rx[dir][0].clear();
rx[dir][1].clear();
}
}
self.key_cache
.lock()
.unwrap()
.fill([EpochKeySlot::default(), EpochKeySlot::default()]);
self.datagram.sync_root_key(
root_key,
session_generation,
initial_epoch,
preserve_rx_grace,
);
}
pub fn dir_for_sender(sender_peer_id: PeerId, receiver_peer_id: PeerId) -> usize {
pub fn dir_for_sender(
sender_peer_id: PeerId,
receiver_peer_id: PeerId,
) -> SecureDatagramDirection {
if sender_peer_id < receiver_peer_id {
0
SecureDatagramDirection::AToB
} else {
1
SecureDatagramDirection::BToA
}
}
fn hkdf_traffic_key(&self, epoch: u32, dir: usize) -> [u8; 32] {
let root_key = self.root_key();
let salt = [0u8; 32];
let mut extract = HmacSha256::new_from_slice(&salt).unwrap();
extract.update(&root_key);
let prk = extract.finalize().into_bytes();
let mut info = Vec::with_capacity(9 + 4 + 1);
info.extend_from_slice(b"et-traffic");
info.extend_from_slice(&epoch.to_be_bytes());
info.push(dir as u8);
let mut expand = HmacSha256::new_from_slice(&prk).unwrap();
expand.update(&info);
expand.update(&[1u8]);
let okm = expand.finalize().into_bytes();
let mut key = [0u8; 32];
key.copy_from_slice(&okm[..32]);
key
}
fn get_or_create_encryptor(
&self,
epoch: u32,
dir: usize,
generation: u32,
is_send: bool,
) -> Arc<dyn Encryptor> {
let mut guard = self.key_cache.lock().unwrap();
for slot in guard[dir].iter_mut() {
if slot.valid && slot.epoch == epoch && slot.generation == generation {
return slot.get_encryptor(is_send);
}
}
let key = self.hkdf_traffic_key(epoch, dir);
let mut key_128 = [0u8; 16];
key_128.copy_from_slice(&key[..16]);
let slot = EpochKeySlot {
epoch,
generation,
valid: true,
send_cipher: Some(create_encryptor(&self.send_cipher_algorithm, key_128, key)),
recv_cipher: Some(create_encryptor(&self.recv_cipher_algorithm, key_128, key)),
};
let ret = slot.get_encryptor(is_send);
if !guard[dir][0].valid || guard[dir][0].epoch == epoch {
guard[dir][0] = slot;
} else {
guard[dir][1] = slot;
}
ret
}
fn maybe_rotate_epoch(&self, now_ms: u64) {
let packets = self
.send_packets_since_epoch
.fetch_add(1, Ordering::Relaxed)
+ 1;
let started = self.send_epoch_started_ms.load(Ordering::Relaxed);
if packets < Self::ROTATE_AFTER_PACKETS
&& now_ms.saturating_sub(started) < Self::ROTATE_AFTER_MS
{
return;
}
let cur = self.send_epoch.load(Ordering::Relaxed);
let next = cur.wrapping_add(1);
if self
.send_epoch
.compare_exchange(cur, next, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
self.send_epoch_started_ms.store(now_ms, Ordering::Relaxed);
self.send_packets_since_epoch.store(0, Ordering::Relaxed);
}
}
fn next_nonce(&self, dir: usize) -> (u32, u64, [u8; 12]) {
let now_ms = now_ms();
self.maybe_rotate_epoch(now_ms);
let epoch = self.send_epoch.load(Ordering::Relaxed);
let seq = self.send_seq[dir].fetch_add(1, Ordering::Relaxed);
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&epoch.to_be_bytes());
nonce[4..].copy_from_slice(&seq.to_be_bytes());
(epoch, seq, nonce)
}
fn parse_tail(payload: &[u8]) -> Option<[u8; 12]> {
let tail = StandardAeadTail::ref_from_suffix(payload)?;
Some(tail.nonce)
}
fn evict_old_rx_slots(rx: &mut [[EpochRxSlot; 2]; 2], now_ms: u64) {
for dir_slots in rx.iter_mut() {
for slot in dir_slots.iter_mut() {
if !slot.valid {
continue;
}
let last = slot.last_rx_ms;
if last != 0 && now_ms.saturating_sub(last) > Self::EVICT_IDLE_AFTER_MS {
slot.clear();
}
}
}
}
fn epoch_in_slots(slots: &[EpochRxSlot; 2], epoch: u32) -> bool {
slots[0].valid && slots[0].epoch == epoch || slots[1].valid && slots[1].epoch == epoch
}
fn sync_rx_grace_active(&self, now_ms: u64) -> bool {
let expires_at_ms = self.sync_rx_grace_expires_at_ms.load(Ordering::Relaxed);
if expires_at_ms == 0 {
return false;
}
if now_ms < expires_at_ms {
return true;
}
self.sync_rx_grace_expires_at_ms.store(0, Ordering::Relaxed);
false
}
fn check_replay(&self, epoch: u32, seq: u64, dir: usize, now_ms: u64) -> bool {
let mut rx = self.rx_slots.lock().unwrap();
Self::evict_old_rx_slots(&mut rx, now_ms);
let mut sync_rx_grace = if self.sync_rx_grace_active(now_ms) {
let mut sync_rx_grace = self.sync_rx_grace.lock().unwrap();
sync_rx_grace.maybe_expire(now_ms);
if sync_rx_grace.valid {
Self::evict_old_rx_slots(&mut sync_rx_grace.slots, now_ms);
Some(sync_rx_grace)
} else {
self.sync_rx_grace_expires_at_ms.store(0, Ordering::Relaxed);
None
}
} else {
None
};
let send_epoch = self.send_epoch.load(Ordering::Relaxed);
{
let mut key_cache = self.key_cache.lock().unwrap();
for d in 0..2 {
for s in 0..2 {
if !key_cache[d][s].valid {
continue;
}
let e = key_cache[d][s].epoch;
let allowed = e == send_epoch
|| rx[d][0].valid && rx[d][0].epoch == e
|| rx[d][1].valid && rx[d][1].epoch == e
|| sync_rx_grace
.as_ref()
.is_some_and(|g| Self::epoch_in_slots(&g.slots[d], e));
if !allowed {
key_cache[d][s].valid = false;
}
}
}
}
if sync_rx_grace
.as_ref()
.is_some_and(|g| Self::epoch_in_slots(&g.slots[dir], epoch))
{
for slot in sync_rx_grace.as_mut().unwrap().slots[dir].iter_mut() {
if slot.valid && slot.epoch == epoch {
slot.last_rx_ms = now_ms;
return slot.window.accept(seq);
}
}
}
if !rx[dir][0].valid {
rx[dir][0] = EpochRxSlot {
epoch,
window: ReplayWindow256::default(),
last_rx_ms: now_ms,
valid: true,
};
}
if rx[dir][0].valid && epoch == rx[dir][0].epoch {
rx[dir][0].last_rx_ms = now_ms;
return rx[dir][0].window.accept(seq);
}
if rx[dir][1].valid && epoch == rx[dir][1].epoch {
rx[dir][1].last_rx_ms = now_ms;
return rx[dir][1].window.accept(seq);
}
if rx[dir][0].valid && epoch > rx[dir][0].epoch {
let mut baseline_epoch = send_epoch;
if rx[dir][0].valid {
baseline_epoch = baseline_epoch.max(rx[dir][0].epoch);
}
if rx[dir][1].valid {
baseline_epoch = baseline_epoch.max(rx[dir][1].epoch);
}
let max_allowed_epoch =
baseline_epoch.saturating_add(Self::MAX_ACCEPTED_RX_EPOCH_AHEAD);
if epoch > max_allowed_epoch {
return false;
}
rx[dir][1] = rx[dir][0];
rx[dir][0] = EpochRxSlot {
epoch,
window: ReplayWindow256::default(),
last_rx_ms: now_ms,
valid: true,
};
return rx[dir][0].window.accept(seq);
}
false
}
pub fn encrypt_payload(
&self,
sender_peer_id: PeerId,
@@ -872,19 +342,8 @@ impl PeerSession {
if !self.is_valid() {
return Err(anyhow!("session invalidated"));
}
let dir = Self::dir_for_sender(sender_peer_id, receiver_peer_id);
let (epoch, _seq, nonce_bytes) = self.next_nonce(dir);
let encryptor = self.get_or_create_encryptor(epoch, dir, self.session_generation(), true);
if let Err(e) = encryptor.encrypt_with_nonce(pkt, Some(nonce_bytes.as_slice())) {
tracing::warn!(
peer_id = ?self.peer_id,
?e,
"session encrypt failed, invalidating"
);
self.invalidate();
return Err(e.into());
}
Ok(())
self.datagram
.encrypt_payload(Self::dir_for_sender(sender_peer_id, receiver_peer_id), pkt)
}
pub fn decrypt_payload(
@@ -896,47 +355,13 @@ impl PeerSession {
if !self.is_valid() {
return Err(anyhow!("session invalidated"));
}
let dir = Self::dir_for_sender(sender_peer_id, receiver_peer_id);
let nonce_bytes =
Self::parse_tail(ciphertext_with_tail.payload()).ok_or_else(|| anyhow!("no tail"))?;
let epoch = u32::from_be_bytes(nonce_bytes[..4].try_into().unwrap());
let seq = u64::from_be_bytes(nonce_bytes[4..].try_into().unwrap());
let now_ms = now_ms();
if !self.check_replay(epoch, seq, dir, now_ms) {
return Err(anyhow!(
"replay rejected, sender_peer_id: {:?}, receiver_peer_id: {:?}",
sender_peer_id,
receiver_peer_id
));
}
let encryptor = self.get_or_create_encryptor(epoch, dir, self.session_generation(), false);
if let Err(e) = encryptor.decrypt(ciphertext_with_tail) {
let count = self.decrypt_fail_count.fetch_add(1, Ordering::Relaxed) + 1;
if count >= Self::DECRYPT_FAIL_THRESHOLD {
self.invalidate();
tracing::warn!(
peer_id = ?self.peer_id,
count,
"session auto-invalidated after consecutive decrypt failures"
);
}
return Err(e.into());
}
self.decrypt_fail_count.store(0, Ordering::Relaxed);
Ok(())
self.datagram.decrypt_payload(
Self::dir_for_sender(sender_peer_id, receiver_peer_id),
ciphertext_with_tail,
)
}
}
fn now_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
#[cfg(test)]
mod tests {
use super::*;
@@ -984,177 +409,10 @@ mod tests {
}
#[test]
fn replay_rejects_far_future_epoch_without_poisoning_window() {
let peer_id: PeerId = 10;
let root_key = PeerSession::new_root_key();
let generation = 1u32;
let initial_epoch = 0u32;
let s = PeerSession::new(
peer_id,
root_key,
generation,
initial_epoch,
"aes-256-gcm".to_string(),
"aes-256-gcm".to_string(),
None,
);
let now = now_ms();
assert!(s.check_replay(0, 1, 0, now));
assert!(s.check_replay(0, 2, 0, now));
assert!(!s.check_replay(1000, 1, 0, now));
assert!(s.check_replay(1, 1, 0, now + 1));
assert!(s.check_replay(1, 2, 0, now + 2));
}
#[test]
fn replay_window_shift_preserves_bits() {
let mut w = ReplayWindow256::default();
// Accept seqs 0..10
for i in 0..10u64 {
assert!(w.accept(i), "seq {i} should be accepted");
}
assert_eq!(w.max_seq, 9);
// All seqs 0..10 should be marked as seen (replay)
for i in 0..10u64 {
assert!(!w.accept(i), "seq {i} should be rejected as replay");
}
// Seq 10 should still be accepted
assert!(w.accept(10));
}
#[test]
fn replay_window_out_of_order_within_window() {
let mut w = ReplayWindow256::default();
// Accept even seqs 0,2,4,...,20
for i in (0..=20u64).step_by(2) {
assert!(w.accept(i), "seq {i} should be accepted");
}
// Now accept odd seqs 1,3,5,...,19 (out of order, within window)
for i in (1..=19u64).step_by(2) {
assert!(w.accept(i), "seq {i} should be accepted (out of order)");
}
// All seqs 0..=20 should now be marked as seen
for i in 0..=20u64 {
assert!(!w.accept(i), "seq {i} should be rejected as replay");
}
}
#[test]
fn sync_root_key_allows_any_epoch_from_remote() {
// After sync_root_key, the remote peer may still be sending at an
// old epoch. The receiver should accept those packets.
let peer_id: PeerId = 10;
let root_key = PeerSession::new_root_key();
let s = PeerSession::new(
peer_id,
root_key,
1,
0,
"aes-256-gcm".to_string(),
"aes-256-gcm".to_string(),
None,
);
// Simulate receiving some packets at epoch 0
let now = now_ms();
assert!(s.check_replay(0, 0, 0, now));
assert!(s.check_replay(0, 1, 0, now));
// Sync with initial_epoch=2 (simulating a Sync action)
s.sync_root_key(root_key, 2, 2, true);
// Remote peer is still sending at epoch 0 — should be accepted
// (rx_slots were cleared, so the first packet establishes the epoch)
assert!(
s.check_replay(0, 10, 0, now + 1),
"packets at old epoch should be accepted after sync"
);
}
#[test]
fn sync_root_key_keeps_previous_epochs_during_grace_window() {
let peer_id: PeerId = 10;
let root_key = PeerSession::new_root_key();
let s = PeerSession::new(
peer_id,
root_key,
1,
0,
"aes-256-gcm".to_string(),
"aes-256-gcm".to_string(),
None,
);
let now = now_ms();
assert!(s.check_replay(0, 0, 0, now));
assert!(s.check_replay(1, 0, 0, now + 1));
s.sync_root_key(root_key, 2, 2, true);
// The first packet after sync may already use the new epoch.
assert!(s.check_replay(2, 0, 0, now + 2));
// Older in-flight packets from pre-sync epochs should still be accepted
// during the grace period, regardless of arrival order.
assert!(s.check_replay(1, 1, 0, now + 3));
assert!(s.check_replay(0, 1, 0, now + 4));
}
#[test]
fn sync_root_key_expires_previous_epochs_after_grace_window() {
let peer_id: PeerId = 10;
let root_key = PeerSession::new_root_key();
let s = PeerSession::new(
peer_id,
root_key,
1,
0,
"aes-256-gcm".to_string(),
"aes-256-gcm".to_string(),
None,
);
let now = now_ms();
assert!(s.check_replay(0, 0, 0, now));
assert!(s.check_replay(1, 0, 0, now + 1));
s.sync_root_key(root_key, 2, 2, true);
assert!(s.check_replay(2, 0, 0, now + 2));
assert!(
!s.check_replay(0, 1, 0, now + PeerSession::SYNC_RX_GRACE_AFTER_MS + 3),
"old epochs should stop being accepted once the sync grace window expires"
);
}
#[test]
fn sync_root_key_does_not_preserve_previous_epochs_when_root_key_changes() {
let peer_id: PeerId = 10;
let root_key = PeerSession::new_root_key();
let s = PeerSession::new(
peer_id,
root_key,
1,
0,
"aes-256-gcm".to_string(),
"aes-256-gcm".to_string(),
None,
);
let now = now_ms();
assert!(s.check_replay(0, 0, 0, now));
assert!(s.check_replay(1, 0, 0, now + 1));
s.sync_root_key(PeerSession::new_root_key(), 2, 2, true);
assert!(s.check_replay(2, 0, 0, now + 2));
assert!(
!s.check_replay(1, 1, 0, now + 3),
"old epochs should not be preserved when sync replaces the root key"
fn sync_root_key_preserves_generic_grace_window_constant() {
assert_eq!(
PeerSession::SYNC_RX_GRACE_AFTER_MS,
SecureDatagramSession::SYNC_RX_GRACE_AFTER_MS
);
}
}
+2
View File
@@ -141,6 +141,8 @@ pub trait Route {
fn get_peer_groups(&self, peer_id: PeerId) -> Arc<Vec<String>>;
async fn refresh_acl_groups(&self) {}
async fn get_peer_groups_by_ip(&self, ip: &std::net::IpAddr) -> Arc<Vec<String>> {
match self.get_peer_id_by_ip(ip).await {
Some(peer_id) => self.get_peer_groups(peer_id),
File diff suppressed because it is too large Load Diff
+102 -1
View File
@@ -8,7 +8,7 @@ use crate::{
PeerId,
error::Error,
global_ctx::{
NetworkIdentity,
NetworkIdentity, TrustedKeySource,
tests::{get_mock_global_ctx, get_mock_global_ctx_with_network},
},
stats_manager::{LabelSet, LabelType, MetricName},
@@ -1315,6 +1315,107 @@ async fn credential_node_group_assignment() {
.await;
}
#[tokio::test]
async fn credential_node_connected_via_admin_b_trusts_admin_a_groups() {
use crate::proto::acl::{Acl, AclV1, GroupIdentity, GroupInfo};
let admin_a = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await;
let admin_b = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await;
let group_declares = vec![GroupIdentity {
group_name: "platform-admin".to_string(),
group_secret: "platform-admin-secret".to_string(),
}];
admin_a.get_global_ctx().config.set_acl(Some(Acl {
acl_v1: Some(AclV1 {
group: Some(GroupInfo {
declares: group_declares.clone(),
members: vec!["platform-admin".to_string()],
}),
..Default::default()
}),
}));
admin_b.get_global_ctx().config.set_acl(Some(Acl {
acl_v1: Some(AclV1 {
group: Some(GroupInfo {
declares: group_declares,
members: vec![],
}),
..Default::default()
}),
}));
connect_peer_manager(admin_a.clone(), admin_b.clone()).await;
wait_route_appear(admin_a.clone(), admin_b.clone())
.await
.unwrap();
let (_cred_id, cred_secret) = admin_a
.get_global_ctx()
.get_credential_manager()
.generate_credential(vec![], false, vec![], std::time::Duration::from_secs(3600));
admin_a
.get_global_ctx()
.issue_event(crate::common::global_ctx::GlobalCtxEvent::CredentialChanged);
let privkey_bytes: [u8; 32] = base64::engine::general_purpose::STANDARD
.decode(&cred_secret)
.unwrap()
.try_into()
.unwrap();
let private = x25519_dalek::StaticSecret::from(privkey_bytes);
let credential_pubkey = x25519_dalek::PublicKey::from(&private).as_bytes().to_vec();
wait_for_condition(
|| {
let admin_b = admin_b.clone();
let credential_pubkey = credential_pubkey.clone();
async move {
admin_b.get_global_ctx().is_pubkey_trusted_with_source(
&credential_pubkey,
"net1",
TrustedKeySource::OspfCredential,
)
}
},
Duration::from_secs(10),
)
.await;
let cred_c = create_mock_peer_manager_credential("net1".to_string(), &private).await;
connect_peer_manager(cred_c.clone(), admin_b.clone()).await;
let admin_a_id = admin_a.my_peer_id();
wait_for_condition(
|| {
let cred_c = cred_c.clone();
async move {
cred_c
.list_routes()
.await
.iter()
.any(|r| r.peer_id == admin_a_id)
}
},
Duration::from_secs(10),
)
.await;
wait_for_condition(
|| {
let cred_c = cred_c.clone();
async move {
cred_c
.get_route()
.get_peer_groups(admin_a_id)
.contains(&"platform-admin".to_string())
}
},
Duration::from_secs(10),
)
.await;
}
/// Minimal test: two secure peers connect and discover each other's route.
#[tokio::test]
async fn two_secure_peers_route_appear() {
+3
View File
@@ -220,6 +220,9 @@ impl WsTunnelConnector {
let is_wss = is_wss(&addr)?;
let socket_addr = SocketAddr::from_url(addr.clone(), ip_version).await?;
let stream = tcp_socket.connect(socket_addr).await?;
if let Err(error) = stream.set_nodelay(true) {
tracing::warn!(?error, "set_nodelay fail in ws connect");
}
let info = TunnelInfo {
tunnel_type: addr.scheme().to_owned(),
+18 -1
View File
@@ -126,7 +126,7 @@ impl WebClient {
}
};
if support_encryption {
if support_encryption && security::web_secure_tunnel_supported() {
log::info!("Server supports encryption, reconnecting with secure tunnel");
drop(session);
@@ -159,6 +159,23 @@ impl WebClient {
continue;
}
if support_encryption {
if secure_mode {
connected.store(false, Ordering::Release);
let wait = 1;
log::warn!(
"secure-mode enabled but local build lacks aes-gcm support for web secure tunnel, retrying in {} seconds...",
wait
);
tokio::time::sleep(std::time::Duration::from_secs(wait)).await;
continue;
}
log::warn!(
"Server supports encryption but local build lacks aes-gcm support for web secure tunnel, falling back to legacy tunnel"
);
}
if secure_mode {
connected.store(false, Ordering::Release);
let wait = 1;
+176 -48
View File
@@ -1,11 +1,12 @@
use std::sync::{Arc, Mutex};
use std::time::Duration;
use bytes::BytesMut;
use futures::{SinkExt, StreamExt};
use snow::{Builder, TransportState, params::NoiseParams};
use snow::{Builder, params::NoiseParams};
use crate::{
common::config::EncryptionAlgorithm,
peers::secure_datagram::{SecureDatagramDirection, SecureDatagramSession},
proto::common::TunnelInfo,
tunnel::{
SplitTunnel, StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream,
@@ -17,6 +18,11 @@ use crate::{
const NOISE_MAGIC: &[u8] = b"ET_WEB_NOISE_V1:";
const NOISE_PROLOGUE: &[u8] = b"easytier-webclient-noise-v1";
const NOISE_PATTERN: &str = "Noise_NN_25519_ChaChaPoly_SHA256";
const WEB_SECURE_CIPHER_ALGORITHM: &str = "aes-gcm";
const WEB_SESSION_GENERATION: u32 = 1;
const WEB_INITIAL_EPOCH: u32 = 0;
const WEB_SECURE_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(3);
const WEB_SECURE_ACCEPT_TIMEOUT: Duration = WEB_SECURE_HANDSHAKE_TIMEOUT;
struct RawSplitTunnel {
info: Option<TunnelInfo>,
@@ -50,24 +56,42 @@ impl Tunnel for RawSplitTunnel {
}
}
struct NoiseTunnelFilter {
transport: Arc<Mutex<TransportState>>,
#[derive(Clone, Copy)]
enum SecureTunnelRole {
Initiator,
Responder,
}
impl TunnelFilter for NoiseTunnelFilter {
impl SecureTunnelRole {
fn send_dir(self) -> SecureDatagramDirection {
match self {
Self::Initiator => SecureDatagramDirection::AToB,
Self::Responder => SecureDatagramDirection::BToA,
}
}
fn recv_dir(self) -> SecureDatagramDirection {
match self {
Self::Initiator => SecureDatagramDirection::BToA,
Self::Responder => SecureDatagramDirection::AToB,
}
}
}
struct SecureDatagramTunnelFilter {
session: Arc<SecureDatagramSession>,
role: SecureTunnelRole,
}
impl TunnelFilter for SecureDatagramTunnelFilter {
type FilterOutput = ();
fn before_send(&self, data: ZCPacket) -> Option<ZCPacket> {
let plain = data.tunnel_payload();
let mut encrypted = vec![0u8; plain.len() + 64];
let len = self
.transport
.lock()
.unwrap()
.write_message(plain, &mut encrypted)
.ok()?;
let mut packet = ZCPacket::new_with_payload(&encrypted[..len]);
let mut packet = ZCPacket::new_with_payload(data.tunnel_payload());
packet.fill_peer_manager_hdr(0, 0, PacketType::Data as u8);
self.session
.encrypt_payload(self.role.send_dir(), &mut packet)
.ok()?;
Some(packet)
}
@@ -76,23 +100,24 @@ impl TunnelFilter for NoiseTunnelFilter {
Ok(v) => v,
Err(e) => return Some(Err(e)),
};
let cipher = packet.payload();
let mut plain = vec![0u8; cipher.len() + 64];
let len = match self
.transport
.lock()
let mut cipher = ZCPacket::new_with_payload(packet.payload());
cipher.fill_peer_manager_hdr(0, 0, PacketType::Data as u8);
cipher
.mut_peer_manager_header()
.unwrap()
.read_message(cipher, &mut plain)
.set_encrypted(true);
if let Err(e) = self
.session
.decrypt_payload(self.role.recv_dir(), &mut cipher)
{
Ok(v) => v,
Err(e) => {
return Some(Err(TunnelError::InvalidPacket(format!(
"noise decrypt failed: {e}"
))));
}
};
return Some(Err(TunnelError::InvalidPacket(format!(
"secure datagram decrypt failed: {e}"
))));
}
Some(Ok(ZCPacket::new_from_buf(
BytesMut::from(&plain[..len]),
cipher.payload_bytes(),
ZCPacketType::DummyTunnel,
)))
}
@@ -117,24 +142,50 @@ fn decode_noise_payload(payload: &[u8]) -> Option<&[u8]> {
payload.strip_prefix(NOISE_MAGIC)
}
pub fn web_secure_tunnel_supported() -> bool {
WEB_SECURE_CIPHER_ALGORITHM
.parse::<EncryptionAlgorithm>()
.is_ok()
}
fn web_secure_cipher_algorithm() -> Result<&'static str, TunnelError> {
if !web_secure_tunnel_supported() {
return Err(TunnelError::InternalError(format!(
"web secure tunnel requires {WEB_SECURE_CIPHER_ALGORITHM} support"
)));
}
Ok(WEB_SECURE_CIPHER_ALGORITHM)
}
fn new_web_secure_session(root_key: [u8; 32], algorithm: &str) -> Arc<SecureDatagramSession> {
let algo = algorithm.to_string();
Arc::new(SecureDatagramSession::new(
root_key,
WEB_SESSION_GENERATION,
WEB_INITIAL_EPOCH,
algo.clone(),
algo,
))
}
fn wrap_secure_tunnel(
info: Option<TunnelInfo>,
stream: std::pin::Pin<Box<dyn ZCPacketStream>>,
sink: std::pin::Pin<Box<dyn ZCPacketSink>>,
transport: TransportState,
session: Arc<SecureDatagramSession>,
role: SecureTunnelRole,
) -> Box<dyn Tunnel> {
let raw = RawSplitTunnel::new(info, stream, sink);
Box::new(TunnelWithFilter::new(
raw,
NoiseTunnelFilter {
transport: Arc::new(Mutex::new(transport)),
},
SecureDatagramTunnelFilter { session, role },
))
}
pub async fn upgrade_client_tunnel(
tunnel: Box<dyn Tunnel>,
) -> Result<Box<dyn Tunnel>, TunnelError> {
let web_cipher_algorithm = web_secure_cipher_algorithm()?;
let info = tunnel.info();
let (mut stream, mut sink) = tunnel.split();
@@ -156,19 +207,32 @@ pub async fn upgrade_client_tunnel(
)))
.await?;
let msg2_packet = stream.next().await.ok_or(TunnelError::Shutdown)??;
let msg2_packet = match tokio::time::timeout(WEB_SECURE_HANDSHAKE_TIMEOUT, stream.next()).await
{
Ok(Some(Ok(packet))) => packet,
Ok(Some(Err(error))) => return Err(error),
Ok(None) => return Err(TunnelError::Shutdown),
Err(error) => return Err(error.into()),
};
let msg2_cipher = decode_noise_payload(msg2_packet.payload())
.ok_or_else(|| TunnelError::InvalidPacket("invalid noise msg2 magic".to_string()))?;
let mut msg2 = vec![0u8; 1024];
state
.read_message(msg2_cipher, &mut msg2)
let mut root_key_buf = [0u8; 32];
let root_key_len = state
.read_message(msg2_cipher, &mut root_key_buf)
.map_err(|e| TunnelError::InvalidPacket(format!("read noise msg2 failed: {e}")))?;
if root_key_len != root_key_buf.len() {
return Err(TunnelError::InvalidPacket(format!(
"invalid web secure root key len: {root_key_len}"
)));
}
let transport = state
.into_transport_mode()
.map_err(|e| TunnelError::InternalError(format!("switch transport mode failed: {e}")))?;
Ok(wrap_secure_tunnel(info, stream, sink, transport))
Ok(wrap_secure_tunnel(
info,
stream,
sink,
new_web_secure_session(root_key_buf, web_cipher_algorithm),
SecureTunnelRole::Initiator,
))
}
pub async fn accept_or_upgrade_server_tunnel(
@@ -179,7 +243,7 @@ pub async fn accept_or_upgrade_server_tunnel(
let mut stream = stream;
let mut sink = sink;
let first_packet = match tokio::time::timeout(Duration::from_secs(1), stream.next()).await {
let first_packet = match tokio::time::timeout(WEB_SECURE_ACCEPT_TIMEOUT, stream.next()).await {
Ok(Some(Ok(packet))) => packet,
Ok(Some(Err(error))) => return Err(error),
Ok(None) => return Err(TunnelError::Shutdown),
@@ -197,6 +261,7 @@ pub async fn accept_or_upgrade_server_tunnel(
false,
));
};
let web_cipher_algorithm = web_secure_cipher_algorithm()?;
let params: NoiseParams = NOISE_PATTERN
.parse()
@@ -212,18 +277,81 @@ pub async fn accept_or_upgrade_server_tunnel(
.read_message(msg1_cipher, &mut msg1)
.map_err(|e| TunnelError::InvalidPacket(format!("read noise msg1 failed: {e}")))?;
let root_key = SecureDatagramSession::new_root_key();
let mut msg2 = vec![0u8; 1024];
let msg2_len = state
.write_message(&[], &mut msg2)
.write_message(&root_key, &mut msg2)
.map_err(|e| TunnelError::InternalError(format!("write noise msg2 failed: {e}")))?;
sink.send(pack_control_packet(&encode_noise_payload(
&msg2[..msg2_len],
)))
.await?;
let transport = state
.into_transport_mode()
.map_err(|e| TunnelError::InternalError(format!("switch transport mode failed: {e}")))?;
Ok((wrap_secure_tunnel(info, stream, sink, transport), true))
Ok((
wrap_secure_tunnel(
info,
stream,
sink,
new_web_secure_session(root_key, web_cipher_algorithm),
SecureTunnelRole::Responder,
),
true,
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tunnel::ring::create_ring_tunnel_pair;
#[test]
fn web_secure_cipher_algorithm_matches_support_flag() {
let result = web_secure_cipher_algorithm();
if web_secure_tunnel_supported() {
assert_eq!(result.unwrap(), WEB_SECURE_CIPHER_ALGORITHM);
} else {
assert!(matches!(result, Err(TunnelError::InternalError(_))));
}
}
#[test]
fn web_secure_session_uses_pinned_cipher_algorithm() {
if !web_secure_tunnel_supported() {
return;
}
let session = new_web_secure_session(
SecureDatagramSession::new_root_key(),
web_secure_cipher_algorithm().unwrap(),
);
session
.check_encrypt_algo_same(WEB_SECURE_CIPHER_ALGORITHM, WEB_SECURE_CIPHER_ALGORITHM)
.unwrap();
}
#[tokio::test]
async fn upgrade_client_tunnel_times_out_when_server_never_replies() {
let (server_tunnel, client_tunnel) = create_ring_tunnel_pair();
let _server_tunnel = server_tunnel;
let err = upgrade_client_tunnel(client_tunnel).await.unwrap_err();
assert!(matches!(err, TunnelError::Timeout(_)));
}
#[tokio::test]
async fn accept_secure_tunnel_after_short_client_delay() {
let (server_tunnel, client_tunnel) = create_ring_tunnel_pair();
let server_task =
tokio::spawn(async move { accept_or_upgrade_server_tunnel(server_tunnel).await });
tokio::time::sleep(Duration::from_millis(1500)).await;
let client_task = tokio::spawn(async move { upgrade_client_tunnel(client_tunnel).await });
let (server_res, client_res) = tokio::join!(server_task, client_task);
let (_, secure) = server_res.unwrap().unwrap();
assert!(secure);
assert!(client_res.unwrap().is_ok());
}
}