tunnel(bind): gather all bind logic to a single function (#2070)

* extract a Bindable trait for binding TcpSocket, TcpListener, and UdpSocket
This commit is contained in:
Luna Yao
2026-04-12 16:16:58 +02:00
committed by GitHub
parent 869e1b89f5
commit 6f3e708679
10 changed files with 370 additions and 5846 deletions
Generated
+82 -5565
View File
File diff suppressed because it is too large Load Diff
+1
View File
@@ -242,6 +242,7 @@ hickory-server = { version = "0.25.2", features = [
"resolver",
], optional = true }
bon = "3.9.1"
derive_builder = "0.20.2"
humantime-serde = "1.1.1"
multimap = "0.10.1"
+20 -47
View File
@@ -18,7 +18,7 @@ use crate::gateway::kcp_proxy::NatDstKcpConnector;
use crate::{
common::{
config::PortForwardConfig, global_ctx::GlobalCtxEvent, join_joinset_background,
netns::NetNS, scoped_task::ScopedTask,
scoped_task::ScopedTask,
},
gateway::{
fast_socks5::{
@@ -30,10 +30,7 @@ use crate::{
ip_reassembler::IpReassembler,
tokio_smoltcp::{BufferSize, Net, NetConfig, channel_device},
},
tunnel::{
common::setup_socket2,
packet_def::{PacketType, ZCPacket},
},
tunnel::packet_def::{PacketType, ZCPacket},
};
use anyhow::Context;
use dashmap::DashMap;
@@ -42,21 +39,21 @@ use pnet::packet::{
};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpSocket, UdpSocket},
net::{TcpListener, UdpSocket},
select,
sync::{Mutex, Notify, mpsc},
task::JoinSet,
time::timeout,
};
#[cfg(feature = "kcp")]
use super::tcp_proxy::NatDstConnector as _;
use crate::tunnel::common::bind;
use crate::{
common::{error::Error, global_ctx::GlobalCtx},
peers::{PeerPacketFilter, peer_manager::PeerManager},
};
#[cfg(feature = "kcp")]
use super::tcp_proxy::NatDstConnector as _;
enum SocksUdpSocket {
UdpSocket(Arc<tokio::net::UdpSocket>),
SmolUdpSocket(super::tokio_smoltcp::UdpSocket),
@@ -328,38 +325,6 @@ impl AsyncTcpConnector for Socks5AutoConnector {
}
}
fn bind_tcp_socket(addr: SocketAddr, net_ns: NetNS) -> Result<TcpListener, Error> {
let _g = net_ns.guard();
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
setup_socket2(&socket2_socket, &addr, true)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into());
if let Err(e) = socket.set_nodelay(true) {
tracing::warn!(?e, "set_nodelay fail in listen");
}
Ok(socket.listen(1024)?)
}
fn bind_udp_socket(addr: SocketAddr, net_ns: NetNS) -> Result<UdpSocket, Error> {
let _g = net_ns.guard();
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_socket2(&socket2_socket, &addr, true)?;
Ok(UdpSocket::from_std(socket2_socket.into())?)
}
struct Socks5ServerNet {
ipv4_addr: cidr::Ipv4Inet,
auth: Option<SimpleUserPassword>,
@@ -702,10 +667,10 @@ impl Socks5Server {
proxy_url.port().unwrap()
);
let listener = bind_tcp_socket(
bind_addr.parse::<SocketAddr>().unwrap(),
self.global_ctx.net_ns.clone(),
)?;
let listener = bind::<TcpListener>()
.addr(bind_addr.parse::<SocketAddr>().unwrap())
.net_ns(self.global_ctx.net_ns.clone())
.call()?;
let entries = self.entries.clone();
let entry_count = self.entry_count.clone();
@@ -838,7 +803,10 @@ impl Socks5Server {
pub async fn add_tcp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> {
let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr);
let listener = bind_tcp_socket(bind_addr, self.global_ctx.net_ns.clone())?;
let listener = bind::<TcpListener>()
.addr(bind_addr)
.net_ns(self.global_ctx.net_ns.clone())
.call()?;
let net = self.net.clone();
let entries = self.entries.clone();
@@ -906,7 +874,12 @@ impl Socks5Server {
#[tracing::instrument(name = "add_udp_port_forward", skip(self))]
pub async fn add_udp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> {
let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr);
let socket = Arc::new(bind_udp_socket(bind_addr, self.global_ctx.net_ns.clone())?);
let socket = Arc::new(
bind::<UdpSocket>()
.addr(bind_addr)
.net_ns(self.global_ctx.net_ns.clone())
.call()?,
);
let entries = self.entries.clone();
let entry_count = self.entry_count.clone();
+10 -20
View File
@@ -24,18 +24,18 @@ use tokio::{
use tracing::Level;
use super::{CidrSet, ip_reassembler::IpReassembler};
use crate::tunnel::common::bind;
use crate::{
common::{PeerId, error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask},
gateway::ip_reassembler::{ComposeIpv4PacketArgs, compose_ipv4_packet},
peers::{PeerPacketFilter, peer_manager::PeerManager},
tunnel::{
common::{reserve_buf, setup_socket2},
common::reserve_buf,
packet_def::{PacketType, ZCPacket},
},
};
use super::{CidrSet, ip_reassembler::IpReassembler};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct UdpNatKey {
src_socket: SocketAddr,
@@ -63,18 +63,9 @@ impl UdpNatEntry {
denied: bool,
) -> Result<Self, Error> {
// TODO: try use src port, so we will be ip restricted nat type
let socket = if denied {
None
} else {
let socket2_socket = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
let dst_socket_addr = "0.0.0.0:0".parse().unwrap();
setup_socket2(&socket2_socket, &dst_socket_addr, true)?;
Some(UdpSocket::from_std(socket2_socket.into())?)
};
let socket = (!denied)
.then(|| bind().addr("0.0.0.0:0".parse().unwrap()).call())
.transpose()?;
Ok(Self {
src_peer_id,
@@ -403,11 +394,10 @@ impl UdpProxy {
#[async_trait::async_trait]
impl PeerPacketFilter for UdpProxy {
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
if self.try_handle_packet(&packet).await.is_some() {
return None;
} else {
return Some(packet);
}
self.try_handle_packet(&packet)
.await
.is_none()
.then_some(packet)
}
}
+138 -43
View File
@@ -1,3 +1,7 @@
use bon::builder;
use futures::{Future, Sink, Stream, stream::FuturesUnordered};
use network_interface::NetworkInterfaceConfig as _;
use pin_project_lite::pin_project;
use std::{
any::Any,
net::{IpAddr, SocketAddr},
@@ -5,26 +9,21 @@ use std::{
sync::{Arc, Mutex},
task::{Poll, ready},
};
use futures::{Future, Sink, Stream, stream::FuturesUnordered};
use network_interface::NetworkInterfaceConfig as _;
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use tokio_stream::StreamExt;
use tokio_util::io::poll_write_buf;
use zerocopy::FromBytes as _;
use super::TunnelInfo;
use crate::tunnel::packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacket};
use super::{
SinkItem, StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream,
buf::BufList,
packet_def::{TCP_TUNNEL_HEADER_SIZE, TCPTunnelHeader, ZCPacketType},
};
use crate::common::netns::NetNS;
use crate::tunnel::packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacket};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use tokio::net::{TcpListener, TcpSocket, UdpSocket};
use tokio_stream::StreamExt;
use tokio_util::io::poll_write_buf;
use zerocopy::FromBytes as _;
pub struct TunnelWrapper<R, W> {
reader: Arc<Mutex<Option<R>>>,
@@ -344,7 +343,70 @@ pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
None
}
pub(crate) fn setup_socket2_ext(
pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
mut futures: FuturesUnordered<Fut>,
) -> Result<Ret, TunnelError>
where
Fut: Future<Output = Result<Ret, E>> + Send,
E: std::error::Error + Into<TunnelError> + Send + 'static,
{
// return last error
let mut last_err = None;
while let Some(ret) = futures.next().await {
if let Err(e) = ret {
last_err = Some(e.into());
} else {
return ret.map_err(|e| e.into());
}
}
Err(last_err.unwrap_or(TunnelError::Shutdown))
}
// region bind
pub trait Bindable: Sized {
const TYPE: socket2::Type;
const PROTOCOL: Option<socket2::Protocol>;
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError>;
}
impl Bindable for TcpSocket {
const TYPE: socket2::Type = socket2::Type::STREAM;
const PROTOCOL: Option<socket2::Protocol> = Some(socket2::Protocol::TCP);
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError> {
let socket = TcpSocket::from_std_stream(socket.into());
if let Err(error) = socket.set_nodelay(true) {
tracing::warn!(?error, "set_nodelay failed for tcp socket");
}
Ok(socket)
}
}
impl Bindable for TcpListener {
const TYPE: socket2::Type = socket2::Type::STREAM;
const PROTOCOL: Option<socket2::Protocol> = Some(socket2::Protocol::TCP);
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError> {
Ok(TcpSocket::finalize(socket)?.listen(1024)?)
}
}
impl Bindable for UdpSocket {
const TYPE: socket2::Type = socket2::Type::DGRAM;
const PROTOCOL: Option<socket2::Protocol> = Some(socket2::Protocol::UDP);
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError> {
Ok(UdpSocket::from_std(socket.into())?)
}
}
fn setup_socket2_ext(
socket2_socket: &socket2::Socket,
bind_addr: &SocketAddr,
#[allow(unused_variables)] bind_dev: Option<String>,
@@ -408,38 +470,69 @@ pub(crate) fn setup_socket2_ext(
Ok(())
}
pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
mut futures: FuturesUnordered<Fut>,
) -> Result<Ret, TunnelError>
where
Fut: Future<Output = Result<Ret, E>> + Send,
E: std::error::Error + Into<TunnelError> + Send + 'static,
{
// return last error
let mut last_err = None;
while let Some(ret) = futures.next().await {
if let Err(e) = ret {
last_err = Some(e.into());
} else {
return ret.map_err(|e| e.into());
}
}
Err(last_err.unwrap_or(TunnelError::Shutdown))
#[derive(Debug, Default, Clone)]
pub enum BindDev {
#[default]
Auto,
Disabled,
Custom(String),
}
pub(crate) fn setup_socket2(
socket2_socket: &socket2::Socket,
bind_addr: &SocketAddr,
only_v6: bool,
) -> Result<(), TunnelError> {
setup_socket2_ext(
socket2_socket,
bind_addr,
super::common::get_interface_name_by_ip(&bind_addr.ip()),
only_v6,
)
impl From<String> for BindDev {
fn from(value: String) -> Self {
if value.is_empty() {
Self::Disabled
} else {
Self::Custom(value)
}
}
}
impl From<&str> for BindDev {
fn from(value: &str) -> Self {
value.to_string().into()
}
}
/// Binds a socket to a specific address and optionally a network interface.
///
/// This function creates a new socket, applies specific configurations (such as
/// binding to a device or setting IPv6-only flags), and finalizes it into the
/// requested [`Bindable`] type.
///
/// # Arguments
///
/// * `addr` - The `SocketAddr` to bind the socket to.
/// * `dev` - The name of the network interface to bind to:
/// * **(default) `BindDev::Auto`**: Enables **auto-discovery**. The function will attempt to automatically
/// resolve the interface name associated with the provided `addr.ip()`.
/// * **empty string or `BindDev::Disabled`**: **Disables** auto-discovery and
/// explicitly chooses **not** to bind to any specific device. The routing will be
/// left entirely to the OS.
/// * **non-empty string or `BindDev::Custom(..)`**: Skips auto-discovery and explicitly binds to
/// the specified interface.
/// * `net_ns` - An optional network namespace to switch into before creating the socket.
/// * `only_v6` - If `true`, sets the `IPV6_V6ONLY` flag on the socket.
///
/// # Errors
///
/// Returns a [`TunnelError`] if socket creation, configuration, or finalization fails.
#[builder]
pub fn bind<B: Bindable>(
addr: SocketAddr,
#[builder(default, into)] dev: BindDev,
net_ns: Option<NetNS>,
#[builder(default)] only_v6: bool,
) -> Result<B, TunnelError> {
let _g = net_ns.map(|n| n.guard());
let dev = match dev {
BindDev::Auto => get_interface_name_by_ip(&addr.ip()),
BindDev::Disabled => None,
BindDev::Custom(s) => Some(s),
};
let socket = socket2::Socket::new(socket2::Domain::for_address(addr), B::TYPE, B::PROTOCOL)?;
setup_socket2_ext(&socket, &addr, dev, only_v6)?;
B::finalize(socket)
}
pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) {
@@ -448,6 +541,8 @@ pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) {
}
}
// endregion
pub mod tests {
use atomic_shim::AtomicU64;
use std::{sync::Arc, time::Instant};
+27 -23
View File
@@ -4,9 +4,10 @@
use super::{FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener};
use crate::common::global_ctx::ArcGlobalCtx;
use crate::tunnel::common::bind;
use crate::tunnel::{
TunnelInfo,
common::{FramedReader, FramedWriter, TunnelWrapper, setup_socket2},
common::{FramedReader, FramedWriter, TunnelWrapper},
};
use anyhow::Context;
use derivative::Derivative;
@@ -19,6 +20,7 @@ use quinn::{
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::OnceLock;
use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio::net::UdpSocket;
// region config
pub fn transport_config() -> Arc<TransportConfig> {
@@ -179,27 +181,28 @@ pub struct QuicEndpointManager {
static QUIC_ENDPOINT_MANAGER: OnceLock<QuicEndpointManager> = OnceLock::new();
impl QuicEndpointManager {
fn try_create(addr: SocketAddr, dual_stack: bool) -> std::io::Result<Endpoint> {
let socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_socket2(&socket, &addr, addr.is_ipv6() && !dual_stack)
.map_err(std::io::Error::other)?;
let socket = std::net::UdpSocket::from(socket);
let runtime = default_runtime().ok_or(std::io::Error::other("no async runtime found"))?;
fn try_create(addr: SocketAddr, dual_stack: bool) -> Result<Endpoint, TunnelError> {
let socket = bind::<UdpSocket>()
.addr(addr)
.only_v6(addr.is_ipv6() && !dual_stack)
.call()?;
let runtime = default_runtime().ok_or(TunnelError::InternalError(
"no async runtime found".to_owned(),
))?;
let mut endpoint = Endpoint::new_with_abstract_socket(
endpoint_config(),
None,
runtime.wrap_udp_socket(socket)?,
runtime.wrap_udp_socket(socket.into_std()?)?,
runtime,
)?;
endpoint.set_default_client_config(client_config());
Ok(endpoint)
}
fn create<F>(&self, mut selector: F) -> std::io::Result<(&RwPool<Endpoint>, Option<Endpoint>)>
fn create<F>(
&self,
mut selector: F,
) -> Result<(&RwPool<Endpoint>, Option<Endpoint>), TunnelError>
where
F: FnMut(&QuicEndpointManager) -> (&RwPool<Endpoint>, Option<(SocketAddr, bool)>),
{
@@ -210,10 +213,10 @@ impl QuicEndpointManager {
};
let endpoint = Self::try_create(addr, dual_stack);
if let Err(e) = endpoint.as_ref()
if let Err(error) = endpoint.as_ref()
&& dual_stack
{
tracing::warn!("create dual stack quic endpoint failed: {:?}", e);
tracing::warn!(?error, "create dual stack quic endpoint failed");
self.both.disable();
self.ipv4.enable();
self.ipv6.enable();
@@ -263,7 +266,7 @@ impl QuicEndpointManager {
///
/// # Arguments
/// * `addr`: listen address
fn server(global_ctx: &ArcGlobalCtx, addr: SocketAddr) -> std::io::Result<Endpoint> {
fn server(global_ctx: &ArcGlobalCtx, addr: SocketAddr) -> Result<Endpoint, TunnelError> {
let mgr = Self::load(global_ctx);
let (pool, endpoint) = mgr.create(|mgr| {
@@ -289,7 +292,7 @@ impl QuicEndpointManager {
///
/// # Arguments
/// * `ip_version`: the IP version of the remote address
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> std::io::Result<Endpoint> {
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> Result<Endpoint, TunnelError> {
let mgr = Self::load(global_ctx);
let (pool, endpoint) = mgr.create(|mgr| {
@@ -318,7 +321,7 @@ impl QuicEndpointManager {
async fn connect(
global_ctx: &ArcGlobalCtx,
addr: SocketAddr,
) -> std::io::Result<(Endpoint, Connection)> {
) -> Result<(Endpoint, Connection), TunnelError> {
let ip_version = if addr.ip().is_ipv4() {
IpVersion::V4
} else {
@@ -327,8 +330,9 @@ impl QuicEndpointManager {
let endpoint = Self::client(global_ctx, ip_version)?;
let connection = endpoint
.connect(addr, "localhost")
.map_err(std::io::Error::other)?
.await?;
.with_context(|| format!("failed to create connection to {}", addr))?
.await
.with_context(|| format!("failed to connect to {}", addr))?;
Ok((endpoint, connection))
}
@@ -585,10 +589,10 @@ mod tests {
async fn invalid_peer_addr_impl() {
let mut connector =
QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap(), global_ctx());
let err = connector.connect().await.unwrap_err();
let err = format!("{:?}", connector.connect().await.unwrap_err());
assert!(
err.to_string().contains("invalid remote address"),
"unexpected error: {:?}",
err.contains("invalid remote address"),
"unexpected error: {}",
err
);
}
+12 -29
View File
@@ -1,7 +1,7 @@
use std::net::SocketAddr;
use super::{FromUrl, TunnelInfo};
use crate::tunnel::common::setup_socket2;
use crate::tunnel::common::bind;
use async_trait::async_trait;
use futures::stream::FuturesUnordered;
use tokio::net::{TcpListener, TcpSocket, TcpStream};
@@ -59,25 +59,15 @@ impl TcpTunnelListener {
impl TunnelListener for TcpTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
self.listener = None;
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
setup_socket2(&socket2_socket, &addr, true)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into());
if let Err(e) = socket.set_nodelay(true) {
tracing::warn!(?e, "set_nodelay fail in listen");
}
let listener = bind::<TcpListener>().addr(addr).only_v6(true).call()?;
self.addr
.set_port(Some(socket.local_addr()?.port()))
.set_port(Some(listener.local_addr()?.port()))
.unwrap();
self.listener = Some(listener);
self.listener = Some(socket.listen(1024)?);
Ok(())
}
@@ -167,21 +157,14 @@ impl TcpTunnelConnector {
let futures = FuturesUnordered::new();
for bind_addr in self.bind_addrs.iter() {
tracing::info!(bind_addr = ?bind_addr, ?addr, "bind addr");
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
continue;
tracing::info!(?bind_addr, ?addr, "bind addr");
match bind::<TcpSocket>().addr(*bind_addr).only_v6(true).call() {
Ok(socket) => futures.push(socket.connect(addr)),
Err(error) => {
tracing::error!(?bind_addr, ?addr, ?error, "bind addr fail");
continue;
}
}
let socket = TcpSocket::from_std_stream(socket2_socket.into());
futures.push(socket.connect(addr));
}
let ret = wait_for_connect_futures(futures).await;
+22 -33
View File
@@ -18,16 +18,16 @@ use tokio::{
sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender},
task::JoinSet,
};
use tracing::{Instrument, instrument};
use super::{
FromUrl, IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelInfo, TunnelListener,
TunnelUrl,
common::{setup_socket2, setup_socket2_ext, wait_for_connect_futures},
common::wait_for_connect_futures,
packet_def::{UDP_TUNNEL_HEADER_SIZE, UDPTunnelHeader, V6HolePunchPacket},
ring::{RingSink, RingStream},
};
use crate::tunnel::common::bind;
use crate::{
common::{join_joinset_background, scoped_task::ScopedTask, shrink_dashmap},
tunnel::{
@@ -536,21 +536,14 @@ impl UdpTunnelListener {
impl TunnelListener for UdpTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
let tunnel_url: TunnelUrl = self.addr.clone().into();
if let Some(bind_dev) = tunnel_url.bind_dev() {
setup_socket2_ext(&socket2_socket, &addr, Some(bind_dev), true)?;
} else {
setup_socket2(&socket2_socket, &addr, true)?;
}
self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
self.socket = Some(Arc::new(
bind()
.addr(addr)
.only_v6(true)
.maybe_dev(tunnel_url.bind_dev())
.call()?,
));
self.data.socket = self.socket.clone();
self.addr
@@ -833,17 +826,14 @@ impl UdpTunnelConnector {
let futures = FuturesUnordered::new();
for bind_addr in self.bind_addrs.iter() {
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(*bind_addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
continue;
tracing::info!(?bind_addr, ?addr, "bind addr");
match bind().addr(*bind_addr).only_v6(true).call() {
Ok(socket) => futures.push(self.try_connect_with_socket(Arc::new(socket), addr)),
Err(error) => {
tracing::error!(?error, ?bind_addr, ?addr, "bind addr fail");
continue;
}
}
let socket = UdpSocket::from_std(socket2_socket.into())?;
futures.push(self.try_connect_with_socket(Arc::new(socket), addr));
}
wait_for_connect_futures(futures).await
}
@@ -1034,13 +1024,12 @@ mod tests {
)
.await
.unwrap();
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)
.unwrap();
setup_socket2_ext(&socket2_socket, &addr, bind_dev.clone(), true).unwrap();
let _ = bind::<UdpSocket>()
.addr(addr)
.maybe_dev(bind_dev.clone())
.only_v6(true)
.call()
.unwrap();
}
}
+18 -28
View File
@@ -1,9 +1,10 @@
use super::{
FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
common::{TunnelWrapper, setup_socket2, wait_for_connect_futures},
common::{TunnelWrapper, wait_for_connect_futures},
insecure_tls::{get_insecure_tls_cert, init_crypto_provider},
packet_def::{ZCPacket, ZCPacketType},
};
use crate::tunnel::common::bind;
use crate::{proto::common::TunnelInfo, tunnel::insecure_tls::get_insecure_tls_client_config};
use anyhow::Context;
use bytes::BytesMut;
@@ -160,20 +161,16 @@ impl WsTunnelListener {
#[async_trait::async_trait]
impl TunnelListener for WsTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
self.listener = None;
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
setup_socket2(&socket2_socket, &addr, true)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into());
let listener = bind::<TcpListener>().addr(addr).only_v6(true).call()?;
self.addr
.set_port(Some(socket.local_addr()?.port()))
.set_port(Some(listener.local_addr()?.port()))
.unwrap();
self.listener = Some(listener);
self.listener = Some(socket.listen(1024)?);
Ok(())
}
@@ -283,25 +280,18 @@ impl WsTunnelConnector {
let futures = FuturesUnordered::new();
for bind_addr in self.bind_addrs.iter() {
tracing::info!(bind_addr = ?bind_addr, ?addr, "bind addr");
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
continue;
tracing::info!(?bind_addr, ?addr, "bind addr");
match bind().addr(*bind_addr).only_v6(true).call() {
Ok(socket) => futures.push(Self::connect_with(
self.addr.clone(),
self.ip_version,
socket,
)),
Err(error) => {
tracing::error!(?bind_addr, ?addr, ?error, "bind addr fail");
continue;
}
}
let socket = TcpSocket::from_std_stream(socket2_socket.into());
futures.push(Self::connect_with(
self.addr.clone(),
self.ip_version,
socket,
))
}
wait_for_connect_futures(futures).await
+40 -58
View File
@@ -6,6 +6,23 @@ use std::{
time::Duration,
};
use super::{
FromUrl, IpVersion, Tunnel, TunnelError, TunnelInfo, TunnelListener, TunnelUrl, ZCPacketSink,
ZCPacketStream,
common::wait_for_connect_futures,
generate_digest_from_str,
packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacketType},
ring::create_ring_tunnel_pair,
};
use crate::tunnel::common::{BindDev, bind};
use crate::{
common::shrink_dashmap,
tunnel::{
build_url_from_socket_addr,
common::TunnelWrapper,
packet_def::{WG_TUNNEL_HEADER_SIZE, ZCPacket},
},
};
use anyhow::Context;
use async_recursion::async_recursion;
use async_trait::async_trait;
@@ -20,23 +37,6 @@ use futures::{SinkExt, StreamExt, stream::FuturesUnordered};
use rand::RngCore;
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use super::{
FromUrl, IpVersion, Tunnel, TunnelError, TunnelInfo, TunnelListener, TunnelUrl, ZCPacketSink,
ZCPacketStream,
common::{setup_socket2, setup_socket2_ext, wait_for_connect_futures},
generate_digest_from_str,
packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacketType},
ring::create_ring_tunnel_pair,
};
use crate::{
common::shrink_dashmap,
tunnel::{
build_url_from_socket_addr,
common::TunnelWrapper,
packet_def::{WG_TUNNEL_HEADER_SIZE, ZCPacket},
},
};
const MAX_PACKET: usize = 2048;
#[derive(Debug, Clone)]
@@ -555,20 +555,14 @@ impl WgTunnelListener {
impl TunnelListener for WgTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
let tunnel_url: TunnelUrl = self.addr.clone().into();
if let Some(bind_dev) = tunnel_url.bind_dev() {
setup_socket2_ext(&socket2_socket, &addr, Some(bind_dev), true)?;
} else {
setup_socket2(&socket2_socket, &addr, true)?;
}
self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
self.udp = Some(Arc::new(
bind()
.addr(addr)
.only_v6(true)
.maybe_dev(tunnel_url.bind_dev())
.call()?,
));
self.addr
.set_port(Some(self.udp.as_ref().unwrap().local_addr()?.port()))
.unwrap();
@@ -695,13 +689,11 @@ impl WgTunnelConnector {
}
async fn connect_with_ipv6(&self, addr: SocketAddr) -> Result<Box<dyn Tunnel>, TunnelError> {
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_socket2_ext(&socket2_socket, &"[::]:0".parse().unwrap(), None, true)?;
let socket = UdpSocket::from_std(socket2_socket.into())?;
let socket = bind()
.addr("[::]:0".parse().unwrap())
.dev(BindDev::Disabled)
.only_v6(true)
.call()?;
Self::connect_with_socket(self.addr.clone(), self.config.clone(), socket, addr).await
}
}
@@ -723,29 +715,19 @@ impl super::TunnelConnector for WgTunnelConnector {
};
let futures = FuturesUnordered::new();
for bind_addr in bind_addrs.into_iter() {
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(bind_addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
if let Err(e) = setup_socket2(&socket2_socket, &bind_addr, true) {
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
continue;
}
let socket = match UdpSocket::from_std(socket2_socket.into()) {
Ok(s) => s,
Err(e) => {
tracing::error!(bind_addr = ?bind_addr, ?addr, "create udp socket fail: {:?}", e);
tracing::info!(?bind_addr, ?addr, "bind addr");
match bind().addr(bind_addr).only_v6(true).call() {
Ok(socket) => futures.push(Self::connect_with_socket(
self.addr.clone(),
self.config.clone(),
socket,
addr,
)),
Err(error) => {
tracing::error!(?error, ?bind_addr, ?addr, "bind addr fail");
continue;
}
};
tracing::info!(?bind_addr, ?self.addr, "prepare wg connect task");
futures.push(Self::connect_with_socket(
self.addr.clone(),
self.config.clone(),
socket,
addr,
));
}
}
wait_for_connect_futures(futures).await