diff --git a/easytier/src/tunnel/websocket.rs b/easytier/src/tunnel/websocket.rs index 71b9d01e..2b534b45 100644 --- a/easytier/src/tunnel/websocket.rs +++ b/easytier/src/tunnel/websocket.rs @@ -91,7 +91,9 @@ impl WSTunnelListener { } async fn try_accept(&self, stream: TcpStream) -> Result, TunnelError> { - let mut remote_addr = stream.peer_addr()?; + let peer_addr = stream.peer_addr()?; + let mut remote_addr = + super::build_url_from_socket_addr(&peer_addr.to_string(), self.addr.scheme()); let stream = if is_wss(&self.addr)? { init_crypto_provider(); @@ -114,7 +116,7 @@ impl WSTunnelListener { if TRUSTED_PROXIES .iter() - .any(|net| net.contains(remote_addr.ip())) + .any(|net| net.contains(peer_addr.ip())) { if let Some(forwarded) = request .headers() @@ -130,7 +132,12 @@ impl WSTunnelListener { }) { if let Some(ip) = forwarded.remotest_forwarded_for_ip() { - remote_addr = SocketAddr::new(ip, 0); + remote_addr.set_host(Some(&ip.to_string())).map_err(|_| { + TunnelError::InvalidAddr(format!("invalid forwarded ip {}", ip)) + })?; + remote_addr + .query_pairs_mut() + .append_pair("proxy", &peer_addr.to_string()); } } } @@ -140,13 +147,7 @@ impl WSTunnelListener { let info = TunnelInfo { tunnel_type: self.addr.scheme().to_owned(), local_addr: Some(self.local_url().into()), - remote_addr: Some( - super::build_url_from_socket_addr( - &remote_addr.to_string(), - self.addr.scheme().to_string().as_str(), - ) - .into(), - ), + remote_addr: Some(remote_addr.into()), }; Ok(Box::new(TunnelWrapper::new( @@ -403,11 +404,22 @@ pub mod tests { .unwrap(); assert_eq!(remote_addr.host_str().unwrap(), "203.0.113.5"); + let proxy_addr = remote_addr + .query_pairs() + .find(|(k, _)| k == "proxy") + .map(|(_, v)| v.into_owned()) + .unwrap(); + assert_eq!(proxy_addr, "127.0.0.1:25560"); tunnel }); - let mut stream = TcpStream::connect("127.0.0.1:25559").await.unwrap(); + let socket = TcpSocket::new_v4().unwrap(); + socket.bind("127.0.0.1:25560".parse().unwrap()).unwrap(); + let mut stream = socket + .connect("127.0.0.1:25559".parse().unwrap()) + .await + .unwrap(); let handshake = "GET / HTTP/1.1\r\n\ Host: 127.0.0.1:25559\r\n\