refactor: remove NoGroAsyncUdpSocket (#1867)

This commit is contained in:
Luna Yao
2026-04-10 17:22:08 +02:00
committed by GitHub
parent 19c80c7b9c
commit 8311b11713
12 changed files with 401 additions and 172 deletions
+3 -1
View File
@@ -105,7 +105,9 @@ pub async fn create_connector_by_url(
IpScheme::Tcp => TcpTunnelConnector::new(url).boxed(),
IpScheme::Udp => UdpTunnelConnector::new(url).boxed(),
#[cfg(feature = "quic")]
IpScheme::Quic => tunnel::quic::QuicTunnelConnector::new(url).boxed(),
IpScheme::Quic => {
tunnel::quic::QuicTunnelConnector::new(url, global_ctx.clone()).boxed()
}
#[cfg(feature = "wireguard")]
IpScheme::Wg => {
use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector};
+6 -4
View File
@@ -26,7 +26,9 @@ use derivative::Derivative;
use derive_more::{Constructor, Deref, DerefMut, From, Into};
use prost::Message;
use quinn::udp::{EcnCodepoint, RecvMeta, Transmit};
use quinn::{AsyncUdpSocket, Endpoint, RecvStream, SendStream, StreamId, TokioRuntime, UdpPoller};
use quinn::{
AsyncUdpSocket, Endpoint, RecvStream, SendStream, StreamId, UdpPoller, default_runtime,
};
use std::cmp::min;
use std::future::Future;
use std::io::IoSliceMut;
@@ -806,7 +808,7 @@ impl QuicProxy {
endpoint_config(),
Some(server_config()),
Arc::new(socket),
Arc::new(TokioRuntime),
default_runtime().unwrap(),
)
.unwrap();
endpoint.set_default_client_config(client_config());
@@ -1020,7 +1022,7 @@ mod tests {
endpoint_config.clone(),
Some(server_config.clone()),
socket_client.clone(),
Arc::new(TokioRuntime),
default_runtime().unwrap(),
)
.unwrap();
client_endpoint.set_default_client_config(client_config.clone());
@@ -1030,7 +1032,7 @@ mod tests {
endpoint_config.clone(),
Some(server_config.clone()),
socket_server.clone(),
Arc::new(TokioRuntime),
default_runtime().unwrap(),
)
.unwrap();
server_endpoint.set_default_client_config(client_config.clone());
+3 -3
View File
@@ -31,7 +31,7 @@ use crate::{
tokio_smoltcp::{BufferSize, Net, NetConfig, channel_device},
},
tunnel::{
common::setup_sokcet2,
common::setup_socket2,
packet_def::{PacketType, ZCPacket},
},
};
@@ -336,7 +336,7 @@ fn bind_tcp_socket(addr: SocketAddr, net_ns: NetNS) -> Result<TcpListener, Error
Some(socket2::Protocol::TCP),
)?;
setup_sokcet2(&socket2_socket, &addr)?;
setup_socket2(&socket2_socket, &addr, true)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into());
@@ -355,7 +355,7 @@ fn bind_udp_socket(addr: SocketAddr, net_ns: NetNS) -> Result<UdpSocket, Error>
Some(socket2::Protocol::UDP),
)?;
setup_sokcet2(&socket2_socket, &addr)?;
setup_socket2(&socket2_socket, &addr, true)?;
Ok(UdpSocket::from_std(socket2_socket.into())?)
}
+2 -2
View File
@@ -29,7 +29,7 @@ use crate::{
gateway::ip_reassembler::{ComposeIpv4PacketArgs, compose_ipv4_packet},
peers::{PeerPacketFilter, peer_manager::PeerManager},
tunnel::{
common::{reserve_buf, setup_sokcet2},
common::{reserve_buf, setup_socket2},
packet_def::{PacketType, ZCPacket},
},
};
@@ -72,7 +72,7 @@ impl UdpNatEntry {
Some(socket2::Protocol::UDP),
)?;
let dst_socket_addr = "0.0.0.0:0".parse().unwrap();
setup_sokcet2(&socket2_socket, &dst_socket_addr)?;
setup_socket2(&socket2_socket, &dst_socket_addr, true)?;
Some(UdpSocket::from_std(socket2_socket.into())?)
};
+5 -3
View File
@@ -25,7 +25,7 @@ use crate::{
pub fn create_listener_by_url(
l: &url::Url,
#[allow(unused_variables)] ctx: ArcGlobalCtx,
global_ctx: ArcGlobalCtx,
) -> Result<Box<dyn TunnelListener>, Error> {
Ok(match l.try_into()? {
TunnelScheme::Ip(scheme) => match scheme {
@@ -34,7 +34,7 @@ pub fn create_listener_by_url(
#[cfg(feature = "wireguard")]
IpScheme::Wg => {
use crate::tunnel::wireguard::{WgConfig, WgTunnelListener};
let nid = ctx.get_network_identity();
let nid = global_ctx.get_network_identity();
let wg_config = WgConfig::new_from_network_identity(
&nid.network_name,
&nid.network_secret.unwrap_or_default(),
@@ -42,7 +42,9 @@ pub fn create_listener_by_url(
WgTunnelListener::new(l.clone(), wg_config).boxed()
}
#[cfg(feature = "quic")]
IpScheme::Quic => tunnel::quic::QuicTunnelListener::new(l.clone()).boxed(),
IpScheme::Quic => {
tunnel::quic::QuicTunnelListener::new(l.clone(), global_ctx.clone()).boxed()
}
#[cfg(feature = "websocket")]
IpScheme::Ws | IpScheme::Wss => {
tunnel::websocket::WsTunnelListener::new(l.clone()).boxed()
+7 -4
View File
@@ -344,10 +344,11 @@ pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
None
}
pub(crate) fn setup_sokcet2_ext(
pub(crate) fn setup_socket2_ext(
socket2_socket: &socket2::Socket,
bind_addr: &SocketAddr,
#[allow(unused_variables)] bind_dev: Option<String>,
only_v6: bool,
) -> Result<(), TunnelError> {
#[cfg(target_os = "windows")]
{
@@ -356,7 +357,7 @@ pub(crate) fn setup_sokcet2_ext(
}
if bind_addr.is_ipv6() {
socket2_socket.set_only_v6(true)?;
socket2_socket.set_only_v6(only_v6)?;
}
socket2_socket.set_nonblocking(true)?;
@@ -428,14 +429,16 @@ where
Err(last_err.unwrap_or(TunnelError::Shutdown))
}
pub(crate) fn setup_sokcet2(
pub(crate) fn setup_socket2(
socket2_socket: &socket2::Socket,
bind_addr: &SocketAddr,
only_v6: bool,
) -> Result<(), TunnelError> {
setup_sokcet2_ext(
setup_socket2_ext(
socket2_socket,
bind_addr,
super::common::get_interface_name_by_ip(&bind_addr.ip()),
only_v6,
)
}
+1 -1
View File
@@ -46,7 +46,7 @@ pub mod unix;
#[derive(thiserror::Error, Debug)]
pub enum TunnelError {
#[error("io error")]
#[error("io error: {0}")]
IOError(#[from] std::io::Error),
#[error("invalid packet. msg: {0}")]
InvalidPacket(String),
+358 -138
View File
@@ -2,22 +2,25 @@
//!
//! Checkout the `README.md` for guidance.
use std::{
error::Error, io::IoSliceMut, net::SocketAddr, pin::Pin, sync::Arc, task::Poll, time::Duration,
};
use super::{FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener};
use crate::common::global_ctx::ArcGlobalCtx;
use crate::tunnel::{
FromUrl, TunnelInfo,
common::{FramedReader, FramedWriter, TunnelWrapper, setup_sokcet2},
TunnelInfo,
common::{FramedReader, FramedWriter, TunnelWrapper, setup_socket2},
};
use anyhow::Context;
use super::{IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener};
use derivative::Derivative;
use derive_more::{Deref, DerefMut};
use parking_lot::RwLock;
use quinn::{
AsyncUdpSocket, ClientConfig, Connection, Endpoint, EndpointConfig, ServerConfig,
TransportConfig, UdpPoller, congestion::BbrConfig, udp::RecvMeta,
ClientConfig, Connection, Endpoint, EndpointConfig, ServerConfig, TransportConfig,
congestion::BbrConfig, default_runtime,
};
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::OnceLock;
use std::{net::SocketAddr, sync::Arc, time::Duration};
// region config
pub fn transport_config() -> Arc<TransportConfig> {
let mut config = TransportConfig::default();
@@ -50,86 +53,287 @@ pub fn endpoint_config() -> EndpointConfig {
config.max_udp_payload_size(65527).unwrap();
config
}
//endregion
#[derive(Clone, Debug)]
struct NoGroAsyncUdpSocket {
inner: Arc<dyn AsyncUdpSocket>,
//region rw pool
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
#[derive(Debug, Deref, DerefMut)]
struct RwPoolInner<Item> {
#[deref]
#[deref_mut]
pool: Vec<Item>,
enabled: bool,
}
impl AsyncUdpSocket for NoGroAsyncUdpSocket {
fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn UdpPoller>> {
self.inner.clone().create_io_poller()
#[derive(Debug)]
struct RwPool<Item> {
ephemeral: RwLock<RwPoolInner<Item>>,
persistent: RwLock<RwPoolInner<Item>>,
capacity: usize,
}
impl<Item> RwPool<Item> {
fn new(capacity: usize) -> Self {
Self {
ephemeral: RwLock::new(RwPoolInner::default()),
persistent: RwLock::new(RwPoolInner::default()),
capacity,
}
}
fn try_send(&self, transmit: &quinn::udp::Transmit) -> std::io::Result<()> {
self.inner.try_send(transmit)
}
/// Receive UDP datagrams, or register to be woken if receiving may succeed in the future
fn poll_recv(
/// return the capacity of the ephemeral pool;
/// if `ephemeral` or `persistent` is None, read lock `self`'s pool
fn capacity(
&self,
cx: &mut std::task::Context,
bufs: &mut [IoSliceMut<'_>],
meta: &mut [RecvMeta],
) -> Poll<std::io::Result<usize>> {
self.inner.poll_recv(cx, bufs, meta)
ephemeral: Option<&RwPoolInner<Item>>,
persistent: Option<&RwPoolInner<Item>>,
) -> usize {
let guard;
let ephemeral = if let Some(ephemeral) = ephemeral {
ephemeral
} else {
guard = self.ephemeral.read();
&guard
};
let guard;
let persistent = if let Some(persistent) = persistent {
persistent
} else {
guard = self.persistent.read();
&guard
};
(self.capacity * ephemeral.enabled as usize).saturating_sub(persistent.len())
}
/// Look up the local IP address and port used by this socket
fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.inner.local_addr()
fn is_full(&self) -> bool {
let pool = self.ephemeral.read();
pool.len() >= self.capacity(Some(&pool), None)
}
fn may_fragment(&self) -> bool {
self.inner.may_fragment()
fn is_enabled(&self) -> bool {
self.ephemeral.read().enabled
}
fn max_transmit_segments(&self) -> usize {
self.inner.max_transmit_segments()
fn enable(&self) {
self.ephemeral.write().enabled = true;
self.resize();
}
fn max_receive_segments(&self) -> usize {
1
fn disable(&self) {
self.ephemeral.write().enabled = false;
self.resize();
}
/// push an item to the persistent pool
fn push(&self, item: Item) {
self.persistent.write().push(item);
self.resize();
}
/// try to push an item to the ephemeral pool, return the item if full
fn try_push(&self, item: Item) -> Option<Item> {
let mut pool = self.ephemeral.write();
if pool.len() < self.capacity(Some(&pool), None) {
pool.push(item);
return None;
}
Some(item)
}
fn resize(&self) {
let resize = {
let pool = self.ephemeral.read();
pool.capacity() != self.capacity(Some(&pool), None)
};
if resize {
let mut pool = self.ephemeral.write();
let capacity = self.capacity(Some(&pool), None);
pool.reserve_exact(capacity);
pool.truncate(capacity);
pool.shrink_to(capacity);
}
}
fn with_iter<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut dyn Iterator<Item = &Item>) -> R,
{
let ephemeral = self.ephemeral.read();
let persistent = self.persistent.read();
f(&mut persistent.iter().chain(ephemeral.iter()))
}
}
//endregion
//region endpoint manager
#[derive(Debug)]
pub struct QuicEndpointManager {
ipv4: RwPool<Endpoint>,
ipv6: RwPool<Endpoint>,
both: RwPool<Endpoint>,
}
static QUIC_ENDPOINT_MANAGER: OnceLock<QuicEndpointManager> = OnceLock::new();
impl QuicEndpointManager {
fn try_create(addr: SocketAddr, dual_stack: bool) -> std::io::Result<Endpoint> {
let socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_socket2(&socket, &addr, addr.is_ipv6() && !dual_stack)
.map_err(std::io::Error::other)?;
let socket = std::net::UdpSocket::from(socket);
let runtime = default_runtime().ok_or(std::io::Error::other("no async runtime found"))?;
let mut endpoint = Endpoint::new_with_abstract_socket(
endpoint_config(),
None,
runtime.wrap_udp_socket(socket)?,
runtime,
)?;
endpoint.set_default_client_config(client_config());
Ok(endpoint)
}
fn create<F>(&self, mut selector: F) -> std::io::Result<(&RwPool<Endpoint>, Option<Endpoint>)>
where
F: FnMut(&QuicEndpointManager) -> (&RwPool<Endpoint>, Option<(SocketAddr, bool)>),
{
loop {
let (pool, r) = selector(self);
let Some((addr, dual_stack)) = r else {
return Ok((pool, None));
};
let endpoint = Self::try_create(addr, dual_stack);
if let Err(e) = endpoint.as_ref()
&& dual_stack
{
tracing::warn!("create dual stack quic endpoint failed: {:?}", e);
self.both.disable();
self.ipv4.enable();
self.ipv6.enable();
continue;
}
return Ok((pool, Some(endpoint?)));
}
}
}
/// Constructs a QUIC endpoint configured to listen for incoming connections on a certain address
/// and port.
///
/// ## Returns
///
/// - an [`Endpoint`] configured to accept incoming QUIC connections
#[allow(unused)]
pub fn make_server_endpoint(bind_addr: SocketAddr) -> Result<Endpoint, Box<dyn Error>> {
let server_config = server_config();
let client_config = client_config();
let endpoint_config = endpoint_config();
impl QuicEndpointManager {
fn new(capacity: usize) -> Self {
let ipv4 = RwPool::new(capacity.div_ceil(2));
let ipv6 = RwPool::new(capacity.div_ceil(2));
let both = RwPool::new(capacity);
both.enable();
Self { ipv4, ipv6, both }
}
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(bind_addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_sokcet2(&socket2_socket, &bind_addr)?;
let socket = std::net::UdpSocket::from(socket2_socket);
fn load(global_ctx: &ArcGlobalCtx) -> &Self {
let capacity = global_ctx
.config
.get_flags()
.multi_thread
.then(std::thread::available_parallelism)
.and_then(|r| r.ok())
.map(|n| n.get())
.unwrap_or(1);
let runtime =
quinn::default_runtime().ok_or_else(|| std::io::Error::other("no async runtime found"))?;
let socket: NoGroAsyncUdpSocket = NoGroAsyncUdpSocket {
inner: runtime.wrap_udp_socket(socket)?,
};
let mut endpoint = Endpoint::new_with_abstract_socket(
endpoint_config,
Some(server_config),
Arc::new(socket),
runtime,
)?;
endpoint.set_default_client_config(client_config);
Ok(endpoint)
let mgr = QUIC_ENDPOINT_MANAGER.get();
match mgr {
Some(mgr) => {
for pool in [&mgr.ipv4, &mgr.ipv6, &mgr.both] {
pool.resize();
}
}
None => {
let _ = QUIC_ENDPOINT_MANAGER.set(Self::new(capacity));
}
}
QUIC_ENDPOINT_MANAGER.get().unwrap()
}
/// Get a QUIC endpoint to be used as a server
///
/// # Arguments
/// * `addr`: listen address
fn server(global_ctx: &ArcGlobalCtx, addr: SocketAddr) -> std::io::Result<Endpoint> {
let mgr = Self::load(global_ctx);
let (pool, endpoint) = mgr.create(|mgr| {
let dual_stack = addr.ip() == Ipv6Addr::UNSPECIFIED && mgr.both.is_enabled();
let pool = if addr.is_ipv4() {
&mgr.ipv4
} else if dual_stack {
&mgr.both
} else {
&mgr.ipv6
};
(pool, Some((addr, dual_stack)))
})?;
let endpoint = endpoint.expect("server endpoint creation should not return None");
endpoint.set_server_config(Some(server_config()));
pool.push(endpoint.clone());
Ok(endpoint)
}
/// Get a quic endpoint to be used as a client
///
/// # Arguments
/// * `ip_version`: the IP version of the remote address
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> std::io::Result<Endpoint> {
let mgr = Self::load(global_ctx);
let (pool, endpoint) = mgr.create(|mgr| {
let dual_stack = mgr.both.is_enabled();
let (pool, addr) = match ip_version {
IpVersion::V4 if !dual_stack => (&mgr.ipv4, (Ipv4Addr::UNSPECIFIED, 0).into()),
_ => {
let pool = if dual_stack { &mgr.both } else { &mgr.ipv6 };
(pool, (Ipv6Addr::UNSPECIFIED, 0).into())
}
};
if pool.is_full() {
(pool, None)
} else {
(pool, Some((addr, dual_stack)))
}
})?;
if let Some(endpoint) = endpoint {
pool.try_push(endpoint);
}
Ok(pool.with_iter(|iter| iter.min_by_key(|e| e.open_connections()).unwrap().clone()))
}
async fn connect(
global_ctx: &ArcGlobalCtx,
addr: SocketAddr,
) -> std::io::Result<(Endpoint, Connection)> {
let ip_version = if addr.ip().is_ipv4() {
IpVersion::V4
} else {
IpVersion::V6
};
let endpoint = Self::client(global_ctx, ip_version)?;
let connection = endpoint
.connect(addr, "localhost")
.map_err(std::io::Error::other)?
.await?;
Ok((endpoint, connection))
}
}
#[allow(unused)]
pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"];
//endregion
struct ConnWrapper {
conn: Connection,
@@ -143,13 +347,15 @@ impl Drop for ConnWrapper {
pub struct QuicTunnelListener {
addr: url::Url,
global_ctx: ArcGlobalCtx,
endpoint: Option<Endpoint>,
}
impl QuicTunnelListener {
pub fn new(addr: url::Url) -> Self {
pub fn new(addr: url::Url, global_ctx: ArcGlobalCtx) -> Self {
QuicTunnelListener {
addr,
global_ctx,
endpoint: None,
}
}
@@ -192,13 +398,11 @@ impl QuicTunnelListener {
impl TunnelListener for QuicTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let endpoint = make_server_endpoint(addr)
.map_err(|e| anyhow::anyhow!("make server endpoint error: {:?}", e))?;
self.endpoint = Some(endpoint);
let endpoint = QuicEndpointManager::server(&self.global_ctx, addr)?;
self.addr
.set_port(Some(self.endpoint.as_ref().unwrap().local_addr()?.port()))
.set_port(Some(endpoint.local_addr()?.port()))
.unwrap();
self.endpoint = Some(endpoint);
Ok(())
}
@@ -222,15 +426,15 @@ impl TunnelListener for QuicTunnelListener {
pub struct QuicTunnelConnector {
addr: url::Url,
endpoint: Option<Endpoint>,
global_ctx: ArcGlobalCtx,
ip_version: IpVersion,
}
impl QuicTunnelConnector {
pub fn new(addr: url::Url) -> Self {
pub fn new(addr: url::Url, global_ctx: ArcGlobalCtx) -> Self {
QuicTunnelConnector {
addr,
endpoint: None,
global_ctx,
ip_version: IpVersion::Both,
}
}
@@ -240,38 +444,10 @@ impl QuicTunnelConnector {
impl TunnelConnector for QuicTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if addr.port() == 0 {
return Err(TunnelError::InvalidAddr(format!(
"invalid remote QUIC port 0 in url: {} (port 0 is not a valid QUIC port)",
self.addr
)));
}
let local_addr = if addr.is_ipv4() {
"0.0.0.0:0"
} else {
"[::]:0"
};
let mut endpoint = Endpoint::client(local_addr.parse().unwrap())?;
endpoint.set_default_client_config(client_config());
// connect to server
let connection = endpoint
.connect(addr, "localhost")
.map_err(|e| {
TunnelError::InvalidAddr(format!(
"failed to create QUIC connection, url: {}, error: {}",
self.addr, e
))
})?
.await
.with_context(|| "connect failed")?;
tracing::info!("[client] connected: addr={}", connection.remote_address());
let (endpoint, connection) = QuicEndpointManager::connect(&self.global_ctx, addr).await?;
let local_addr = endpoint.local_addr()?;
self.endpoint = Some(endpoint);
let (w, r) = connection
.open_bi()
.await
@@ -308,68 +484,112 @@ impl TunnelConnector for QuicTunnelConnector {
#[cfg(test)]
mod tests {
use crate::common::global_ctx::tests::get_mock_global_ctx_with_network;
use crate::tunnel::{
IpVersion, TunnelConnector,
TunnelConnector,
common::tests::{_tunnel_bench, _tunnel_pingpong},
};
use std::sync::LazyLock;
use tokio::runtime::{Builder, Runtime};
use super::*;
#[tokio::test]
async fn quic_pingpong() {
let listener = QuicTunnelListener::new("quic://0.0.0.0:21011".parse().unwrap());
let connector = QuicTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap());
// Shared runtime for all tests to avoid endpoint invalidation across runtimes
static RUNTIME: LazyLock<Runtime> =
LazyLock::new(|| Builder::new_multi_thread().enable_all().build().unwrap());
fn global_ctx() -> ArcGlobalCtx {
let identity = crate::common::config::NetworkIdentity::default();
get_mock_global_ctx_with_network(Some(identity))
}
#[test]
fn quic_pingpong() {
RUNTIME.block_on(quic_pingpong_impl())
}
async fn quic_pingpong_impl() {
let listener = QuicTunnelListener::new("quic://[::]:21011".parse().unwrap(), global_ctx());
let connector =
QuicTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap(), global_ctx());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn quic_bench() {
let listener = QuicTunnelListener::new("quic://0.0.0.0:21012".parse().unwrap());
let connector = QuicTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap());
#[test]
fn quic_bench() {
RUNTIME.block_on(quic_bench_impl())
}
async fn quic_bench_impl() {
let listener = QuicTunnelListener::new("quic://[::]:21012".parse().unwrap(), global_ctx());
let connector =
QuicTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap(), global_ctx());
_tunnel_bench(listener, connector).await
}
#[tokio::test]
async fn ipv6_pingpong() {
let listener = QuicTunnelListener::new("quic://[::1]:31015".parse().unwrap());
let connector = QuicTunnelConnector::new("quic://[::1]:31015".parse().unwrap());
#[test]
fn ipv6_pingpong() {
RUNTIME.block_on(ipv6_pingpong_impl())
}
async fn ipv6_pingpong_impl() {
let listener = QuicTunnelListener::new("quic://[::1]:31015".parse().unwrap(), global_ctx());
let connector =
QuicTunnelConnector::new("quic://[::1]:31015".parse().unwrap(), global_ctx());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn ipv6_domain_pingpong() {
let listener = QuicTunnelListener::new("quic://[::1]:31016".parse().unwrap());
let mut connector =
QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
#[test]
fn ipv6_domain_pingpong() {
RUNTIME.block_on(ipv6_domain_pingpong_impl())
}
async fn ipv6_domain_pingpong_impl() {
let listener = QuicTunnelListener::new("quic://[::1]:31016".parse().unwrap(), global_ctx());
let mut connector = QuicTunnelConnector::new(
"quic://test.easytier.top:31016".parse().unwrap(),
global_ctx(),
);
connector.set_ip_version(IpVersion::V6);
_tunnel_pingpong(listener, connector).await;
let listener = QuicTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap());
let mut connector =
QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
let listener =
QuicTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap(), global_ctx());
let mut connector = QuicTunnelConnector::new(
"quic://test.easytier.top:31016".parse().unwrap(),
global_ctx(),
);
connector.set_ip_version(IpVersion::V4);
_tunnel_pingpong(listener, connector).await;
}
#[tokio::test]
async fn test_alloc_port() {
#[test]
fn alloc_port() {
RUNTIME.block_on(alloc_port_impl())
}
async fn alloc_port_impl() {
// v4
let mut listener = QuicTunnelListener::new("quic://0.0.0.0:0".parse().unwrap());
let mut listener =
QuicTunnelListener::new("quic://0.0.0.0:0".parse().unwrap(), global_ctx());
listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap();
assert!(port > 0);
// v6
let mut listener = QuicTunnelListener::new("quic://[::]:0".parse().unwrap());
let mut listener = QuicTunnelListener::new("quic://[::]:0".parse().unwrap(), global_ctx());
listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap();
assert!(port > 0);
}
#[tokio::test]
async fn quic_connector_reject_port_zero() {
let mut connector = QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap());
let err = connector.connect().await.unwrap_err().to_string();
assert!(err.contains("port 0"), "unexpected error: {}", err);
#[test]
fn invalid_peer_addr() {
RUNTIME.block_on(invalid_peer_addr_impl())
}
async fn invalid_peer_addr_impl() {
let mut connector =
QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap(), global_ctx());
let err = connector.connect().await.unwrap_err();
assert!(
err.to_string().contains("invalid remote address"),
"unexpected error: {:?}",
err
);
}
}
+3 -3
View File
@@ -1,7 +1,7 @@
use std::net::SocketAddr;
use super::{FromUrl, TunnelInfo};
use crate::tunnel::common::setup_sokcet2;
use crate::tunnel::common::setup_socket2;
use async_trait::async_trait;
use futures::stream::FuturesUnordered;
use tokio::net::{TcpListener, TcpSocket, TcpStream};
@@ -66,7 +66,7 @@ impl TunnelListener for TcpTunnelListener {
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
setup_sokcet2(&socket2_socket, &addr)?;
setup_socket2(&socket2_socket, &addr, true)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into());
if let Err(e) = socket.set_nodelay(true) {
@@ -175,7 +175,7 @@ impl TcpTunnelConnector {
Some(socket2::Protocol::TCP),
)?;
if let Err(e) = setup_sokcet2(&socket2_socket, bind_addr) {
if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
continue;
}
+5 -5
View File
@@ -24,7 +24,7 @@ use tracing::{Instrument, instrument};
use super::{
FromUrl, IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelInfo, TunnelListener,
TunnelUrl,
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
common::{setup_socket2, setup_socket2_ext, wait_for_connect_futures},
packet_def::{UDP_TUNNEL_HEADER_SIZE, UDPTunnelHeader, V6HolePunchPacket},
ring::{RingSink, RingStream},
};
@@ -545,9 +545,9 @@ impl TunnelListener for UdpTunnelListener {
let tunnel_url: TunnelUrl = self.addr.clone().into();
if let Some(bind_dev) = tunnel_url.bind_dev() {
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
setup_socket2_ext(&socket2_socket, &addr, Some(bind_dev), true)?;
} else {
setup_sokcet2(&socket2_socket, &addr)?;
setup_socket2(&socket2_socket, &addr, true)?;
}
self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
@@ -838,7 +838,7 @@ impl UdpTunnelConnector {
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
if let Err(e) = setup_sokcet2(&socket2_socket, bind_addr) {
if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
continue;
}
@@ -1040,7 +1040,7 @@ mod tests {
Some(socket2::Protocol::UDP),
)
.unwrap();
setup_sokcet2_ext(&socket2_socket, &addr, bind_dev.clone()).unwrap();
setup_socket2_ext(&socket2_socket, &addr, bind_dev.clone(), true).unwrap();
}
}
+3 -3
View File
@@ -1,6 +1,6 @@
use super::{
FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
common::{TunnelWrapper, setup_sokcet2, wait_for_connect_futures},
common::{TunnelWrapper, setup_socket2, wait_for_connect_futures},
insecure_tls::{get_insecure_tls_cert, init_crypto_provider},
packet_def::{ZCPacket, ZCPacketType},
};
@@ -166,7 +166,7 @@ impl TunnelListener for WsTunnelListener {
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
setup_sokcet2(&socket2_socket, &addr)?;
setup_socket2(&socket2_socket, &addr, true)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into());
self.addr
@@ -291,7 +291,7 @@ impl WsTunnelConnector {
Some(socket2::Protocol::TCP),
)?;
if let Err(e) = setup_sokcet2(&socket2_socket, bind_addr) {
if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
continue;
}
+5 -5
View File
@@ -23,7 +23,7 @@ use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use super::{
FromUrl, IpVersion, Tunnel, TunnelError, TunnelInfo, TunnelListener, TunnelUrl, ZCPacketSink,
ZCPacketStream,
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
common::{setup_socket2, setup_socket2_ext, wait_for_connect_futures},
generate_digest_from_str,
packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacketType},
ring::create_ring_tunnel_pair,
@@ -563,9 +563,9 @@ impl TunnelListener for WgTunnelListener {
let tunnel_url: TunnelUrl = self.addr.clone().into();
if let Some(bind_dev) = tunnel_url.bind_dev() {
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
setup_socket2_ext(&socket2_socket, &addr, Some(bind_dev), true)?;
} else {
setup_sokcet2(&socket2_socket, &addr)?;
setup_socket2(&socket2_socket, &addr, true)?;
}
self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
@@ -700,7 +700,7 @@ impl WgTunnelConnector {
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_sokcet2_ext(&socket2_socket, &"[::]:0".parse().unwrap(), None)?;
setup_socket2_ext(&socket2_socket, &"[::]:0".parse().unwrap(), None, true)?;
let socket = UdpSocket::from_std(socket2_socket.into())?;
Self::connect_with_socket(self.addr.clone(), self.config.clone(), socket, addr).await
}
@@ -728,7 +728,7 @@ impl super::TunnelConnector for WgTunnelConnector {
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
if let Err(e) = setup_sokcet2(&socket2_socket, &bind_addr) {
if let Err(e) = setup_socket2(&socket2_socket, &bind_addr, true) {
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
continue;
}