From 8311b11713ddc6bdb3bc7bdbeeb6cf052567805b Mon Sep 17 00:00:00 2001 From: Luna Yao <40349250+ZnqbuZ@users.noreply.github.com> Date: Fri, 10 Apr 2026 17:22:08 +0200 Subject: [PATCH] refactor: remove NoGroAsyncUdpSocket (#1867) --- easytier/src/connector/mod.rs | 4 +- easytier/src/gateway/quic_proxy.rs | 10 +- easytier/src/gateway/socks5.rs | 6 +- easytier/src/gateway/udp_proxy.rs | 4 +- easytier/src/instance/listeners.rs | 8 +- easytier/src/tunnel/common.rs | 11 +- easytier/src/tunnel/mod.rs | 2 +- easytier/src/tunnel/quic.rs | 496 +++++++++++++++++++++-------- easytier/src/tunnel/tcp.rs | 6 +- easytier/src/tunnel/udp.rs | 10 +- easytier/src/tunnel/websocket.rs | 6 +- easytier/src/tunnel/wireguard.rs | 10 +- 12 files changed, 401 insertions(+), 172 deletions(-) diff --git a/easytier/src/connector/mod.rs b/easytier/src/connector/mod.rs index bfe8740c..832f0173 100644 --- a/easytier/src/connector/mod.rs +++ b/easytier/src/connector/mod.rs @@ -105,7 +105,9 @@ pub async fn create_connector_by_url( IpScheme::Tcp => TcpTunnelConnector::new(url).boxed(), IpScheme::Udp => UdpTunnelConnector::new(url).boxed(), #[cfg(feature = "quic")] - IpScheme::Quic => tunnel::quic::QuicTunnelConnector::new(url).boxed(), + IpScheme::Quic => { + tunnel::quic::QuicTunnelConnector::new(url, global_ctx.clone()).boxed() + } #[cfg(feature = "wireguard")] IpScheme::Wg => { use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector}; diff --git a/easytier/src/gateway/quic_proxy.rs b/easytier/src/gateway/quic_proxy.rs index 743a9b88..c959efc8 100644 --- a/easytier/src/gateway/quic_proxy.rs +++ b/easytier/src/gateway/quic_proxy.rs @@ -26,7 +26,9 @@ use derivative::Derivative; use derive_more::{Constructor, Deref, DerefMut, From, Into}; use prost::Message; use quinn::udp::{EcnCodepoint, RecvMeta, Transmit}; -use quinn::{AsyncUdpSocket, Endpoint, RecvStream, SendStream, StreamId, TokioRuntime, UdpPoller}; +use quinn::{ + AsyncUdpSocket, Endpoint, RecvStream, SendStream, StreamId, UdpPoller, default_runtime, +}; use std::cmp::min; use std::future::Future; use std::io::IoSliceMut; @@ -806,7 +808,7 @@ impl QuicProxy { endpoint_config(), Some(server_config()), Arc::new(socket), - Arc::new(TokioRuntime), + default_runtime().unwrap(), ) .unwrap(); endpoint.set_default_client_config(client_config()); @@ -1020,7 +1022,7 @@ mod tests { endpoint_config.clone(), Some(server_config.clone()), socket_client.clone(), - Arc::new(TokioRuntime), + default_runtime().unwrap(), ) .unwrap(); client_endpoint.set_default_client_config(client_config.clone()); @@ -1030,7 +1032,7 @@ mod tests { endpoint_config.clone(), Some(server_config.clone()), socket_server.clone(), - Arc::new(TokioRuntime), + default_runtime().unwrap(), ) .unwrap(); server_endpoint.set_default_client_config(client_config.clone()); diff --git a/easytier/src/gateway/socks5.rs b/easytier/src/gateway/socks5.rs index 7b19ffff..728bd131 100644 --- a/easytier/src/gateway/socks5.rs +++ b/easytier/src/gateway/socks5.rs @@ -31,7 +31,7 @@ use crate::{ tokio_smoltcp::{BufferSize, Net, NetConfig, channel_device}, }, tunnel::{ - common::setup_sokcet2, + common::setup_socket2, packet_def::{PacketType, ZCPacket}, }, }; @@ -336,7 +336,7 @@ fn bind_tcp_socket(addr: SocketAddr, net_ns: NetNS) -> Result Result Some(socket2::Protocol::UDP), )?; - setup_sokcet2(&socket2_socket, &addr)?; + setup_socket2(&socket2_socket, &addr, true)?; Ok(UdpSocket::from_std(socket2_socket.into())?) } diff --git a/easytier/src/gateway/udp_proxy.rs b/easytier/src/gateway/udp_proxy.rs index 5a9a98d6..ab6ecc95 100644 --- a/easytier/src/gateway/udp_proxy.rs +++ b/easytier/src/gateway/udp_proxy.rs @@ -29,7 +29,7 @@ use crate::{ gateway::ip_reassembler::{ComposeIpv4PacketArgs, compose_ipv4_packet}, peers::{PeerPacketFilter, peer_manager::PeerManager}, tunnel::{ - common::{reserve_buf, setup_sokcet2}, + common::{reserve_buf, setup_socket2}, packet_def::{PacketType, ZCPacket}, }, }; @@ -72,7 +72,7 @@ impl UdpNatEntry { Some(socket2::Protocol::UDP), )?; let dst_socket_addr = "0.0.0.0:0".parse().unwrap(); - setup_sokcet2(&socket2_socket, &dst_socket_addr)?; + setup_socket2(&socket2_socket, &dst_socket_addr, true)?; Some(UdpSocket::from_std(socket2_socket.into())?) }; diff --git a/easytier/src/instance/listeners.rs b/easytier/src/instance/listeners.rs index 6f7782dd..72a1b9af 100644 --- a/easytier/src/instance/listeners.rs +++ b/easytier/src/instance/listeners.rs @@ -25,7 +25,7 @@ use crate::{ pub fn create_listener_by_url( l: &url::Url, - #[allow(unused_variables)] ctx: ArcGlobalCtx, + global_ctx: ArcGlobalCtx, ) -> Result, Error> { Ok(match l.try_into()? { TunnelScheme::Ip(scheme) => match scheme { @@ -34,7 +34,7 @@ pub fn create_listener_by_url( #[cfg(feature = "wireguard")] IpScheme::Wg => { use crate::tunnel::wireguard::{WgConfig, WgTunnelListener}; - let nid = ctx.get_network_identity(); + let nid = global_ctx.get_network_identity(); let wg_config = WgConfig::new_from_network_identity( &nid.network_name, &nid.network_secret.unwrap_or_default(), @@ -42,7 +42,9 @@ pub fn create_listener_by_url( WgTunnelListener::new(l.clone(), wg_config).boxed() } #[cfg(feature = "quic")] - IpScheme::Quic => tunnel::quic::QuicTunnelListener::new(l.clone()).boxed(), + IpScheme::Quic => { + tunnel::quic::QuicTunnelListener::new(l.clone(), global_ctx.clone()).boxed() + } #[cfg(feature = "websocket")] IpScheme::Ws | IpScheme::Wss => { tunnel::websocket::WsTunnelListener::new(l.clone()).boxed() diff --git a/easytier/src/tunnel/common.rs b/easytier/src/tunnel/common.rs index f6814f6e..92766a29 100644 --- a/easytier/src/tunnel/common.rs +++ b/easytier/src/tunnel/common.rs @@ -344,10 +344,11 @@ pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option { None } -pub(crate) fn setup_sokcet2_ext( +pub(crate) fn setup_socket2_ext( socket2_socket: &socket2::Socket, bind_addr: &SocketAddr, #[allow(unused_variables)] bind_dev: Option, + only_v6: bool, ) -> Result<(), TunnelError> { #[cfg(target_os = "windows")] { @@ -356,7 +357,7 @@ pub(crate) fn setup_sokcet2_ext( } if bind_addr.is_ipv6() { - socket2_socket.set_only_v6(true)?; + socket2_socket.set_only_v6(only_v6)?; } socket2_socket.set_nonblocking(true)?; @@ -428,14 +429,16 @@ where Err(last_err.unwrap_or(TunnelError::Shutdown)) } -pub(crate) fn setup_sokcet2( +pub(crate) fn setup_socket2( socket2_socket: &socket2::Socket, bind_addr: &SocketAddr, + only_v6: bool, ) -> Result<(), TunnelError> { - setup_sokcet2_ext( + setup_socket2_ext( socket2_socket, bind_addr, super::common::get_interface_name_by_ip(&bind_addr.ip()), + only_v6, ) } diff --git a/easytier/src/tunnel/mod.rs b/easytier/src/tunnel/mod.rs index ca169327..36feb753 100644 --- a/easytier/src/tunnel/mod.rs +++ b/easytier/src/tunnel/mod.rs @@ -46,7 +46,7 @@ pub mod unix; #[derive(thiserror::Error, Debug)] pub enum TunnelError { - #[error("io error")] + #[error("io error: {0}")] IOError(#[from] std::io::Error), #[error("invalid packet. msg: {0}")] InvalidPacket(String), diff --git a/easytier/src/tunnel/quic.rs b/easytier/src/tunnel/quic.rs index 96ca5a6f..5263ce3c 100644 --- a/easytier/src/tunnel/quic.rs +++ b/easytier/src/tunnel/quic.rs @@ -2,22 +2,25 @@ //! //! Checkout the `README.md` for guidance. -use std::{ - error::Error, io::IoSliceMut, net::SocketAddr, pin::Pin, sync::Arc, task::Poll, time::Duration, -}; - +use super::{FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener}; +use crate::common::global_ctx::ArcGlobalCtx; use crate::tunnel::{ - FromUrl, TunnelInfo, - common::{FramedReader, FramedWriter, TunnelWrapper, setup_sokcet2}, + TunnelInfo, + common::{FramedReader, FramedWriter, TunnelWrapper, setup_socket2}, }; use anyhow::Context; - -use super::{IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener}; +use derivative::Derivative; +use derive_more::{Deref, DerefMut}; +use parking_lot::RwLock; use quinn::{ - AsyncUdpSocket, ClientConfig, Connection, Endpoint, EndpointConfig, ServerConfig, - TransportConfig, UdpPoller, congestion::BbrConfig, udp::RecvMeta, + ClientConfig, Connection, Endpoint, EndpointConfig, ServerConfig, TransportConfig, + congestion::BbrConfig, default_runtime, }; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::sync::OnceLock; +use std::{net::SocketAddr, sync::Arc, time::Duration}; +// region config pub fn transport_config() -> Arc { let mut config = TransportConfig::default(); @@ -50,86 +53,287 @@ pub fn endpoint_config() -> EndpointConfig { config.max_udp_payload_size(65527).unwrap(); config } +//endregion -#[derive(Clone, Debug)] -struct NoGroAsyncUdpSocket { - inner: Arc, +//region rw pool +#[derive(Derivative)] +#[derivative(Default(bound = ""))] +#[derive(Debug, Deref, DerefMut)] +struct RwPoolInner { + #[deref] + #[deref_mut] + pool: Vec, + enabled: bool, } -impl AsyncUdpSocket for NoGroAsyncUdpSocket { - fn create_io_poller(self: Arc) -> Pin> { - self.inner.clone().create_io_poller() +#[derive(Debug)] +struct RwPool { + ephemeral: RwLock>, + persistent: RwLock>, + capacity: usize, +} + +impl RwPool { + fn new(capacity: usize) -> Self { + Self { + ephemeral: RwLock::new(RwPoolInner::default()), + persistent: RwLock::new(RwPoolInner::default()), + capacity, + } } - fn try_send(&self, transmit: &quinn::udp::Transmit) -> std::io::Result<()> { - self.inner.try_send(transmit) - } - - /// Receive UDP datagrams, or register to be woken if receiving may succeed in the future - fn poll_recv( + /// return the capacity of the ephemeral pool; + /// if `ephemeral` or `persistent` is None, read lock `self`'s pool + fn capacity( &self, - cx: &mut std::task::Context, - bufs: &mut [IoSliceMut<'_>], - meta: &mut [RecvMeta], - ) -> Poll> { - self.inner.poll_recv(cx, bufs, meta) + ephemeral: Option<&RwPoolInner>, + persistent: Option<&RwPoolInner>, + ) -> usize { + let guard; + let ephemeral = if let Some(ephemeral) = ephemeral { + ephemeral + } else { + guard = self.ephemeral.read(); + &guard + }; + + let guard; + let persistent = if let Some(persistent) = persistent { + persistent + } else { + guard = self.persistent.read(); + &guard + }; + + (self.capacity * ephemeral.enabled as usize).saturating_sub(persistent.len()) } - /// Look up the local IP address and port used by this socket - fn local_addr(&self) -> std::io::Result { - self.inner.local_addr() + fn is_full(&self) -> bool { + let pool = self.ephemeral.read(); + pool.len() >= self.capacity(Some(&pool), None) } - fn may_fragment(&self) -> bool { - self.inner.may_fragment() + fn is_enabled(&self) -> bool { + self.ephemeral.read().enabled } - fn max_transmit_segments(&self) -> usize { - self.inner.max_transmit_segments() + fn enable(&self) { + self.ephemeral.write().enabled = true; + self.resize(); } - fn max_receive_segments(&self) -> usize { - 1 + fn disable(&self) { + self.ephemeral.write().enabled = false; + self.resize(); + } + + /// push an item to the persistent pool + fn push(&self, item: Item) { + self.persistent.write().push(item); + self.resize(); + } + + /// try to push an item to the ephemeral pool, return the item if full + fn try_push(&self, item: Item) -> Option { + let mut pool = self.ephemeral.write(); + if pool.len() < self.capacity(Some(&pool), None) { + pool.push(item); + return None; + } + Some(item) + } + + fn resize(&self) { + let resize = { + let pool = self.ephemeral.read(); + pool.capacity() != self.capacity(Some(&pool), None) + }; + if resize { + let mut pool = self.ephemeral.write(); + let capacity = self.capacity(Some(&pool), None); + pool.reserve_exact(capacity); + pool.truncate(capacity); + pool.shrink_to(capacity); + } + } + + fn with_iter(&self, f: F) -> R + where + F: FnOnce(&mut dyn Iterator) -> R, + { + let ephemeral = self.ephemeral.read(); + let persistent = self.persistent.read(); + f(&mut persistent.iter().chain(ephemeral.iter())) + } +} +//endregion + +//region endpoint manager +#[derive(Debug)] +pub struct QuicEndpointManager { + ipv4: RwPool, + ipv6: RwPool, + both: RwPool, +} + +static QUIC_ENDPOINT_MANAGER: OnceLock = OnceLock::new(); + +impl QuicEndpointManager { + fn try_create(addr: SocketAddr, dual_stack: bool) -> std::io::Result { + let socket = socket2::Socket::new( + socket2::Domain::for_address(addr), + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + )?; + setup_socket2(&socket, &addr, addr.is_ipv6() && !dual_stack) + .map_err(std::io::Error::other)?; + let socket = std::net::UdpSocket::from(socket); + let runtime = default_runtime().ok_or(std::io::Error::other("no async runtime found"))?; + let mut endpoint = Endpoint::new_with_abstract_socket( + endpoint_config(), + None, + runtime.wrap_udp_socket(socket)?, + runtime, + )?; + endpoint.set_default_client_config(client_config()); + Ok(endpoint) + } + + fn create(&self, mut selector: F) -> std::io::Result<(&RwPool, Option)> + where + F: FnMut(&QuicEndpointManager) -> (&RwPool, Option<(SocketAddr, bool)>), + { + loop { + let (pool, r) = selector(self); + let Some((addr, dual_stack)) = r else { + return Ok((pool, None)); + }; + + let endpoint = Self::try_create(addr, dual_stack); + if let Err(e) = endpoint.as_ref() + && dual_stack + { + tracing::warn!("create dual stack quic endpoint failed: {:?}", e); + self.both.disable(); + self.ipv4.enable(); + self.ipv6.enable(); + continue; + } + + return Ok((pool, Some(endpoint?))); + } } } -/// Constructs a QUIC endpoint configured to listen for incoming connections on a certain address -/// and port. -/// -/// ## Returns -/// -/// - an [`Endpoint`] configured to accept incoming QUIC connections -#[allow(unused)] -pub fn make_server_endpoint(bind_addr: SocketAddr) -> Result> { - let server_config = server_config(); - let client_config = client_config(); - let endpoint_config = endpoint_config(); +impl QuicEndpointManager { + fn new(capacity: usize) -> Self { + let ipv4 = RwPool::new(capacity.div_ceil(2)); + let ipv6 = RwPool::new(capacity.div_ceil(2)); + let both = RwPool::new(capacity); + both.enable(); + Self { ipv4, ipv6, both } + } - let socket2_socket = socket2::Socket::new( - socket2::Domain::for_address(bind_addr), - socket2::Type::DGRAM, - Some(socket2::Protocol::UDP), - )?; - setup_sokcet2(&socket2_socket, &bind_addr)?; - let socket = std::net::UdpSocket::from(socket2_socket); + fn load(global_ctx: &ArcGlobalCtx) -> &Self { + let capacity = global_ctx + .config + .get_flags() + .multi_thread + .then(std::thread::available_parallelism) + .and_then(|r| r.ok()) + .map(|n| n.get()) + .unwrap_or(1); - let runtime = - quinn::default_runtime().ok_or_else(|| std::io::Error::other("no async runtime found"))?; - let socket: NoGroAsyncUdpSocket = NoGroAsyncUdpSocket { - inner: runtime.wrap_udp_socket(socket)?, - }; - let mut endpoint = Endpoint::new_with_abstract_socket( - endpoint_config, - Some(server_config), - Arc::new(socket), - runtime, - )?; - endpoint.set_default_client_config(client_config); - Ok(endpoint) + let mgr = QUIC_ENDPOINT_MANAGER.get(); + match mgr { + Some(mgr) => { + for pool in [&mgr.ipv4, &mgr.ipv6, &mgr.both] { + pool.resize(); + } + } + None => { + let _ = QUIC_ENDPOINT_MANAGER.set(Self::new(capacity)); + } + } + + QUIC_ENDPOINT_MANAGER.get().unwrap() + } + + /// Get a QUIC endpoint to be used as a server + /// + /// # Arguments + /// * `addr`: listen address + fn server(global_ctx: &ArcGlobalCtx, addr: SocketAddr) -> std::io::Result { + let mgr = Self::load(global_ctx); + + let (pool, endpoint) = mgr.create(|mgr| { + let dual_stack = addr.ip() == Ipv6Addr::UNSPECIFIED && mgr.both.is_enabled(); + let pool = if addr.is_ipv4() { + &mgr.ipv4 + } else if dual_stack { + &mgr.both + } else { + &mgr.ipv6 + }; + (pool, Some((addr, dual_stack))) + })?; + + let endpoint = endpoint.expect("server endpoint creation should not return None"); + endpoint.set_server_config(Some(server_config())); + pool.push(endpoint.clone()); + + Ok(endpoint) + } + + /// Get a quic endpoint to be used as a client + /// + /// # Arguments + /// * `ip_version`: the IP version of the remote address + fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> std::io::Result { + let mgr = Self::load(global_ctx); + + let (pool, endpoint) = mgr.create(|mgr| { + let dual_stack = mgr.both.is_enabled(); + let (pool, addr) = match ip_version { + IpVersion::V4 if !dual_stack => (&mgr.ipv4, (Ipv4Addr::UNSPECIFIED, 0).into()), + _ => { + let pool = if dual_stack { &mgr.both } else { &mgr.ipv6 }; + (pool, (Ipv6Addr::UNSPECIFIED, 0).into()) + } + }; + if pool.is_full() { + (pool, None) + } else { + (pool, Some((addr, dual_stack))) + } + })?; + + if let Some(endpoint) = endpoint { + pool.try_push(endpoint); + } + + Ok(pool.with_iter(|iter| iter.min_by_key(|e| e.open_connections()).unwrap().clone())) + } + + async fn connect( + global_ctx: &ArcGlobalCtx, + addr: SocketAddr, + ) -> std::io::Result<(Endpoint, Connection)> { + let ip_version = if addr.ip().is_ipv4() { + IpVersion::V4 + } else { + IpVersion::V6 + }; + let endpoint = Self::client(global_ctx, ip_version)?; + let connection = endpoint + .connect(addr, "localhost") + .map_err(std::io::Error::other)? + .await?; + + Ok((endpoint, connection)) + } } - -#[allow(unused)] -pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"]; +//endregion struct ConnWrapper { conn: Connection, @@ -143,13 +347,15 @@ impl Drop for ConnWrapper { pub struct QuicTunnelListener { addr: url::Url, + global_ctx: ArcGlobalCtx, endpoint: Option, } impl QuicTunnelListener { - pub fn new(addr: url::Url) -> Self { + pub fn new(addr: url::Url, global_ctx: ArcGlobalCtx) -> Self { QuicTunnelListener { addr, + global_ctx, endpoint: None, } } @@ -192,13 +398,11 @@ impl QuicTunnelListener { impl TunnelListener for QuicTunnelListener { async fn listen(&mut self) -> Result<(), TunnelError> { let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?; - let endpoint = make_server_endpoint(addr) - .map_err(|e| anyhow::anyhow!("make server endpoint error: {:?}", e))?; - self.endpoint = Some(endpoint); - + let endpoint = QuicEndpointManager::server(&self.global_ctx, addr)?; self.addr - .set_port(Some(self.endpoint.as_ref().unwrap().local_addr()?.port())) + .set_port(Some(endpoint.local_addr()?.port())) .unwrap(); + self.endpoint = Some(endpoint); Ok(()) } @@ -222,15 +426,15 @@ impl TunnelListener for QuicTunnelListener { pub struct QuicTunnelConnector { addr: url::Url, - endpoint: Option, + global_ctx: ArcGlobalCtx, ip_version: IpVersion, } impl QuicTunnelConnector { - pub fn new(addr: url::Url) -> Self { + pub fn new(addr: url::Url, global_ctx: ArcGlobalCtx) -> Self { QuicTunnelConnector { addr, - endpoint: None, + global_ctx, ip_version: IpVersion::Both, } } @@ -240,38 +444,10 @@ impl QuicTunnelConnector { impl TunnelConnector for QuicTunnelConnector { async fn connect(&mut self) -> Result, TunnelError> { let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?; - if addr.port() == 0 { - return Err(TunnelError::InvalidAddr(format!( - "invalid remote QUIC port 0 in url: {} (port 0 is not a valid QUIC port)", - self.addr - ))); - } - let local_addr = if addr.is_ipv4() { - "0.0.0.0:0" - } else { - "[::]:0" - }; - - let mut endpoint = Endpoint::client(local_addr.parse().unwrap())?; - endpoint.set_default_client_config(client_config()); - - // connect to server - let connection = endpoint - .connect(addr, "localhost") - .map_err(|e| { - TunnelError::InvalidAddr(format!( - "failed to create QUIC connection, url: {}, error: {}", - self.addr, e - )) - })? - .await - .with_context(|| "connect failed")?; - tracing::info!("[client] connected: addr={}", connection.remote_address()); + let (endpoint, connection) = QuicEndpointManager::connect(&self.global_ctx, addr).await?; let local_addr = endpoint.local_addr()?; - self.endpoint = Some(endpoint); - let (w, r) = connection .open_bi() .await @@ -308,68 +484,112 @@ impl TunnelConnector for QuicTunnelConnector { #[cfg(test)] mod tests { + use crate::common::global_ctx::tests::get_mock_global_ctx_with_network; use crate::tunnel::{ - IpVersion, TunnelConnector, + TunnelConnector, common::tests::{_tunnel_bench, _tunnel_pingpong}, }; + use std::sync::LazyLock; + use tokio::runtime::{Builder, Runtime}; use super::*; - #[tokio::test] - async fn quic_pingpong() { - let listener = QuicTunnelListener::new("quic://0.0.0.0:21011".parse().unwrap()); - let connector = QuicTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap()); + // Shared runtime for all tests to avoid endpoint invalidation across runtimes + static RUNTIME: LazyLock = + LazyLock::new(|| Builder::new_multi_thread().enable_all().build().unwrap()); + + fn global_ctx() -> ArcGlobalCtx { + let identity = crate::common::config::NetworkIdentity::default(); + get_mock_global_ctx_with_network(Some(identity)) + } + + #[test] + fn quic_pingpong() { + RUNTIME.block_on(quic_pingpong_impl()) + } + async fn quic_pingpong_impl() { + let listener = QuicTunnelListener::new("quic://[::]:21011".parse().unwrap(), global_ctx()); + let connector = + QuicTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap(), global_ctx()); _tunnel_pingpong(listener, connector).await } - #[tokio::test] - async fn quic_bench() { - let listener = QuicTunnelListener::new("quic://0.0.0.0:21012".parse().unwrap()); - let connector = QuicTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap()); + #[test] + fn quic_bench() { + RUNTIME.block_on(quic_bench_impl()) + } + async fn quic_bench_impl() { + let listener = QuicTunnelListener::new("quic://[::]:21012".parse().unwrap(), global_ctx()); + let connector = + QuicTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap(), global_ctx()); _tunnel_bench(listener, connector).await } - #[tokio::test] - async fn ipv6_pingpong() { - let listener = QuicTunnelListener::new("quic://[::1]:31015".parse().unwrap()); - let connector = QuicTunnelConnector::new("quic://[::1]:31015".parse().unwrap()); + #[test] + fn ipv6_pingpong() { + RUNTIME.block_on(ipv6_pingpong_impl()) + } + async fn ipv6_pingpong_impl() { + let listener = QuicTunnelListener::new("quic://[::1]:31015".parse().unwrap(), global_ctx()); + let connector = + QuicTunnelConnector::new("quic://[::1]:31015".parse().unwrap(), global_ctx()); _tunnel_pingpong(listener, connector).await } - #[tokio::test] - async fn ipv6_domain_pingpong() { - let listener = QuicTunnelListener::new("quic://[::1]:31016".parse().unwrap()); - let mut connector = - QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap()); + #[test] + fn ipv6_domain_pingpong() { + RUNTIME.block_on(ipv6_domain_pingpong_impl()) + } + async fn ipv6_domain_pingpong_impl() { + let listener = QuicTunnelListener::new("quic://[::1]:31016".parse().unwrap(), global_ctx()); + let mut connector = QuicTunnelConnector::new( + "quic://test.easytier.top:31016".parse().unwrap(), + global_ctx(), + ); connector.set_ip_version(IpVersion::V6); _tunnel_pingpong(listener, connector).await; - let listener = QuicTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap()); - let mut connector = - QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap()); + let listener = + QuicTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap(), global_ctx()); + let mut connector = QuicTunnelConnector::new( + "quic://test.easytier.top:31016".parse().unwrap(), + global_ctx(), + ); connector.set_ip_version(IpVersion::V4); _tunnel_pingpong(listener, connector).await; } - #[tokio::test] - async fn test_alloc_port() { + #[test] + fn alloc_port() { + RUNTIME.block_on(alloc_port_impl()) + } + async fn alloc_port_impl() { // v4 - let mut listener = QuicTunnelListener::new("quic://0.0.0.0:0".parse().unwrap()); + let mut listener = + QuicTunnelListener::new("quic://0.0.0.0:0".parse().unwrap(), global_ctx()); listener.listen().await.unwrap(); let port = listener.local_url().port().unwrap(); assert!(port > 0); // v6 - let mut listener = QuicTunnelListener::new("quic://[::]:0".parse().unwrap()); + let mut listener = QuicTunnelListener::new("quic://[::]:0".parse().unwrap(), global_ctx()); listener.listen().await.unwrap(); let port = listener.local_url().port().unwrap(); assert!(port > 0); } - #[tokio::test] - async fn quic_connector_reject_port_zero() { - let mut connector = QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap()); - let err = connector.connect().await.unwrap_err().to_string(); - assert!(err.contains("port 0"), "unexpected error: {}", err); + #[test] + fn invalid_peer_addr() { + RUNTIME.block_on(invalid_peer_addr_impl()) + } + async fn invalid_peer_addr_impl() { + let mut connector = + QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap(), global_ctx()); + let err = connector.connect().await.unwrap_err(); + assert!( + err.to_string().contains("invalid remote address"), + "unexpected error: {:?}", + err + ); } } diff --git a/easytier/src/tunnel/tcp.rs b/easytier/src/tunnel/tcp.rs index abcb73f6..2061c2fe 100644 --- a/easytier/src/tunnel/tcp.rs +++ b/easytier/src/tunnel/tcp.rs @@ -1,7 +1,7 @@ use std::net::SocketAddr; use super::{FromUrl, TunnelInfo}; -use crate::tunnel::common::setup_sokcet2; +use crate::tunnel::common::setup_socket2; use async_trait::async_trait; use futures::stream::FuturesUnordered; use tokio::net::{TcpListener, TcpSocket, TcpStream}; @@ -66,7 +66,7 @@ impl TunnelListener for TcpTunnelListener { socket2::Type::STREAM, Some(socket2::Protocol::TCP), )?; - setup_sokcet2(&socket2_socket, &addr)?; + setup_socket2(&socket2_socket, &addr, true)?; let socket = TcpSocket::from_std_stream(socket2_socket.into()); if let Err(e) = socket.set_nodelay(true) { @@ -175,7 +175,7 @@ impl TcpTunnelConnector { Some(socket2::Protocol::TCP), )?; - if let Err(e) = setup_sokcet2(&socket2_socket, bind_addr) { + if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) { tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e); continue; } diff --git a/easytier/src/tunnel/udp.rs b/easytier/src/tunnel/udp.rs index 6432b64f..7119d143 100644 --- a/easytier/src/tunnel/udp.rs +++ b/easytier/src/tunnel/udp.rs @@ -24,7 +24,7 @@ use tracing::{Instrument, instrument}; use super::{ FromUrl, IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelInfo, TunnelListener, TunnelUrl, - common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures}, + common::{setup_socket2, setup_socket2_ext, wait_for_connect_futures}, packet_def::{UDP_TUNNEL_HEADER_SIZE, UDPTunnelHeader, V6HolePunchPacket}, ring::{RingSink, RingStream}, }; @@ -545,9 +545,9 @@ impl TunnelListener for UdpTunnelListener { let tunnel_url: TunnelUrl = self.addr.clone().into(); if let Some(bind_dev) = tunnel_url.bind_dev() { - setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?; + setup_socket2_ext(&socket2_socket, &addr, Some(bind_dev), true)?; } else { - setup_sokcet2(&socket2_socket, &addr)?; + setup_socket2(&socket2_socket, &addr, true)?; } self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?)); @@ -838,7 +838,7 @@ impl UdpTunnelConnector { socket2::Type::DGRAM, Some(socket2::Protocol::UDP), )?; - if let Err(e) = setup_sokcet2(&socket2_socket, bind_addr) { + if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) { tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e); continue; } @@ -1040,7 +1040,7 @@ mod tests { Some(socket2::Protocol::UDP), ) .unwrap(); - setup_sokcet2_ext(&socket2_socket, &addr, bind_dev.clone()).unwrap(); + setup_socket2_ext(&socket2_socket, &addr, bind_dev.clone(), true).unwrap(); } } diff --git a/easytier/src/tunnel/websocket.rs b/easytier/src/tunnel/websocket.rs index eed5f111..2772dad7 100644 --- a/easytier/src/tunnel/websocket.rs +++ b/easytier/src/tunnel/websocket.rs @@ -1,6 +1,6 @@ use super::{ FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener, - common::{TunnelWrapper, setup_sokcet2, wait_for_connect_futures}, + common::{TunnelWrapper, setup_socket2, wait_for_connect_futures}, insecure_tls::{get_insecure_tls_cert, init_crypto_provider}, packet_def::{ZCPacket, ZCPacketType}, }; @@ -166,7 +166,7 @@ impl TunnelListener for WsTunnelListener { socket2::Type::STREAM, Some(socket2::Protocol::TCP), )?; - setup_sokcet2(&socket2_socket, &addr)?; + setup_socket2(&socket2_socket, &addr, true)?; let socket = TcpSocket::from_std_stream(socket2_socket.into()); self.addr @@ -291,7 +291,7 @@ impl WsTunnelConnector { Some(socket2::Protocol::TCP), )?; - if let Err(e) = setup_sokcet2(&socket2_socket, bind_addr) { + if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) { tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e); continue; } diff --git a/easytier/src/tunnel/wireguard.rs b/easytier/src/tunnel/wireguard.rs index d1968c60..844f384b 100644 --- a/easytier/src/tunnel/wireguard.rs +++ b/easytier/src/tunnel/wireguard.rs @@ -23,7 +23,7 @@ use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet}; use super::{ FromUrl, IpVersion, Tunnel, TunnelError, TunnelInfo, TunnelListener, TunnelUrl, ZCPacketSink, ZCPacketStream, - common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures}, + common::{setup_socket2, setup_socket2_ext, wait_for_connect_futures}, generate_digest_from_str, packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacketType}, ring::create_ring_tunnel_pair, @@ -563,9 +563,9 @@ impl TunnelListener for WgTunnelListener { let tunnel_url: TunnelUrl = self.addr.clone().into(); if let Some(bind_dev) = tunnel_url.bind_dev() { - setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?; + setup_socket2_ext(&socket2_socket, &addr, Some(bind_dev), true)?; } else { - setup_sokcet2(&socket2_socket, &addr)?; + setup_socket2(&socket2_socket, &addr, true)?; } self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?)); @@ -700,7 +700,7 @@ impl WgTunnelConnector { socket2::Type::DGRAM, Some(socket2::Protocol::UDP), )?; - setup_sokcet2_ext(&socket2_socket, &"[::]:0".parse().unwrap(), None)?; + setup_socket2_ext(&socket2_socket, &"[::]:0".parse().unwrap(), None, true)?; let socket = UdpSocket::from_std(socket2_socket.into())?; Self::connect_with_socket(self.addr.clone(), self.config.clone(), socket, addr).await } @@ -728,7 +728,7 @@ impl super::TunnelConnector for WgTunnelConnector { socket2::Type::DGRAM, Some(socket2::Protocol::UDP), )?; - if let Err(e) = setup_sokcet2(&socket2_socket, &bind_addr) { + if let Err(e) = setup_socket2(&socket2_socket, &bind_addr, true) { tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e); continue; }