tunnel(bind): gather all bind logic to a single function (#2070)

* extract a Bindable trait for binding TcpSocket, TcpListener, and UdpSocket
This commit is contained in:
Luna Yao
2026-04-12 16:16:58 +02:00
committed by GitHub
parent 869e1b89f5
commit 6f3e708679
10 changed files with 370 additions and 5846 deletions
Generated
+82 -5565
View File
File diff suppressed because it is too large Load Diff
+1
View File
@@ -242,6 +242,7 @@ hickory-server = { version = "0.25.2", features = [
"resolver", "resolver",
], optional = true } ], optional = true }
bon = "3.9.1"
derive_builder = "0.20.2" derive_builder = "0.20.2"
humantime-serde = "1.1.1" humantime-serde = "1.1.1"
multimap = "0.10.1" multimap = "0.10.1"
+20 -47
View File
@@ -18,7 +18,7 @@ use crate::gateway::kcp_proxy::NatDstKcpConnector;
use crate::{ use crate::{
common::{ common::{
config::PortForwardConfig, global_ctx::GlobalCtxEvent, join_joinset_background, config::PortForwardConfig, global_ctx::GlobalCtxEvent, join_joinset_background,
netns::NetNS, scoped_task::ScopedTask, scoped_task::ScopedTask,
}, },
gateway::{ gateway::{
fast_socks5::{ fast_socks5::{
@@ -30,10 +30,7 @@ use crate::{
ip_reassembler::IpReassembler, ip_reassembler::IpReassembler,
tokio_smoltcp::{BufferSize, Net, NetConfig, channel_device}, tokio_smoltcp::{BufferSize, Net, NetConfig, channel_device},
}, },
tunnel::{ tunnel::packet_def::{PacketType, ZCPacket},
common::setup_socket2,
packet_def::{PacketType, ZCPacket},
},
}; };
use anyhow::Context; use anyhow::Context;
use dashmap::DashMap; use dashmap::DashMap;
@@ -42,21 +39,21 @@ use pnet::packet::{
}; };
use tokio::{ use tokio::{
io::{AsyncRead, AsyncWrite}, io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpSocket, UdpSocket}, net::{TcpListener, UdpSocket},
select, select,
sync::{Mutex, Notify, mpsc}, sync::{Mutex, Notify, mpsc},
task::JoinSet, task::JoinSet,
time::timeout, time::timeout,
}; };
#[cfg(feature = "kcp")]
use super::tcp_proxy::NatDstConnector as _;
use crate::tunnel::common::bind;
use crate::{ use crate::{
common::{error::Error, global_ctx::GlobalCtx}, common::{error::Error, global_ctx::GlobalCtx},
peers::{PeerPacketFilter, peer_manager::PeerManager}, peers::{PeerPacketFilter, peer_manager::PeerManager},
}; };
#[cfg(feature = "kcp")]
use super::tcp_proxy::NatDstConnector as _;
enum SocksUdpSocket { enum SocksUdpSocket {
UdpSocket(Arc<tokio::net::UdpSocket>), UdpSocket(Arc<tokio::net::UdpSocket>),
SmolUdpSocket(super::tokio_smoltcp::UdpSocket), SmolUdpSocket(super::tokio_smoltcp::UdpSocket),
@@ -328,38 +325,6 @@ impl AsyncTcpConnector for Socks5AutoConnector {
} }
} }
fn bind_tcp_socket(addr: SocketAddr, net_ns: NetNS) -> Result<TcpListener, Error> {
let _g = net_ns.guard();
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
setup_socket2(&socket2_socket, &addr, true)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into());
if let Err(e) = socket.set_nodelay(true) {
tracing::warn!(?e, "set_nodelay fail in listen");
}
Ok(socket.listen(1024)?)
}
fn bind_udp_socket(addr: SocketAddr, net_ns: NetNS) -> Result<UdpSocket, Error> {
let _g = net_ns.guard();
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_socket2(&socket2_socket, &addr, true)?;
Ok(UdpSocket::from_std(socket2_socket.into())?)
}
struct Socks5ServerNet { struct Socks5ServerNet {
ipv4_addr: cidr::Ipv4Inet, ipv4_addr: cidr::Ipv4Inet,
auth: Option<SimpleUserPassword>, auth: Option<SimpleUserPassword>,
@@ -702,10 +667,10 @@ impl Socks5Server {
proxy_url.port().unwrap() proxy_url.port().unwrap()
); );
let listener = bind_tcp_socket( let listener = bind::<TcpListener>()
bind_addr.parse::<SocketAddr>().unwrap(), .addr(bind_addr.parse::<SocketAddr>().unwrap())
self.global_ctx.net_ns.clone(), .net_ns(self.global_ctx.net_ns.clone())
)?; .call()?;
let entries = self.entries.clone(); let entries = self.entries.clone();
let entry_count = self.entry_count.clone(); let entry_count = self.entry_count.clone();
@@ -838,7 +803,10 @@ impl Socks5Server {
pub async fn add_tcp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> { pub async fn add_tcp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> {
let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr); let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr);
let listener = bind_tcp_socket(bind_addr, self.global_ctx.net_ns.clone())?; let listener = bind::<TcpListener>()
.addr(bind_addr)
.net_ns(self.global_ctx.net_ns.clone())
.call()?;
let net = self.net.clone(); let net = self.net.clone();
let entries = self.entries.clone(); let entries = self.entries.clone();
@@ -906,7 +874,12 @@ impl Socks5Server {
#[tracing::instrument(name = "add_udp_port_forward", skip(self))] #[tracing::instrument(name = "add_udp_port_forward", skip(self))]
pub async fn add_udp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> { pub async fn add_udp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> {
let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr); let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr);
let socket = Arc::new(bind_udp_socket(bind_addr, self.global_ctx.net_ns.clone())?); let socket = Arc::new(
bind::<UdpSocket>()
.addr(bind_addr)
.net_ns(self.global_ctx.net_ns.clone())
.call()?,
);
let entries = self.entries.clone(); let entries = self.entries.clone();
let entry_count = self.entry_count.clone(); let entry_count = self.entry_count.clone();
+10 -20
View File
@@ -24,18 +24,18 @@ use tokio::{
use tracing::Level; use tracing::Level;
use super::{CidrSet, ip_reassembler::IpReassembler};
use crate::tunnel::common::bind;
use crate::{ use crate::{
common::{PeerId, error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask}, common::{PeerId, error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask},
gateway::ip_reassembler::{ComposeIpv4PacketArgs, compose_ipv4_packet}, gateway::ip_reassembler::{ComposeIpv4PacketArgs, compose_ipv4_packet},
peers::{PeerPacketFilter, peer_manager::PeerManager}, peers::{PeerPacketFilter, peer_manager::PeerManager},
tunnel::{ tunnel::{
common::{reserve_buf, setup_socket2}, common::reserve_buf,
packet_def::{PacketType, ZCPacket}, packet_def::{PacketType, ZCPacket},
}, },
}; };
use super::{CidrSet, ip_reassembler::IpReassembler};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct UdpNatKey { struct UdpNatKey {
src_socket: SocketAddr, src_socket: SocketAddr,
@@ -63,18 +63,9 @@ impl UdpNatEntry {
denied: bool, denied: bool,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
// TODO: try use src port, so we will be ip restricted nat type // TODO: try use src port, so we will be ip restricted nat type
let socket = if denied { let socket = (!denied)
None .then(|| bind().addr("0.0.0.0:0".parse().unwrap()).call())
} else { .transpose()?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
let dst_socket_addr = "0.0.0.0:0".parse().unwrap();
setup_socket2(&socket2_socket, &dst_socket_addr, true)?;
Some(UdpSocket::from_std(socket2_socket.into())?)
};
Ok(Self { Ok(Self {
src_peer_id, src_peer_id,
@@ -403,11 +394,10 @@ impl UdpProxy {
#[async_trait::async_trait] #[async_trait::async_trait]
impl PeerPacketFilter for UdpProxy { impl PeerPacketFilter for UdpProxy {
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> { async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
if self.try_handle_packet(&packet).await.is_some() { self.try_handle_packet(&packet)
return None; .await
} else { .is_none()
return Some(packet); .then_some(packet)
}
} }
} }
+138 -43
View File
@@ -1,3 +1,7 @@
use bon::builder;
use futures::{Future, Sink, Stream, stream::FuturesUnordered};
use network_interface::NetworkInterfaceConfig as _;
use pin_project_lite::pin_project;
use std::{ use std::{
any::Any, any::Any,
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
@@ -5,26 +9,21 @@ use std::{
sync::{Arc, Mutex}, sync::{Arc, Mutex},
task::{Poll, ready}, task::{Poll, ready},
}; };
use futures::{Future, Sink, Stream, stream::FuturesUnordered};
use network_interface::NetworkInterfaceConfig as _;
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use tokio_stream::StreamExt;
use tokio_util::io::poll_write_buf;
use zerocopy::FromBytes as _;
use super::TunnelInfo; use super::TunnelInfo;
use crate::tunnel::packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacket};
use super::{ use super::{
SinkItem, StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream, SinkItem, StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream,
buf::BufList, buf::BufList,
packet_def::{TCP_TUNNEL_HEADER_SIZE, TCPTunnelHeader, ZCPacketType}, packet_def::{TCP_TUNNEL_HEADER_SIZE, TCPTunnelHeader, ZCPacketType},
}; };
use crate::common::netns::NetNS;
use crate::tunnel::packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacket};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use tokio::net::{TcpListener, TcpSocket, UdpSocket};
use tokio_stream::StreamExt;
use tokio_util::io::poll_write_buf;
use zerocopy::FromBytes as _;
pub struct TunnelWrapper<R, W> { pub struct TunnelWrapper<R, W> {
reader: Arc<Mutex<Option<R>>>, reader: Arc<Mutex<Option<R>>>,
@@ -344,7 +343,70 @@ pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
None None
} }
pub(crate) fn setup_socket2_ext( pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
mut futures: FuturesUnordered<Fut>,
) -> Result<Ret, TunnelError>
where
Fut: Future<Output = Result<Ret, E>> + Send,
E: std::error::Error + Into<TunnelError> + Send + 'static,
{
// return last error
let mut last_err = None;
while let Some(ret) = futures.next().await {
if let Err(e) = ret {
last_err = Some(e.into());
} else {
return ret.map_err(|e| e.into());
}
}
Err(last_err.unwrap_or(TunnelError::Shutdown))
}
// region bind
pub trait Bindable: Sized {
const TYPE: socket2::Type;
const PROTOCOL: Option<socket2::Protocol>;
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError>;
}
impl Bindable for TcpSocket {
const TYPE: socket2::Type = socket2::Type::STREAM;
const PROTOCOL: Option<socket2::Protocol> = Some(socket2::Protocol::TCP);
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError> {
let socket = TcpSocket::from_std_stream(socket.into());
if let Err(error) = socket.set_nodelay(true) {
tracing::warn!(?error, "set_nodelay failed for tcp socket");
}
Ok(socket)
}
}
impl Bindable for TcpListener {
const TYPE: socket2::Type = socket2::Type::STREAM;
const PROTOCOL: Option<socket2::Protocol> = Some(socket2::Protocol::TCP);
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError> {
Ok(TcpSocket::finalize(socket)?.listen(1024)?)
}
}
impl Bindable for UdpSocket {
const TYPE: socket2::Type = socket2::Type::DGRAM;
const PROTOCOL: Option<socket2::Protocol> = Some(socket2::Protocol::UDP);
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError> {
Ok(UdpSocket::from_std(socket.into())?)
}
}
fn setup_socket2_ext(
socket2_socket: &socket2::Socket, socket2_socket: &socket2::Socket,
bind_addr: &SocketAddr, bind_addr: &SocketAddr,
#[allow(unused_variables)] bind_dev: Option<String>, #[allow(unused_variables)] bind_dev: Option<String>,
@@ -408,38 +470,69 @@ pub(crate) fn setup_socket2_ext(
Ok(()) Ok(())
} }
pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>( #[derive(Debug, Default, Clone)]
mut futures: FuturesUnordered<Fut>, pub enum BindDev {
) -> Result<Ret, TunnelError> #[default]
where Auto,
Fut: Future<Output = Result<Ret, E>> + Send, Disabled,
E: std::error::Error + Into<TunnelError> + Send + 'static, Custom(String),
{
// return last error
let mut last_err = None;
while let Some(ret) = futures.next().await {
if let Err(e) = ret {
last_err = Some(e.into());
} else {
return ret.map_err(|e| e.into());
}
}
Err(last_err.unwrap_or(TunnelError::Shutdown))
} }
pub(crate) fn setup_socket2( impl From<String> for BindDev {
socket2_socket: &socket2::Socket, fn from(value: String) -> Self {
bind_addr: &SocketAddr, if value.is_empty() {
only_v6: bool, Self::Disabled
) -> Result<(), TunnelError> { } else {
setup_socket2_ext( Self::Custom(value)
socket2_socket, }
bind_addr, }
super::common::get_interface_name_by_ip(&bind_addr.ip()), }
only_v6,
) impl From<&str> for BindDev {
fn from(value: &str) -> Self {
value.to_string().into()
}
}
/// Binds a socket to a specific address and optionally a network interface.
///
/// This function creates a new socket, applies specific configurations (such as
/// binding to a device or setting IPv6-only flags), and finalizes it into the
/// requested [`Bindable`] type.
///
/// # Arguments
///
/// * `addr` - The `SocketAddr` to bind the socket to.
/// * `dev` - The name of the network interface to bind to:
/// * **(default) `BindDev::Auto`**: Enables **auto-discovery**. The function will attempt to automatically
/// resolve the interface name associated with the provided `addr.ip()`.
/// * **empty string or `BindDev::Disabled`**: **Disables** auto-discovery and
/// explicitly chooses **not** to bind to any specific device. The routing will be
/// left entirely to the OS.
/// * **non-empty string or `BindDev::Custom(..)`**: Skips auto-discovery and explicitly binds to
/// the specified interface.
/// * `net_ns` - An optional network namespace to switch into before creating the socket.
/// * `only_v6` - If `true`, sets the `IPV6_V6ONLY` flag on the socket.
///
/// # Errors
///
/// Returns a [`TunnelError`] if socket creation, configuration, or finalization fails.
#[builder]
pub fn bind<B: Bindable>(
addr: SocketAddr,
#[builder(default, into)] dev: BindDev,
net_ns: Option<NetNS>,
#[builder(default)] only_v6: bool,
) -> Result<B, TunnelError> {
let _g = net_ns.map(|n| n.guard());
let dev = match dev {
BindDev::Auto => get_interface_name_by_ip(&addr.ip()),
BindDev::Disabled => None,
BindDev::Custom(s) => Some(s),
};
let socket = socket2::Socket::new(socket2::Domain::for_address(addr), B::TYPE, B::PROTOCOL)?;
setup_socket2_ext(&socket, &addr, dev, only_v6)?;
B::finalize(socket)
} }
pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) { pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) {
@@ -448,6 +541,8 @@ pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) {
} }
} }
// endregion
pub mod tests { pub mod tests {
use atomic_shim::AtomicU64; use atomic_shim::AtomicU64;
use std::{sync::Arc, time::Instant}; use std::{sync::Arc, time::Instant};
+27 -23
View File
@@ -4,9 +4,10 @@
use super::{FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener}; use super::{FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener};
use crate::common::global_ctx::ArcGlobalCtx; use crate::common::global_ctx::ArcGlobalCtx;
use crate::tunnel::common::bind;
use crate::tunnel::{ use crate::tunnel::{
TunnelInfo, TunnelInfo,
common::{FramedReader, FramedWriter, TunnelWrapper, setup_socket2}, common::{FramedReader, FramedWriter, TunnelWrapper},
}; };
use anyhow::Context; use anyhow::Context;
use derivative::Derivative; use derivative::Derivative;
@@ -19,6 +20,7 @@ use quinn::{
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::OnceLock; use std::sync::OnceLock;
use std::{net::SocketAddr, sync::Arc, time::Duration}; use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio::net::UdpSocket;
// region config // region config
pub fn transport_config() -> Arc<TransportConfig> { pub fn transport_config() -> Arc<TransportConfig> {
@@ -179,27 +181,28 @@ pub struct QuicEndpointManager {
static QUIC_ENDPOINT_MANAGER: OnceLock<QuicEndpointManager> = OnceLock::new(); static QUIC_ENDPOINT_MANAGER: OnceLock<QuicEndpointManager> = OnceLock::new();
impl QuicEndpointManager { impl QuicEndpointManager {
fn try_create(addr: SocketAddr, dual_stack: bool) -> std::io::Result<Endpoint> { fn try_create(addr: SocketAddr, dual_stack: bool) -> Result<Endpoint, TunnelError> {
let socket = socket2::Socket::new( let socket = bind::<UdpSocket>()
socket2::Domain::for_address(addr), .addr(addr)
socket2::Type::DGRAM, .only_v6(addr.is_ipv6() && !dual_stack)
Some(socket2::Protocol::UDP), .call()?;
)?; let runtime = default_runtime().ok_or(TunnelError::InternalError(
setup_socket2(&socket, &addr, addr.is_ipv6() && !dual_stack) "no async runtime found".to_owned(),
.map_err(std::io::Error::other)?; ))?;
let socket = std::net::UdpSocket::from(socket);
let runtime = default_runtime().ok_or(std::io::Error::other("no async runtime found"))?;
let mut endpoint = Endpoint::new_with_abstract_socket( let mut endpoint = Endpoint::new_with_abstract_socket(
endpoint_config(), endpoint_config(),
None, None,
runtime.wrap_udp_socket(socket)?, runtime.wrap_udp_socket(socket.into_std()?)?,
runtime, runtime,
)?; )?;
endpoint.set_default_client_config(client_config()); endpoint.set_default_client_config(client_config());
Ok(endpoint) Ok(endpoint)
} }
fn create<F>(&self, mut selector: F) -> std::io::Result<(&RwPool<Endpoint>, Option<Endpoint>)> fn create<F>(
&self,
mut selector: F,
) -> Result<(&RwPool<Endpoint>, Option<Endpoint>), TunnelError>
where where
F: FnMut(&QuicEndpointManager) -> (&RwPool<Endpoint>, Option<(SocketAddr, bool)>), F: FnMut(&QuicEndpointManager) -> (&RwPool<Endpoint>, Option<(SocketAddr, bool)>),
{ {
@@ -210,10 +213,10 @@ impl QuicEndpointManager {
}; };
let endpoint = Self::try_create(addr, dual_stack); let endpoint = Self::try_create(addr, dual_stack);
if let Err(e) = endpoint.as_ref() if let Err(error) = endpoint.as_ref()
&& dual_stack && dual_stack
{ {
tracing::warn!("create dual stack quic endpoint failed: {:?}", e); tracing::warn!(?error, "create dual stack quic endpoint failed");
self.both.disable(); self.both.disable();
self.ipv4.enable(); self.ipv4.enable();
self.ipv6.enable(); self.ipv6.enable();
@@ -263,7 +266,7 @@ impl QuicEndpointManager {
/// ///
/// # Arguments /// # Arguments
/// * `addr`: listen address /// * `addr`: listen address
fn server(global_ctx: &ArcGlobalCtx, addr: SocketAddr) -> std::io::Result<Endpoint> { fn server(global_ctx: &ArcGlobalCtx, addr: SocketAddr) -> Result<Endpoint, TunnelError> {
let mgr = Self::load(global_ctx); let mgr = Self::load(global_ctx);
let (pool, endpoint) = mgr.create(|mgr| { let (pool, endpoint) = mgr.create(|mgr| {
@@ -289,7 +292,7 @@ impl QuicEndpointManager {
/// ///
/// # Arguments /// # Arguments
/// * `ip_version`: the IP version of the remote address /// * `ip_version`: the IP version of the remote address
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> std::io::Result<Endpoint> { fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> Result<Endpoint, TunnelError> {
let mgr = Self::load(global_ctx); let mgr = Self::load(global_ctx);
let (pool, endpoint) = mgr.create(|mgr| { let (pool, endpoint) = mgr.create(|mgr| {
@@ -318,7 +321,7 @@ impl QuicEndpointManager {
async fn connect( async fn connect(
global_ctx: &ArcGlobalCtx, global_ctx: &ArcGlobalCtx,
addr: SocketAddr, addr: SocketAddr,
) -> std::io::Result<(Endpoint, Connection)> { ) -> Result<(Endpoint, Connection), TunnelError> {
let ip_version = if addr.ip().is_ipv4() { let ip_version = if addr.ip().is_ipv4() {
IpVersion::V4 IpVersion::V4
} else { } else {
@@ -327,8 +330,9 @@ impl QuicEndpointManager {
let endpoint = Self::client(global_ctx, ip_version)?; let endpoint = Self::client(global_ctx, ip_version)?;
let connection = endpoint let connection = endpoint
.connect(addr, "localhost") .connect(addr, "localhost")
.map_err(std::io::Error::other)? .with_context(|| format!("failed to create connection to {}", addr))?
.await?; .await
.with_context(|| format!("failed to connect to {}", addr))?;
Ok((endpoint, connection)) Ok((endpoint, connection))
} }
@@ -585,10 +589,10 @@ mod tests {
async fn invalid_peer_addr_impl() { async fn invalid_peer_addr_impl() {
let mut connector = let mut connector =
QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap(), global_ctx()); QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap(), global_ctx());
let err = connector.connect().await.unwrap_err(); let err = format!("{:?}", connector.connect().await.unwrap_err());
assert!( assert!(
err.to_string().contains("invalid remote address"), err.contains("invalid remote address"),
"unexpected error: {:?}", "unexpected error: {}",
err err
); );
} }
+12 -29
View File
@@ -1,7 +1,7 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use super::{FromUrl, TunnelInfo}; use super::{FromUrl, TunnelInfo};
use crate::tunnel::common::setup_socket2; use crate::tunnel::common::bind;
use async_trait::async_trait; use async_trait::async_trait;
use futures::stream::FuturesUnordered; use futures::stream::FuturesUnordered;
use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio::net::{TcpListener, TcpSocket, TcpStream};
@@ -59,25 +59,15 @@ impl TcpTunnelListener {
impl TunnelListener for TcpTunnelListener { impl TunnelListener for TcpTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> { async fn listen(&mut self) -> Result<(), TunnelError> {
self.listener = None; self.listener = None;
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?; let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let listener = bind::<TcpListener>().addr(addr).only_v6(true).call()?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
setup_socket2(&socket2_socket, &addr, true)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into());
if let Err(e) = socket.set_nodelay(true) {
tracing::warn!(?e, "set_nodelay fail in listen");
}
self.addr self.addr
.set_port(Some(socket.local_addr()?.port())) .set_port(Some(listener.local_addr()?.port()))
.unwrap(); .unwrap();
self.listener = Some(listener);
self.listener = Some(socket.listen(1024)?);
Ok(()) Ok(())
} }
@@ -167,21 +157,14 @@ impl TcpTunnelConnector {
let futures = FuturesUnordered::new(); let futures = FuturesUnordered::new();
for bind_addr in self.bind_addrs.iter() { for bind_addr in self.bind_addrs.iter() {
tracing::info!(bind_addr = ?bind_addr, ?addr, "bind addr"); tracing::info!(?bind_addr, ?addr, "bind addr");
match bind::<TcpSocket>().addr(*bind_addr).only_v6(true).call() {
let socket2_socket = socket2::Socket::new( Ok(socket) => futures.push(socket.connect(addr)),
socket2::Domain::for_address(addr), Err(error) => {
socket2::Type::STREAM, tracing::error!(?bind_addr, ?addr, ?error, "bind addr fail");
Some(socket2::Protocol::TCP), continue;
)?; }
if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
continue;
} }
let socket = TcpSocket::from_std_stream(socket2_socket.into());
futures.push(socket.connect(addr));
} }
let ret = wait_for_connect_futures(futures).await; let ret = wait_for_connect_futures(futures).await;
+22 -33
View File
@@ -18,16 +18,16 @@ use tokio::{
sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender}, sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender},
task::JoinSet, task::JoinSet,
}; };
use tracing::{Instrument, instrument}; use tracing::{Instrument, instrument};
use super::{ use super::{
FromUrl, IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelInfo, TunnelListener, FromUrl, IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelInfo, TunnelListener,
TunnelUrl, TunnelUrl,
common::{setup_socket2, setup_socket2_ext, wait_for_connect_futures}, common::wait_for_connect_futures,
packet_def::{UDP_TUNNEL_HEADER_SIZE, UDPTunnelHeader, V6HolePunchPacket}, packet_def::{UDP_TUNNEL_HEADER_SIZE, UDPTunnelHeader, V6HolePunchPacket},
ring::{RingSink, RingStream}, ring::{RingSink, RingStream},
}; };
use crate::tunnel::common::bind;
use crate::{ use crate::{
common::{join_joinset_background, scoped_task::ScopedTask, shrink_dashmap}, common::{join_joinset_background, scoped_task::ScopedTask, shrink_dashmap},
tunnel::{ tunnel::{
@@ -536,21 +536,14 @@ impl UdpTunnelListener {
impl TunnelListener for UdpTunnelListener { impl TunnelListener for UdpTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> { async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?; let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
let tunnel_url: TunnelUrl = self.addr.clone().into(); let tunnel_url: TunnelUrl = self.addr.clone().into();
if let Some(bind_dev) = tunnel_url.bind_dev() { self.socket = Some(Arc::new(
setup_socket2_ext(&socket2_socket, &addr, Some(bind_dev), true)?; bind()
} else { .addr(addr)
setup_socket2(&socket2_socket, &addr, true)?; .only_v6(true)
} .maybe_dev(tunnel_url.bind_dev())
.call()?,
self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?)); ));
self.data.socket = self.socket.clone(); self.data.socket = self.socket.clone();
self.addr self.addr
@@ -833,17 +826,14 @@ impl UdpTunnelConnector {
let futures = FuturesUnordered::new(); let futures = FuturesUnordered::new();
for bind_addr in self.bind_addrs.iter() { for bind_addr in self.bind_addrs.iter() {
let socket2_socket = socket2::Socket::new( tracing::info!(?bind_addr, ?addr, "bind addr");
socket2::Domain::for_address(*bind_addr), match bind().addr(*bind_addr).only_v6(true).call() {
socket2::Type::DGRAM, Ok(socket) => futures.push(self.try_connect_with_socket(Arc::new(socket), addr)),
Some(socket2::Protocol::UDP), Err(error) => {
)?; tracing::error!(?error, ?bind_addr, ?addr, "bind addr fail");
if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) { continue;
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e); }
continue;
} }
let socket = UdpSocket::from_std(socket2_socket.into())?;
futures.push(self.try_connect_with_socket(Arc::new(socket), addr));
} }
wait_for_connect_futures(futures).await wait_for_connect_futures(futures).await
} }
@@ -1034,13 +1024,12 @@ mod tests {
) )
.await .await
.unwrap(); .unwrap();
let socket2_socket = socket2::Socket::new( let _ = bind::<UdpSocket>()
socket2::Domain::for_address(addr), .addr(addr)
socket2::Type::DGRAM, .maybe_dev(bind_dev.clone())
Some(socket2::Protocol::UDP), .only_v6(true)
) .call()
.unwrap(); .unwrap();
setup_socket2_ext(&socket2_socket, &addr, bind_dev.clone(), true).unwrap();
} }
} }
+18 -28
View File
@@ -1,9 +1,10 @@
use super::{ use super::{
FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener, FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
common::{TunnelWrapper, setup_socket2, wait_for_connect_futures}, common::{TunnelWrapper, wait_for_connect_futures},
insecure_tls::{get_insecure_tls_cert, init_crypto_provider}, insecure_tls::{get_insecure_tls_cert, init_crypto_provider},
packet_def::{ZCPacket, ZCPacketType}, packet_def::{ZCPacket, ZCPacketType},
}; };
use crate::tunnel::common::bind;
use crate::{proto::common::TunnelInfo, tunnel::insecure_tls::get_insecure_tls_client_config}; use crate::{proto::common::TunnelInfo, tunnel::insecure_tls::get_insecure_tls_client_config};
use anyhow::Context; use anyhow::Context;
use bytes::BytesMut; use bytes::BytesMut;
@@ -160,20 +161,16 @@ impl WsTunnelListener {
#[async_trait::async_trait] #[async_trait::async_trait]
impl TunnelListener for WsTunnelListener { impl TunnelListener for WsTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> { async fn listen(&mut self) -> Result<(), TunnelError> {
self.listener = None;
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?; let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new( let listener = bind::<TcpListener>().addr(addr).only_v6(true).call()?;
socket2::Domain::for_address(addr),
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
setup_socket2(&socket2_socket, &addr, true)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into());
self.addr self.addr
.set_port(Some(socket.local_addr()?.port())) .set_port(Some(listener.local_addr()?.port()))
.unwrap(); .unwrap();
self.listener = Some(listener);
self.listener = Some(socket.listen(1024)?);
Ok(()) Ok(())
} }
@@ -283,25 +280,18 @@ impl WsTunnelConnector {
let futures = FuturesUnordered::new(); let futures = FuturesUnordered::new();
for bind_addr in self.bind_addrs.iter() { for bind_addr in self.bind_addrs.iter() {
tracing::info!(bind_addr = ?bind_addr, ?addr, "bind addr"); tracing::info!(?bind_addr, ?addr, "bind addr");
match bind().addr(*bind_addr).only_v6(true).call() {
let socket2_socket = socket2::Socket::new( Ok(socket) => futures.push(Self::connect_with(
socket2::Domain::for_address(addr), self.addr.clone(),
socket2::Type::STREAM, self.ip_version,
Some(socket2::Protocol::TCP), socket,
)?; )),
Err(error) => {
if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) { tracing::error!(?bind_addr, ?addr, ?error, "bind addr fail");
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e); continue;
continue; }
} }
let socket = TcpSocket::from_std_stream(socket2_socket.into());
futures.push(Self::connect_with(
self.addr.clone(),
self.ip_version,
socket,
))
} }
wait_for_connect_futures(futures).await wait_for_connect_futures(futures).await
+40 -58
View File
@@ -6,6 +6,23 @@ use std::{
time::Duration, time::Duration,
}; };
use super::{
FromUrl, IpVersion, Tunnel, TunnelError, TunnelInfo, TunnelListener, TunnelUrl, ZCPacketSink,
ZCPacketStream,
common::wait_for_connect_futures,
generate_digest_from_str,
packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacketType},
ring::create_ring_tunnel_pair,
};
use crate::tunnel::common::{BindDev, bind};
use crate::{
common::shrink_dashmap,
tunnel::{
build_url_from_socket_addr,
common::TunnelWrapper,
packet_def::{WG_TUNNEL_HEADER_SIZE, ZCPacket},
},
};
use anyhow::Context; use anyhow::Context;
use async_recursion::async_recursion; use async_recursion::async_recursion;
use async_trait::async_trait; use async_trait::async_trait;
@@ -20,23 +37,6 @@ use futures::{SinkExt, StreamExt, stream::FuturesUnordered};
use rand::RngCore; use rand::RngCore;
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet}; use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use super::{
FromUrl, IpVersion, Tunnel, TunnelError, TunnelInfo, TunnelListener, TunnelUrl, ZCPacketSink,
ZCPacketStream,
common::{setup_socket2, setup_socket2_ext, wait_for_connect_futures},
generate_digest_from_str,
packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacketType},
ring::create_ring_tunnel_pair,
};
use crate::{
common::shrink_dashmap,
tunnel::{
build_url_from_socket_addr,
common::TunnelWrapper,
packet_def::{WG_TUNNEL_HEADER_SIZE, ZCPacket},
},
};
const MAX_PACKET: usize = 2048; const MAX_PACKET: usize = 2048;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -555,20 +555,14 @@ impl WgTunnelListener {
impl TunnelListener for WgTunnelListener { impl TunnelListener for WgTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> { async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?; let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
let tunnel_url: TunnelUrl = self.addr.clone().into(); let tunnel_url: TunnelUrl = self.addr.clone().into();
if let Some(bind_dev) = tunnel_url.bind_dev() { self.udp = Some(Arc::new(
setup_socket2_ext(&socket2_socket, &addr, Some(bind_dev), true)?; bind()
} else { .addr(addr)
setup_socket2(&socket2_socket, &addr, true)?; .only_v6(true)
} .maybe_dev(tunnel_url.bind_dev())
.call()?,
self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?)); ));
self.addr self.addr
.set_port(Some(self.udp.as_ref().unwrap().local_addr()?.port())) .set_port(Some(self.udp.as_ref().unwrap().local_addr()?.port()))
.unwrap(); .unwrap();
@@ -695,13 +689,11 @@ impl WgTunnelConnector {
} }
async fn connect_with_ipv6(&self, addr: SocketAddr) -> Result<Box<dyn Tunnel>, TunnelError> { async fn connect_with_ipv6(&self, addr: SocketAddr) -> Result<Box<dyn Tunnel>, TunnelError> {
let socket2_socket = socket2::Socket::new( let socket = bind()
socket2::Domain::for_address(addr), .addr("[::]:0".parse().unwrap())
socket2::Type::DGRAM, .dev(BindDev::Disabled)
Some(socket2::Protocol::UDP), .only_v6(true)
)?; .call()?;
setup_socket2_ext(&socket2_socket, &"[::]:0".parse().unwrap(), None, true)?;
let socket = UdpSocket::from_std(socket2_socket.into())?;
Self::connect_with_socket(self.addr.clone(), self.config.clone(), socket, addr).await Self::connect_with_socket(self.addr.clone(), self.config.clone(), socket, addr).await
} }
} }
@@ -723,29 +715,19 @@ impl super::TunnelConnector for WgTunnelConnector {
}; };
let futures = FuturesUnordered::new(); let futures = FuturesUnordered::new();
for bind_addr in bind_addrs.into_iter() { for bind_addr in bind_addrs.into_iter() {
let socket2_socket = socket2::Socket::new( tracing::info!(?bind_addr, ?addr, "bind addr");
socket2::Domain::for_address(bind_addr), match bind().addr(bind_addr).only_v6(true).call() {
socket2::Type::DGRAM, Ok(socket) => futures.push(Self::connect_with_socket(
Some(socket2::Protocol::UDP), self.addr.clone(),
)?; self.config.clone(),
if let Err(e) = setup_socket2(&socket2_socket, &bind_addr, true) { socket,
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e); addr,
continue; )),
} Err(error) => {
let socket = match UdpSocket::from_std(socket2_socket.into()) { tracing::error!(?error, ?bind_addr, ?addr, "bind addr fail");
Ok(s) => s,
Err(e) => {
tracing::error!(bind_addr = ?bind_addr, ?addr, "create udp socket fail: {:?}", e);
continue; continue;
} }
}; }
tracing::info!(?bind_addr, ?self.addr, "prepare wg connect task");
futures.push(Self::connect_with_socket(
self.addr.clone(),
self.config.clone(),
socket,
addr,
));
} }
wait_for_connect_futures(futures).await wait_for_connect_futures(futures).await