diff --git a/Cargo.lock b/Cargo.lock index 21d4b84..de1e697 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3943,6 +3943,7 @@ dependencies = [ "get_if_addrs", "hickory-resolver", "http-body-util", + "httparse", "hyper", "hyper-util", "ipnet", diff --git a/wstunnel/Cargo.toml b/wstunnel/Cargo.toml index be0b25b..7581808 100644 --- a/wstunnel/Cargo.toml +++ b/wstunnel/Cargo.toml @@ -36,6 +36,7 @@ nix = { version = "0.30.1", features = ["socket", "net", "uio"] } parking_lot = "0.12.4" pin-project = "1" notify = { version = "8.0.0", features = [] } +httparse = { version = "1.10.1", features = [] } rustls-native-certs = { version = "0.8.1", features = [] } rustls-pemfile = { version = "2.2.0", features = [] } diff --git a/wstunnel/src/config.rs b/wstunnel/src/config.rs index 13f41c5..b8f35f1 100644 --- a/wstunnel/src/config.rs +++ b/wstunnel/src/config.rs @@ -539,10 +539,7 @@ mod parsers { let get_proxy_protocol = |options: &BTreeMap| options.contains_key("proxy_protocol"); let Some((proto, tunnel_info)) = arg.split_once("://") else { - return Err(Error::new( - ErrorKind::InvalidInput, - format!("cannot parse protocol from {arg}"), - )); + return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse protocol from {arg}"))); }; match proto { @@ -689,10 +686,7 @@ mod parsers { pub fn parse_sni_override(arg: &str) -> Result, io::Error> { match DnsName::try_from(arg.to_string()) { Ok(val) => Ok(val), - Err(err) => Err(io::Error::new( - ErrorKind::InvalidInput, - format!("Invalid sni override: {err}"), - )), + Err(err) => Err(io::Error::new(ErrorKind::InvalidInput, format!("Invalid sni override: {err}"))), } } diff --git a/wstunnel/src/protocols/http_proxy/server.rs b/wstunnel/src/protocols/http_proxy/server.rs index 9c13ea6..af8280a 100644 --- a/wstunnel/src/protocols/http_proxy/server.rs +++ b/wstunnel/src/protocols/http_proxy/server.rs @@ -3,7 +3,7 @@ use std::future::Future; use bytes::Bytes; use log::{debug, error}; -use std::net::SocketAddr; +use std::net::{Ipv4Addr, SocketAddr}; use std::pin::Pin; use std::sync::Arc; @@ -21,7 +21,7 @@ use tokio::net::{TcpListener, TcpStream}; use tokio::select; use tokio::task::JoinSet; use tracing::log::info; -use url::Host; +use url::{Host, Url}; #[allow(clippy::type_complexity)] pub struct HttpProxyListener { @@ -36,7 +36,7 @@ impl Stream for HttpProxyListener { } } -fn handle_request( +fn handle_http_connect_request( credentials: &Option, dest: &Mutex>, req: Request, @@ -81,9 +81,7 @@ pub async fn run_server( timeout: Option, credentials: Option<(String, String)>, ) -> Result { - info!( - "Starting http proxy server listening cnx on {bind} with credentials {credentials:?}" - ); + info!("Starting http proxy server listening cnx on {bind} with credentials {credentials:?}"); let listener = TcpListener::bind(bind) .await @@ -140,21 +138,47 @@ pub async fn run_server( let handle_new_cnx = { let proxy_cfg = proxy_cfg.clone(); async move { - let http1 = &proxy_cfg.1; - let auth_header = &proxy_cfg.0; - let forward_to = Mutex::new(None); - let conn_fut = http1.serve_connection( - hyper_util::rt::TokioIo::new(&mut stream), - service_fn(|req| handle_request(auth_header, &forward_to, req)), - ); + // We need to know if the http request if a CONNECT method or a regular one. + // HTTP CONNECT requires doing a handshake with client (which is easier) + // While for regular method, we need to replay the request as if it was done by the client. + // Non HTTP CONNECT method only works for non TLS connection/request. + let forward_to = { + let mut buf = [0; 512]; + let buf_size = stream.peek(&mut buf).await.ok()?; + let mut http_parser = httparse::Request::new(&mut []); - match conn_fut.await { - Ok(_) => Some((stream, forward_to.into_inner())), - Err(err) => { - info!("Error while serving connection: {err}"); + let _ = http_parser.parse(&buf[..buf_size]); + if http_parser.method == Some(hyper::Method::CONNECT.as_str()) { None + } else { + let url = Url::parse(http_parser.path.unwrap_or("")).ok()?; + let host = url.host().unwrap_or(Host::Ipv4(Ipv4Addr::UNSPECIFIED)).to_owned(); + let port = url.port_or_known_default().unwrap_or(80); + Some((host, port)) } - } + }; + + // Handle regular http request. Meaning we need to forward it directly as is + return if forward_to.is_some() { + Some((stream, forward_to)) + } else { + // Handle HTTP CONNECT request + let http1 = &proxy_cfg.1; + let auth_header = &proxy_cfg.0; + let forward_to = Mutex::new(None); + let conn_fut = http1.serve_connection( + hyper_util::rt::TokioIo::new(&mut stream), + service_fn(|req| handle_http_connect_request(auth_header, &forward_to, req)), + ); + + match conn_fut.await { + Ok(_) => Some((stream, forward_to.into_inner())), + Err(err) => { + info!("Error while serving connection: {err}"); + None + } + } + }; } }; tasks.spawn(handle_new_cnx); diff --git a/wstunnel/src/protocols/unix_sock/server.rs b/wstunnel/src/protocols/unix_sock/server.rs index 0a0502a..8e186b7 100644 --- a/wstunnel/src/protocols/unix_sock/server.rs +++ b/wstunnel/src/protocols/unix_sock/server.rs @@ -51,8 +51,8 @@ pub async fn run_server(socket_path: &Path) -> Result = port_mapping.split(':').collect(); if port_mapping_parts.len() != 2 { - Err(serde::de::Error::custom(format!( - "Invalid port_mapping entry: {port_mapping}" - ))) + Err(serde::de::Error::custom(format!("Invalid port_mapping entry: {port_mapping}"))) } else { let orig_port = port_mapping_parts[0].parse::().map_err(serde::de::Error::custom)?; let target_port = port_mapping_parts[1].parse::().map_err(serde::de::Error::custom)?;