fix: race in websocket with early data

This commit is contained in:
wwqgtxx
2026-04-13 23:19:57 +08:00
parent b00c985728
commit 299fd33795
+46 -36
View File
@@ -42,7 +42,7 @@ type websocketWithEarlyDataConn struct {
net.Conn
wsWriter N.ExtendedWriter
underlay net.Conn
dialed chan bool
dialed chan struct{}
cancel context.CancelFunc
ctx context.Context
config *WebsocketConfig
@@ -174,7 +174,7 @@ func (wsc *websocketConn) Close() error {
return nil
}
func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
func (wsedc *websocketWithEarlyDataConn) dial(earlyData []byte) error {
base64DataBuf := &bytes.Buffer{}
base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf)
@@ -193,8 +193,8 @@ func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
return fmt.Errorf("failed to dial WebSocket: %w", err)
}
wsedc.dialed <- true
wsedc.wsWriter = N.NewExtendedWriter(wsedc.Conn)
close(wsedc.dialed)
if earlyDataBuf.Len() != 0 {
_, err = wsedc.Conn.Write(earlyDataBuf.Bytes())
}
@@ -203,67 +203,68 @@ func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
}
func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) {
if wsedc.ctx.Err() != nil {
select {
case <-wsedc.ctx.Done():
return 0, io.ErrClosedPipe
}
if wsedc.Conn == nil {
if err := wsedc.Dial(b); err != nil {
case <-wsedc.dialed:
return wsedc.Conn.Write(b)
default:
if err := wsedc.dial(b); err != nil {
return 0, err
}
return len(b), nil
}
return wsedc.Conn.Write(b)
}
func (wsedc *websocketWithEarlyDataConn) WriteBuffer(buffer *buf.Buffer) error {
if wsedc.ctx.Err() != nil {
select {
case <-wsedc.ctx.Done():
return io.ErrClosedPipe
}
if wsedc.Conn == nil {
if err := wsedc.Dial(buffer.Bytes()); err != nil {
case <-wsedc.dialed:
return wsedc.wsWriter.WriteBuffer(buffer)
default:
if err := wsedc.dial(buffer.Bytes()); err != nil {
return err
}
return nil
}
return wsedc.wsWriter.WriteBuffer(buffer)
}
func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) {
if wsedc.ctx.Err() != nil {
select {
case <-wsedc.ctx.Done():
return 0, io.ErrClosedPipe
case <-wsedc.dialed:
return wsedc.Conn.Read(b)
}
if wsedc.Conn == nil {
select {
case <-wsedc.ctx.Done():
return 0, io.ErrUnexpectedEOF
case <-wsedc.dialed:
}
}
return wsedc.Conn.Read(b)
}
func (wsedc *websocketWithEarlyDataConn) Close() error {
wsedc.cancel()
if wsedc.Conn == nil { // is dialing or not dialed
select {
case <-wsedc.dialed:
return wsedc.Conn.Close()
default:
return wsedc.underlay.Close()
}
return wsedc.Conn.Close()
}
func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr {
if wsedc.Conn == nil {
select {
case <-wsedc.dialed:
return wsedc.Conn.LocalAddr()
default:
return wsedc.underlay.LocalAddr()
}
return wsedc.Conn.LocalAddr()
}
func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr {
if wsedc.Conn == nil {
select {
case <-wsedc.dialed:
return wsedc.Conn.RemoteAddr()
default:
return wsedc.underlay.RemoteAddr()
}
return wsedc.Conn.RemoteAddr()
}
func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error {
@@ -274,17 +275,21 @@ func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error {
}
func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error {
if wsedc.Conn == nil {
select {
case <-wsedc.dialed:
return wsedc.Conn.SetReadDeadline(t)
default:
return nil
}
return wsedc.Conn.SetReadDeadline(t)
}
func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error {
if wsedc.Conn == nil {
select {
case <-wsedc.dialed:
return wsedc.Conn.SetReadDeadline(t)
default:
return nil
}
return wsedc.Conn.SetWriteDeadline(t)
}
func (wsedc *websocketWithEarlyDataConn) FrontHeadroom() int {
@@ -307,13 +312,18 @@ func (wsedc *websocketWithEarlyDataConn) Upstream() any {
//}
func (wsedc *websocketWithEarlyDataConn) NeedHandshake() bool {
return wsedc.Conn == nil
select {
case <-wsedc.dialed:
return false
default:
return true
}
}
func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
ctx, cancel := context.WithCancel(context.Background())
conn = &websocketWithEarlyDataConn{
dialed: make(chan bool, 1),
dialed: make(chan struct{}),
cancel: cancel,
ctx: ctx,
underlay: conn,