Merge pull request #108 from EasyTier/latency_first

This commit is contained in:
Sijie.Sun
2024-05-13 22:30:50 +08:00
committed by GitHub
17 changed files with 722 additions and 304 deletions
Generated
+3 -39
View File
@@ -1112,18 +1112,6 @@ dependencies = [
"thiserror",
]
[[package]]
name = "deprecate-until"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a3767f826efbbe5a5ae093920b58b43b01734202be697e1354914e862e8e704"
dependencies = [
"proc-macro2",
"quote",
"semver",
"syn 2.0.48",
]
[[package]]
name = "deranged"
version = "0.3.11"
@@ -1300,8 +1288,8 @@ dependencies = [
"network-interface",
"nix 0.27.1",
"once_cell",
"pathfinding",
"percent-encoding",
"petgraph",
"pin-project-lite",
"pnet",
"postcard",
@@ -2342,15 +2330,6 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "integer-sqrt"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "276ec31bcb4a9ee45f58bec6f9ec700ae4cf4f4f8f2fa7e06cb406bd5ffdd770"
dependencies = [
"num-traits",
]
[[package]]
name = "ioctl-sys"
version = "0.8.0"
@@ -3225,21 +3204,6 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8835116a5c179084a830efb3adc117ab007512b535bc1a21c991d3b32a6b44dd"
[[package]]
name = "pathfinding"
version = "4.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0a21c30f03223ae4a4c892f077b3189133689b8a659a84372f8422384ce94c9"
dependencies = [
"deprecate-until",
"fixedbitset",
"indexmap 2.2.6",
"integer-sqrt",
"num-traits",
"rustc-hash",
"thiserror",
]
[[package]]
name = "pbkdf2"
version = "0.11.0"
@@ -3270,9 +3234,9 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "petgraph"
version = "0.6.4"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9"
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
dependencies = [
"fixedbitset",
"indexmap 2.2.6",
+1 -1
View File
@@ -137,7 +137,7 @@ async-recursion = "1.0.5"
network-interface = "1.1.1"
# for ospf route
pathfinding = "4.9.1"
petgraph = "0.6.5"
# for encryption
boringtun = { git = "https://github.com/EasyTier/boringtun.git", optional = true, rev = "449204c" }
+1 -9
View File
@@ -117,16 +117,8 @@ service ConnectorManageRpc {
rpc ManageConnector (ManageConnectorRequest) returns (ManageConnectorResponse);
}
enum LatencyLevel {
VeryLow = 0;
Low = 1;
Normal = 2;
High = 3;
VeryHigh = 4;
}
message DirectConnectedPeerInfo {
LatencyLevel latency_level = 2;
int32 latency_ms = 1;
}
message PeerInfoForGlobalMap {
+2
View File
@@ -150,6 +150,8 @@ pub struct Flags {
pub enable_ipv6: bool,
#[derivative(Default(value = "1420"))]
pub mtu: u16,
#[derivative(Default(value = "true"))]
pub latency_first: bool,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
+1 -7
View File
@@ -331,13 +331,7 @@ async fn main() -> Result<(), Error> {
let direct_peers = v
.direct_peers
.iter()
.map(|(k, v)| {
format!(
"{}:{:?}",
k,
LatencyLevel::try_from(v.latency_level).unwrap()
)
})
.map(|(k, v)| format!("{}: {:?}ms", k, v.latency_ms,))
.collect::<Vec<_>>();
table_rows.push(PeerCenterTableItem {
node_id: node_id.to_string(),
+8 -12
View File
@@ -97,10 +97,6 @@ struct Cli {
"wg://0.0.0.0:11011".to_string()])]
listeners: Vec<String>,
/// specify the linux network namespace, default is the root namespace
#[arg(long)]
net_ns: Option<String>,
#[arg(long, help = "console log level",
value_parser = clap::builder::PossibleValuesParser::new(["trace", "debug", "info", "warn", "error", "off"]))]
console_log_level: Option<String>,
@@ -122,13 +118,6 @@ struct Cli {
)]
instance_name: String,
#[arg(
short = 'd',
long,
help = "instance uuid to identify this vpn node in whole vpn network example: 123e4567-e89b-12d3-a456-426614174000"
)]
instance_id: Option<String>,
#[arg(
long,
help = "url that defines the vpn portal, allow other vpn clients to connect.
@@ -163,6 +152,13 @@ and the vpn client is in network of 10.14.14.0/24"
help = "mtu of the TUN device, default is 1420 for non-encryption, 1400 for encryption"
)]
mtu: Option<u16>,
#[arg(
long,
help = "path to the log file, if not set, will print to stdout",
default_value = "false"
)]
latency_first: bool,
}
impl From<Cli> for TomlConfigLoader {
@@ -188,7 +184,6 @@ impl From<Cli> for TomlConfigLoader {
cli.network_secret.clone(),
));
cfg.set_netns(cli.net_ns.clone());
if let Some(ipv4) = &cli.ipv4 {
cfg.set_ipv4(
ipv4.parse()
@@ -307,6 +302,7 @@ impl From<Cli> for TomlConfigLoader {
}
f.enable_encryption = !cli.disable_encryption;
f.enable_ipv6 = !cli.disable_ipv6;
f.latency_first = cli.latency_first;
if let Some(mtu) = cli.mtu {
f.mtu = mtu;
}
+24 -1
View File
@@ -1,6 +1,7 @@
use std::borrow::BorrowMut;
use std::net::Ipv4Addr;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Weak};
use anyhow::Context;
@@ -44,6 +45,8 @@ struct IpProxy {
tcp_proxy: Arc<TcpProxy>,
icmp_proxy: Arc<IcmpProxy>,
udp_proxy: Arc<UdpProxy>,
global_ctx: ArcGlobalCtx,
started: Arc<AtomicBool>,
}
impl IpProxy {
@@ -57,10 +60,17 @@ impl IpProxy {
tcp_proxy,
icmp_proxy,
udp_proxy,
global_ctx,
started: Arc::new(AtomicBool::new(false)),
})
}
async fn start(&self) -> Result<(), Error> {
if self.global_ctx.get_proxy_cidrs().is_empty() || self.started.load(Ordering::Relaxed) {
return Ok(());
}
self.started.store(true, Ordering::Relaxed);
self.tcp_proxy.start().await?;
self.icmp_proxy.start().await?;
self.udp_proxy.start().await?;
@@ -297,11 +307,16 @@ impl Instance {
self.get_global_ctx(),
self.get_peer_manager(),
)?);
self.ip_proxy.as_ref().unwrap().start().await?;
self.run_ip_proxy().await?;
self.udp_hole_puncher.lock().await.run().await?;
self.peer_center.init().await;
let route_calc = self.peer_center.get_cost_calculator();
self.peer_manager
.get_route()
.set_route_cost_fn(route_calc)
.await;
self.add_initial_peers().await?;
@@ -312,6 +327,14 @@ impl Instance {
Ok(())
}
pub async fn run_ip_proxy(&mut self) -> Result<(), Error> {
if self.ip_proxy.is_none() {
return Err(anyhow::anyhow!("ip proxy not enabled.").into());
}
self.ip_proxy.as_ref().unwrap().start().await?;
Ok(())
}
pub async fn run_vpn_portal(&mut self) -> Result<(), Error> {
if self.global_ctx.get_vpn_portal_cidr().is_none() {
return Err(anyhow::anyhow!("vpn portal cidr not set.").into());
+150 -93
View File
@@ -1,24 +1,23 @@
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::{Duration, SystemTime},
collections::BTreeSet,
sync::Arc,
time::{Duration, Instant, SystemTime},
};
use crossbeam::atomic::AtomicCell;
use futures::Future;
use tokio::{
sync::{Mutex, RwLock},
task::JoinSet,
};
use std::sync::RwLock;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tracing::Instrument;
use crate::{
common::PeerId,
peers::{peer_manager::PeerManager, rpc_service::PeerManagerRpcService},
peers::{
peer_manager::PeerManager,
route_trait::{RouteCostCalculator, RouteCostCalculatorInterface},
rpc_service::PeerManagerRpcService,
},
rpc::{GetGlobalPeerMapRequest, GetGlobalPeerMapResponse},
};
@@ -34,7 +33,8 @@ struct PeerCenterBase {
lock: Arc<Mutex<()>>,
}
static SERVICE_ID: u32 = 5;
// static SERVICE_ID: u32 = 5; for compatibility with the original code
static SERVICE_ID: u32 = 50;
struct PeridicJobCtx<T> {
peer_mgr: Arc<PeerManager>,
@@ -132,7 +132,7 @@ impl PeerCenterBase {
pub struct PeerCenterInstanceService {
global_peer_map: Arc<RwLock<GlobalPeerMap>>,
global_peer_map_digest: Arc<RwLock<Digest>>,
global_peer_map_digest: Arc<AtomicCell<Digest>>,
}
#[tonic::async_trait]
@@ -141,7 +141,7 @@ impl crate::rpc::cli::peer_center_rpc_server::PeerCenterRpc for PeerCenterInstan
&self,
_request: tonic::Request<GetGlobalPeerMapRequest>,
) -> Result<tonic::Response<GetGlobalPeerMapResponse>, tonic::Status> {
let global_peer_map = self.global_peer_map.read().await.clone();
let global_peer_map = self.global_peer_map.read().unwrap().clone();
Ok(tonic::Response::new(GetGlobalPeerMapResponse {
global_peer_map: global_peer_map
.map
@@ -157,7 +157,8 @@ pub struct PeerCenterInstance {
client: Arc<PeerCenterBase>,
global_peer_map: Arc<RwLock<GlobalPeerMap>>,
global_peer_map_digest: Arc<RwLock<Digest>>,
global_peer_map_digest: Arc<AtomicCell<Digest>>,
global_peer_map_update_time: Arc<AtomicCell<Instant>>,
}
impl PeerCenterInstance {
@@ -166,7 +167,8 @@ impl PeerCenterInstance {
peer_mgr: peer_mgr.clone(),
client: Arc::new(PeerCenterBase::new(peer_mgr.clone())),
global_peer_map: Arc::new(RwLock::new(GlobalPeerMap::new())),
global_peer_map_digest: Arc::new(RwLock::new(Digest::default())),
global_peer_map_digest: Arc::new(AtomicCell::new(Digest::default())),
global_peer_map_update_time: Arc::new(AtomicCell::new(Instant::now())),
}
}
@@ -179,12 +181,14 @@ impl PeerCenterInstance {
async fn init_get_global_info_job(&self) {
struct Ctx {
global_peer_map: Arc<RwLock<GlobalPeerMap>>,
global_peer_map_digest: Arc<RwLock<Digest>>,
global_peer_map_digest: Arc<AtomicCell<Digest>>,
global_peer_map_update_time: Arc<AtomicCell<Instant>>,
}
let ctx = Arc::new(Ctx {
global_peer_map: self.global_peer_map.clone(),
global_peer_map_digest: self.global_peer_map_digest.clone(),
global_peer_map_update_time: self.global_peer_map_update_time.clone(),
});
self.client
@@ -193,10 +197,7 @@ impl PeerCenterInstance {
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(3);
let ret = client
.get_global_peer_map(
rpc_ctx,
ctx.job_ctx.global_peer_map_digest.read().await.clone(),
)
.get_global_peer_map(rpc_ctx, ctx.job_ctx.global_peer_map_digest.load())
.await?;
let Ok(resp) = ret else {
@@ -217,10 +218,13 @@ impl PeerCenterInstance {
resp.digest
);
*ctx.job_ctx.global_peer_map.write().await = resp.global_peer_map;
*ctx.job_ctx.global_peer_map_digest.write().await = resp.digest;
*ctx.job_ctx.global_peer_map.write().unwrap() = resp.global_peer_map;
ctx.job_ctx.global_peer_map_digest.store(resp.digest);
ctx.job_ctx
.global_peer_map_update_time
.store(Instant::now());
Ok(10000)
Ok(5000)
})
.await;
}
@@ -228,67 +232,53 @@ impl PeerCenterInstance {
async fn init_report_peers_job(&self) {
struct Ctx {
service: PeerManagerRpcService,
need_send_peers: AtomicBool,
last_report_peers: Mutex<PeerInfoForGlobalMap>,
last_report_peers: Mutex<BTreeSet<PeerId>>,
last_center_peer: AtomicCell<PeerId>,
last_report_time: AtomicCell<Instant>,
}
let ctx = Arc::new(Ctx {
service: PeerManagerRpcService::new(self.peer_mgr.clone()),
need_send_peers: AtomicBool::new(true),
last_report_peers: Mutex::new(PeerInfoForGlobalMap::default()),
last_report_peers: Mutex::new(BTreeSet::new()),
last_center_peer: AtomicCell::new(PeerId::default()),
last_report_time: AtomicCell::new(Instant::now()),
});
self.client
.init_periodic_job(ctx, |client, ctx| async move {
let my_node_id = ctx.peer_mgr.my_peer_id();
let peers: PeerInfoForGlobalMap = ctx.job_ctx.service.list_peers().await.into();
let peer_list = peers.direct_peers.keys().map(|k| *k).collect();
let job_ctx = &ctx.job_ctx;
// if peers are not same in next 10 seconds, report peers to center server
let mut peers = PeerInfoForGlobalMap::default();
for _ in 1..10 {
peers = ctx.job_ctx.service.list_peers().await.into();
if ctx.center_peer.load() != ctx.job_ctx.last_center_peer.load() {
// if center peer changed, report peers immediately
break;
}
if peers == *ctx.job_ctx.last_report_peers.lock().await {
return Ok(3000);
}
tokio::time::sleep(Duration::from_secs(2)).await;
// only report when:
// 1. center peer changed
// 2. last report time is more than 60 seconds
// 3. peers changed
if ctx.center_peer.load() == ctx.job_ctx.last_center_peer.load()
&& job_ctx.last_report_time.load().elapsed().as_secs() < 60
&& *job_ctx.last_report_peers.lock().await == peer_list
{
return Ok(5000);
}
*ctx.job_ctx.last_report_peers.lock().await = peers.clone();
let mut hasher = DefaultHasher::new();
peers.hash(&mut hasher);
let peers = if ctx.job_ctx.need_send_peers.load(Ordering::Relaxed) {
Some(peers)
} else {
None
};
let mut rpc_ctx = tarpc::context::current();
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(3);
let ret = client
.report_peers(
rpc_ctx,
my_node_id.clone(),
peers,
hasher.finish() as Digest,
)
.report_peers(rpc_ctx, my_node_id.clone(), peers)
.await?;
if matches!(ret.as_ref().err(), Some(Error::DigestMismatch)) {
ctx.job_ctx.need_send_peers.store(true, Ordering::Relaxed);
return Ok(0);
} else if ret.is_err() {
if ret.is_ok() {
ctx.job_ctx.last_center_peer.store(ctx.center_peer.load());
*ctx.job_ctx.last_report_peers.lock().await = peer_list;
ctx.job_ctx.last_report_time.store(Instant::now());
} else {
tracing::error!("report peers to center server got error result: {:?}", ret);
return Ok(500);
}
ctx.job_ctx.last_center_peer.store(ctx.center_peer.load());
ctx.job_ctx.need_send_peers.store(false, Ordering::Relaxed);
Ok(3000)
Ok(5000)
})
.await;
}
@@ -299,15 +289,61 @@ impl PeerCenterInstance {
global_peer_map_digest: self.global_peer_map_digest.clone(),
}
}
pub fn get_cost_calculator(&self) -> RouteCostCalculator {
struct RouteCostCalculatorImpl {
global_peer_map: Arc<RwLock<GlobalPeerMap>>,
global_peer_map_clone: GlobalPeerMap,
last_update_time: AtomicCell<Instant>,
global_peer_map_update_time: Arc<AtomicCell<Instant>>,
}
impl RouteCostCalculatorInterface for RouteCostCalculatorImpl {
fn calculate_cost(&self, src: PeerId, dst: PeerId) -> i32 {
let ret = self
.global_peer_map_clone
.map
.get(&src)
.and_then(|src_peer_info| src_peer_info.direct_peers.get(&dst))
.and_then(|info| Some(info.latency_ms));
ret.unwrap_or(80)
}
fn begin_update(&mut self) {
let global_peer_map = self.global_peer_map.read().unwrap();
self.global_peer_map_clone = global_peer_map.clone();
}
fn end_update(&mut self) {
self.last_update_time
.store(self.global_peer_map_update_time.load());
}
fn need_update(&self) -> bool {
self.last_update_time.load() < self.global_peer_map_update_time.load()
}
}
Box::new(RouteCostCalculatorImpl {
global_peer_map: self.global_peer_map.clone(),
global_peer_map_clone: GlobalPeerMap::new(),
last_update_time: AtomicCell::new(
self.global_peer_map_update_time.load() - Duration::from_secs(1),
),
global_peer_map_update_time: self.global_peer_map_update_time.clone(),
})
}
}
#[cfg(test)]
mod tests {
use std::ops::Deref;
use crate::{
peer_center::server::get_global_data,
peers::tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
peers::tests::{
connect_peer_manager, create_mock_peer_manager, wait_for_condition, wait_route_appear,
},
};
use super::*;
@@ -340,43 +376,64 @@ mod tests {
let center_data = get_global_data(center_peer);
// wait center_data has 3 records for 10 seconds
let now = std::time::Instant::now();
loop {
if center_data.read().await.global_peer_map.map.len() == 3 {
println!(
"center data ready, {:#?}",
center_data.read().await.global_peer_map
);
break;
}
if now.elapsed().as_secs() > 60 {
panic!("center data not ready");
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
wait_for_condition(
|| async {
if center_data.global_peer_map.len() == 4 {
println!("center data {:#?}", center_data.global_peer_map);
true
} else {
false
}
},
Duration::from_secs(10),
)
.await;
let mut digest = None;
for pc in peer_centers.iter() {
let rpc_service = pc.get_rpc_service();
let now = std::time::Instant::now();
while now.elapsed().as_secs() < 10 {
if rpc_service.global_peer_map.read().await.map.len() == 3 {
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
assert_eq!(rpc_service.global_peer_map.read().await.map.len(), 3);
wait_for_condition(
|| async { rpc_service.global_peer_map.read().unwrap().map.len() == 3 },
Duration::from_secs(10),
)
.await;
println!("rpc service ready, {:#?}", rpc_service.global_peer_map);
if digest.is_none() {
digest = Some(rpc_service.global_peer_map_digest.read().await.clone());
digest = Some(rpc_service.global_peer_map_digest.load());
} else {
let v = rpc_service.global_peer_map_digest.read().await;
assert_eq!(digest.as_ref().unwrap(), v.deref());
let v = rpc_service.global_peer_map_digest.load();
assert_eq!(digest.unwrap(), v);
}
let mut route_cost = pc.get_cost_calculator();
assert!(route_cost.need_update());
route_cost.begin_update();
assert!(
route_cost.calculate_cost(peer_mgr_a.my_peer_id(), peer_mgr_b.my_peer_id()) < 30
);
assert!(
route_cost.calculate_cost(peer_mgr_b.my_peer_id(), peer_mgr_a.my_peer_id()) < 30
);
assert!(
route_cost.calculate_cost(peer_mgr_b.my_peer_id(), peer_mgr_c.my_peer_id()) < 30
);
assert!(
route_cost.calculate_cost(peer_mgr_c.my_peer_id(), peer_mgr_b.my_peer_id()) < 30
);
assert!(
route_cost.calculate_cost(peer_mgr_c.my_peer_id(), peer_mgr_a.my_peer_id()) > 50
);
assert!(
route_cost.calculate_cost(peer_mgr_a.my_peer_id(), peer_mgr_c.my_peer_id()) > 50
);
route_cost.end_update();
assert!(!route_cost.need_update());
}
let global_digest = get_global_data(center_peer).read().await.digest.clone();
let global_digest = get_global_data(center_peer).digest.load();
assert_eq!(digest.as_ref().unwrap(), &global_digest);
}
}
+73 -66
View File
@@ -1,45 +1,48 @@
use std::{
collections::BinaryHeap,
hash::{Hash, Hasher},
sync::Arc,
};
use crossbeam::atomic::AtomicCell;
use dashmap::DashMap;
use once_cell::sync::Lazy;
use tokio::{sync::RwLock, task::JoinSet};
use tokio::{task::JoinSet};
use crate::common::PeerId;
use crate::{common::PeerId, rpc::DirectConnectedPeerInfo};
use super::{
service::{GetGlobalPeerMapResponse, GlobalPeerMap, PeerCenterService, PeerInfoForGlobalMap},
Digest, Error,
};
pub(crate) struct PeerCenterServerGlobalData {
pub global_peer_map: GlobalPeerMap,
pub digest: Digest,
pub update_time: std::time::Instant,
pub peer_update_time: DashMap<PeerId, std::time::Instant>,
#[derive(Debug, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)]
pub(crate) struct SrcDstPeerPair {
src: PeerId,
dst: PeerId,
}
impl PeerCenterServerGlobalData {
fn new() -> Self {
PeerCenterServerGlobalData {
global_peer_map: GlobalPeerMap::new(),
digest: Digest::default(),
update_time: std::time::Instant::now(),
peer_update_time: DashMap::new(),
}
}
#[derive(Debug, Clone)]
pub(crate) struct PeerCenterInfoEntry {
info: DirectConnectedPeerInfo,
update_time: std::time::Instant,
}
#[derive(Default)]
pub(crate) struct PeerCenterServerGlobalData {
pub(crate) global_peer_map: DashMap<SrcDstPeerPair, PeerCenterInfoEntry>,
pub(crate) peer_report_time: DashMap<PeerId, std::time::Instant>,
pub(crate) digest: AtomicCell<Digest>,
}
// a global unique instance for PeerCenterServer
pub(crate) static GLOBAL_DATA: Lazy<DashMap<PeerId, Arc<RwLock<PeerCenterServerGlobalData>>>> =
pub(crate) static GLOBAL_DATA: Lazy<DashMap<PeerId, Arc<PeerCenterServerGlobalData>>> =
Lazy::new(DashMap::new);
pub(crate) fn get_global_data(node_id: PeerId) -> Arc<RwLock<PeerCenterServerGlobalData>> {
pub(crate) fn get_global_data(node_id: PeerId) -> Arc<PeerCenterServerGlobalData> {
GLOBAL_DATA
.entry(node_id)
.or_insert_with(|| Arc::new(RwLock::new(PeerCenterServerGlobalData::new())))
.or_insert_with(|| Arc::new(PeerCenterServerGlobalData::default()))
.value()
.clone()
}
@@ -48,8 +51,6 @@ pub(crate) fn get_global_data(node_id: PeerId) -> Arc<RwLock<PeerCenterServerGlo
pub struct PeerCenterServer {
// every peer has its own server, so use per-struct dash map is ok.
my_node_id: PeerId,
digest_map: DashMap<PeerId, Digest>,
tasks: Arc<JoinSet<()>>,
}
@@ -65,26 +66,32 @@ impl PeerCenterServer {
PeerCenterServer {
my_node_id,
digest_map: DashMap::new(),
tasks: Arc::new(tasks),
}
}
async fn clean_outdated_peer(my_node_id: PeerId) {
let data = get_global_data(my_node_id);
let mut locked_data = data.write().await;
let now = std::time::Instant::now();
let mut to_remove = Vec::new();
for kv in locked_data.peer_update_time.iter() {
if now.duration_since(*kv.value()).as_secs() > 20 {
to_remove.push(*kv.key());
}
}
for peer_id in to_remove {
locked_data.global_peer_map.map.remove(&peer_id);
locked_data.peer_update_time.remove(&peer_id);
}
data.peer_report_time.retain(|_, v| {
std::time::Instant::now().duration_since(*v) < std::time::Duration::from_secs(180)
});
data.global_peer_map.retain(|_, v| {
std::time::Instant::now().duration_since(v.update_time)
< std::time::Duration::from_secs(180)
});
}
fn calc_global_digest(my_node_id: PeerId) -> Digest {
let data = get_global_data(my_node_id);
let mut hasher = std::collections::hash_map::DefaultHasher::new();
data.global_peer_map
.iter()
.map(|v| v.key().clone())
.collect::<BinaryHeap<_>>()
.into_sorted_vec()
.into_iter()
.for_each(|v| v.hash(&mut hasher));
hasher.finish()
}
}
@@ -95,39 +102,28 @@ impl PeerCenterService for PeerCenterServer {
self,
_: tarpc::context::Context,
my_peer_id: PeerId,
peers: Option<PeerInfoForGlobalMap>,
digest: Digest,
peers: PeerInfoForGlobalMap,
) -> Result<(), Error> {
tracing::trace!("receive report_peers");
tracing::debug!("receive report_peers");
let data = get_global_data(self.my_node_id);
let mut locked_data = data.write().await;
locked_data
.peer_update_time
data.peer_report_time
.insert(my_peer_id, std::time::Instant::now());
let old_digest = self.digest_map.get(&my_peer_id);
// if digest match, no need to update
if let Some(old_digest) = old_digest {
if *old_digest == digest {
return Ok(());
}
for (peer_id, peer_info) in peers.direct_peers {
let pair = SrcDstPeerPair {
src: my_peer_id,
dst: peer_id,
};
let entry = PeerCenterInfoEntry {
info: peer_info,
update_time: std::time::Instant::now(),
};
data.global_peer_map.insert(pair, entry);
}
if peers.is_none() {
return Err(Error::DigestMismatch);
}
self.digest_map.insert(my_peer_id, digest);
locked_data
.global_peer_map
.map
.insert(my_peer_id, peers.unwrap());
let mut hasher = std::collections::hash_map::DefaultHasher::new();
locked_data.global_peer_map.map.hash(&mut hasher);
locked_data.digest = hasher.finish() as Digest;
locked_data.update_time = std::time::Instant::now();
data.digest
.store(PeerCenterServer::calc_global_digest(self.my_node_id));
Ok(())
}
@@ -138,15 +134,26 @@ impl PeerCenterService for PeerCenterServer {
digest: Digest,
) -> Result<Option<GetGlobalPeerMapResponse>, Error> {
let data = get_global_data(self.my_node_id);
if digest == data.read().await.digest {
if digest == data.digest.load() && digest != 0 {
return Ok(None);
}
let data = get_global_data(self.my_node_id);
let locked_data = data.read().await;
let mut global_peer_map = GlobalPeerMap::new();
for item in data.global_peer_map.iter() {
let (pair, entry) = item.pair();
global_peer_map
.map
.entry(pair.src)
.or_insert_with(|| PeerInfoForGlobalMap {
direct_peers: Default::default(),
})
.direct_peers
.insert(pair.dst, entry.info.clone());
}
Ok(Some(GetGlobalPeerMapResponse {
global_peer_map: locked_data.global_peer_map.clone(),
digest: locked_data.digest,
global_peer_map,
digest: data.digest.load(),
}))
}
}
+6 -26
View File
@@ -5,39 +5,23 @@ use crate::{common::PeerId, rpc::DirectConnectedPeerInfo};
use super::{Digest, Error};
use crate::rpc::PeerInfo;
pub type LatencyLevel = crate::rpc::cli::LatencyLevel;
impl LatencyLevel {
pub const fn from_latency_ms(lat_ms: u32) -> Self {
if lat_ms < 10 {
LatencyLevel::VeryLow
} else if lat_ms < 50 {
LatencyLevel::Low
} else if lat_ms < 100 {
LatencyLevel::Normal
} else if lat_ms < 200 {
LatencyLevel::High
} else {
LatencyLevel::VeryHigh
}
}
}
pub type PeerInfoForGlobalMap = crate::rpc::cli::PeerInfoForGlobalMap;
impl From<Vec<PeerInfo>> for PeerInfoForGlobalMap {
fn from(peers: Vec<PeerInfo>) -> Self {
let mut peer_map = BTreeMap::new();
for peer in peers {
let min_lat = peer
let Some(min_lat) = peer
.conns
.iter()
.map(|conn| conn.stats.as_ref().unwrap().latency_us)
.min()
.unwrap_or(0);
else {
continue;
};
let dp_info = DirectConnectedPeerInfo {
latency_level: LatencyLevel::from_latency_ms(min_lat as u32 / 1000) as i32,
latency_ms: std::cmp::max(1, (min_lat as u32 / 1000) as i32),
};
// sort conn info so hash result is stable
@@ -73,11 +57,7 @@ pub struct GetGlobalPeerMapResponse {
pub trait PeerCenterService {
// report center server which peer is directly connected to me
// digest is a hash of current peer map, if digest not match, we need to transfer the whole map
async fn report_peers(
my_peer_id: PeerId,
peers: Option<PeerInfoForGlobalMap>,
digest: Digest,
) -> Result<(), Error>;
async fn report_peers(my_peer_id: PeerId, peers: PeerInfoForGlobalMap) -> Result<(), Error>;
async fn get_global_peer_map(digest: Digest)
-> Result<Option<GetGlobalPeerMapResponse>, Error>;
@@ -29,6 +29,7 @@ use super::{
peer_conn::PeerConn,
peer_map::PeerMap,
peer_rpc::{PeerRpcManager, PeerRpcManagerTransport},
route_trait::NextHopPolicy,
PacketRecvChan, PacketRecvChanReceiver,
};
@@ -66,7 +67,10 @@ impl ForeignNetworkManagerData {
.get(&network_name)
.ok_or_else(|| Error::RouteError(Some("no peer in network".to_string())))?
.clone();
entry.peer_map.send_msg(msg, dst_peer_id).await
entry
.peer_map
.send_msg(msg, dst_peer_id, NextHopPolicy::LeastHop)
.await
}
fn get_peer_network(&self, peer_id: PeerId) -> Option<String> {
@@ -275,7 +279,10 @@ impl ForeignNetworkManager {
}
if let Some(entry) = data.get_network_entry(&from_network) {
let ret = entry.peer_map.send_msg(packet_bytes, to_peer_id).await;
let ret = entry
.peer_map
.send_msg(packet_bytes, to_peer_id, NextHopPolicy::LeastHop)
.await;
if ret.is_err() {
tracing::error!("forward packet to peer failed: {:?}", ret.err());
}
+58 -10
View File
@@ -22,7 +22,9 @@ use tokio_util::bytes::Bytes;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
peers::{
peer_conn::PeerConn, peer_rpc::PeerRpcManagerTransport, route_trait::RouteInterface,
peer_conn::PeerConn,
peer_rpc::PeerRpcManagerTransport,
route_trait::{NextHopPolicy, RouteInterface},
PeerPacketFilter,
},
tunnel::{
@@ -73,7 +75,10 @@ impl PeerRpcManagerTransport for RpcTransport {
.ok_or(Error::Unknown)?;
let peers = self.peers.upgrade().ok_or(Error::Unknown)?;
if let Some(gateway_id) = peers.get_gateway_peer_id(dst_peer_id).await {
if let Some(gateway_id) = peers
.get_gateway_peer_id(dst_peer_id, NextHopPolicy::LeastHop)
.await
{
tracing::trace!(
?dst_peer_id,
?gateway_id,
@@ -320,20 +325,33 @@ impl PeerManager {
let my_peer_id = self.my_peer_id;
let peers = self.peers.clone();
let pipe_line = self.peer_packet_process_pipeline.clone();
let foreign_client = self.foreign_network_client.clone();
let encryptor = self.encryptor.clone();
self.tasks.lock().await.spawn(async move {
log::trace!("start_peer_recv");
while let Some(mut ret) = recv.next().await {
let Some(hdr) = ret.peer_manager_header() else {
let Some(hdr) = ret.mut_peer_manager_header() else {
tracing::warn!(?ret, "invalid packet, skip");
continue;
};
tracing::trace!(?hdr, ?ret, "peer recv a packet...");
tracing::trace!(?hdr, "peer recv a packet...");
let from_peer_id = hdr.from_peer_id.get();
let to_peer_id = hdr.to_peer_id.get();
if to_peer_id != my_peer_id {
if hdr.forward_counter > 7 {
tracing::warn!(?hdr, "forward counter exceed, drop packet");
continue;
}
if hdr.forward_counter > 2 && hdr.is_latency_first() {
tracing::trace!(?hdr, "set_latency_first false because too many hop");
hdr.set_latency_first(false);
}
hdr.forward_counter += 1;
tracing::trace!(?to_peer_id, ?my_peer_id, "need forward");
let ret = peers.send_msg(ret, to_peer_id).await;
let ret =
Self::send_msg_internal(&peers, &foreign_client, ret, to_peer_id).await;
if ret.is_err() {
tracing::error!(?ret, ?to_peer_id, ?from_peer_id, "forward packet error");
}
@@ -518,11 +536,31 @@ impl PeerManager {
}
}
fn get_next_hop_policy(is_first_latency: bool) -> NextHopPolicy {
if is_first_latency {
NextHopPolicy::LeastCost
} else {
NextHopPolicy::LeastHop
}
}
pub async fn send_msg(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
if let Some(gateway) = self.peers.get_gateway_peer_id(dst_peer_id).await {
self.peers.send_msg_directly(msg, gateway).await
} else if self.foreign_network_client.has_next_hop(dst_peer_id) {
self.foreign_network_client.send_msg(msg, dst_peer_id).await
Self::send_msg_internal(&self.peers, &self.foreign_network_client, msg, dst_peer_id).await
}
async fn send_msg_internal(
peers: &Arc<PeerMap>,
foreign_network_client: &Arc<ForeignNetworkClient>,
msg: ZCPacket,
dst_peer_id: PeerId,
) -> Result<(), Error> {
let policy =
Self::get_next_hop_policy(msg.peer_manager_header().unwrap().is_latency_first());
if let Some(gateway) = peers.get_gateway_peer_id(dst_peer_id, policy).await {
peers.send_msg_directly(msg, gateway).await
} else if foreign_network_client.has_next_hop(dst_peer_id) {
foreign_network_client.send_msg(msg, dst_peer_id).await
} else {
Err(Error::RouteError(None))
}
@@ -564,6 +602,12 @@ impl PeerManager {
.encrypt(&mut msg)
.with_context(|| "encrypt failed")?;
let is_latency_first = self.global_ctx.get_flags().latency_first;
msg.mut_peer_manager_header()
.unwrap()
.set_latency_first(is_latency_first);
let next_hop_policy = Self::get_next_hop_policy(is_latency_first);
let mut errs: Vec<Error> = vec![];
let mut msg = Some(msg);
@@ -581,7 +625,11 @@ impl PeerManager {
.to_peer_id
.set(*peer_id);
if let Some(gateway) = self.peers.get_gateway_peer_id(*peer_id).await {
if let Some(gateway) = self
.peers
.get_gateway_peer_id(*peer_id, next_hop_policy.clone())
.await
{
if let Err(e) = self.peers.send_msg_directly(msg, gateway).await {
errs.push(e);
}
+18 -6
View File
@@ -18,7 +18,7 @@ use crate::{
use super::{
peer::Peer,
peer_conn::{PeerConn, PeerConnId},
route_trait::ArcRoute,
route_trait::{ArcRoute, NextHopPolicy},
PacketRecvChan,
};
@@ -94,18 +94,25 @@ impl PeerMap {
Ok(())
}
pub async fn get_gateway_peer_id(&self, dst_peer_id: PeerId) -> Option<PeerId> {
pub async fn get_gateway_peer_id(
&self,
dst_peer_id: PeerId,
policy: NextHopPolicy,
) -> Option<PeerId> {
if dst_peer_id == self.my_peer_id {
return Some(dst_peer_id);
}
if self.has_peer(dst_peer_id) {
if self.has_peer(dst_peer_id) && matches!(policy, NextHopPolicy::LeastHop) {
return Some(dst_peer_id);
}
// get route info
for route in self.routes.read().await.iter() {
if let Some(gateway_peer_id) = route.get_next_hop(dst_peer_id).await {
if let Some(gateway_peer_id) = route
.get_next_hop_with_policy(dst_peer_id, policy.clone())
.await
{
// for foreign network, gateway_peer_id may not connect to me
if self.has_peer(gateway_peer_id) {
return Some(gateway_peer_id);
@@ -116,8 +123,13 @@ impl PeerMap {
None
}
pub async fn send_msg(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
let Some(gateway_peer_id) = self.get_gateway_peer_id(dst_peer_id).await else {
pub async fn send_msg(
&self,
msg: ZCPacket,
dst_peer_id: PeerId,
policy: NextHopPolicy,
) -> Result<(), Error> {
let Some(gateway_peer_id) = self.get_gateway_peer_id(dst_peer_id, policy).await else {
return Err(Error::RouteError(Some(format!(
"peer map sengmsg no gateway for dst_peer_id: {}",
dst_peer_id
+295 -28
View File
@@ -10,6 +10,11 @@ use std::{
};
use dashmap::DashMap;
use petgraph::{
algo::{all_simple_paths, astar, dijkstra},
graph::NodeIndex,
Directed, Graph,
};
use serde::{Deserialize, Serialize};
use tokio::{select, sync::Mutex, task::JoinSet};
@@ -19,7 +24,14 @@ use crate::{
rpc::{NatType, StunInfo},
};
use super::{peer_rpc::PeerRpcManager, PeerPacketFilter};
use super::{
peer_rpc::PeerRpcManager,
route_trait::{
DefaultRouteCostCalculator, NextHopPolicy, RouteCostCalculator,
RouteCostCalculatorInterface,
},
PeerPacketFilter,
};
static SERVICE_ID: u32 = 7;
static UPDATE_PEER_INFO_PERIOD: Duration = Duration::from_secs(3600);
@@ -360,11 +372,15 @@ impl SyncedRouteInfo {
}
}
type PeerGraph = Graph<PeerId, i32, Directed>;
type PeerIdToNodexIdxMap = DashMap<PeerId, NodeIndex>;
type NextHopMap = DashMap<PeerId, (PeerId, i32)>;
// computed with SyncedRouteInfo. used to get next hop.
#[derive(Debug)]
struct RouteTable {
peer_infos: DashMap<PeerId, RoutePeerInfo>,
next_hop_map: DashMap<PeerId, (PeerId, i32)>,
next_hop_map: NextHopMap,
ipv4_peer_id_map: DashMap<Ipv4Addr, PeerId>,
cidr_peer_id_map: DashMap<cidr::IpCidr, PeerId>,
}
@@ -393,7 +409,121 @@ impl RouteTable {
.map(|x| NatType::try_from(x.udp_stun_info as i32).unwrap())
}
fn build_from_synced_info(&self, my_peer_id: PeerId, synced_info: &SyncedRouteInfo) {
fn build_peer_graph_from_synced_info<T: RouteCostCalculatorInterface>(
peers: Vec<PeerId>,
synced_info: &SyncedRouteInfo,
cost_calc: &mut T,
) -> (PeerGraph, PeerIdToNodexIdxMap) {
let mut graph: PeerGraph = Graph::new();
let peer_id_to_node_index = PeerIdToNodexIdxMap::new();
for peer_id in peers.iter() {
peer_id_to_node_index.insert(*peer_id, graph.add_node(*peer_id));
}
for peer_id in peers.iter() {
let connected_peers = synced_info
.get_connected_peers(*peer_id)
.unwrap_or(BTreeSet::new());
for dst_peer_id in connected_peers.iter() {
let Some(dst_idx) = peer_id_to_node_index.get(dst_peer_id) else {
continue;
};
graph.add_edge(
*peer_id_to_node_index.get(&peer_id).unwrap(),
*dst_idx,
cost_calc.calculate_cost(*peer_id, *dst_peer_id),
);
}
}
(graph, peer_id_to_node_index)
}
fn gen_next_hop_map_with_least_hop<T: RouteCostCalculatorInterface>(
my_peer_id: PeerId,
graph: &PeerGraph,
idx_map: &PeerIdToNodexIdxMap,
cost_calc: &mut T,
) -> NextHopMap {
let res = dijkstra(&graph, *idx_map.get(&my_peer_id).unwrap(), None, |_| 1);
let next_hop_map = NextHopMap::new();
for (node_idx, cost) in res.iter() {
if *cost == 0 {
continue;
}
let all_paths = all_simple_paths::<Vec<_>, _>(
graph,
*idx_map.get(&my_peer_id).unwrap(),
*node_idx,
*cost - 1,
Some(*cost - 1),
)
.collect::<Vec<_>>();
assert!(!all_paths.is_empty());
// find a path with least cost.
let mut min_cost = i32::MAX;
let mut min_path = Vec::new();
for path in all_paths.iter() {
let mut cost = 0;
for i in 0..path.len() - 1 {
let src_peer_id = *graph.node_weight(path[i]).unwrap();
let dst_peer_id = *graph.node_weight(path[i + 1]).unwrap();
cost += cost_calc.calculate_cost(src_peer_id, dst_peer_id);
}
if cost <= min_cost {
min_cost = cost;
min_path = path.clone();
}
}
next_hop_map.insert(
*graph.node_weight(*node_idx).unwrap(),
(*graph.node_weight(min_path[1]).unwrap(), *cost as i32),
);
}
next_hop_map
}
fn gen_next_hop_map_with_least_cost(
my_peer_id: PeerId,
graph: &PeerGraph,
idx_map: &PeerIdToNodexIdxMap,
) -> NextHopMap {
let next_hop_map = NextHopMap::new();
for item in idx_map.iter() {
if *item.key() == my_peer_id {
continue;
}
let dst_peer_node_idx = *item.value();
let Some((cost, path)) = astar::astar(
graph,
*idx_map.get(&my_peer_id).unwrap(),
|node_idx| node_idx == dst_peer_node_idx,
|e| *e.weight(),
|_| 0,
) else {
continue;
};
next_hop_map.insert(*item.key(), (*graph.node_weight(path[1]).unwrap(), cost));
}
next_hop_map
}
fn build_from_synced_info<T: RouteCostCalculatorInterface>(
&self,
my_peer_id: PeerId,
synced_info: &SyncedRouteInfo,
policy: NextHopPolicy,
mut cost_calc: T,
) {
// build peer_infos
self.peer_infos.clear();
for item in synced_info.peer_infos.iter() {
@@ -410,28 +540,20 @@ impl RouteTable {
// build next hop map
self.next_hop_map.clear();
self.next_hop_map.insert(my_peer_id, (my_peer_id, 0));
for item in self.peer_infos.iter() {
let peer_id = *item.key();
if peer_id == my_peer_id {
continue;
}
let Some(path) = pathfinding::prelude::bfs(
&my_peer_id,
|p| {
synced_info
.get_connected_peers(*p)
.unwrap_or_else(|| BTreeSet::new())
},
|x| *x == peer_id,
) else {
continue;
};
if !path.is_empty() {
assert!(path.len() >= 2);
self.next_hop_map
.insert(peer_id, (path[1], (path.len() - 1) as i32));
}
let (graph, idx_map) = Self::build_peer_graph_from_synced_info(
self.peer_infos.iter().map(|x| *x.key()).collect(),
&synced_info,
&mut cost_calc,
);
let next_hop_map = if matches!(policy, NextHopPolicy::LeastHop) {
Self::gen_next_hop_map_with_least_hop(my_peer_id, &graph, &idx_map, &mut cost_calc)
} else {
Self::gen_next_hop_map_with_least_cost(my_peer_id, &graph, &idx_map)
};
for item in next_hop_map.iter() {
self.next_hop_map.insert(*item.key(), *item.value());
}
// build graph
// build ipv4_peer_id_map, cidr_peer_id_map
self.ipv4_peer_id_map.clear();
@@ -563,7 +685,9 @@ struct PeerRouteServiceImpl {
interface: Arc<Mutex<Option<RouteInterfaceBox>>>,
cost_calculator: Arc<std::sync::Mutex<Option<RouteCostCalculator>>>,
route_table: RouteTable,
route_table_with_cost: RouteTable,
synced_route_info: Arc<SyncedRouteInfo>,
cached_local_conn_map: std::sync::Mutex<RouteConnBitmap>,
}
@@ -585,9 +709,17 @@ impl PeerRouteServiceImpl {
PeerRouteServiceImpl {
my_peer_id,
global_ctx,
interface: Arc::new(Mutex::new(None)),
sessions: DashMap::new(),
interface: Arc::new(Mutex::new(None)),
cost_calculator: Arc::new(std::sync::Mutex::new(Some(Box::new(
DefaultRouteCostCalculator,
)))),
route_table: RouteTable::new(),
route_table_with_cost: RouteTable::new(),
synced_route_info: Arc::new(SyncedRouteInfo {
peer_infos: DashMap::new(),
conn_map: DashMap::new(),
@@ -649,8 +781,32 @@ impl PeerRouteServiceImpl {
}
fn update_route_table(&self) {
self.route_table
.build_from_synced_info(self.my_peer_id, &self.synced_route_info);
let mut calc_locked = self.cost_calculator.lock().unwrap();
calc_locked.as_mut().unwrap().begin_update();
self.route_table.build_from_synced_info(
self.my_peer_id,
&self.synced_route_info,
NextHopPolicy::LeastHop,
calc_locked.as_mut().unwrap(),
);
self.route_table_with_cost.build_from_synced_info(
self.my_peer_id,
&self.synced_route_info,
NextHopPolicy::LeastCost,
calc_locked.as_mut().unwrap(),
);
calc_locked.as_mut().unwrap().end_update();
}
fn cost_calculator_need_update(&self) -> bool {
self.cost_calculator
.lock()
.unwrap()
.as_ref()
.map(|x| x.need_update())
.unwrap_or(false)
}
fn update_route_table_and_cached_local_conn_bitmap(&self) {
@@ -1173,6 +1329,7 @@ impl PeerRoute {
session_mgr.maintain_sessions(service_impl).await;
}
#[tracing::instrument(skip(session_mgr))]
async fn update_my_peer_info_routine(
service_impl: Arc<PeerRouteServiceImpl>,
session_mgr: RouteSessionManager,
@@ -1183,6 +1340,11 @@ impl PeerRoute {
session_mgr.sync_now("update_my_infos");
}
if service_impl.cost_calculator_need_update() {
tracing::debug!("cost_calculator_need_update");
service_impl.update_route_table();
}
select! {
ev = global_event_receiver.recv() => {
tracing::info!(?ev, "global event received in update_my_peer_info_routine");
@@ -1234,6 +1396,19 @@ impl Route for PeerRoute {
route_table.get_next_hop(dst_peer_id).map(|x| x.0)
}
async fn get_next_hop_with_policy(
&self,
dst_peer_id: PeerId,
policy: NextHopPolicy,
) -> Option<PeerId> {
let route_table = if matches!(policy, NextHopPolicy::LeastCost) {
&self.service_impl.route_table_with_cost
} else {
&self.service_impl.route_table
};
route_table.get_next_hop(dst_peer_id).map(|x| x.0)
}
async fn list_routes(&self) -> Vec<crate::rpc::Route> {
let route_table = &self.service_impl.route_table;
let mut routes = Vec::new();
@@ -1265,6 +1440,11 @@ impl Route for PeerRoute {
tracing::info!(?ipv4_addr, "no peer id for ipv4");
None
}
async fn set_route_cost_fn(&self, _cost_fn: RouteCostCalculator) {
*self.service_impl.cost_calculator.lock().unwrap() = Some(_cost_fn);
self.service_impl.update_route_table();
}
}
impl PeerPacketFilter for Arc<PeerRoute> {}
@@ -1282,7 +1462,7 @@ mod tests {
connector::udp_hole_punch::tests::replace_stun_info_collector,
peers::{
peer_manager::{PeerManager, RouteAlgoType},
route_trait::Route,
route_trait::{NextHopPolicy, Route, RouteCostCalculatorInterface},
tests::{connect_peer_manager, wait_for_condition},
},
rpc::NatType,
@@ -1609,4 +1789,91 @@ mod tests {
println!("session: {:?}", r_a.session_mgr.dump_sessions());
check_rpc_counter(&r_a, p_b.my_peer_id(), 2, 2);
}
#[tokio::test]
async fn test_cost_calculator() {
let p_a = create_mock_pmgr().await;
let p_b = create_mock_pmgr().await;
let p_c = create_mock_pmgr().await;
let p_d = create_mock_pmgr().await;
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_a.clone(), p_c.clone()).await;
connect_peer_manager(p_d.clone(), p_b.clone()).await;
connect_peer_manager(p_d.clone(), p_c.clone()).await;
connect_peer_manager(p_b.clone(), p_c.clone()).await;
let _r_a = create_mock_route(p_a.clone()).await;
let _r_b = create_mock_route(p_b.clone()).await;
let _r_c = create_mock_route(p_c.clone()).await;
let r_d = create_mock_route(p_d.clone()).await;
// in normal mode, packet from p_c should directly forward to p_a
wait_for_condition(
|| async { r_d.get_next_hop(p_a.my_peer_id()).await != None },
Duration::from_secs(5),
)
.await;
struct TestCostCalculator {
p_a_peer_id: PeerId,
p_b_peer_id: PeerId,
p_c_peer_id: PeerId,
p_d_peer_id: PeerId,
}
impl RouteCostCalculatorInterface for TestCostCalculator {
fn calculate_cost(&self, src: PeerId, dst: PeerId) -> i32 {
if src == self.p_d_peer_id && dst == self.p_b_peer_id {
return 100;
}
if src == self.p_d_peer_id && dst == self.p_c_peer_id {
return 1;
}
if src == self.p_c_peer_id && dst == self.p_a_peer_id {
return 101;
}
if src == self.p_b_peer_id && dst == self.p_a_peer_id {
return 1;
}
if src == self.p_c_peer_id && dst == self.p_b_peer_id {
return 2;
}
1
}
}
r_d.set_route_cost_fn(Box::new(TestCostCalculator {
p_a_peer_id: p_a.my_peer_id(),
p_b_peer_id: p_b.my_peer_id(),
p_c_peer_id: p_c.my_peer_id(),
p_d_peer_id: p_d.my_peer_id(),
}))
.await;
// after set cost, packet from p_c should forward to p_b first
wait_for_condition(
|| async {
r_d.get_next_hop_with_policy(p_a.my_peer_id(), NextHopPolicy::LeastCost)
.await
== Some(p_c.my_peer_id())
},
Duration::from_secs(5),
)
.await;
wait_for_condition(
|| async {
r_d.get_next_hop_with_policy(p_a.my_peer_id(), NextHopPolicy::LeastHop)
.await
== Some(p_b.my_peer_id())
},
Duration::from_secs(5),
)
.await;
}
}
+47
View File
@@ -5,6 +5,18 @@ use tokio_util::bytes::Bytes;
use crate::common::{error::Error, PeerId};
#[derive(Clone, Debug)]
pub enum NextHopPolicy {
LeastHop,
LeastCost,
}
impl Default for NextHopPolicy {
fn default() -> Self {
NextHopPolicy::LeastHop
}
}
#[async_trait]
pub trait RouteInterface {
async fn list_peers(&self) -> Vec<PeerId>;
@@ -19,6 +31,31 @@ pub trait RouteInterface {
pub type RouteInterfaceBox = Box<dyn RouteInterface + Send + Sync>;
#[auto_impl::auto_impl(Box , &mut)]
pub trait RouteCostCalculatorInterface: Send + Sync {
fn begin_update(&mut self) {}
fn end_update(&mut self) {}
fn calculate_cost(&self, _src: PeerId, _dst: PeerId) -> i32 {
1
}
fn need_update(&self) -> bool {
false
}
fn dump(&self) -> String {
"All routes have cost 1".to_string()
}
}
#[derive(Clone, Debug, Default)]
pub struct DefaultRouteCostCalculator;
impl RouteCostCalculatorInterface for DefaultRouteCostCalculator {}
pub type RouteCostCalculator = Box<dyn RouteCostCalculatorInterface>;
#[async_trait]
#[auto_impl::auto_impl(Box, Arc)]
pub trait Route {
@@ -26,11 +63,21 @@ pub trait Route {
async fn close(&self);
async fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId>;
async fn get_next_hop_with_policy(
&self,
peer_id: PeerId,
_policy: NextHopPolicy,
) -> Option<PeerId> {
self.get_next_hop(peer_id).await
}
async fn list_routes(&self) -> Vec<crate::rpc::Route>;
async fn get_peer_id_by_ipv4(&self, _ipv4: &Ipv4Addr) -> Option<PeerId> {
None
}
async fn set_route_cost_fn(&self, _cost_fn: RouteCostCalculator) {}
}
pub type ArcRoute = Arc<Box<dyn Route + Send + Sync>>;
+6 -3
View File
@@ -187,12 +187,13 @@ pub async fn basic_three_node_test(#[values("tcp", "udp", "wg", "ws", "wss")] pr
pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
use crate::tunnel::{common::tests::_tunnel_pingpong_netns, tcp::TcpTunnelListener};
let insts = init_three_node(proto).await;
let mut insts = init_three_node(proto).await;
insts[2]
.get_global_ctx()
.add_proxy_cidr("10.1.2.0/24".parse().unwrap())
.unwrap();
insts[2].run_ip_proxy().await.unwrap();
assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1);
wait_proxy_route_appear(
@@ -222,12 +223,13 @@ pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str
#[tokio::test]
#[serial_test::serial]
pub async fn icmp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
let insts = init_three_node(proto).await;
let mut insts = init_three_node(proto).await;
insts[2]
.get_global_ctx()
.add_proxy_cidr("10.1.2.0/24".parse().unwrap())
.unwrap();
insts[2].run_ip_proxy().await.unwrap();
assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1);
wait_proxy_route_appear(
@@ -318,12 +320,13 @@ pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str
pub async fn udp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
use crate::tunnel::{common::tests::_tunnel_pingpong_netns, udp::UdpTunnelListener};
let insts = init_three_node(proto).await;
let mut insts = init_three_node(proto).await;
insts[2]
.get_global_ctx()
.add_proxy_cidr("10.1.2.0/24".parse().unwrap())
.unwrap();
insts[2].run_ip_proxy().await.unwrap();
assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1);
wait_proxy_route_appear(
+20 -1
View File
@@ -59,6 +59,7 @@ pub enum PacketType {
bitflags::bitflags! {
struct PeerManagerHeaderFlags: u8 {
const ENCRYPTED = 0b0000_0001;
const LATENCY_FIRST = 0b0000_0010;
}
}
@@ -69,7 +70,8 @@ pub struct PeerManagerHeader {
pub to_peer_id: U32<DefaultEndian>,
pub packet_type: u8,
pub flags: u8,
reserved: U16<DefaultEndian>,
pub forward_counter: u8,
reserved: u8,
pub len: U32<DefaultEndian>,
}
pub const PEER_MANAGER_HEADER_SIZE: usize = std::mem::size_of::<PeerManagerHeader>();
@@ -90,6 +92,22 @@ impl PeerManagerHeader {
}
self.flags = flags.bits();
}
pub fn is_latency_first(&self) -> bool {
PeerManagerHeaderFlags::from_bits(self.flags)
.unwrap()
.contains(PeerManagerHeaderFlags::LATENCY_FIRST)
}
pub fn set_latency_first(&mut self, latency_first: bool) {
let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap();
if latency_first {
flags.insert(PeerManagerHeaderFlags::LATENCY_FIRST);
} else {
flags.remove(PeerManagerHeaderFlags::LATENCY_FIRST);
}
self.flags = flags.bits();
}
}
// reserve the space for aes tag and nonce
@@ -362,6 +380,7 @@ impl ZCPacket {
hdr.to_peer_id.set(to_peer_id);
hdr.packet_type = packet_type;
hdr.flags = 0;
hdr.forward_counter = 1;
hdr.len.set(payload_len as u32);
}