mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-04-23 00:27:06 +08:00
709 lines
21 KiB
Rust
709 lines
21 KiB
Rust
use std::{
|
|
any::Any,
|
|
net::{IpAddr, SocketAddr},
|
|
pin::Pin,
|
|
sync::{Arc, Mutex},
|
|
task::{Poll, ready},
|
|
};
|
|
|
|
use futures::{Future, Sink, Stream, stream::FuturesUnordered};
|
|
use network_interface::NetworkInterfaceConfig as _;
|
|
use pin_project_lite::pin_project;
|
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
|
|
|
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
|
use tokio_stream::StreamExt;
|
|
use tokio_util::io::poll_write_buf;
|
|
use zerocopy::FromBytes as _;
|
|
|
|
use super::TunnelInfo;
|
|
|
|
use crate::tunnel::packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacket};
|
|
|
|
use super::{
|
|
SinkItem, StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream,
|
|
buf::BufList,
|
|
packet_def::{TCP_TUNNEL_HEADER_SIZE, TCPTunnelHeader, ZCPacketType},
|
|
};
|
|
|
|
pub struct TunnelWrapper<R, W> {
|
|
reader: Arc<Mutex<Option<R>>>,
|
|
writer: Arc<Mutex<Option<W>>>,
|
|
info: Option<TunnelInfo>,
|
|
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
|
}
|
|
|
|
impl<R, W> TunnelWrapper<R, W> {
|
|
pub fn new(reader: R, writer: W, info: Option<TunnelInfo>) -> Self {
|
|
Self::new_with_associate_data(reader, writer, info, None)
|
|
}
|
|
|
|
pub fn new_with_associate_data(
|
|
reader: R,
|
|
writer: W,
|
|
info: Option<TunnelInfo>,
|
|
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
|
) -> Self {
|
|
TunnelWrapper {
|
|
reader: Arc::new(Mutex::new(Some(reader))),
|
|
writer: Arc::new(Mutex::new(Some(writer))),
|
|
info,
|
|
associate_data,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<R, W> Tunnel for TunnelWrapper<R, W>
|
|
where
|
|
R: ZCPacketStream + Send + 'static,
|
|
W: ZCPacketSink + Send + 'static,
|
|
{
|
|
fn split(&self) -> (Pin<Box<dyn ZCPacketStream>>, Pin<Box<dyn ZCPacketSink>>) {
|
|
let reader = self.reader.lock().unwrap().take().unwrap();
|
|
let writer = self.writer.lock().unwrap().take().unwrap();
|
|
(Box::pin(reader), Box::pin(writer))
|
|
}
|
|
|
|
fn info(&self) -> Option<TunnelInfo> {
|
|
self.info.clone()
|
|
}
|
|
}
|
|
|
|
// a length delimited codec for async reader
|
|
pin_project! {
|
|
pub struct FramedReader<R> {
|
|
#[pin]
|
|
reader: R,
|
|
buf: BytesMut,
|
|
max_packet_size: usize,
|
|
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
|
error: Option<TunnelError>,
|
|
}
|
|
}
|
|
|
|
impl<R> FramedReader<R> {
|
|
pub fn new(reader: R, max_packet_size: usize) -> Self {
|
|
Self::new_with_associate_data(reader, max_packet_size, None)
|
|
}
|
|
|
|
pub fn new_with_associate_data(
|
|
reader: R,
|
|
max_packet_size: usize,
|
|
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
|
) -> Self {
|
|
FramedReader {
|
|
reader,
|
|
buf: BytesMut::with_capacity(max_packet_size),
|
|
max_packet_size,
|
|
associate_data,
|
|
error: None,
|
|
}
|
|
}
|
|
|
|
fn extract_one_packet(
|
|
buf: &mut BytesMut,
|
|
max_packet_size: usize,
|
|
) -> Option<Result<ZCPacket, TunnelError>> {
|
|
if buf.len() < TCP_TUNNEL_HEADER_SIZE {
|
|
// header is not complete
|
|
return None;
|
|
}
|
|
|
|
let header = TCPTunnelHeader::ref_from_prefix(&buf[..]).unwrap();
|
|
let body_len = header.len.get() as usize;
|
|
if body_len > max_packet_size {
|
|
// body is too long
|
|
return Some(Err(TunnelError::InvalidPacket("body too long".to_string())));
|
|
}
|
|
|
|
if buf.len() < TCP_TUNNEL_HEADER_SIZE + body_len {
|
|
// body is not complete
|
|
return None;
|
|
}
|
|
|
|
// extract one packet
|
|
let packet_buf = buf.split_to(TCP_TUNNEL_HEADER_SIZE + body_len);
|
|
Some(Ok(ZCPacket::new_from_buf(packet_buf, ZCPacketType::TCP)))
|
|
}
|
|
}
|
|
|
|
impl<R> Stream for FramedReader<R>
|
|
where
|
|
R: AsyncRead + Send + 'static + Unpin,
|
|
{
|
|
type Item = StreamItem;
|
|
|
|
fn poll_next(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<Option<Self::Item>> {
|
|
let mut self_mut = self.project();
|
|
|
|
loop {
|
|
if let Some(e) = self_mut.error.as_ref() {
|
|
tracing::warn!("poll_next on a failed FramedReader, {:?}", e);
|
|
return Poll::Ready(None);
|
|
}
|
|
|
|
if let Some(packet) = Self::extract_one_packet(self_mut.buf, *self_mut.max_packet_size)
|
|
{
|
|
if let Err(TunnelError::InvalidPacket(msg)) = packet.as_ref() {
|
|
self_mut
|
|
.error
|
|
.replace(TunnelError::InvalidPacket(msg.clone()));
|
|
}
|
|
return Poll::Ready(Some(packet));
|
|
}
|
|
|
|
reserve_buf(
|
|
self_mut.buf,
|
|
*self_mut.max_packet_size,
|
|
*self_mut.max_packet_size * 2,
|
|
);
|
|
|
|
let cap = self_mut.buf.capacity() - self_mut.buf.len();
|
|
let buf = self_mut.buf.chunk_mut().as_mut_ptr();
|
|
let buf = unsafe { std::slice::from_raw_parts_mut(buf, cap) };
|
|
let mut buf = ReadBuf::new(buf);
|
|
|
|
let ret = ready!(self_mut.reader.as_mut().poll_read(cx, &mut buf));
|
|
let len = buf.filled().len();
|
|
unsafe { self_mut.buf.advance_mut(len) };
|
|
|
|
match ret {
|
|
Ok(_) => {
|
|
if len == 0 {
|
|
return Poll::Ready(None);
|
|
}
|
|
}
|
|
Err(e) => {
|
|
return Poll::Ready(Some(Err(TunnelError::IOError(e))));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub trait ZCPacketToBytes {
|
|
fn zcpacket_into_bytes(&self, zc_packet: ZCPacket) -> Result<Bytes, TunnelError>;
|
|
}
|
|
|
|
pub struct TcpZCPacketToBytes;
|
|
impl ZCPacketToBytes for TcpZCPacketToBytes {
|
|
fn zcpacket_into_bytes(&self, item: ZCPacket) -> Result<Bytes, TunnelError> {
|
|
let mut item = item.convert_type(ZCPacketType::TCP);
|
|
|
|
let tcp_len = PEER_MANAGER_HEADER_SIZE + item.payload_len();
|
|
let Some(header) = item.mut_tcp_tunnel_header() else {
|
|
return Err(TunnelError::InvalidPacket("packet too short".to_string()));
|
|
};
|
|
header.len.set(tcp_len.try_into().unwrap());
|
|
|
|
Ok(item.into_bytes())
|
|
}
|
|
}
|
|
|
|
pin_project! {
|
|
pub struct FramedWriter<W, C> {
|
|
#[pin]
|
|
writer: W,
|
|
sending_bufs: BufList<Bytes>,
|
|
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
|
|
|
converter: C,
|
|
}
|
|
}
|
|
|
|
impl<W, C> FramedWriter<W, C> {
|
|
fn max_buffer_count(&self) -> usize {
|
|
64
|
|
}
|
|
}
|
|
|
|
impl<W> FramedWriter<W, TcpZCPacketToBytes> {
|
|
pub fn new(writer: W) -> Self {
|
|
Self::new_with_associate_data(writer, None)
|
|
}
|
|
|
|
pub fn new_with_associate_data(
|
|
writer: W,
|
|
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
|
) -> Self {
|
|
FramedWriter {
|
|
writer,
|
|
sending_bufs: BufList::new(),
|
|
associate_data,
|
|
converter: TcpZCPacketToBytes {},
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<W, C: ZCPacketToBytes + Send + 'static> FramedWriter<W, C> {
|
|
pub fn new_with_converter(writer: W, converter: C) -> Self {
|
|
Self::new_with_converter_and_associate_data(writer, converter, None)
|
|
}
|
|
|
|
pub fn new_with_converter_and_associate_data(
|
|
writer: W,
|
|
converter: C,
|
|
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
|
) -> Self {
|
|
FramedWriter {
|
|
writer,
|
|
sending_bufs: BufList::new(),
|
|
associate_data,
|
|
converter,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<W, C> Sink<SinkItem> for FramedWriter<W, C>
|
|
where
|
|
W: AsyncWrite + Send + 'static,
|
|
C: ZCPacketToBytes + Send + 'static,
|
|
{
|
|
type Error = TunnelError;
|
|
|
|
fn poll_ready(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
|
let max_buffer_count = self.max_buffer_count();
|
|
if self.sending_bufs.bufs_cnt() >= max_buffer_count {
|
|
self.as_mut().poll_flush(cx)
|
|
} else {
|
|
tracing::trace!(bufs_cnt = self.sending_bufs.bufs_cnt(), "ready to send");
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
}
|
|
|
|
fn start_send(self: Pin<&mut Self>, item: ZCPacket) -> Result<(), Self::Error> {
|
|
let pinned = self.project();
|
|
pinned
|
|
.sending_bufs
|
|
.push(pinned.converter.zcpacket_into_bytes(item)?);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn poll_flush(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> Poll<Result<(), Self::Error>> {
|
|
let mut pinned = self.project();
|
|
let mut remaining = pinned.sending_bufs.remaining();
|
|
while remaining != 0 {
|
|
let n = ready!(poll_write_buf(
|
|
pinned.writer.as_mut(),
|
|
cx,
|
|
pinned.sending_bufs
|
|
))?;
|
|
if n == 0 {
|
|
return Poll::Ready(Err(TunnelError::IOError(std::io::Error::new(
|
|
std::io::ErrorKind::WriteZero,
|
|
"failed to \
|
|
write frame to transport",
|
|
))));
|
|
}
|
|
remaining -= n;
|
|
}
|
|
|
|
tracing::trace!(?remaining, "flushed");
|
|
|
|
// Try flushing the underlying IO
|
|
ready!(pinned.writer.poll_flush(cx))?;
|
|
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
|
|
fn poll_close(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> Poll<Result<(), Self::Error>> {
|
|
ready!(self.as_mut().poll_flush(cx))?;
|
|
ready!(self.project().writer.poll_shutdown(cx))?;
|
|
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
}
|
|
|
|
pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
|
|
if local_ip.is_unspecified() || local_ip.is_multicast() {
|
|
return None;
|
|
}
|
|
let ifaces = network_interface::NetworkInterface::show().ok()?;
|
|
for iface in ifaces {
|
|
for addr in iface.addr {
|
|
if addr.ip() == *local_ip {
|
|
return Some(iface.name);
|
|
}
|
|
}
|
|
}
|
|
|
|
tracing::error!(?local_ip, "can not find interface name by ip");
|
|
None
|
|
}
|
|
|
|
pub(crate) fn setup_socket2_ext(
|
|
socket2_socket: &socket2::Socket,
|
|
bind_addr: &SocketAddr,
|
|
#[allow(unused_variables)] bind_dev: Option<String>,
|
|
only_v6: bool,
|
|
) -> Result<(), TunnelError> {
|
|
#[cfg(target_os = "windows")]
|
|
{
|
|
let is_udp = matches!(socket2_socket.r#type()?, socket2::Type::DGRAM);
|
|
crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, bind_dev, is_udp)?;
|
|
}
|
|
|
|
if bind_addr.is_ipv6() {
|
|
socket2_socket.set_only_v6(only_v6)?;
|
|
}
|
|
|
|
socket2_socket.set_nonblocking(true)?;
|
|
socket2_socket.set_reuse_address(true)?;
|
|
if let Err(e) = socket2_socket.bind(&socket2::SockAddr::from(*bind_addr)) {
|
|
if bind_addr.is_ipv4() {
|
|
return Err(e.into());
|
|
} else {
|
|
tracing::warn!(?e, "bind failed, do not return error for ipv6");
|
|
}
|
|
}
|
|
|
|
// #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
|
|
// socket2_socket.set_reuse_port(true)?;
|
|
|
|
if bind_addr.ip().is_unspecified() {
|
|
return Ok(());
|
|
}
|
|
|
|
// linux/mac does not use interface of bind_addr to send packet, so we need to bind device
|
|
// win can handle this with bind correctly
|
|
#[cfg(any(target_os = "ios", target_os = "macos"))]
|
|
if let Some(dev_name) = bind_dev {
|
|
// use IP_BOUND_IF to bind device
|
|
unsafe {
|
|
let dev_idx = nix::libc::if_nametoindex(dev_name.as_str().as_ptr() as *const i8);
|
|
tracing::warn!(?dev_idx, ?dev_name, "bind device");
|
|
if bind_addr.is_ipv4() {
|
|
socket2_socket.bind_device_by_index_v4(std::num::NonZeroU32::new(dev_idx))?;
|
|
} else {
|
|
socket2_socket.bind_device_by_index_v6(std::num::NonZeroU32::new(dev_idx))?;
|
|
}
|
|
tracing::warn!(?dev_idx, ?dev_name, "bind device doen");
|
|
}
|
|
}
|
|
|
|
#[cfg(any(
|
|
target_os = "android",
|
|
target_os = "fuchsia",
|
|
target_os = "linux",
|
|
target_env = "ohos"
|
|
))]
|
|
if let Some(dev_name) = bind_dev {
|
|
tracing::trace!(dev_name = ?dev_name, "bind device");
|
|
socket2_socket.bind_device(Some(dev_name.as_bytes()))?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
|
|
mut futures: FuturesUnordered<Fut>,
|
|
) -> Result<Ret, TunnelError>
|
|
where
|
|
Fut: Future<Output = Result<Ret, E>> + Send,
|
|
E: std::error::Error + Into<TunnelError> + Send + 'static,
|
|
{
|
|
// return last error
|
|
let mut last_err = None;
|
|
|
|
while let Some(ret) = futures.next().await {
|
|
if let Err(e) = ret {
|
|
last_err = Some(e.into());
|
|
} else {
|
|
return ret.map_err(|e| e.into());
|
|
}
|
|
}
|
|
|
|
Err(last_err.unwrap_or(TunnelError::Shutdown))
|
|
}
|
|
|
|
pub(crate) fn setup_socket2(
|
|
socket2_socket: &socket2::Socket,
|
|
bind_addr: &SocketAddr,
|
|
only_v6: bool,
|
|
) -> Result<(), TunnelError> {
|
|
setup_socket2_ext(
|
|
socket2_socket,
|
|
bind_addr,
|
|
super::common::get_interface_name_by_ip(&bind_addr.ip()),
|
|
only_v6,
|
|
)
|
|
}
|
|
|
|
pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) {
|
|
if buf.capacity() < min_size {
|
|
buf.reserve(max_size);
|
|
}
|
|
}
|
|
|
|
pub mod tests {
|
|
use atomic_shim::AtomicU64;
|
|
use std::{sync::Arc, time::Instant};
|
|
|
|
use futures::{Future, SinkExt, StreamExt};
|
|
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
|
|
|
|
use crate::{
|
|
common::netns::NetNS,
|
|
tunnel::{TunnelConnector, TunnelListener, packet_def::ZCPacket},
|
|
};
|
|
|
|
pub async fn _tunnel_echo_server(tunnel: Box<dyn super::Tunnel>, once: bool) {
|
|
let (mut recv, mut send) = tunnel.split();
|
|
|
|
if !once {
|
|
while let Some(item) = recv.next().await {
|
|
let Ok(msg) = item else {
|
|
continue;
|
|
};
|
|
tracing::debug!(?msg, "recv a msg, try echo back");
|
|
if send.send(msg).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
} else {
|
|
let Some(ret) = recv.next().await else {
|
|
panic!("recv error");
|
|
};
|
|
|
|
if ret.is_err() {
|
|
tracing::debug!(?ret, "recv error");
|
|
return;
|
|
}
|
|
|
|
let res = ret.unwrap();
|
|
tracing::debug!(?res, "recv a msg, try echo back");
|
|
send.send(res).await.unwrap();
|
|
}
|
|
let _ = send.flush().await;
|
|
let _ = send.close().await;
|
|
|
|
tracing::warn!("echo server exit...");
|
|
}
|
|
|
|
pub(crate) async fn _tunnel_pingpong<L, C>(listener: L, connector: C)
|
|
where
|
|
L: TunnelListener + Send + Sync + 'static,
|
|
C: TunnelConnector + Send + Sync + 'static,
|
|
{
|
|
_tunnel_pingpong_netns_with_timeout(
|
|
listener,
|
|
connector,
|
|
NetNS::new(None),
|
|
NetNS::new(None),
|
|
"12345678abcdefg".as_bytes().to_vec(),
|
|
// only used by tunnel test, so set a long timeout
|
|
tokio::time::Duration::from_secs(5),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
}
|
|
|
|
async fn _tunnel_pingpong_netns<L, C>(
|
|
mut listener: L,
|
|
mut connector: C,
|
|
l_netns: NetNS,
|
|
c_netns: NetNS,
|
|
buf: Vec<u8>,
|
|
) where
|
|
L: TunnelListener + Send + Sync + 'static,
|
|
C: TunnelConnector + Send + Sync + 'static,
|
|
{
|
|
l_netns
|
|
.run_async(|| async {
|
|
listener.listen().await.unwrap();
|
|
})
|
|
.await;
|
|
|
|
let lis = tokio::spawn(async move {
|
|
let ret = listener.accept().await.unwrap();
|
|
println!("accept: {:?}", ret.info());
|
|
assert_eq!(
|
|
url::Url::from(ret.info().unwrap().local_addr.unwrap()),
|
|
listener.local_url()
|
|
);
|
|
_tunnel_echo_server(ret, false).await
|
|
});
|
|
|
|
let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap();
|
|
println!("connect: {:?}", tunnel.info());
|
|
|
|
if connector.remote_url().scheme() == "faketcp" {
|
|
// listener need some time to start capturing packet
|
|
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
|
}
|
|
|
|
assert_eq!(
|
|
url::Url::from(tunnel.info().unwrap().remote_addr.unwrap()),
|
|
connector.remote_url(),
|
|
);
|
|
|
|
let (mut recv, mut send) = tunnel.split();
|
|
|
|
send.send(ZCPacket::new_with_payload(buf.as_slice()))
|
|
.await
|
|
.unwrap();
|
|
|
|
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), recv.next())
|
|
.await
|
|
.unwrap()
|
|
.unwrap()
|
|
.unwrap();
|
|
println!("echo back: {:?}", ret);
|
|
assert_eq!(ret.payload(), Bytes::from(buf));
|
|
|
|
send.close().await.unwrap();
|
|
|
|
if ["udp", "wg"].contains(&connector.remote_url().scheme()) {
|
|
lis.abort();
|
|
} else {
|
|
// lis should finish in 1 second
|
|
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), lis).await;
|
|
assert!(ret.is_ok());
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn _tunnel_pingpong_netns_with_timeout<L, C>(
|
|
listener: L,
|
|
connector: C,
|
|
l_netns: NetNS,
|
|
c_netns: NetNS,
|
|
buf: Vec<u8>,
|
|
timeout: std::time::Duration,
|
|
) -> Result<(), anyhow::Error>
|
|
where
|
|
L: TunnelListener + Send + Sync + 'static,
|
|
C: TunnelConnector + Send + Sync + 'static,
|
|
{
|
|
let handle = tokio::spawn(async move {
|
|
_tunnel_pingpong_netns(listener, connector, l_netns, c_netns, buf).await;
|
|
});
|
|
|
|
match tokio::time::timeout(timeout, handle).await {
|
|
Ok(join_res) => match join_res {
|
|
Ok(_) => Ok(()),
|
|
Err(join_err) => {
|
|
if join_err.is_panic() {
|
|
let payload = join_err.into_panic();
|
|
let msg = match payload.downcast::<String>() {
|
|
Ok(s) => *s,
|
|
Err(payload) => match payload.downcast::<&str>() {
|
|
Ok(s) => (*s).to_string(),
|
|
Err(_) => "non-string panic payload".to_string(),
|
|
},
|
|
};
|
|
Err(anyhow::anyhow!("task panicked: {}", msg))
|
|
} else {
|
|
Err(anyhow::anyhow!("task cancelled"))
|
|
}
|
|
}
|
|
},
|
|
Err(elapsed) => Err(elapsed.into()),
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn _tunnel_bench<L, C>(listener: L, connector: C)
|
|
where
|
|
L: TunnelListener + Send + Sync + 'static,
|
|
C: TunnelConnector + Send + Sync + 'static,
|
|
{
|
|
_tunnel_bench_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await;
|
|
}
|
|
|
|
pub(crate) async fn _tunnel_bench_netns<L, C>(
|
|
mut listener: L,
|
|
mut connector: C,
|
|
netns_l: NetNS,
|
|
netns_c: NetNS,
|
|
) -> usize
|
|
where
|
|
L: TunnelListener + Send + Sync + 'static,
|
|
C: TunnelConnector + Send + Sync + 'static,
|
|
{
|
|
{
|
|
let _g = netns_l.guard();
|
|
listener.listen().await.unwrap();
|
|
}
|
|
|
|
let bps = Arc::new(AtomicU64::new(0));
|
|
let bps_clone = bps.clone();
|
|
|
|
let lis = tokio::spawn(async move {
|
|
let ret = listener.accept().await.unwrap();
|
|
// _tunnel_echo_server(ret, false).await
|
|
let (mut r, _s) = ret.split();
|
|
let now = Instant::now();
|
|
let mut count = 0;
|
|
while let Some(Ok(p)) = r.next().await {
|
|
count += p.payload_len();
|
|
let elapsed_sec = now.elapsed().as_secs();
|
|
if elapsed_sec > 0 {
|
|
bps_clone.store(
|
|
count as u64 / now.elapsed().as_secs(),
|
|
std::sync::atomic::Ordering::Relaxed,
|
|
);
|
|
}
|
|
}
|
|
});
|
|
|
|
let tunnel = {
|
|
let _g = netns_c.guard();
|
|
connector.connect().await.unwrap()
|
|
};
|
|
|
|
let (_recv, mut send) = tunnel.split();
|
|
|
|
// prepare a 4k buffer with random data
|
|
let mut send_buf = BytesMut::new();
|
|
for _ in 0..64 {
|
|
send_buf.put_i128(rand::random::<i128>());
|
|
}
|
|
|
|
let now = Instant::now();
|
|
while now.elapsed().as_secs() < 10 {
|
|
// send.feed(item)
|
|
let item = ZCPacket::new_with_payload(send_buf.as_ref());
|
|
send.feed(item).await.unwrap();
|
|
}
|
|
|
|
send.close().await.unwrap();
|
|
drop(send);
|
|
drop(connector);
|
|
drop(tunnel);
|
|
|
|
tracing::warn!("wait for recv to finish...");
|
|
let bps = bps.load(std::sync::atomic::Ordering::Acquire);
|
|
println!("bps: {}", bps);
|
|
|
|
lis.abort();
|
|
bps as usize
|
|
}
|
|
|
|
pub async fn wait_for_condition<F, FRet>(mut condition: F, timeout: std::time::Duration)
|
|
where
|
|
F: FnMut() -> FRet + Send,
|
|
FRet: Future<Output = bool>,
|
|
{
|
|
let now = std::time::Instant::now();
|
|
while now.elapsed() < timeout {
|
|
if condition().await {
|
|
return;
|
|
}
|
|
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
|
}
|
|
assert!(condition().await, "Timeout")
|
|
}
|
|
}
|