mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-04-22 16:17:23 +08:00
refactor: remove NoGroAsyncUdpSocket (#1867)
This commit is contained in:
@@ -105,7 +105,9 @@ pub async fn create_connector_by_url(
|
||||
IpScheme::Tcp => TcpTunnelConnector::new(url).boxed(),
|
||||
IpScheme::Udp => UdpTunnelConnector::new(url).boxed(),
|
||||
#[cfg(feature = "quic")]
|
||||
IpScheme::Quic => tunnel::quic::QuicTunnelConnector::new(url).boxed(),
|
||||
IpScheme::Quic => {
|
||||
tunnel::quic::QuicTunnelConnector::new(url, global_ctx.clone()).boxed()
|
||||
}
|
||||
#[cfg(feature = "wireguard")]
|
||||
IpScheme::Wg => {
|
||||
use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector};
|
||||
|
||||
@@ -26,7 +26,9 @@ use derivative::Derivative;
|
||||
use derive_more::{Constructor, Deref, DerefMut, From, Into};
|
||||
use prost::Message;
|
||||
use quinn::udp::{EcnCodepoint, RecvMeta, Transmit};
|
||||
use quinn::{AsyncUdpSocket, Endpoint, RecvStream, SendStream, StreamId, TokioRuntime, UdpPoller};
|
||||
use quinn::{
|
||||
AsyncUdpSocket, Endpoint, RecvStream, SendStream, StreamId, UdpPoller, default_runtime,
|
||||
};
|
||||
use std::cmp::min;
|
||||
use std::future::Future;
|
||||
use std::io::IoSliceMut;
|
||||
@@ -806,7 +808,7 @@ impl QuicProxy {
|
||||
endpoint_config(),
|
||||
Some(server_config()),
|
||||
Arc::new(socket),
|
||||
Arc::new(TokioRuntime),
|
||||
default_runtime().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
endpoint.set_default_client_config(client_config());
|
||||
@@ -1020,7 +1022,7 @@ mod tests {
|
||||
endpoint_config.clone(),
|
||||
Some(server_config.clone()),
|
||||
socket_client.clone(),
|
||||
Arc::new(TokioRuntime),
|
||||
default_runtime().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
client_endpoint.set_default_client_config(client_config.clone());
|
||||
@@ -1030,7 +1032,7 @@ mod tests {
|
||||
endpoint_config.clone(),
|
||||
Some(server_config.clone()),
|
||||
socket_server.clone(),
|
||||
Arc::new(TokioRuntime),
|
||||
default_runtime().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
server_endpoint.set_default_client_config(client_config.clone());
|
||||
|
||||
@@ -31,7 +31,7 @@ use crate::{
|
||||
tokio_smoltcp::{BufferSize, Net, NetConfig, channel_device},
|
||||
},
|
||||
tunnel::{
|
||||
common::setup_sokcet2,
|
||||
common::setup_socket2,
|
||||
packet_def::{PacketType, ZCPacket},
|
||||
},
|
||||
};
|
||||
@@ -336,7 +336,7 @@ fn bind_tcp_socket(addr: SocketAddr, net_ns: NetNS) -> Result<TcpListener, Error
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
|
||||
setup_sokcet2(&socket2_socket, &addr)?;
|
||||
setup_socket2(&socket2_socket, &addr, true)?;
|
||||
|
||||
let socket = TcpSocket::from_std_stream(socket2_socket.into());
|
||||
|
||||
@@ -355,7 +355,7 @@ fn bind_udp_socket(addr: SocketAddr, net_ns: NetNS) -> Result<UdpSocket, Error>
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
|
||||
setup_sokcet2(&socket2_socket, &addr)?;
|
||||
setup_socket2(&socket2_socket, &addr, true)?;
|
||||
|
||||
Ok(UdpSocket::from_std(socket2_socket.into())?)
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ use crate::{
|
||||
gateway::ip_reassembler::{ComposeIpv4PacketArgs, compose_ipv4_packet},
|
||||
peers::{PeerPacketFilter, peer_manager::PeerManager},
|
||||
tunnel::{
|
||||
common::{reserve_buf, setup_sokcet2},
|
||||
common::{reserve_buf, setup_socket2},
|
||||
packet_def::{PacketType, ZCPacket},
|
||||
},
|
||||
};
|
||||
@@ -72,7 +72,7 @@ impl UdpNatEntry {
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
let dst_socket_addr = "0.0.0.0:0".parse().unwrap();
|
||||
setup_sokcet2(&socket2_socket, &dst_socket_addr)?;
|
||||
setup_socket2(&socket2_socket, &dst_socket_addr, true)?;
|
||||
Some(UdpSocket::from_std(socket2_socket.into())?)
|
||||
};
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ use crate::{
|
||||
|
||||
pub fn create_listener_by_url(
|
||||
l: &url::Url,
|
||||
#[allow(unused_variables)] ctx: ArcGlobalCtx,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
) -> Result<Box<dyn TunnelListener>, Error> {
|
||||
Ok(match l.try_into()? {
|
||||
TunnelScheme::Ip(scheme) => match scheme {
|
||||
@@ -34,7 +34,7 @@ pub fn create_listener_by_url(
|
||||
#[cfg(feature = "wireguard")]
|
||||
IpScheme::Wg => {
|
||||
use crate::tunnel::wireguard::{WgConfig, WgTunnelListener};
|
||||
let nid = ctx.get_network_identity();
|
||||
let nid = global_ctx.get_network_identity();
|
||||
let wg_config = WgConfig::new_from_network_identity(
|
||||
&nid.network_name,
|
||||
&nid.network_secret.unwrap_or_default(),
|
||||
@@ -42,7 +42,9 @@ pub fn create_listener_by_url(
|
||||
WgTunnelListener::new(l.clone(), wg_config).boxed()
|
||||
}
|
||||
#[cfg(feature = "quic")]
|
||||
IpScheme::Quic => tunnel::quic::QuicTunnelListener::new(l.clone()).boxed(),
|
||||
IpScheme::Quic => {
|
||||
tunnel::quic::QuicTunnelListener::new(l.clone(), global_ctx.clone()).boxed()
|
||||
}
|
||||
#[cfg(feature = "websocket")]
|
||||
IpScheme::Ws | IpScheme::Wss => {
|
||||
tunnel::websocket::WsTunnelListener::new(l.clone()).boxed()
|
||||
|
||||
@@ -344,10 +344,11 @@ pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn setup_sokcet2_ext(
|
||||
pub(crate) fn setup_socket2_ext(
|
||||
socket2_socket: &socket2::Socket,
|
||||
bind_addr: &SocketAddr,
|
||||
#[allow(unused_variables)] bind_dev: Option<String>,
|
||||
only_v6: bool,
|
||||
) -> Result<(), TunnelError> {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
@@ -356,7 +357,7 @@ pub(crate) fn setup_sokcet2_ext(
|
||||
}
|
||||
|
||||
if bind_addr.is_ipv6() {
|
||||
socket2_socket.set_only_v6(true)?;
|
||||
socket2_socket.set_only_v6(only_v6)?;
|
||||
}
|
||||
|
||||
socket2_socket.set_nonblocking(true)?;
|
||||
@@ -428,14 +429,16 @@ where
|
||||
Err(last_err.unwrap_or(TunnelError::Shutdown))
|
||||
}
|
||||
|
||||
pub(crate) fn setup_sokcet2(
|
||||
pub(crate) fn setup_socket2(
|
||||
socket2_socket: &socket2::Socket,
|
||||
bind_addr: &SocketAddr,
|
||||
only_v6: bool,
|
||||
) -> Result<(), TunnelError> {
|
||||
setup_sokcet2_ext(
|
||||
setup_socket2_ext(
|
||||
socket2_socket,
|
||||
bind_addr,
|
||||
super::common::get_interface_name_by_ip(&bind_addr.ip()),
|
||||
only_v6,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ pub mod unix;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum TunnelError {
|
||||
#[error("io error")]
|
||||
#[error("io error: {0}")]
|
||||
IOError(#[from] std::io::Error),
|
||||
#[error("invalid packet. msg: {0}")]
|
||||
InvalidPacket(String),
|
||||
|
||||
+358
-138
@@ -2,22 +2,25 @@
|
||||
//!
|
||||
//! Checkout the `README.md` for guidance.
|
||||
|
||||
use std::{
|
||||
error::Error, io::IoSliceMut, net::SocketAddr, pin::Pin, sync::Arc, task::Poll, time::Duration,
|
||||
};
|
||||
|
||||
use super::{FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener};
|
||||
use crate::common::global_ctx::ArcGlobalCtx;
|
||||
use crate::tunnel::{
|
||||
FromUrl, TunnelInfo,
|
||||
common::{FramedReader, FramedWriter, TunnelWrapper, setup_sokcet2},
|
||||
TunnelInfo,
|
||||
common::{FramedReader, FramedWriter, TunnelWrapper, setup_socket2},
|
||||
};
|
||||
use anyhow::Context;
|
||||
|
||||
use super::{IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener};
|
||||
use derivative::Derivative;
|
||||
use derive_more::{Deref, DerefMut};
|
||||
use parking_lot::RwLock;
|
||||
use quinn::{
|
||||
AsyncUdpSocket, ClientConfig, Connection, Endpoint, EndpointConfig, ServerConfig,
|
||||
TransportConfig, UdpPoller, congestion::BbrConfig, udp::RecvMeta,
|
||||
ClientConfig, Connection, Endpoint, EndpointConfig, ServerConfig, TransportConfig,
|
||||
congestion::BbrConfig, default_runtime,
|
||||
};
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::sync::OnceLock;
|
||||
use std::{net::SocketAddr, sync::Arc, time::Duration};
|
||||
|
||||
// region config
|
||||
pub fn transport_config() -> Arc<TransportConfig> {
|
||||
let mut config = TransportConfig::default();
|
||||
|
||||
@@ -50,86 +53,287 @@ pub fn endpoint_config() -> EndpointConfig {
|
||||
config.max_udp_payload_size(65527).unwrap();
|
||||
config
|
||||
}
|
||||
//endregion
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct NoGroAsyncUdpSocket {
|
||||
inner: Arc<dyn AsyncUdpSocket>,
|
||||
//region rw pool
|
||||
#[derive(Derivative)]
|
||||
#[derivative(Default(bound = ""))]
|
||||
#[derive(Debug, Deref, DerefMut)]
|
||||
struct RwPoolInner<Item> {
|
||||
#[deref]
|
||||
#[deref_mut]
|
||||
pool: Vec<Item>,
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl AsyncUdpSocket for NoGroAsyncUdpSocket {
|
||||
fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn UdpPoller>> {
|
||||
self.inner.clone().create_io_poller()
|
||||
#[derive(Debug)]
|
||||
struct RwPool<Item> {
|
||||
ephemeral: RwLock<RwPoolInner<Item>>,
|
||||
persistent: RwLock<RwPoolInner<Item>>,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl<Item> RwPool<Item> {
|
||||
fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
ephemeral: RwLock::new(RwPoolInner::default()),
|
||||
persistent: RwLock::new(RwPoolInner::default()),
|
||||
capacity,
|
||||
}
|
||||
}
|
||||
|
||||
fn try_send(&self, transmit: &quinn::udp::Transmit) -> std::io::Result<()> {
|
||||
self.inner.try_send(transmit)
|
||||
}
|
||||
|
||||
/// Receive UDP datagrams, or register to be woken if receiving may succeed in the future
|
||||
fn poll_recv(
|
||||
/// return the capacity of the ephemeral pool;
|
||||
/// if `ephemeral` or `persistent` is None, read lock `self`'s pool
|
||||
fn capacity(
|
||||
&self,
|
||||
cx: &mut std::task::Context,
|
||||
bufs: &mut [IoSliceMut<'_>],
|
||||
meta: &mut [RecvMeta],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.inner.poll_recv(cx, bufs, meta)
|
||||
ephemeral: Option<&RwPoolInner<Item>>,
|
||||
persistent: Option<&RwPoolInner<Item>>,
|
||||
) -> usize {
|
||||
let guard;
|
||||
let ephemeral = if let Some(ephemeral) = ephemeral {
|
||||
ephemeral
|
||||
} else {
|
||||
guard = self.ephemeral.read();
|
||||
&guard
|
||||
};
|
||||
|
||||
let guard;
|
||||
let persistent = if let Some(persistent) = persistent {
|
||||
persistent
|
||||
} else {
|
||||
guard = self.persistent.read();
|
||||
&guard
|
||||
};
|
||||
|
||||
(self.capacity * ephemeral.enabled as usize).saturating_sub(persistent.len())
|
||||
}
|
||||
|
||||
/// Look up the local IP address and port used by this socket
|
||||
fn local_addr(&self) -> std::io::Result<SocketAddr> {
|
||||
self.inner.local_addr()
|
||||
fn is_full(&self) -> bool {
|
||||
let pool = self.ephemeral.read();
|
||||
pool.len() >= self.capacity(Some(&pool), None)
|
||||
}
|
||||
|
||||
fn may_fragment(&self) -> bool {
|
||||
self.inner.may_fragment()
|
||||
fn is_enabled(&self) -> bool {
|
||||
self.ephemeral.read().enabled
|
||||
}
|
||||
|
||||
fn max_transmit_segments(&self) -> usize {
|
||||
self.inner.max_transmit_segments()
|
||||
fn enable(&self) {
|
||||
self.ephemeral.write().enabled = true;
|
||||
self.resize();
|
||||
}
|
||||
|
||||
fn max_receive_segments(&self) -> usize {
|
||||
1
|
||||
fn disable(&self) {
|
||||
self.ephemeral.write().enabled = false;
|
||||
self.resize();
|
||||
}
|
||||
|
||||
/// push an item to the persistent pool
|
||||
fn push(&self, item: Item) {
|
||||
self.persistent.write().push(item);
|
||||
self.resize();
|
||||
}
|
||||
|
||||
/// try to push an item to the ephemeral pool, return the item if full
|
||||
fn try_push(&self, item: Item) -> Option<Item> {
|
||||
let mut pool = self.ephemeral.write();
|
||||
if pool.len() < self.capacity(Some(&pool), None) {
|
||||
pool.push(item);
|
||||
return None;
|
||||
}
|
||||
Some(item)
|
||||
}
|
||||
|
||||
fn resize(&self) {
|
||||
let resize = {
|
||||
let pool = self.ephemeral.read();
|
||||
pool.capacity() != self.capacity(Some(&pool), None)
|
||||
};
|
||||
if resize {
|
||||
let mut pool = self.ephemeral.write();
|
||||
let capacity = self.capacity(Some(&pool), None);
|
||||
pool.reserve_exact(capacity);
|
||||
pool.truncate(capacity);
|
||||
pool.shrink_to(capacity);
|
||||
}
|
||||
}
|
||||
|
||||
fn with_iter<F, R>(&self, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&mut dyn Iterator<Item = &Item>) -> R,
|
||||
{
|
||||
let ephemeral = self.ephemeral.read();
|
||||
let persistent = self.persistent.read();
|
||||
f(&mut persistent.iter().chain(ephemeral.iter()))
|
||||
}
|
||||
}
|
||||
//endregion
|
||||
|
||||
//region endpoint manager
|
||||
#[derive(Debug)]
|
||||
pub struct QuicEndpointManager {
|
||||
ipv4: RwPool<Endpoint>,
|
||||
ipv6: RwPool<Endpoint>,
|
||||
both: RwPool<Endpoint>,
|
||||
}
|
||||
|
||||
static QUIC_ENDPOINT_MANAGER: OnceLock<QuicEndpointManager> = OnceLock::new();
|
||||
|
||||
impl QuicEndpointManager {
|
||||
fn try_create(addr: SocketAddr, dual_stack: bool) -> std::io::Result<Endpoint> {
|
||||
let socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
setup_socket2(&socket, &addr, addr.is_ipv6() && !dual_stack)
|
||||
.map_err(std::io::Error::other)?;
|
||||
let socket = std::net::UdpSocket::from(socket);
|
||||
let runtime = default_runtime().ok_or(std::io::Error::other("no async runtime found"))?;
|
||||
let mut endpoint = Endpoint::new_with_abstract_socket(
|
||||
endpoint_config(),
|
||||
None,
|
||||
runtime.wrap_udp_socket(socket)?,
|
||||
runtime,
|
||||
)?;
|
||||
endpoint.set_default_client_config(client_config());
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
fn create<F>(&self, mut selector: F) -> std::io::Result<(&RwPool<Endpoint>, Option<Endpoint>)>
|
||||
where
|
||||
F: FnMut(&QuicEndpointManager) -> (&RwPool<Endpoint>, Option<(SocketAddr, bool)>),
|
||||
{
|
||||
loop {
|
||||
let (pool, r) = selector(self);
|
||||
let Some((addr, dual_stack)) = r else {
|
||||
return Ok((pool, None));
|
||||
};
|
||||
|
||||
let endpoint = Self::try_create(addr, dual_stack);
|
||||
if let Err(e) = endpoint.as_ref()
|
||||
&& dual_stack
|
||||
{
|
||||
tracing::warn!("create dual stack quic endpoint failed: {:?}", e);
|
||||
self.both.disable();
|
||||
self.ipv4.enable();
|
||||
self.ipv6.enable();
|
||||
continue;
|
||||
}
|
||||
|
||||
return Ok((pool, Some(endpoint?)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs a QUIC endpoint configured to listen for incoming connections on a certain address
|
||||
/// and port.
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// - an [`Endpoint`] configured to accept incoming QUIC connections
|
||||
#[allow(unused)]
|
||||
pub fn make_server_endpoint(bind_addr: SocketAddr) -> Result<Endpoint, Box<dyn Error>> {
|
||||
let server_config = server_config();
|
||||
let client_config = client_config();
|
||||
let endpoint_config = endpoint_config();
|
||||
impl QuicEndpointManager {
|
||||
fn new(capacity: usize) -> Self {
|
||||
let ipv4 = RwPool::new(capacity.div_ceil(2));
|
||||
let ipv6 = RwPool::new(capacity.div_ceil(2));
|
||||
let both = RwPool::new(capacity);
|
||||
both.enable();
|
||||
Self { ipv4, ipv6, both }
|
||||
}
|
||||
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(bind_addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, &bind_addr)?;
|
||||
let socket = std::net::UdpSocket::from(socket2_socket);
|
||||
fn load(global_ctx: &ArcGlobalCtx) -> &Self {
|
||||
let capacity = global_ctx
|
||||
.config
|
||||
.get_flags()
|
||||
.multi_thread
|
||||
.then(std::thread::available_parallelism)
|
||||
.and_then(|r| r.ok())
|
||||
.map(|n| n.get())
|
||||
.unwrap_or(1);
|
||||
|
||||
let runtime =
|
||||
quinn::default_runtime().ok_or_else(|| std::io::Error::other("no async runtime found"))?;
|
||||
let socket: NoGroAsyncUdpSocket = NoGroAsyncUdpSocket {
|
||||
inner: runtime.wrap_udp_socket(socket)?,
|
||||
};
|
||||
let mut endpoint = Endpoint::new_with_abstract_socket(
|
||||
endpoint_config,
|
||||
Some(server_config),
|
||||
Arc::new(socket),
|
||||
runtime,
|
||||
)?;
|
||||
endpoint.set_default_client_config(client_config);
|
||||
Ok(endpoint)
|
||||
let mgr = QUIC_ENDPOINT_MANAGER.get();
|
||||
match mgr {
|
||||
Some(mgr) => {
|
||||
for pool in [&mgr.ipv4, &mgr.ipv6, &mgr.both] {
|
||||
pool.resize();
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let _ = QUIC_ENDPOINT_MANAGER.set(Self::new(capacity));
|
||||
}
|
||||
}
|
||||
|
||||
QUIC_ENDPOINT_MANAGER.get().unwrap()
|
||||
}
|
||||
|
||||
/// Get a QUIC endpoint to be used as a server
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `addr`: listen address
|
||||
fn server(global_ctx: &ArcGlobalCtx, addr: SocketAddr) -> std::io::Result<Endpoint> {
|
||||
let mgr = Self::load(global_ctx);
|
||||
|
||||
let (pool, endpoint) = mgr.create(|mgr| {
|
||||
let dual_stack = addr.ip() == Ipv6Addr::UNSPECIFIED && mgr.both.is_enabled();
|
||||
let pool = if addr.is_ipv4() {
|
||||
&mgr.ipv4
|
||||
} else if dual_stack {
|
||||
&mgr.both
|
||||
} else {
|
||||
&mgr.ipv6
|
||||
};
|
||||
(pool, Some((addr, dual_stack)))
|
||||
})?;
|
||||
|
||||
let endpoint = endpoint.expect("server endpoint creation should not return None");
|
||||
endpoint.set_server_config(Some(server_config()));
|
||||
pool.push(endpoint.clone());
|
||||
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
/// Get a quic endpoint to be used as a client
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `ip_version`: the IP version of the remote address
|
||||
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> std::io::Result<Endpoint> {
|
||||
let mgr = Self::load(global_ctx);
|
||||
|
||||
let (pool, endpoint) = mgr.create(|mgr| {
|
||||
let dual_stack = mgr.both.is_enabled();
|
||||
let (pool, addr) = match ip_version {
|
||||
IpVersion::V4 if !dual_stack => (&mgr.ipv4, (Ipv4Addr::UNSPECIFIED, 0).into()),
|
||||
_ => {
|
||||
let pool = if dual_stack { &mgr.both } else { &mgr.ipv6 };
|
||||
(pool, (Ipv6Addr::UNSPECIFIED, 0).into())
|
||||
}
|
||||
};
|
||||
if pool.is_full() {
|
||||
(pool, None)
|
||||
} else {
|
||||
(pool, Some((addr, dual_stack)))
|
||||
}
|
||||
})?;
|
||||
|
||||
if let Some(endpoint) = endpoint {
|
||||
pool.try_push(endpoint);
|
||||
}
|
||||
|
||||
Ok(pool.with_iter(|iter| iter.min_by_key(|e| e.open_connections()).unwrap().clone()))
|
||||
}
|
||||
|
||||
async fn connect(
|
||||
global_ctx: &ArcGlobalCtx,
|
||||
addr: SocketAddr,
|
||||
) -> std::io::Result<(Endpoint, Connection)> {
|
||||
let ip_version = if addr.ip().is_ipv4() {
|
||||
IpVersion::V4
|
||||
} else {
|
||||
IpVersion::V6
|
||||
};
|
||||
let endpoint = Self::client(global_ctx, ip_version)?;
|
||||
let connection = endpoint
|
||||
.connect(addr, "localhost")
|
||||
.map_err(std::io::Error::other)?
|
||||
.await?;
|
||||
|
||||
Ok((endpoint, connection))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"];
|
||||
//endregion
|
||||
|
||||
struct ConnWrapper {
|
||||
conn: Connection,
|
||||
@@ -143,13 +347,15 @@ impl Drop for ConnWrapper {
|
||||
|
||||
pub struct QuicTunnelListener {
|
||||
addr: url::Url,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
endpoint: Option<Endpoint>,
|
||||
}
|
||||
|
||||
impl QuicTunnelListener {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
pub fn new(addr: url::Url, global_ctx: ArcGlobalCtx) -> Self {
|
||||
QuicTunnelListener {
|
||||
addr,
|
||||
global_ctx,
|
||||
endpoint: None,
|
||||
}
|
||||
}
|
||||
@@ -192,13 +398,11 @@ impl QuicTunnelListener {
|
||||
impl TunnelListener for QuicTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
|
||||
let endpoint = make_server_endpoint(addr)
|
||||
.map_err(|e| anyhow::anyhow!("make server endpoint error: {:?}", e))?;
|
||||
self.endpoint = Some(endpoint);
|
||||
|
||||
let endpoint = QuicEndpointManager::server(&self.global_ctx, addr)?;
|
||||
self.addr
|
||||
.set_port(Some(self.endpoint.as_ref().unwrap().local_addr()?.port()))
|
||||
.set_port(Some(endpoint.local_addr()?.port()))
|
||||
.unwrap();
|
||||
self.endpoint = Some(endpoint);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -222,15 +426,15 @@ impl TunnelListener for QuicTunnelListener {
|
||||
|
||||
pub struct QuicTunnelConnector {
|
||||
addr: url::Url,
|
||||
endpoint: Option<Endpoint>,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
ip_version: IpVersion,
|
||||
}
|
||||
|
||||
impl QuicTunnelConnector {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
pub fn new(addr: url::Url, global_ctx: ArcGlobalCtx) -> Self {
|
||||
QuicTunnelConnector {
|
||||
addr,
|
||||
endpoint: None,
|
||||
global_ctx,
|
||||
ip_version: IpVersion::Both,
|
||||
}
|
||||
}
|
||||
@@ -240,38 +444,10 @@ impl QuicTunnelConnector {
|
||||
impl TunnelConnector for QuicTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
|
||||
if addr.port() == 0 {
|
||||
return Err(TunnelError::InvalidAddr(format!(
|
||||
"invalid remote QUIC port 0 in url: {} (port 0 is not a valid QUIC port)",
|
||||
self.addr
|
||||
)));
|
||||
}
|
||||
let local_addr = if addr.is_ipv4() {
|
||||
"0.0.0.0:0"
|
||||
} else {
|
||||
"[::]:0"
|
||||
};
|
||||
|
||||
let mut endpoint = Endpoint::client(local_addr.parse().unwrap())?;
|
||||
endpoint.set_default_client_config(client_config());
|
||||
|
||||
// connect to server
|
||||
let connection = endpoint
|
||||
.connect(addr, "localhost")
|
||||
.map_err(|e| {
|
||||
TunnelError::InvalidAddr(format!(
|
||||
"failed to create QUIC connection, url: {}, error: {}",
|
||||
self.addr, e
|
||||
))
|
||||
})?
|
||||
.await
|
||||
.with_context(|| "connect failed")?;
|
||||
tracing::info!("[client] connected: addr={}", connection.remote_address());
|
||||
let (endpoint, connection) = QuicEndpointManager::connect(&self.global_ctx, addr).await?;
|
||||
|
||||
let local_addr = endpoint.local_addr()?;
|
||||
|
||||
self.endpoint = Some(endpoint);
|
||||
|
||||
let (w, r) = connection
|
||||
.open_bi()
|
||||
.await
|
||||
@@ -308,68 +484,112 @@ impl TunnelConnector for QuicTunnelConnector {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::common::global_ctx::tests::get_mock_global_ctx_with_network;
|
||||
use crate::tunnel::{
|
||||
IpVersion, TunnelConnector,
|
||||
TunnelConnector,
|
||||
common::tests::{_tunnel_bench, _tunnel_pingpong},
|
||||
};
|
||||
use std::sync::LazyLock;
|
||||
use tokio::runtime::{Builder, Runtime};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn quic_pingpong() {
|
||||
let listener = QuicTunnelListener::new("quic://0.0.0.0:21011".parse().unwrap());
|
||||
let connector = QuicTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap());
|
||||
// Shared runtime for all tests to avoid endpoint invalidation across runtimes
|
||||
static RUNTIME: LazyLock<Runtime> =
|
||||
LazyLock::new(|| Builder::new_multi_thread().enable_all().build().unwrap());
|
||||
|
||||
fn global_ctx() -> ArcGlobalCtx {
|
||||
let identity = crate::common::config::NetworkIdentity::default();
|
||||
get_mock_global_ctx_with_network(Some(identity))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quic_pingpong() {
|
||||
RUNTIME.block_on(quic_pingpong_impl())
|
||||
}
|
||||
async fn quic_pingpong_impl() {
|
||||
let listener = QuicTunnelListener::new("quic://[::]:21011".parse().unwrap(), global_ctx());
|
||||
let connector =
|
||||
QuicTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap(), global_ctx());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn quic_bench() {
|
||||
let listener = QuicTunnelListener::new("quic://0.0.0.0:21012".parse().unwrap());
|
||||
let connector = QuicTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap());
|
||||
#[test]
|
||||
fn quic_bench() {
|
||||
RUNTIME.block_on(quic_bench_impl())
|
||||
}
|
||||
async fn quic_bench_impl() {
|
||||
let listener = QuicTunnelListener::new("quic://[::]:21012".parse().unwrap(), global_ctx());
|
||||
let connector =
|
||||
QuicTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap(), global_ctx());
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ipv6_pingpong() {
|
||||
let listener = QuicTunnelListener::new("quic://[::1]:31015".parse().unwrap());
|
||||
let connector = QuicTunnelConnector::new("quic://[::1]:31015".parse().unwrap());
|
||||
#[test]
|
||||
fn ipv6_pingpong() {
|
||||
RUNTIME.block_on(ipv6_pingpong_impl())
|
||||
}
|
||||
async fn ipv6_pingpong_impl() {
|
||||
let listener = QuicTunnelListener::new("quic://[::1]:31015".parse().unwrap(), global_ctx());
|
||||
let connector =
|
||||
QuicTunnelConnector::new("quic://[::1]:31015".parse().unwrap(), global_ctx());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ipv6_domain_pingpong() {
|
||||
let listener = QuicTunnelListener::new("quic://[::1]:31016".parse().unwrap());
|
||||
let mut connector =
|
||||
QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
|
||||
#[test]
|
||||
fn ipv6_domain_pingpong() {
|
||||
RUNTIME.block_on(ipv6_domain_pingpong_impl())
|
||||
}
|
||||
async fn ipv6_domain_pingpong_impl() {
|
||||
let listener = QuicTunnelListener::new("quic://[::1]:31016".parse().unwrap(), global_ctx());
|
||||
let mut connector = QuicTunnelConnector::new(
|
||||
"quic://test.easytier.top:31016".parse().unwrap(),
|
||||
global_ctx(),
|
||||
);
|
||||
connector.set_ip_version(IpVersion::V6);
|
||||
_tunnel_pingpong(listener, connector).await;
|
||||
|
||||
let listener = QuicTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap());
|
||||
let mut connector =
|
||||
QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
|
||||
let listener =
|
||||
QuicTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap(), global_ctx());
|
||||
let mut connector = QuicTunnelConnector::new(
|
||||
"quic://test.easytier.top:31016".parse().unwrap(),
|
||||
global_ctx(),
|
||||
);
|
||||
connector.set_ip_version(IpVersion::V4);
|
||||
_tunnel_pingpong(listener, connector).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_alloc_port() {
|
||||
#[test]
|
||||
fn alloc_port() {
|
||||
RUNTIME.block_on(alloc_port_impl())
|
||||
}
|
||||
async fn alloc_port_impl() {
|
||||
// v4
|
||||
let mut listener = QuicTunnelListener::new("quic://0.0.0.0:0".parse().unwrap());
|
||||
let mut listener =
|
||||
QuicTunnelListener::new("quic://0.0.0.0:0".parse().unwrap(), global_ctx());
|
||||
listener.listen().await.unwrap();
|
||||
let port = listener.local_url().port().unwrap();
|
||||
assert!(port > 0);
|
||||
|
||||
// v6
|
||||
let mut listener = QuicTunnelListener::new("quic://[::]:0".parse().unwrap());
|
||||
let mut listener = QuicTunnelListener::new("quic://[::]:0".parse().unwrap(), global_ctx());
|
||||
listener.listen().await.unwrap();
|
||||
let port = listener.local_url().port().unwrap();
|
||||
assert!(port > 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn quic_connector_reject_port_zero() {
|
||||
let mut connector = QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap());
|
||||
let err = connector.connect().await.unwrap_err().to_string();
|
||||
assert!(err.contains("port 0"), "unexpected error: {}", err);
|
||||
#[test]
|
||||
fn invalid_peer_addr() {
|
||||
RUNTIME.block_on(invalid_peer_addr_impl())
|
||||
}
|
||||
async fn invalid_peer_addr_impl() {
|
||||
let mut connector =
|
||||
QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap(), global_ctx());
|
||||
let err = connector.connect().await.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("invalid remote address"),
|
||||
"unexpected error: {:?}",
|
||||
err
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use super::{FromUrl, TunnelInfo};
|
||||
use crate::tunnel::common::setup_sokcet2;
|
||||
use crate::tunnel::common::setup_socket2;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||
@@ -66,7 +66,7 @@ impl TunnelListener for TcpTunnelListener {
|
||||
socket2::Type::STREAM,
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, &addr)?;
|
||||
setup_socket2(&socket2_socket, &addr, true)?;
|
||||
let socket = TcpSocket::from_std_stream(socket2_socket.into());
|
||||
|
||||
if let Err(e) = socket.set_nodelay(true) {
|
||||
@@ -175,7 +175,7 @@ impl TcpTunnelConnector {
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
|
||||
if let Err(e) = setup_sokcet2(&socket2_socket, bind_addr) {
|
||||
if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
|
||||
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ use tracing::{Instrument, instrument};
|
||||
use super::{
|
||||
FromUrl, IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelInfo, TunnelListener,
|
||||
TunnelUrl,
|
||||
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
|
||||
common::{setup_socket2, setup_socket2_ext, wait_for_connect_futures},
|
||||
packet_def::{UDP_TUNNEL_HEADER_SIZE, UDPTunnelHeader, V6HolePunchPacket},
|
||||
ring::{RingSink, RingStream},
|
||||
};
|
||||
@@ -545,9 +545,9 @@ impl TunnelListener for UdpTunnelListener {
|
||||
|
||||
let tunnel_url: TunnelUrl = self.addr.clone().into();
|
||||
if let Some(bind_dev) = tunnel_url.bind_dev() {
|
||||
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
|
||||
setup_socket2_ext(&socket2_socket, &addr, Some(bind_dev), true)?;
|
||||
} else {
|
||||
setup_sokcet2(&socket2_socket, &addr)?;
|
||||
setup_socket2(&socket2_socket, &addr, true)?;
|
||||
}
|
||||
|
||||
self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
|
||||
@@ -838,7 +838,7 @@ impl UdpTunnelConnector {
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
if let Err(e) = setup_sokcet2(&socket2_socket, bind_addr) {
|
||||
if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
|
||||
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
@@ -1040,7 +1040,7 @@ mod tests {
|
||||
Some(socket2::Protocol::UDP),
|
||||
)
|
||||
.unwrap();
|
||||
setup_sokcet2_ext(&socket2_socket, &addr, bind_dev.clone()).unwrap();
|
||||
setup_socket2_ext(&socket2_socket, &addr, bind_dev.clone(), true).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::{
|
||||
FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
|
||||
common::{TunnelWrapper, setup_sokcet2, wait_for_connect_futures},
|
||||
common::{TunnelWrapper, setup_socket2, wait_for_connect_futures},
|
||||
insecure_tls::{get_insecure_tls_cert, init_crypto_provider},
|
||||
packet_def::{ZCPacket, ZCPacketType},
|
||||
};
|
||||
@@ -166,7 +166,7 @@ impl TunnelListener for WsTunnelListener {
|
||||
socket2::Type::STREAM,
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, &addr)?;
|
||||
setup_socket2(&socket2_socket, &addr, true)?;
|
||||
let socket = TcpSocket::from_std_stream(socket2_socket.into());
|
||||
|
||||
self.addr
|
||||
@@ -291,7 +291,7 @@ impl WsTunnelConnector {
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
|
||||
if let Err(e) = setup_sokcet2(&socket2_socket, bind_addr) {
|
||||
if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
|
||||
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
||||
use super::{
|
||||
FromUrl, IpVersion, Tunnel, TunnelError, TunnelInfo, TunnelListener, TunnelUrl, ZCPacketSink,
|
||||
ZCPacketStream,
|
||||
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
|
||||
common::{setup_socket2, setup_socket2_ext, wait_for_connect_futures},
|
||||
generate_digest_from_str,
|
||||
packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacketType},
|
||||
ring::create_ring_tunnel_pair,
|
||||
@@ -563,9 +563,9 @@ impl TunnelListener for WgTunnelListener {
|
||||
|
||||
let tunnel_url: TunnelUrl = self.addr.clone().into();
|
||||
if let Some(bind_dev) = tunnel_url.bind_dev() {
|
||||
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
|
||||
setup_socket2_ext(&socket2_socket, &addr, Some(bind_dev), true)?;
|
||||
} else {
|
||||
setup_sokcet2(&socket2_socket, &addr)?;
|
||||
setup_socket2(&socket2_socket, &addr, true)?;
|
||||
}
|
||||
|
||||
self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
|
||||
@@ -700,7 +700,7 @@ impl WgTunnelConnector {
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
setup_sokcet2_ext(&socket2_socket, &"[::]:0".parse().unwrap(), None)?;
|
||||
setup_socket2_ext(&socket2_socket, &"[::]:0".parse().unwrap(), None, true)?;
|
||||
let socket = UdpSocket::from_std(socket2_socket.into())?;
|
||||
Self::connect_with_socket(self.addr.clone(), self.config.clone(), socket, addr).await
|
||||
}
|
||||
@@ -728,7 +728,7 @@ impl super::TunnelConnector for WgTunnelConnector {
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
if let Err(e) = setup_sokcet2(&socket2_socket, &bind_addr) {
|
||||
if let Err(e) = setup_socket2(&socket2_socket, &bind_addr, true) {
|
||||
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user