mirror of
https://github.com/MetaCubeX/mihomo.git
synced 2026-04-22 16:17:16 +08:00
fix: race in websocket with early data
This commit is contained in:
@@ -42,7 +42,7 @@ type websocketWithEarlyDataConn struct {
|
||||
net.Conn
|
||||
wsWriter N.ExtendedWriter
|
||||
underlay net.Conn
|
||||
dialed chan bool
|
||||
dialed chan struct{}
|
||||
cancel context.CancelFunc
|
||||
ctx context.Context
|
||||
config *WebsocketConfig
|
||||
@@ -174,7 +174,7 @@ func (wsc *websocketConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
|
||||
func (wsedc *websocketWithEarlyDataConn) dial(earlyData []byte) error {
|
||||
base64DataBuf := &bytes.Buffer{}
|
||||
base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf)
|
||||
|
||||
@@ -193,8 +193,8 @@ func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
|
||||
return fmt.Errorf("failed to dial WebSocket: %w", err)
|
||||
}
|
||||
|
||||
wsedc.dialed <- true
|
||||
wsedc.wsWriter = N.NewExtendedWriter(wsedc.Conn)
|
||||
close(wsedc.dialed)
|
||||
if earlyDataBuf.Len() != 0 {
|
||||
_, err = wsedc.Conn.Write(earlyDataBuf.Bytes())
|
||||
}
|
||||
@@ -203,67 +203,68 @@ func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
|
||||
}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) {
|
||||
if wsedc.ctx.Err() != nil {
|
||||
select {
|
||||
case <-wsedc.ctx.Done():
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
if wsedc.Conn == nil {
|
||||
if err := wsedc.Dial(b); err != nil {
|
||||
case <-wsedc.dialed:
|
||||
return wsedc.Conn.Write(b)
|
||||
default:
|
||||
if err := wsedc.dial(b); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
return wsedc.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
if wsedc.ctx.Err() != nil {
|
||||
select {
|
||||
case <-wsedc.ctx.Done():
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
if wsedc.Conn == nil {
|
||||
if err := wsedc.Dial(buffer.Bytes()); err != nil {
|
||||
case <-wsedc.dialed:
|
||||
return wsedc.wsWriter.WriteBuffer(buffer)
|
||||
default:
|
||||
if err := wsedc.dial(buffer.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return wsedc.wsWriter.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) {
|
||||
if wsedc.ctx.Err() != nil {
|
||||
select {
|
||||
case <-wsedc.ctx.Done():
|
||||
return 0, io.ErrClosedPipe
|
||||
case <-wsedc.dialed:
|
||||
return wsedc.Conn.Read(b)
|
||||
}
|
||||
if wsedc.Conn == nil {
|
||||
select {
|
||||
case <-wsedc.ctx.Done():
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
case <-wsedc.dialed:
|
||||
}
|
||||
}
|
||||
return wsedc.Conn.Read(b)
|
||||
}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) Close() error {
|
||||
wsedc.cancel()
|
||||
if wsedc.Conn == nil { // is dialing or not dialed
|
||||
select {
|
||||
case <-wsedc.dialed:
|
||||
return wsedc.Conn.Close()
|
||||
default:
|
||||
return wsedc.underlay.Close()
|
||||
}
|
||||
return wsedc.Conn.Close()
|
||||
}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr {
|
||||
if wsedc.Conn == nil {
|
||||
select {
|
||||
case <-wsedc.dialed:
|
||||
return wsedc.Conn.LocalAddr()
|
||||
default:
|
||||
return wsedc.underlay.LocalAddr()
|
||||
}
|
||||
return wsedc.Conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr {
|
||||
if wsedc.Conn == nil {
|
||||
select {
|
||||
case <-wsedc.dialed:
|
||||
return wsedc.Conn.RemoteAddr()
|
||||
default:
|
||||
return wsedc.underlay.RemoteAddr()
|
||||
}
|
||||
return wsedc.Conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error {
|
||||
@@ -274,17 +275,21 @@ func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error {
|
||||
}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error {
|
||||
if wsedc.Conn == nil {
|
||||
select {
|
||||
case <-wsedc.dialed:
|
||||
return wsedc.Conn.SetReadDeadline(t)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return wsedc.Conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error {
|
||||
if wsedc.Conn == nil {
|
||||
select {
|
||||
case <-wsedc.dialed:
|
||||
return wsedc.Conn.SetReadDeadline(t)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return wsedc.Conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) FrontHeadroom() int {
|
||||
@@ -307,13 +312,18 @@ func (wsedc *websocketWithEarlyDataConn) Upstream() any {
|
||||
//}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) NeedHandshake() bool {
|
||||
return wsedc.Conn == nil
|
||||
select {
|
||||
case <-wsedc.dialed:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
conn = &websocketWithEarlyDataConn{
|
||||
dialed: make(chan bool, 1),
|
||||
dialed: make(chan struct{}),
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
underlay: conn,
|
||||
|
||||
Reference in New Issue
Block a user