diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index 085f7dd4..7b52ad2f 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -42,7 +42,7 @@ type websocketWithEarlyDataConn struct { net.Conn wsWriter N.ExtendedWriter underlay net.Conn - dialed chan bool + dialed chan struct{} cancel context.CancelFunc ctx context.Context config *WebsocketConfig @@ -174,7 +174,7 @@ func (wsc *websocketConn) Close() error { return nil } -func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error { +func (wsedc *websocketWithEarlyDataConn) dial(earlyData []byte) error { base64DataBuf := &bytes.Buffer{} base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf) @@ -193,8 +193,8 @@ func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error { return fmt.Errorf("failed to dial WebSocket: %w", err) } - wsedc.dialed <- true wsedc.wsWriter = N.NewExtendedWriter(wsedc.Conn) + close(wsedc.dialed) if earlyDataBuf.Len() != 0 { _, err = wsedc.Conn.Write(earlyDataBuf.Bytes()) } @@ -203,67 +203,68 @@ func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error { } func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) { - if wsedc.ctx.Err() != nil { + select { + case <-wsedc.ctx.Done(): return 0, io.ErrClosedPipe - } - if wsedc.Conn == nil { - if err := wsedc.Dial(b); err != nil { + case <-wsedc.dialed: + return wsedc.Conn.Write(b) + default: + if err := wsedc.dial(b); err != nil { return 0, err } return len(b), nil } - - return wsedc.Conn.Write(b) } func (wsedc *websocketWithEarlyDataConn) WriteBuffer(buffer *buf.Buffer) error { - if wsedc.ctx.Err() != nil { + select { + case <-wsedc.ctx.Done(): return io.ErrClosedPipe - } - if wsedc.Conn == nil { - if err := wsedc.Dial(buffer.Bytes()); err != nil { + case <-wsedc.dialed: + return wsedc.wsWriter.WriteBuffer(buffer) + default: + if err := wsedc.dial(buffer.Bytes()); err != nil { return err } return nil } - - return wsedc.wsWriter.WriteBuffer(buffer) } func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) { - if wsedc.ctx.Err() != nil { + select { + case <-wsedc.ctx.Done(): return 0, io.ErrClosedPipe + case <-wsedc.dialed: + return wsedc.Conn.Read(b) } - if wsedc.Conn == nil { - select { - case <-wsedc.ctx.Done(): - return 0, io.ErrUnexpectedEOF - case <-wsedc.dialed: - } - } - return wsedc.Conn.Read(b) } func (wsedc *websocketWithEarlyDataConn) Close() error { wsedc.cancel() - if wsedc.Conn == nil { // is dialing or not dialed + select { + case <-wsedc.dialed: + return wsedc.Conn.Close() + default: return wsedc.underlay.Close() } - return wsedc.Conn.Close() } func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr { - if wsedc.Conn == nil { + select { + case <-wsedc.dialed: + return wsedc.Conn.LocalAddr() + default: return wsedc.underlay.LocalAddr() } - return wsedc.Conn.LocalAddr() } func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr { - if wsedc.Conn == nil { + select { + case <-wsedc.dialed: + return wsedc.Conn.RemoteAddr() + default: return wsedc.underlay.RemoteAddr() } - return wsedc.Conn.RemoteAddr() } func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error { @@ -274,17 +275,21 @@ func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error { } func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error { - if wsedc.Conn == nil { + select { + case <-wsedc.dialed: + return wsedc.Conn.SetReadDeadline(t) + default: return nil } - return wsedc.Conn.SetReadDeadline(t) } func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error { - if wsedc.Conn == nil { + select { + case <-wsedc.dialed: + return wsedc.Conn.SetReadDeadline(t) + default: return nil } - return wsedc.Conn.SetWriteDeadline(t) } func (wsedc *websocketWithEarlyDataConn) FrontHeadroom() int { @@ -307,13 +312,18 @@ func (wsedc *websocketWithEarlyDataConn) Upstream() any { //} func (wsedc *websocketWithEarlyDataConn) NeedHandshake() bool { - return wsedc.Conn == nil + select { + case <-wsedc.dialed: + return false + default: + return true + } } func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { ctx, cancel := context.WithCancel(context.Background()) conn = &websocketWithEarlyDataConn{ - dialed: make(chan bool, 1), + dialed: make(chan struct{}), cancel: cancel, ctx: ctx, underlay: conn,