improve executor trait
/ Build - Windows x86 (push) Has been cancelled
/ Build - Windows x86_64 (push) Has been cancelled
/ Build - Linux x86 (push) Has been cancelled
/ Build - Android aarch64 (push) Has been cancelled
/ Build - Linux aarch64 (push) Has been cancelled
/ Build - Linux x86_64 (push) Has been cancelled
/ Build - MacOS aarch64 (push) Has been cancelled
/ Build - MacOS x86_64 (push) Has been cancelled
/ Build - Freebsd x86_64 (push) Has been cancelled
/ Build - Android armv7 (push) Has been cancelled
/ Build - Freebsd x86 (push) Has been cancelled
/ Build - Linux armv7hf (push) Has been cancelled
/ Release (push) Has been cancelled

This commit is contained in:
Σrebe - Romain GERARD
2025-06-01 19:32:38 +02:00
parent 5b42f64601
commit 0daa3fba68
9 changed files with 91 additions and 23 deletions
+72 -4
View File
@@ -1,15 +1,23 @@
use parking_lot::Mutex;
use std::sync::Arc;
use std::sync::{Arc, Weak};
use tokio::runtime::Handle;
use tokio::task::{AbortHandle, JoinSet};
pub trait TokioExecutor: Clone + Send + Sync + 'static {
pub trait TokioExecutorRef: Clone + Send + Sync + 'static {
fn spawn<F>(&self, f: F) -> AbortHandle
where
F: Future + Send + 'static,
F::Output: Send + 'static;
}
pub trait TokioExecutor: TokioExecutorRef {
type Ref: TokioExecutorRef;
fn ref_clone(&self) -> Self::Ref;
}
// ///////////////////////////////
// Default TokioExecutor
// ///////////////////////////////
#[derive(Clone)]
pub struct DefaultTokioExecutor {
handle: Handle,
@@ -19,13 +27,14 @@ impl DefaultTokioExecutor {
Self { handle }
}
}
impl Default for DefaultTokioExecutor {
fn default() -> Self {
Self::new(Handle::current())
}
}
impl TokioExecutor for DefaultTokioExecutor {
impl TokioExecutorRef for DefaultTokioExecutor {
fn spawn<F>(&self, f: F) -> AbortHandle
where
F: Future + Send + 'static,
@@ -35,10 +44,23 @@ impl TokioExecutor for DefaultTokioExecutor {
}
}
impl TokioExecutor for DefaultTokioExecutor {
type Ref = DefaultTokioExecutor;
fn ref_clone(&self) -> DefaultTokioExecutor {
self.clone()
}
}
// ///////////////////////////////
// JoinSetTokioExecutor
// ///////////////////////////////
#[derive(Clone)]
pub struct JoinSetTokioExecutor {
join_set: Arc<Mutex<JoinSet<()>>>,
}
impl JoinSetTokioExecutor {
pub fn new(join_set: JoinSet<()>) -> Self {
Self {
@@ -50,13 +72,18 @@ impl JoinSetTokioExecutor {
}
}
impl Drop for JoinSetTokioExecutor {
fn drop(&mut self) {
self.abort_all();
}
}
impl Default for JoinSetTokioExecutor {
fn default() -> Self {
Self::new(JoinSet::new())
}
}
impl TokioExecutor for JoinSetTokioExecutor {
impl TokioExecutorRef for JoinSetTokioExecutor {
fn spawn<F>(&self, f: F) -> AbortHandle
where
F: Future + Send + 'static,
@@ -67,3 +94,44 @@ impl TokioExecutor for JoinSetTokioExecutor {
})
}
}
impl TokioExecutor for JoinSetTokioExecutor {
type Ref = JoinSetTokioExecutorRef;
fn ref_clone(&self) -> Self::Ref {
JoinSetTokioExecutorRef::new(self)
}
}
#[derive(Clone)]
pub struct JoinSetTokioExecutorRef {
join_set: Weak<Mutex<JoinSet<()>>>,
default_abort_handle: AbortHandle,
}
impl JoinSetTokioExecutorRef {
fn new(exec: &JoinSetTokioExecutor) -> Self {
let default_abort_handle = exec.join_set.lock().spawn(futures_util::future::pending());
let join_set = Arc::downgrade(&exec.join_set);
Self {
join_set,
default_abort_handle,
}
}
}
impl TokioExecutorRef for JoinSetTokioExecutorRef {
fn spawn<F>(&self, f: F) -> AbortHandle
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.join_set
.upgrade()
.map(|l| {
l.lock().spawn(async {
f.await;
})
})
.unwrap_or_else(|| self.default_abort_handle.clone())
}
}
+5 -5
View File
@@ -9,7 +9,7 @@ mod test_integrations;
mod tunnel;
use crate::config::{Client, DEFAULT_CLIENT_UPGRADE_PATH_PREFIX, Server};
use crate::executor::TokioExecutor;
use crate::executor::{TokioExecutor, TokioExecutorRef};
use crate::protocols::dns::DnsResolver;
use crate::protocols::tls;
use crate::restrictions::types::RestrictionsRules;
@@ -39,7 +39,7 @@ use tracing::{error, info};
use url::Url;
pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::Result<()> {
let tunnels = create_client_tunnels(args, executor.clone()).await?;
let tunnels = create_client_tunnels(args, executor.ref_clone()).await?;
// Start all tunnels
let (tx, rx) = oneshot::channel();
@@ -55,7 +55,7 @@ pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::R
async fn create_client_tunnels(
args: Client,
executor: impl TokioExecutor,
executor: impl TokioExecutorRef,
) -> anyhow::Result<Vec<BoxFuture<'static, ()>>> {
let (tls_certificate, tls_key) = if let (Some(cert), Some(key)) =
(args.tls_certificate.as_ref(), args.tls_private_key.as_ref())
@@ -426,7 +426,7 @@ async fn create_client_tunnels(
pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> {
let (tx, rx) = oneshot::channel();
let exec = executor.clone();
let exec = executor.ref_clone();
executor.spawn(async move {
let ret = run_server_impl(args, exec).await;
let _ = tx.send(ret);
@@ -435,7 +435,7 @@ pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::R
rx.await?
}
async fn run_server_impl(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> {
async fn run_server_impl(args: Server, executor: impl TokioExecutorRef) -> anyhow::Result<()> {
let tls_config = if args.remote_addr.scheme() == "wss" {
let tls_certificate = if let Some(cert_path) = &args.tls_certificate {
tls::load_certificates_from_pem(cert_path).expect("Cannot load tls certificate")
+4 -4
View File
@@ -1,4 +1,5 @@
use crate::executor::DefaultTokioExecutor;
use crate::executor::{DefaultTokioExecutor, TokioExecutorRef};
use crate::tunnel;
use crate::tunnel::RemoteAddr;
use crate::tunnel::client::WsClientConfig;
use crate::tunnel::client::cnx_pool::WsConnection;
@@ -7,7 +8,6 @@ use crate::tunnel::listeners::TunnelListener;
use crate::tunnel::tls_reloader::TlsReloader;
use crate::tunnel::transport::io::{TunnelReader, TunnelWriter};
use crate::tunnel::transport::{TransportScheme, jwt_token_to_tunnel};
use crate::{TokioExecutor, tunnel};
use anyhow::Context;
use futures_util::pin_mut;
use hyper::header::COOKIE;
@@ -23,7 +23,7 @@ use url::Host;
use uuid::Uuid;
#[derive(Clone)]
pub struct WsClient<E: TokioExecutor = DefaultTokioExecutor> {
pub struct WsClient<E: TokioExecutorRef = DefaultTokioExecutor> {
pub config: Arc<WsClientConfig>,
pub cnx_pool: bb8::Pool<WsConnection>,
reverse_tunnel_connection_retry_max_backoff: Duration,
@@ -31,7 +31,7 @@ pub struct WsClient<E: TokioExecutor = DefaultTokioExecutor> {
pub(crate) executor: E,
}
impl<E: TokioExecutor> WsClient<E> {
impl<E: TokioExecutorRef> WsClient<E> {
pub async fn new(
config: WsClientConfig,
connection_min_idle: u32,
+2 -2
View File
@@ -1,4 +1,4 @@
use crate::TokioExecutor;
use crate::executor::TokioExecutorRef;
use crate::restrictions::types::RestrictionsRules;
use crate::tunnel::server::WsServer;
use crate::tunnel::server::utils::{HttpResponse, bad_request, inject_cookie};
@@ -18,7 +18,7 @@ use tokio_stream::wrappers::ReceiverStream;
use tracing::{Instrument, Span};
pub(super) async fn http_server_upgrade(
server: WsServer<impl TokioExecutor>,
server: WsServer<impl TokioExecutorRef>,
restrictions: Arc<RestrictionsRules>,
restrict_path_prefix: Option<String>,
client_addr: SocketAddr,
@@ -1,4 +1,4 @@
use crate::TokioExecutor;
use crate::executor::TokioExecutorRef;
use crate::restrictions::types::RestrictionsRules;
use crate::tunnel::server::WsServer;
use crate::tunnel::server::utils::{HttpResponse, bad_request, inject_cookie};
@@ -16,7 +16,7 @@ use tokio::sync::oneshot;
use tracing::{Instrument, Span, error, warn};
pub(super) async fn ws_server_upgrade(
server: WsServer<impl TokioExecutor>,
server: WsServer<impl TokioExecutorRef>,
restrictions: Arc<RestrictionsRules>,
restrict_path_prefix: Option<String>,
client_addr: SocketAddr,
+2 -2
View File
@@ -1,4 +1,4 @@
use crate::TokioExecutor;
use crate::executor::TokioExecutorRef;
use crate::tunnel::RemoteAddr;
use crate::tunnel::listeners::TunnelListener;
use ahash::AHashMap;
@@ -51,7 +51,7 @@ impl<T: TunnelListener> ReverseTunnelServer<T> {
pub async fn run_listening_server(
&self,
executor: &impl TokioExecutor,
executor: &impl TokioExecutorRef,
bind_addr: SocketAddr,
idle_timeout: Duration,
gen_listening_server: impl Future<Output = anyhow::Result<T>>,
+2 -2
View File
@@ -66,12 +66,12 @@ pub struct WsServerConfig {
}
#[derive(Clone)]
pub struct WsServer<E: crate::TokioExecutor = DefaultTokioExecutor> {
pub struct WsServer<E: crate::TokioExecutorRef = DefaultTokioExecutor> {
pub config: Arc<WsServerConfig>,
pub executor: E,
}
impl<E: crate::TokioExecutor> WsServer<E> {
impl<E: crate::TokioExecutorRef> WsServer<E> {
pub fn new(config: WsServerConfig, executor: E) -> Self {
Self {
config: Arc::new(config),
+1 -1
View File
@@ -122,7 +122,7 @@ impl TunnelWrite for Http2TunnelWrite {
pub async fn connect(
request_id: Uuid,
client: &WsClient<impl crate::TokioExecutor>,
client: &WsClient<impl crate::TokioExecutorRef>,
dest_addr: &RemoteAddr,
) -> anyhow::Result<(Http2TunnelRead, Http2TunnelWrite, Parts)> {
let mut pooled_cnx = match client.cnx_pool.get().await {
+1 -1
View File
@@ -236,7 +236,7 @@ impl TunnelRead for WebsocketTunnelRead {
pub async fn connect(
request_id: Uuid,
client: &WsClient<impl crate::TokioExecutor>,
client: &WsClient<impl crate::TokioExecutorRef>,
dest_addr: &RemoteAddr,
) -> anyhow::Result<(WebsocketTunnelRead, WebsocketTunnelWrite, Parts)> {
let client_cfg = &client.config;