diff --git a/Cargo.lock b/Cargo.lock index de1e697..463aa84 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1100,9 +1100,9 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "h2" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9421a676d1b147b16b82c9225157dc629087ef8ec4d5e2960f9437a90dac0a5" +checksum = "17da50a276f1e01e0ba6c029e47b7100754904ee8a278f886546e98575380785" dependencies = [ "atomic-waker", "bytes", @@ -2519,6 +2519,18 @@ dependencies = [ "serde_json", ] +[[package]] +name = "schemars" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1375ba8ef45a6f15d83fa8748f1079428295d403d6ea991d09ab100155fbc06d" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + [[package]] name = "scoped-tls" version = "1.0.1" @@ -2633,16 +2645,17 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf65a400f8f66fb7b0552869ad70157166676db75ed8181f8104ea91cf9d0b42" +checksum = "f2c45cd61fefa9db6f254525d46e392b852e0e61d9a1fd36e5bd183450a556d5" dependencies = [ "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", "indexmap 2.10.0", - "schemars", + "schemars 0.9.0", + "schemars 1.0.3", "serde", "serde_derive", "serde_json", @@ -2652,9 +2665,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81679d9ed988d5e9a5e6531dc3f2c28efbd639cbd1dfb628df08edea6004da77" +checksum = "de90945e6565ce0d9a25098082ed4ee4002e047cb59892c318d66821e14bb30f" dependencies = [ "darling", "proc-macro2", @@ -3923,7 +3936,7 @@ checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" [[package]] name = "wstunnel" -version = "10.4.3" +version = "10.4.4" dependencies = [ "ahash", "anyhow", @@ -3981,7 +3994,7 @@ dependencies = [ [[package]] name = "wstunnel-cli" -version = "10.4.3" +version = "10.4.4" dependencies = [ "anyhow", "clap", diff --git a/wstunnel/src/protocols/http_proxy/server.rs b/wstunnel/src/protocols/http_proxy/server.rs index af8280a..3105c8f 100644 --- a/wstunnel/src/protocols/http_proxy/server.rs +++ b/wstunnel/src/protocols/http_proxy/server.rs @@ -1,8 +1,7 @@ use anyhow::Context; -use std::future::Future; - use bytes::Bytes; use log::{debug, error}; +use std::future::Future; use std::net::{Ipv4Addr, SocketAddr}; use std::pin::Pin; use std::sync::Arc; @@ -41,14 +40,13 @@ fn handle_http_connect_request( dest: &Mutex>, req: Request, ) -> impl Future>, &'static str>> { - const PROXY_AUTHORIZATION_PREFIX: &str = "Basic "; let ok_response = |forward_to: Option<(Host, u16)>| -> Result>, _> { *dest.lock() = forward_to; Ok(Response::builder().status(200).body(Empty::new()).unwrap()) }; fn err_response() -> Result>, &'static str> { info!("Un-authorized connection to http proxy"); - Err("Un-authorized") + Ok(Response::builder().status(401).body(Empty::new()).unwrap()) } if req.method() != hyper::Method::CONNECT { @@ -60,20 +58,33 @@ fn handle_http_connect_request( .ok() .map(|h| (h, req.uri().port_u16().unwrap_or(443))); - let Some(token) = credentials else { - return future::ready(ok_response(forward_to)); - }; + let header = req + .headers() + .get(hyper::header::PROXY_AUTHORIZATION) + .and_then(|h| h.to_str().ok()); - let Some(auth) = req.headers().get(hyper::header::PROXY_AUTHORIZATION) else { + if !verify_credentials(credentials, &header) { return future::ready(err_response()); - }; - - let auth = auth.to_str().unwrap_or_default().trim(); - if auth.starts_with(PROXY_AUTHORIZATION_PREFIX) && &auth[PROXY_AUTHORIZATION_PREFIX.len()..] == token { - return future::ready(ok_response(forward_to)); } - future::ready(err_response()) + future::ready(ok_response(forward_to)) +} + +fn verify_credentials(credentials: &Option, header_value: &Option<&str>) -> bool { + const PROXY_AUTHORIZATION_PREFIX: &str = "Basic "; + + // no creds set, that's ok + let Some(token) = credentials else { + return true; + }; + + // creds set, and no auth provided, that's forbidden + let Some(header_value) = header_value else { + return false; + }; + + let auth = header_value.trim(); + auth.starts_with(PROXY_AUTHORIZATION_PREFIX) && &auth[PROXY_AUTHORIZATION_PREFIX.len()..] == token } pub async fn run_server( @@ -108,8 +119,8 @@ pub async fn run_server( cnx = tasks.join_next(), if !tasks.is_empty() => { match cnx { Some(Ok(Some((stream, Some(f))))) => (stream, Some(f)), - Some(Ok(Some((_, None)))) =>{ - error!("Error while trying to parse connect request"); + Some(Ok(Some((_, None)))) => { + // Bad request or UnAuthorized request continue }, None | Some(Ok(None)) => continue, @@ -143,18 +154,21 @@ pub async fn run_server( // While for regular method, we need to replay the request as if it was done by the client. // Non HTTP CONNECT method only works for non TLS connection/request. let forward_to = { - let mut buf = [0; 512]; - let buf_size = stream.peek(&mut buf).await.ok()?; - let mut http_parser = httparse::Request::new(&mut []); + // Get a pick at data to analyze http request + let mut request_buf = [0; 512]; + let buf_size = stream.peek(&mut request_buf).await.ok()?; - let _ = http_parser.parse(&buf[..buf_size]); + // Parse http request. If no creds/auth is expected don't bother with headers + let mut headers = { + let headers_len = if proxy_cfg.0.is_some() { 32 } else { 0 }; + vec![httparse::EMPTY_HEADER; headers_len] + }; + let mut http_parser = httparse::Request::new(&mut headers); + let _ = http_parser.parse(&request_buf[..buf_size]); if http_parser.method == Some(hyper::Method::CONNECT.as_str()) { None } else { - let url = Url::parse(http_parser.path.unwrap_or("")).ok()?; - let host = url.host().unwrap_or(Host::Ipv4(Ipv4Addr::UNSPECIFIED)).to_owned(); - let port = url.port_or_known_default().unwrap_or(80); - Some((host, port)) + handle_regular_http_request(&http_parser, &proxy_cfg.0) } }; @@ -190,6 +204,28 @@ pub async fn run_server( }) } +fn handle_regular_http_request(http_parser: &httparse::Request, auth_header: &Option) -> Option<(Host, u16)> { + const DEFAULT_HTTP_PORT: u16 = 80; + + let header = http_parser.headers.iter().find_map(|h| { + if h.name == hyper::header::PROXY_AUTHORIZATION { + Some(String::from_utf8_lossy(h.value)) + } else { + None + } + }); + + if !verify_credentials(auth_header, &header.as_deref()) { + return None; + } + + let url = Url::parse(http_parser.path.unwrap_or("")).ok()?; + let host = url.host().unwrap_or(Host::Ipv4(Ipv4Addr::UNSPECIFIED)).to_owned(); + let port = url.port_or_known_default().unwrap_or(DEFAULT_HTTP_PORT); + + Some((host, port)) +} + //#[cfg(test)] //mod tests { // use super::*;