mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-04-22 16:17:23 +08:00
tunnel(bind): gather all bind logic to a single function (#2070)
* extract a Bindable trait for binding TcpSocket, TcpListener, and UdpSocket
This commit is contained in:
Generated
+82
-5565
File diff suppressed because it is too large
Load Diff
@@ -242,6 +242,7 @@ hickory-server = { version = "0.25.2", features = [
|
||||
"resolver",
|
||||
], optional = true }
|
||||
|
||||
bon = "3.9.1"
|
||||
derive_builder = "0.20.2"
|
||||
humantime-serde = "1.1.1"
|
||||
multimap = "0.10.1"
|
||||
|
||||
@@ -18,7 +18,7 @@ use crate::gateway::kcp_proxy::NatDstKcpConnector;
|
||||
use crate::{
|
||||
common::{
|
||||
config::PortForwardConfig, global_ctx::GlobalCtxEvent, join_joinset_background,
|
||||
netns::NetNS, scoped_task::ScopedTask,
|
||||
scoped_task::ScopedTask,
|
||||
},
|
||||
gateway::{
|
||||
fast_socks5::{
|
||||
@@ -30,10 +30,7 @@ use crate::{
|
||||
ip_reassembler::IpReassembler,
|
||||
tokio_smoltcp::{BufferSize, Net, NetConfig, channel_device},
|
||||
},
|
||||
tunnel::{
|
||||
common::setup_socket2,
|
||||
packet_def::{PacketType, ZCPacket},
|
||||
},
|
||||
tunnel::packet_def::{PacketType, ZCPacket},
|
||||
};
|
||||
use anyhow::Context;
|
||||
use dashmap::DashMap;
|
||||
@@ -42,21 +39,21 @@ use pnet::packet::{
|
||||
};
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
net::{TcpListener, TcpSocket, UdpSocket},
|
||||
net::{TcpListener, UdpSocket},
|
||||
select,
|
||||
sync::{Mutex, Notify, mpsc},
|
||||
task::JoinSet,
|
||||
time::timeout,
|
||||
};
|
||||
|
||||
#[cfg(feature = "kcp")]
|
||||
use super::tcp_proxy::NatDstConnector as _;
|
||||
use crate::tunnel::common::bind;
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::GlobalCtx},
|
||||
peers::{PeerPacketFilter, peer_manager::PeerManager},
|
||||
};
|
||||
|
||||
#[cfg(feature = "kcp")]
|
||||
use super::tcp_proxy::NatDstConnector as _;
|
||||
|
||||
enum SocksUdpSocket {
|
||||
UdpSocket(Arc<tokio::net::UdpSocket>),
|
||||
SmolUdpSocket(super::tokio_smoltcp::UdpSocket),
|
||||
@@ -328,38 +325,6 @@ impl AsyncTcpConnector for Socks5AutoConnector {
|
||||
}
|
||||
}
|
||||
|
||||
fn bind_tcp_socket(addr: SocketAddr, net_ns: NetNS) -> Result<TcpListener, Error> {
|
||||
let _g = net_ns.guard();
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::STREAM,
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
|
||||
setup_socket2(&socket2_socket, &addr, true)?;
|
||||
|
||||
let socket = TcpSocket::from_std_stream(socket2_socket.into());
|
||||
|
||||
if let Err(e) = socket.set_nodelay(true) {
|
||||
tracing::warn!(?e, "set_nodelay fail in listen");
|
||||
}
|
||||
|
||||
Ok(socket.listen(1024)?)
|
||||
}
|
||||
|
||||
fn bind_udp_socket(addr: SocketAddr, net_ns: NetNS) -> Result<UdpSocket, Error> {
|
||||
let _g = net_ns.guard();
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
|
||||
setup_socket2(&socket2_socket, &addr, true)?;
|
||||
|
||||
Ok(UdpSocket::from_std(socket2_socket.into())?)
|
||||
}
|
||||
|
||||
struct Socks5ServerNet {
|
||||
ipv4_addr: cidr::Ipv4Inet,
|
||||
auth: Option<SimpleUserPassword>,
|
||||
@@ -702,10 +667,10 @@ impl Socks5Server {
|
||||
proxy_url.port().unwrap()
|
||||
);
|
||||
|
||||
let listener = bind_tcp_socket(
|
||||
bind_addr.parse::<SocketAddr>().unwrap(),
|
||||
self.global_ctx.net_ns.clone(),
|
||||
)?;
|
||||
let listener = bind::<TcpListener>()
|
||||
.addr(bind_addr.parse::<SocketAddr>().unwrap())
|
||||
.net_ns(self.global_ctx.net_ns.clone())
|
||||
.call()?;
|
||||
|
||||
let entries = self.entries.clone();
|
||||
let entry_count = self.entry_count.clone();
|
||||
@@ -838,7 +803,10 @@ impl Socks5Server {
|
||||
|
||||
pub async fn add_tcp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> {
|
||||
let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr);
|
||||
let listener = bind_tcp_socket(bind_addr, self.global_ctx.net_ns.clone())?;
|
||||
let listener = bind::<TcpListener>()
|
||||
.addr(bind_addr)
|
||||
.net_ns(self.global_ctx.net_ns.clone())
|
||||
.call()?;
|
||||
|
||||
let net = self.net.clone();
|
||||
let entries = self.entries.clone();
|
||||
@@ -906,7 +874,12 @@ impl Socks5Server {
|
||||
#[tracing::instrument(name = "add_udp_port_forward", skip(self))]
|
||||
pub async fn add_udp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> {
|
||||
let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr);
|
||||
let socket = Arc::new(bind_udp_socket(bind_addr, self.global_ctx.net_ns.clone())?);
|
||||
let socket = Arc::new(
|
||||
bind::<UdpSocket>()
|
||||
.addr(bind_addr)
|
||||
.net_ns(self.global_ctx.net_ns.clone())
|
||||
.call()?,
|
||||
);
|
||||
|
||||
let entries = self.entries.clone();
|
||||
let entry_count = self.entry_count.clone();
|
||||
|
||||
@@ -24,18 +24,18 @@ use tokio::{
|
||||
|
||||
use tracing::Level;
|
||||
|
||||
use super::{CidrSet, ip_reassembler::IpReassembler};
|
||||
use crate::tunnel::common::bind;
|
||||
use crate::{
|
||||
common::{PeerId, error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask},
|
||||
gateway::ip_reassembler::{ComposeIpv4PacketArgs, compose_ipv4_packet},
|
||||
peers::{PeerPacketFilter, peer_manager::PeerManager},
|
||||
tunnel::{
|
||||
common::{reserve_buf, setup_socket2},
|
||||
common::reserve_buf,
|
||||
packet_def::{PacketType, ZCPacket},
|
||||
},
|
||||
};
|
||||
|
||||
use super::{CidrSet, ip_reassembler::IpReassembler};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct UdpNatKey {
|
||||
src_socket: SocketAddr,
|
||||
@@ -63,18 +63,9 @@ impl UdpNatEntry {
|
||||
denied: bool,
|
||||
) -> Result<Self, Error> {
|
||||
// TODO: try use src port, so we will be ip restricted nat type
|
||||
let socket = if denied {
|
||||
None
|
||||
} else {
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::IPV4,
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
let dst_socket_addr = "0.0.0.0:0".parse().unwrap();
|
||||
setup_socket2(&socket2_socket, &dst_socket_addr, true)?;
|
||||
Some(UdpSocket::from_std(socket2_socket.into())?)
|
||||
};
|
||||
let socket = (!denied)
|
||||
.then(|| bind().addr("0.0.0.0:0".parse().unwrap()).call())
|
||||
.transpose()?;
|
||||
|
||||
Ok(Self {
|
||||
src_peer_id,
|
||||
@@ -403,11 +394,10 @@ impl UdpProxy {
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for UdpProxy {
|
||||
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
|
||||
if self.try_handle_packet(&packet).await.is_some() {
|
||||
return None;
|
||||
} else {
|
||||
return Some(packet);
|
||||
}
|
||||
self.try_handle_packet(&packet)
|
||||
.await
|
||||
.is_none()
|
||||
.then_some(packet)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+138
-43
@@ -1,3 +1,7 @@
|
||||
use bon::builder;
|
||||
use futures::{Future, Sink, Stream, stream::FuturesUnordered};
|
||||
use network_interface::NetworkInterfaceConfig as _;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
any::Any,
|
||||
net::{IpAddr, SocketAddr},
|
||||
@@ -5,26 +9,21 @@ use std::{
|
||||
sync::{Arc, Mutex},
|
||||
task::{Poll, ready},
|
||||
};
|
||||
|
||||
use futures::{Future, Sink, Stream, stream::FuturesUnordered};
|
||||
use network_interface::NetworkInterfaceConfig as _;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_util::io::poll_write_buf;
|
||||
use zerocopy::FromBytes as _;
|
||||
|
||||
use super::TunnelInfo;
|
||||
|
||||
use crate::tunnel::packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacket};
|
||||
|
||||
use super::{
|
||||
SinkItem, StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream,
|
||||
buf::BufList,
|
||||
packet_def::{TCP_TUNNEL_HEADER_SIZE, TCPTunnelHeader, ZCPacketType},
|
||||
};
|
||||
use crate::common::netns::NetNS;
|
||||
use crate::tunnel::packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacket};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use tokio::net::{TcpListener, TcpSocket, UdpSocket};
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_util::io::poll_write_buf;
|
||||
use zerocopy::FromBytes as _;
|
||||
|
||||
pub struct TunnelWrapper<R, W> {
|
||||
reader: Arc<Mutex<Option<R>>>,
|
||||
@@ -344,7 +343,70 @@ pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn setup_socket2_ext(
|
||||
pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
|
||||
mut futures: FuturesUnordered<Fut>,
|
||||
) -> Result<Ret, TunnelError>
|
||||
where
|
||||
Fut: Future<Output = Result<Ret, E>> + Send,
|
||||
E: std::error::Error + Into<TunnelError> + Send + 'static,
|
||||
{
|
||||
// return last error
|
||||
let mut last_err = None;
|
||||
|
||||
while let Some(ret) = futures.next().await {
|
||||
if let Err(e) = ret {
|
||||
last_err = Some(e.into());
|
||||
} else {
|
||||
return ret.map_err(|e| e.into());
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_err.unwrap_or(TunnelError::Shutdown))
|
||||
}
|
||||
|
||||
// region bind
|
||||
|
||||
pub trait Bindable: Sized {
|
||||
const TYPE: socket2::Type;
|
||||
const PROTOCOL: Option<socket2::Protocol>;
|
||||
|
||||
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError>;
|
||||
}
|
||||
|
||||
impl Bindable for TcpSocket {
|
||||
const TYPE: socket2::Type = socket2::Type::STREAM;
|
||||
const PROTOCOL: Option<socket2::Protocol> = Some(socket2::Protocol::TCP);
|
||||
|
||||
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError> {
|
||||
let socket = TcpSocket::from_std_stream(socket.into());
|
||||
|
||||
if let Err(error) = socket.set_nodelay(true) {
|
||||
tracing::warn!(?error, "set_nodelay failed for tcp socket");
|
||||
}
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bindable for TcpListener {
|
||||
const TYPE: socket2::Type = socket2::Type::STREAM;
|
||||
const PROTOCOL: Option<socket2::Protocol> = Some(socket2::Protocol::TCP);
|
||||
|
||||
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError> {
|
||||
Ok(TcpSocket::finalize(socket)?.listen(1024)?)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bindable for UdpSocket {
|
||||
const TYPE: socket2::Type = socket2::Type::DGRAM;
|
||||
const PROTOCOL: Option<socket2::Protocol> = Some(socket2::Protocol::UDP);
|
||||
|
||||
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError> {
|
||||
Ok(UdpSocket::from_std(socket.into())?)
|
||||
}
|
||||
}
|
||||
|
||||
fn setup_socket2_ext(
|
||||
socket2_socket: &socket2::Socket,
|
||||
bind_addr: &SocketAddr,
|
||||
#[allow(unused_variables)] bind_dev: Option<String>,
|
||||
@@ -408,38 +470,69 @@ pub(crate) fn setup_socket2_ext(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
|
||||
mut futures: FuturesUnordered<Fut>,
|
||||
) -> Result<Ret, TunnelError>
|
||||
where
|
||||
Fut: Future<Output = Result<Ret, E>> + Send,
|
||||
E: std::error::Error + Into<TunnelError> + Send + 'static,
|
||||
{
|
||||
// return last error
|
||||
let mut last_err = None;
|
||||
|
||||
while let Some(ret) = futures.next().await {
|
||||
if let Err(e) = ret {
|
||||
last_err = Some(e.into());
|
||||
} else {
|
||||
return ret.map_err(|e| e.into());
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_err.unwrap_or(TunnelError::Shutdown))
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub enum BindDev {
|
||||
#[default]
|
||||
Auto,
|
||||
Disabled,
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
pub(crate) fn setup_socket2(
|
||||
socket2_socket: &socket2::Socket,
|
||||
bind_addr: &SocketAddr,
|
||||
only_v6: bool,
|
||||
) -> Result<(), TunnelError> {
|
||||
setup_socket2_ext(
|
||||
socket2_socket,
|
||||
bind_addr,
|
||||
super::common::get_interface_name_by_ip(&bind_addr.ip()),
|
||||
only_v6,
|
||||
)
|
||||
impl From<String> for BindDev {
|
||||
fn from(value: String) -> Self {
|
||||
if value.is_empty() {
|
||||
Self::Disabled
|
||||
} else {
|
||||
Self::Custom(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for BindDev {
|
||||
fn from(value: &str) -> Self {
|
||||
value.to_string().into()
|
||||
}
|
||||
}
|
||||
|
||||
/// Binds a socket to a specific address and optionally a network interface.
|
||||
///
|
||||
/// This function creates a new socket, applies specific configurations (such as
|
||||
/// binding to a device or setting IPv6-only flags), and finalizes it into the
|
||||
/// requested [`Bindable`] type.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `addr` - The `SocketAddr` to bind the socket to.
|
||||
/// * `dev` - The name of the network interface to bind to:
|
||||
/// * **(default) `BindDev::Auto`**: Enables **auto-discovery**. The function will attempt to automatically
|
||||
/// resolve the interface name associated with the provided `addr.ip()`.
|
||||
/// * **empty string or `BindDev::Disabled`**: **Disables** auto-discovery and
|
||||
/// explicitly chooses **not** to bind to any specific device. The routing will be
|
||||
/// left entirely to the OS.
|
||||
/// * **non-empty string or `BindDev::Custom(..)`**: Skips auto-discovery and explicitly binds to
|
||||
/// the specified interface.
|
||||
/// * `net_ns` - An optional network namespace to switch into before creating the socket.
|
||||
/// * `only_v6` - If `true`, sets the `IPV6_V6ONLY` flag on the socket.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns a [`TunnelError`] if socket creation, configuration, or finalization fails.
|
||||
#[builder]
|
||||
pub fn bind<B: Bindable>(
|
||||
addr: SocketAddr,
|
||||
#[builder(default, into)] dev: BindDev,
|
||||
net_ns: Option<NetNS>,
|
||||
#[builder(default)] only_v6: bool,
|
||||
) -> Result<B, TunnelError> {
|
||||
let _g = net_ns.map(|n| n.guard());
|
||||
let dev = match dev {
|
||||
BindDev::Auto => get_interface_name_by_ip(&addr.ip()),
|
||||
BindDev::Disabled => None,
|
||||
BindDev::Custom(s) => Some(s),
|
||||
};
|
||||
let socket = socket2::Socket::new(socket2::Domain::for_address(addr), B::TYPE, B::PROTOCOL)?;
|
||||
setup_socket2_ext(&socket, &addr, dev, only_v6)?;
|
||||
B::finalize(socket)
|
||||
}
|
||||
|
||||
pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) {
|
||||
@@ -448,6 +541,8 @@ pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) {
|
||||
}
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
pub mod tests {
|
||||
use atomic_shim::AtomicU64;
|
||||
use std::{sync::Arc, time::Instant};
|
||||
|
||||
+27
-23
@@ -4,9 +4,10 @@
|
||||
|
||||
use super::{FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener};
|
||||
use crate::common::global_ctx::ArcGlobalCtx;
|
||||
use crate::tunnel::common::bind;
|
||||
use crate::tunnel::{
|
||||
TunnelInfo,
|
||||
common::{FramedReader, FramedWriter, TunnelWrapper, setup_socket2},
|
||||
common::{FramedReader, FramedWriter, TunnelWrapper},
|
||||
};
|
||||
use anyhow::Context;
|
||||
use derivative::Derivative;
|
||||
@@ -19,6 +20,7 @@ use quinn::{
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::sync::OnceLock;
|
||||
use std::{net::SocketAddr, sync::Arc, time::Duration};
|
||||
use tokio::net::UdpSocket;
|
||||
|
||||
// region config
|
||||
pub fn transport_config() -> Arc<TransportConfig> {
|
||||
@@ -179,27 +181,28 @@ pub struct QuicEndpointManager {
|
||||
static QUIC_ENDPOINT_MANAGER: OnceLock<QuicEndpointManager> = OnceLock::new();
|
||||
|
||||
impl QuicEndpointManager {
|
||||
fn try_create(addr: SocketAddr, dual_stack: bool) -> std::io::Result<Endpoint> {
|
||||
let socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
setup_socket2(&socket, &addr, addr.is_ipv6() && !dual_stack)
|
||||
.map_err(std::io::Error::other)?;
|
||||
let socket = std::net::UdpSocket::from(socket);
|
||||
let runtime = default_runtime().ok_or(std::io::Error::other("no async runtime found"))?;
|
||||
fn try_create(addr: SocketAddr, dual_stack: bool) -> Result<Endpoint, TunnelError> {
|
||||
let socket = bind::<UdpSocket>()
|
||||
.addr(addr)
|
||||
.only_v6(addr.is_ipv6() && !dual_stack)
|
||||
.call()?;
|
||||
let runtime = default_runtime().ok_or(TunnelError::InternalError(
|
||||
"no async runtime found".to_owned(),
|
||||
))?;
|
||||
let mut endpoint = Endpoint::new_with_abstract_socket(
|
||||
endpoint_config(),
|
||||
None,
|
||||
runtime.wrap_udp_socket(socket)?,
|
||||
runtime.wrap_udp_socket(socket.into_std()?)?,
|
||||
runtime,
|
||||
)?;
|
||||
endpoint.set_default_client_config(client_config());
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
fn create<F>(&self, mut selector: F) -> std::io::Result<(&RwPool<Endpoint>, Option<Endpoint>)>
|
||||
fn create<F>(
|
||||
&self,
|
||||
mut selector: F,
|
||||
) -> Result<(&RwPool<Endpoint>, Option<Endpoint>), TunnelError>
|
||||
where
|
||||
F: FnMut(&QuicEndpointManager) -> (&RwPool<Endpoint>, Option<(SocketAddr, bool)>),
|
||||
{
|
||||
@@ -210,10 +213,10 @@ impl QuicEndpointManager {
|
||||
};
|
||||
|
||||
let endpoint = Self::try_create(addr, dual_stack);
|
||||
if let Err(e) = endpoint.as_ref()
|
||||
if let Err(error) = endpoint.as_ref()
|
||||
&& dual_stack
|
||||
{
|
||||
tracing::warn!("create dual stack quic endpoint failed: {:?}", e);
|
||||
tracing::warn!(?error, "create dual stack quic endpoint failed");
|
||||
self.both.disable();
|
||||
self.ipv4.enable();
|
||||
self.ipv6.enable();
|
||||
@@ -263,7 +266,7 @@ impl QuicEndpointManager {
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `addr`: listen address
|
||||
fn server(global_ctx: &ArcGlobalCtx, addr: SocketAddr) -> std::io::Result<Endpoint> {
|
||||
fn server(global_ctx: &ArcGlobalCtx, addr: SocketAddr) -> Result<Endpoint, TunnelError> {
|
||||
let mgr = Self::load(global_ctx);
|
||||
|
||||
let (pool, endpoint) = mgr.create(|mgr| {
|
||||
@@ -289,7 +292,7 @@ impl QuicEndpointManager {
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `ip_version`: the IP version of the remote address
|
||||
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> std::io::Result<Endpoint> {
|
||||
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> Result<Endpoint, TunnelError> {
|
||||
let mgr = Self::load(global_ctx);
|
||||
|
||||
let (pool, endpoint) = mgr.create(|mgr| {
|
||||
@@ -318,7 +321,7 @@ impl QuicEndpointManager {
|
||||
async fn connect(
|
||||
global_ctx: &ArcGlobalCtx,
|
||||
addr: SocketAddr,
|
||||
) -> std::io::Result<(Endpoint, Connection)> {
|
||||
) -> Result<(Endpoint, Connection), TunnelError> {
|
||||
let ip_version = if addr.ip().is_ipv4() {
|
||||
IpVersion::V4
|
||||
} else {
|
||||
@@ -327,8 +330,9 @@ impl QuicEndpointManager {
|
||||
let endpoint = Self::client(global_ctx, ip_version)?;
|
||||
let connection = endpoint
|
||||
.connect(addr, "localhost")
|
||||
.map_err(std::io::Error::other)?
|
||||
.await?;
|
||||
.with_context(|| format!("failed to create connection to {}", addr))?
|
||||
.await
|
||||
.with_context(|| format!("failed to connect to {}", addr))?;
|
||||
|
||||
Ok((endpoint, connection))
|
||||
}
|
||||
@@ -585,10 +589,10 @@ mod tests {
|
||||
async fn invalid_peer_addr_impl() {
|
||||
let mut connector =
|
||||
QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap(), global_ctx());
|
||||
let err = connector.connect().await.unwrap_err();
|
||||
let err = format!("{:?}", connector.connect().await.unwrap_err());
|
||||
assert!(
|
||||
err.to_string().contains("invalid remote address"),
|
||||
"unexpected error: {:?}",
|
||||
err.contains("invalid remote address"),
|
||||
"unexpected error: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
+12
-29
@@ -1,7 +1,7 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use super::{FromUrl, TunnelInfo};
|
||||
use crate::tunnel::common::setup_socket2;
|
||||
use crate::tunnel::common::bind;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||
@@ -59,25 +59,15 @@ impl TcpTunnelListener {
|
||||
impl TunnelListener for TcpTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
self.listener = None;
|
||||
|
||||
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
|
||||
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::STREAM,
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
setup_socket2(&socket2_socket, &addr, true)?;
|
||||
let socket = TcpSocket::from_std_stream(socket2_socket.into());
|
||||
|
||||
if let Err(e) = socket.set_nodelay(true) {
|
||||
tracing::warn!(?e, "set_nodelay fail in listen");
|
||||
}
|
||||
let listener = bind::<TcpListener>().addr(addr).only_v6(true).call()?;
|
||||
|
||||
self.addr
|
||||
.set_port(Some(socket.local_addr()?.port()))
|
||||
.set_port(Some(listener.local_addr()?.port()))
|
||||
.unwrap();
|
||||
self.listener = Some(listener);
|
||||
|
||||
self.listener = Some(socket.listen(1024)?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -167,21 +157,14 @@ impl TcpTunnelConnector {
|
||||
let futures = FuturesUnordered::new();
|
||||
|
||||
for bind_addr in self.bind_addrs.iter() {
|
||||
tracing::info!(bind_addr = ?bind_addr, ?addr, "bind addr");
|
||||
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::STREAM,
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
|
||||
if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
|
||||
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
|
||||
continue;
|
||||
tracing::info!(?bind_addr, ?addr, "bind addr");
|
||||
match bind::<TcpSocket>().addr(*bind_addr).only_v6(true).call() {
|
||||
Ok(socket) => futures.push(socket.connect(addr)),
|
||||
Err(error) => {
|
||||
tracing::error!(?bind_addr, ?addr, ?error, "bind addr fail");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let socket = TcpSocket::from_std_stream(socket2_socket.into());
|
||||
futures.push(socket.connect(addr));
|
||||
}
|
||||
|
||||
let ret = wait_for_connect_futures(futures).await;
|
||||
|
||||
+22
-33
@@ -18,16 +18,16 @@ use tokio::{
|
||||
sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender},
|
||||
task::JoinSet,
|
||||
};
|
||||
|
||||
use tracing::{Instrument, instrument};
|
||||
|
||||
use super::{
|
||||
FromUrl, IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelInfo, TunnelListener,
|
||||
TunnelUrl,
|
||||
common::{setup_socket2, setup_socket2_ext, wait_for_connect_futures},
|
||||
common::wait_for_connect_futures,
|
||||
packet_def::{UDP_TUNNEL_HEADER_SIZE, UDPTunnelHeader, V6HolePunchPacket},
|
||||
ring::{RingSink, RingStream},
|
||||
};
|
||||
use crate::tunnel::common::bind;
|
||||
use crate::{
|
||||
common::{join_joinset_background, scoped_task::ScopedTask, shrink_dashmap},
|
||||
tunnel::{
|
||||
@@ -536,21 +536,14 @@ impl UdpTunnelListener {
|
||||
impl TunnelListener for UdpTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
|
||||
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
|
||||
let tunnel_url: TunnelUrl = self.addr.clone().into();
|
||||
if let Some(bind_dev) = tunnel_url.bind_dev() {
|
||||
setup_socket2_ext(&socket2_socket, &addr, Some(bind_dev), true)?;
|
||||
} else {
|
||||
setup_socket2(&socket2_socket, &addr, true)?;
|
||||
}
|
||||
|
||||
self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
|
||||
self.socket = Some(Arc::new(
|
||||
bind()
|
||||
.addr(addr)
|
||||
.only_v6(true)
|
||||
.maybe_dev(tunnel_url.bind_dev())
|
||||
.call()?,
|
||||
));
|
||||
self.data.socket = self.socket.clone();
|
||||
|
||||
self.addr
|
||||
@@ -833,17 +826,14 @@ impl UdpTunnelConnector {
|
||||
let futures = FuturesUnordered::new();
|
||||
|
||||
for bind_addr in self.bind_addrs.iter() {
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(*bind_addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
|
||||
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
|
||||
continue;
|
||||
tracing::info!(?bind_addr, ?addr, "bind addr");
|
||||
match bind().addr(*bind_addr).only_v6(true).call() {
|
||||
Ok(socket) => futures.push(self.try_connect_with_socket(Arc::new(socket), addr)),
|
||||
Err(error) => {
|
||||
tracing::error!(?error, ?bind_addr, ?addr, "bind addr fail");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
let socket = UdpSocket::from_std(socket2_socket.into())?;
|
||||
futures.push(self.try_connect_with_socket(Arc::new(socket), addr));
|
||||
}
|
||||
wait_for_connect_futures(futures).await
|
||||
}
|
||||
@@ -1034,13 +1024,12 @@ mod tests {
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)
|
||||
.unwrap();
|
||||
setup_socket2_ext(&socket2_socket, &addr, bind_dev.clone(), true).unwrap();
|
||||
let _ = bind::<UdpSocket>()
|
||||
.addr(addr)
|
||||
.maybe_dev(bind_dev.clone())
|
||||
.only_v6(true)
|
||||
.call()
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use super::{
|
||||
FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
|
||||
common::{TunnelWrapper, setup_socket2, wait_for_connect_futures},
|
||||
common::{TunnelWrapper, wait_for_connect_futures},
|
||||
insecure_tls::{get_insecure_tls_cert, init_crypto_provider},
|
||||
packet_def::{ZCPacket, ZCPacketType},
|
||||
};
|
||||
use crate::tunnel::common::bind;
|
||||
use crate::{proto::common::TunnelInfo, tunnel::insecure_tls::get_insecure_tls_client_config};
|
||||
use anyhow::Context;
|
||||
use bytes::BytesMut;
|
||||
@@ -160,20 +161,16 @@ impl WsTunnelListener {
|
||||
#[async_trait::async_trait]
|
||||
impl TunnelListener for WsTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
self.listener = None;
|
||||
|
||||
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::STREAM,
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
setup_socket2(&socket2_socket, &addr, true)?;
|
||||
let socket = TcpSocket::from_std_stream(socket2_socket.into());
|
||||
let listener = bind::<TcpListener>().addr(addr).only_v6(true).call()?;
|
||||
|
||||
self.addr
|
||||
.set_port(Some(socket.local_addr()?.port()))
|
||||
.set_port(Some(listener.local_addr()?.port()))
|
||||
.unwrap();
|
||||
self.listener = Some(listener);
|
||||
|
||||
self.listener = Some(socket.listen(1024)?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -283,25 +280,18 @@ impl WsTunnelConnector {
|
||||
let futures = FuturesUnordered::new();
|
||||
|
||||
for bind_addr in self.bind_addrs.iter() {
|
||||
tracing::info!(bind_addr = ?bind_addr, ?addr, "bind addr");
|
||||
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::STREAM,
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
|
||||
if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
|
||||
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
|
||||
continue;
|
||||
tracing::info!(?bind_addr, ?addr, "bind addr");
|
||||
match bind().addr(*bind_addr).only_v6(true).call() {
|
||||
Ok(socket) => futures.push(Self::connect_with(
|
||||
self.addr.clone(),
|
||||
self.ip_version,
|
||||
socket,
|
||||
)),
|
||||
Err(error) => {
|
||||
tracing::error!(?bind_addr, ?addr, ?error, "bind addr fail");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let socket = TcpSocket::from_std_stream(socket2_socket.into());
|
||||
futures.push(Self::connect_with(
|
||||
self.addr.clone(),
|
||||
self.ip_version,
|
||||
socket,
|
||||
))
|
||||
}
|
||||
|
||||
wait_for_connect_futures(futures).await
|
||||
|
||||
@@ -6,6 +6,23 @@ use std::{
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use super::{
|
||||
FromUrl, IpVersion, Tunnel, TunnelError, TunnelInfo, TunnelListener, TunnelUrl, ZCPacketSink,
|
||||
ZCPacketStream,
|
||||
common::wait_for_connect_futures,
|
||||
generate_digest_from_str,
|
||||
packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacketType},
|
||||
ring::create_ring_tunnel_pair,
|
||||
};
|
||||
use crate::tunnel::common::{BindDev, bind};
|
||||
use crate::{
|
||||
common::shrink_dashmap,
|
||||
tunnel::{
|
||||
build_url_from_socket_addr,
|
||||
common::TunnelWrapper,
|
||||
packet_def::{WG_TUNNEL_HEADER_SIZE, ZCPacket},
|
||||
},
|
||||
};
|
||||
use anyhow::Context;
|
||||
use async_recursion::async_recursion;
|
||||
use async_trait::async_trait;
|
||||
@@ -20,23 +37,6 @@ use futures::{SinkExt, StreamExt, stream::FuturesUnordered};
|
||||
use rand::RngCore;
|
||||
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
||||
|
||||
use super::{
|
||||
FromUrl, IpVersion, Tunnel, TunnelError, TunnelInfo, TunnelListener, TunnelUrl, ZCPacketSink,
|
||||
ZCPacketStream,
|
||||
common::{setup_socket2, setup_socket2_ext, wait_for_connect_futures},
|
||||
generate_digest_from_str,
|
||||
packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacketType},
|
||||
ring::create_ring_tunnel_pair,
|
||||
};
|
||||
use crate::{
|
||||
common::shrink_dashmap,
|
||||
tunnel::{
|
||||
build_url_from_socket_addr,
|
||||
common::TunnelWrapper,
|
||||
packet_def::{WG_TUNNEL_HEADER_SIZE, ZCPacket},
|
||||
},
|
||||
};
|
||||
|
||||
const MAX_PACKET: usize = 2048;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -555,20 +555,14 @@ impl WgTunnelListener {
|
||||
impl TunnelListener for WgTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
|
||||
let tunnel_url: TunnelUrl = self.addr.clone().into();
|
||||
if let Some(bind_dev) = tunnel_url.bind_dev() {
|
||||
setup_socket2_ext(&socket2_socket, &addr, Some(bind_dev), true)?;
|
||||
} else {
|
||||
setup_socket2(&socket2_socket, &addr, true)?;
|
||||
}
|
||||
|
||||
self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
|
||||
self.udp = Some(Arc::new(
|
||||
bind()
|
||||
.addr(addr)
|
||||
.only_v6(true)
|
||||
.maybe_dev(tunnel_url.bind_dev())
|
||||
.call()?,
|
||||
));
|
||||
self.addr
|
||||
.set_port(Some(self.udp.as_ref().unwrap().local_addr()?.port()))
|
||||
.unwrap();
|
||||
@@ -695,13 +689,11 @@ impl WgTunnelConnector {
|
||||
}
|
||||
|
||||
async fn connect_with_ipv6(&self, addr: SocketAddr) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
setup_socket2_ext(&socket2_socket, &"[::]:0".parse().unwrap(), None, true)?;
|
||||
let socket = UdpSocket::from_std(socket2_socket.into())?;
|
||||
let socket = bind()
|
||||
.addr("[::]:0".parse().unwrap())
|
||||
.dev(BindDev::Disabled)
|
||||
.only_v6(true)
|
||||
.call()?;
|
||||
Self::connect_with_socket(self.addr.clone(), self.config.clone(), socket, addr).await
|
||||
}
|
||||
}
|
||||
@@ -723,29 +715,19 @@ impl super::TunnelConnector for WgTunnelConnector {
|
||||
};
|
||||
let futures = FuturesUnordered::new();
|
||||
for bind_addr in bind_addrs.into_iter() {
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(bind_addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
if let Err(e) = setup_socket2(&socket2_socket, &bind_addr, true) {
|
||||
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
let socket = match UdpSocket::from_std(socket2_socket.into()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::error!(bind_addr = ?bind_addr, ?addr, "create udp socket fail: {:?}", e);
|
||||
tracing::info!(?bind_addr, ?addr, "bind addr");
|
||||
match bind().addr(bind_addr).only_v6(true).call() {
|
||||
Ok(socket) => futures.push(Self::connect_with_socket(
|
||||
self.addr.clone(),
|
||||
self.config.clone(),
|
||||
socket,
|
||||
addr,
|
||||
)),
|
||||
Err(error) => {
|
||||
tracing::error!(?error, ?bind_addr, ?addr, "bind addr fail");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
tracing::info!(?bind_addr, ?self.addr, "prepare wg connect task");
|
||||
futures.push(Self::connect_with_socket(
|
||||
self.addr.clone(),
|
||||
self.config.clone(),
|
||||
socket,
|
||||
addr,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
wait_for_connect_futures(futures).await
|
||||
|
||||
Reference in New Issue
Block a user