mirror of
https://github.com/erebe/wstunnel.git
synced 2026-04-22 16:27:13 +08:00
improve executor trait
/ Build - Windows x86 (push) Has been cancelled
/ Build - Windows x86_64 (push) Has been cancelled
/ Build - Linux x86 (push) Has been cancelled
/ Build - Android aarch64 (push) Has been cancelled
/ Build - Linux aarch64 (push) Has been cancelled
/ Build - Linux x86_64 (push) Has been cancelled
/ Build - MacOS aarch64 (push) Has been cancelled
/ Build - MacOS x86_64 (push) Has been cancelled
/ Build - Freebsd x86_64 (push) Has been cancelled
/ Build - Android armv7 (push) Has been cancelled
/ Build - Freebsd x86 (push) Has been cancelled
/ Build - Linux armv7hf (push) Has been cancelled
/ Release (push) Has been cancelled
/ Build - Windows x86 (push) Has been cancelled
/ Build - Windows x86_64 (push) Has been cancelled
/ Build - Linux x86 (push) Has been cancelled
/ Build - Android aarch64 (push) Has been cancelled
/ Build - Linux aarch64 (push) Has been cancelled
/ Build - Linux x86_64 (push) Has been cancelled
/ Build - MacOS aarch64 (push) Has been cancelled
/ Build - MacOS x86_64 (push) Has been cancelled
/ Build - Freebsd x86_64 (push) Has been cancelled
/ Build - Android armv7 (push) Has been cancelled
/ Build - Freebsd x86 (push) Has been cancelled
/ Build - Linux armv7hf (push) Has been cancelled
/ Release (push) Has been cancelled
This commit is contained in:
@@ -1,15 +1,23 @@
|
||||
use parking_lot::Mutex;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Weak};
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::task::{AbortHandle, JoinSet};
|
||||
|
||||
pub trait TokioExecutor: Clone + Send + Sync + 'static {
|
||||
pub trait TokioExecutorRef: Clone + Send + Sync + 'static {
|
||||
fn spawn<F>(&self, f: F) -> AbortHandle
|
||||
where
|
||||
F: Future + Send + 'static,
|
||||
F::Output: Send + 'static;
|
||||
}
|
||||
|
||||
pub trait TokioExecutor: TokioExecutorRef {
|
||||
type Ref: TokioExecutorRef;
|
||||
fn ref_clone(&self) -> Self::Ref;
|
||||
}
|
||||
|
||||
// ///////////////////////////////
|
||||
// Default TokioExecutor
|
||||
// ///////////////////////////////
|
||||
#[derive(Clone)]
|
||||
pub struct DefaultTokioExecutor {
|
||||
handle: Handle,
|
||||
@@ -19,13 +27,14 @@ impl DefaultTokioExecutor {
|
||||
Self { handle }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DefaultTokioExecutor {
|
||||
fn default() -> Self {
|
||||
Self::new(Handle::current())
|
||||
}
|
||||
}
|
||||
|
||||
impl TokioExecutor for DefaultTokioExecutor {
|
||||
impl TokioExecutorRef for DefaultTokioExecutor {
|
||||
fn spawn<F>(&self, f: F) -> AbortHandle
|
||||
where
|
||||
F: Future + Send + 'static,
|
||||
@@ -35,10 +44,23 @@ impl TokioExecutor for DefaultTokioExecutor {
|
||||
}
|
||||
}
|
||||
|
||||
impl TokioExecutor for DefaultTokioExecutor {
|
||||
type Ref = DefaultTokioExecutor;
|
||||
|
||||
fn ref_clone(&self) -> DefaultTokioExecutor {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
// ///////////////////////////////
|
||||
// JoinSetTokioExecutor
|
||||
// ///////////////////////////////
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct JoinSetTokioExecutor {
|
||||
join_set: Arc<Mutex<JoinSet<()>>>,
|
||||
}
|
||||
|
||||
impl JoinSetTokioExecutor {
|
||||
pub fn new(join_set: JoinSet<()>) -> Self {
|
||||
Self {
|
||||
@@ -50,13 +72,18 @@ impl JoinSetTokioExecutor {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for JoinSetTokioExecutor {
|
||||
fn drop(&mut self) {
|
||||
self.abort_all();
|
||||
}
|
||||
}
|
||||
impl Default for JoinSetTokioExecutor {
|
||||
fn default() -> Self {
|
||||
Self::new(JoinSet::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl TokioExecutor for JoinSetTokioExecutor {
|
||||
impl TokioExecutorRef for JoinSetTokioExecutor {
|
||||
fn spawn<F>(&self, f: F) -> AbortHandle
|
||||
where
|
||||
F: Future + Send + 'static,
|
||||
@@ -67,3 +94,44 @@ impl TokioExecutor for JoinSetTokioExecutor {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TokioExecutor for JoinSetTokioExecutor {
|
||||
type Ref = JoinSetTokioExecutorRef;
|
||||
|
||||
fn ref_clone(&self) -> Self::Ref {
|
||||
JoinSetTokioExecutorRef::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct JoinSetTokioExecutorRef {
|
||||
join_set: Weak<Mutex<JoinSet<()>>>,
|
||||
default_abort_handle: AbortHandle,
|
||||
}
|
||||
impl JoinSetTokioExecutorRef {
|
||||
fn new(exec: &JoinSetTokioExecutor) -> Self {
|
||||
let default_abort_handle = exec.join_set.lock().spawn(futures_util::future::pending());
|
||||
let join_set = Arc::downgrade(&exec.join_set);
|
||||
Self {
|
||||
join_set,
|
||||
default_abort_handle,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TokioExecutorRef for JoinSetTokioExecutorRef {
|
||||
fn spawn<F>(&self, f: F) -> AbortHandle
|
||||
where
|
||||
F: Future + Send + 'static,
|
||||
F::Output: Send + 'static,
|
||||
{
|
||||
self.join_set
|
||||
.upgrade()
|
||||
.map(|l| {
|
||||
l.lock().spawn(async {
|
||||
f.await;
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| self.default_abort_handle.clone())
|
||||
}
|
||||
}
|
||||
|
||||
+5
-5
@@ -9,7 +9,7 @@ mod test_integrations;
|
||||
mod tunnel;
|
||||
|
||||
use crate::config::{Client, DEFAULT_CLIENT_UPGRADE_PATH_PREFIX, Server};
|
||||
use crate::executor::TokioExecutor;
|
||||
use crate::executor::{TokioExecutor, TokioExecutorRef};
|
||||
use crate::protocols::dns::DnsResolver;
|
||||
use crate::protocols::tls;
|
||||
use crate::restrictions::types::RestrictionsRules;
|
||||
@@ -39,7 +39,7 @@ use tracing::{error, info};
|
||||
use url::Url;
|
||||
|
||||
pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::Result<()> {
|
||||
let tunnels = create_client_tunnels(args, executor.clone()).await?;
|
||||
let tunnels = create_client_tunnels(args, executor.ref_clone()).await?;
|
||||
|
||||
// Start all tunnels
|
||||
let (tx, rx) = oneshot::channel();
|
||||
@@ -55,7 +55,7 @@ pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::R
|
||||
|
||||
async fn create_client_tunnels(
|
||||
args: Client,
|
||||
executor: impl TokioExecutor,
|
||||
executor: impl TokioExecutorRef,
|
||||
) -> anyhow::Result<Vec<BoxFuture<'static, ()>>> {
|
||||
let (tls_certificate, tls_key) = if let (Some(cert), Some(key)) =
|
||||
(args.tls_certificate.as_ref(), args.tls_private_key.as_ref())
|
||||
@@ -426,7 +426,7 @@ async fn create_client_tunnels(
|
||||
|
||||
pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let exec = executor.clone();
|
||||
let exec = executor.ref_clone();
|
||||
executor.spawn(async move {
|
||||
let ret = run_server_impl(args, exec).await;
|
||||
let _ = tx.send(ret);
|
||||
@@ -435,7 +435,7 @@ pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::R
|
||||
rx.await?
|
||||
}
|
||||
|
||||
async fn run_server_impl(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> {
|
||||
async fn run_server_impl(args: Server, executor: impl TokioExecutorRef) -> anyhow::Result<()> {
|
||||
let tls_config = if args.remote_addr.scheme() == "wss" {
|
||||
let tls_certificate = if let Some(cert_path) = &args.tls_certificate {
|
||||
tls::load_certificates_from_pem(cert_path).expect("Cannot load tls certificate")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::executor::DefaultTokioExecutor;
|
||||
use crate::executor::{DefaultTokioExecutor, TokioExecutorRef};
|
||||
use crate::tunnel;
|
||||
use crate::tunnel::RemoteAddr;
|
||||
use crate::tunnel::client::WsClientConfig;
|
||||
use crate::tunnel::client::cnx_pool::WsConnection;
|
||||
@@ -7,7 +8,6 @@ use crate::tunnel::listeners::TunnelListener;
|
||||
use crate::tunnel::tls_reloader::TlsReloader;
|
||||
use crate::tunnel::transport::io::{TunnelReader, TunnelWriter};
|
||||
use crate::tunnel::transport::{TransportScheme, jwt_token_to_tunnel};
|
||||
use crate::{TokioExecutor, tunnel};
|
||||
use anyhow::Context;
|
||||
use futures_util::pin_mut;
|
||||
use hyper::header::COOKIE;
|
||||
@@ -23,7 +23,7 @@ use url::Host;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WsClient<E: TokioExecutor = DefaultTokioExecutor> {
|
||||
pub struct WsClient<E: TokioExecutorRef = DefaultTokioExecutor> {
|
||||
pub config: Arc<WsClientConfig>,
|
||||
pub cnx_pool: bb8::Pool<WsConnection>,
|
||||
reverse_tunnel_connection_retry_max_backoff: Duration,
|
||||
@@ -31,7 +31,7 @@ pub struct WsClient<E: TokioExecutor = DefaultTokioExecutor> {
|
||||
pub(crate) executor: E,
|
||||
}
|
||||
|
||||
impl<E: TokioExecutor> WsClient<E> {
|
||||
impl<E: TokioExecutorRef> WsClient<E> {
|
||||
pub async fn new(
|
||||
config: WsClientConfig,
|
||||
connection_min_idle: u32,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::TokioExecutor;
|
||||
use crate::executor::TokioExecutorRef;
|
||||
use crate::restrictions::types::RestrictionsRules;
|
||||
use crate::tunnel::server::WsServer;
|
||||
use crate::tunnel::server::utils::{HttpResponse, bad_request, inject_cookie};
|
||||
@@ -18,7 +18,7 @@ use tokio_stream::wrappers::ReceiverStream;
|
||||
use tracing::{Instrument, Span};
|
||||
|
||||
pub(super) async fn http_server_upgrade(
|
||||
server: WsServer<impl TokioExecutor>,
|
||||
server: WsServer<impl TokioExecutorRef>,
|
||||
restrictions: Arc<RestrictionsRules>,
|
||||
restrict_path_prefix: Option<String>,
|
||||
client_addr: SocketAddr,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::TokioExecutor;
|
||||
use crate::executor::TokioExecutorRef;
|
||||
use crate::restrictions::types::RestrictionsRules;
|
||||
use crate::tunnel::server::WsServer;
|
||||
use crate::tunnel::server::utils::{HttpResponse, bad_request, inject_cookie};
|
||||
@@ -16,7 +16,7 @@ use tokio::sync::oneshot;
|
||||
use tracing::{Instrument, Span, error, warn};
|
||||
|
||||
pub(super) async fn ws_server_upgrade(
|
||||
server: WsServer<impl TokioExecutor>,
|
||||
server: WsServer<impl TokioExecutorRef>,
|
||||
restrictions: Arc<RestrictionsRules>,
|
||||
restrict_path_prefix: Option<String>,
|
||||
client_addr: SocketAddr,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::TokioExecutor;
|
||||
use crate::executor::TokioExecutorRef;
|
||||
use crate::tunnel::RemoteAddr;
|
||||
use crate::tunnel::listeners::TunnelListener;
|
||||
use ahash::AHashMap;
|
||||
@@ -51,7 +51,7 @@ impl<T: TunnelListener> ReverseTunnelServer<T> {
|
||||
|
||||
pub async fn run_listening_server(
|
||||
&self,
|
||||
executor: &impl TokioExecutor,
|
||||
executor: &impl TokioExecutorRef,
|
||||
bind_addr: SocketAddr,
|
||||
idle_timeout: Duration,
|
||||
gen_listening_server: impl Future<Output = anyhow::Result<T>>,
|
||||
|
||||
@@ -66,12 +66,12 @@ pub struct WsServerConfig {
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WsServer<E: crate::TokioExecutor = DefaultTokioExecutor> {
|
||||
pub struct WsServer<E: crate::TokioExecutorRef = DefaultTokioExecutor> {
|
||||
pub config: Arc<WsServerConfig>,
|
||||
pub executor: E,
|
||||
}
|
||||
|
||||
impl<E: crate::TokioExecutor> WsServer<E> {
|
||||
impl<E: crate::TokioExecutorRef> WsServer<E> {
|
||||
pub fn new(config: WsServerConfig, executor: E) -> Self {
|
||||
Self {
|
||||
config: Arc::new(config),
|
||||
|
||||
@@ -122,7 +122,7 @@ impl TunnelWrite for Http2TunnelWrite {
|
||||
|
||||
pub async fn connect(
|
||||
request_id: Uuid,
|
||||
client: &WsClient<impl crate::TokioExecutor>,
|
||||
client: &WsClient<impl crate::TokioExecutorRef>,
|
||||
dest_addr: &RemoteAddr,
|
||||
) -> anyhow::Result<(Http2TunnelRead, Http2TunnelWrite, Parts)> {
|
||||
let mut pooled_cnx = match client.cnx_pool.get().await {
|
||||
|
||||
@@ -236,7 +236,7 @@ impl TunnelRead for WebsocketTunnelRead {
|
||||
|
||||
pub async fn connect(
|
||||
request_id: Uuid,
|
||||
client: &WsClient<impl crate::TokioExecutor>,
|
||||
client: &WsClient<impl crate::TokioExecutorRef>,
|
||||
dest_addr: &RemoteAddr,
|
||||
) -> anyhow::Result<(WebsocketTunnelRead, WebsocketTunnelWrite, Parts)> {
|
||||
let client_cfg = &client.config;
|
||||
|
||||
Reference in New Issue
Block a user