mirror of
https://github.com/MetaCubeX/mihomo.git
synced 2026-04-22 16:17:16 +08:00
chore: handle context cancellation in xhttp client dialing
This commit is contained in:
@@ -316,7 +316,7 @@ func (v *Vless) dialContext(ctx context.Context) (c net.Conn, err error) {
|
||||
case "grpc": // gun transport
|
||||
return v.gunClient.Dial()
|
||||
case "xhttp":
|
||||
return v.xhttpClient.Dial()
|
||||
return v.xhttpClient.Dial(ctx)
|
||||
default:
|
||||
}
|
||||
return v.dialer.DialContext(ctx, "tcp", v.addr)
|
||||
|
||||
+40
-13
@@ -14,6 +14,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/mihomo/common/contextutils"
|
||||
"github.com/metacubex/mihomo/common/httputils"
|
||||
|
||||
"github.com/metacubex/http"
|
||||
@@ -297,14 +298,14 @@ func (c *Client) Close() error {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (c *Client) Dial() (net.Conn, error) {
|
||||
func (c *Client) Dial(ctx context.Context) (net.Conn, error) {
|
||||
switch c.mode {
|
||||
case "stream-one":
|
||||
return c.DialStreamOne()
|
||||
return c.DialStreamOne(ctx)
|
||||
case "stream-up":
|
||||
return c.DialStreamUp()
|
||||
return c.DialStreamUp(ctx)
|
||||
case "packet-up":
|
||||
return c.DialPacketUp()
|
||||
return c.DialPacketUp(ctx)
|
||||
default:
|
||||
return nil, fmt.Errorf("xhttp mode %s is not implemented yet", c.mode)
|
||||
}
|
||||
@@ -324,7 +325,7 @@ func (c *Client) getTransport() (uploadTransport http.RoundTripper, downloadTran
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Client) DialStreamOne() (net.Conn, error) {
|
||||
func (c *Client) DialStreamOne(ctx context.Context) (net.Conn, error) {
|
||||
transport, _, err := c.getTransport()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -346,18 +347,23 @@ func (c *Client) DialStreamOne() (net.Conn, error) {
|
||||
// which can't be sent because we haven't returned the conn yet.
|
||||
gotConn := make(chan bool, 1)
|
||||
|
||||
addrCtx := httputils.NewAddrContext(&conn.NetAddr, c.ctx)
|
||||
ctx := httptrace.WithClientTrace(addrCtx, &httptrace.ClientTrace{
|
||||
reqCtx, reqCancel := context.WithCancel(c.ctx) // reqCtx must alive during conn not closed
|
||||
stop := contextutils.AfterFunc(ctx, reqCancel) // temporarily connect ctx with reqCtx when dialing
|
||||
defer stop() // disconnect ctx with reqCtx after dialing
|
||||
|
||||
addrCtx := httputils.NewAddrContext(&conn.NetAddr, reqCtx)
|
||||
streamCtx := httptrace.WithClientTrace(addrCtx, &httptrace.ClientTrace{
|
||||
GotConn: func(info httptrace.GotConnInfo) {
|
||||
gotConn <- true
|
||||
},
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, c.cfg.GetNormalizedUplinkHTTPMethod(), requestURL.String(), pr)
|
||||
req, err := http.NewRequestWithContext(streamCtx, c.cfg.GetNormalizedUplinkHTTPMethod(), requestURL.String(), pr)
|
||||
if err != nil {
|
||||
_ = pr.Close()
|
||||
_ = pw.Close()
|
||||
httputils.CloseTransport(transport)
|
||||
reqCancel()
|
||||
return nil, err
|
||||
}
|
||||
req.Host = c.cfg.Host
|
||||
@@ -366,6 +372,7 @@ func (c *Client) DialStreamOne() (net.Conn, error) {
|
||||
_ = pr.Close()
|
||||
_ = pw.Close()
|
||||
httputils.CloseTransport(transport)
|
||||
reqCancel()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -391,6 +398,7 @@ func (c *Client) DialStreamOne() (net.Conn, error) {
|
||||
_ = pr.Close()
|
||||
_ = pw.Close()
|
||||
httputils.CloseTransport(transport)
|
||||
reqCancel()
|
||||
var buf [0]byte
|
||||
_, err = wrc.Read(buf[:])
|
||||
return nil, err
|
||||
@@ -400,12 +408,13 @@ func (c *Client) DialStreamOne() (net.Conn, error) {
|
||||
conn.onClose = func() {
|
||||
_ = pr.Close()
|
||||
httputils.CloseTransport(transport)
|
||||
reqCancel()
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Client) DialStreamUp() (net.Conn, error) {
|
||||
func (c *Client) DialStreamUp(ctx context.Context) (net.Conn, error) {
|
||||
uploadTransport, downloadTransport, err := c.getTransport()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -436,7 +445,11 @@ func (c *Client) DialStreamUp() (net.Conn, error) {
|
||||
// Async download: avoid blocking on CDN response header buffering
|
||||
gotConn := make(chan bool, 1)
|
||||
|
||||
addrCtx := httputils.NewAddrContext(&conn.NetAddr, c.ctx)
|
||||
reqCtx, reqCancel := context.WithCancel(c.ctx) // reqCtx must alive during conn not closed
|
||||
stop := contextutils.AfterFunc(ctx, reqCancel) // temporarily connect ctx with reqCtx when dialing
|
||||
defer stop() // disconnect ctx with reqCtx after dialing
|
||||
|
||||
addrCtx := httputils.NewAddrContext(&conn.NetAddr, reqCtx)
|
||||
downloadCtx := httptrace.WithClientTrace(addrCtx, &httptrace.ClientTrace{
|
||||
GotConn: func(info httptrace.GotConnInfo) {
|
||||
gotConn <- true
|
||||
@@ -452,18 +465,20 @@ func (c *Client) DialStreamUp() (net.Conn, error) {
|
||||
if err != nil {
|
||||
httputils.CloseTransport(uploadTransport)
|
||||
httputils.CloseTransport(downloadTransport)
|
||||
reqCancel()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := downloadCfg.FillDownloadRequest(downloadReq, sessionID); err != nil {
|
||||
httputils.CloseTransport(uploadTransport)
|
||||
httputils.CloseTransport(downloadTransport)
|
||||
reqCancel()
|
||||
return nil, err
|
||||
}
|
||||
downloadReq.Host = downloadCfg.Host
|
||||
|
||||
uploadReq, err := http.NewRequestWithContext(
|
||||
c.ctx,
|
||||
reqCtx,
|
||||
c.cfg.GetNormalizedUplinkHTTPMethod(),
|
||||
streamURL.String(),
|
||||
pr,
|
||||
@@ -471,12 +486,14 @@ func (c *Client) DialStreamUp() (net.Conn, error) {
|
||||
if err != nil {
|
||||
httputils.CloseTransport(uploadTransport)
|
||||
httputils.CloseTransport(downloadTransport)
|
||||
reqCancel()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = c.cfg.FillStreamRequest(uploadReq, sessionID); err != nil {
|
||||
httputils.CloseTransport(uploadTransport)
|
||||
httputils.CloseTransport(downloadTransport)
|
||||
reqCancel()
|
||||
return nil, err
|
||||
}
|
||||
uploadReq.Host = c.cfg.Host
|
||||
@@ -503,6 +520,7 @@ func (c *Client) DialStreamUp() (net.Conn, error) {
|
||||
_ = pw.Close()
|
||||
httputils.CloseTransport(uploadTransport)
|
||||
httputils.CloseTransport(downloadTransport)
|
||||
reqCancel()
|
||||
var buf [0]byte
|
||||
_, err = wrc.Read(buf[:])
|
||||
return nil, err
|
||||
@@ -530,12 +548,13 @@ func (c *Client) DialStreamUp() (net.Conn, error) {
|
||||
_ = pr.Close()
|
||||
httputils.CloseTransport(uploadTransport)
|
||||
httputils.CloseTransport(downloadTransport)
|
||||
reqCancel()
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Client) DialPacketUp() (net.Conn, error) {
|
||||
func (c *Client) DialPacketUp(ctx context.Context) (net.Conn, error) {
|
||||
uploadTransport, downloadTransport, err := c.getTransport()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -570,7 +589,11 @@ func (c *Client) DialPacketUp() (net.Conn, error) {
|
||||
// Async download: avoid blocking on CDN response header buffering
|
||||
gotConn := make(chan bool, 1)
|
||||
|
||||
addrCtx := httputils.NewAddrContext(&conn.NetAddr, c.ctx)
|
||||
reqCtx, reqCancel := context.WithCancel(c.ctx) // reqCtx must alive during conn not closed
|
||||
stop := contextutils.AfterFunc(ctx, reqCancel) // temporarily connect ctx with reqCtx when dialing
|
||||
defer stop() // disconnect ctx with reqCtx after dialing
|
||||
|
||||
addrCtx := httputils.NewAddrContext(&conn.NetAddr, reqCtx)
|
||||
downloadCtx := httptrace.WithClientTrace(addrCtx, &httptrace.ClientTrace{
|
||||
GotConn: func(info httptrace.GotConnInfo) {
|
||||
gotConn <- true
|
||||
@@ -586,11 +609,13 @@ func (c *Client) DialPacketUp() (net.Conn, error) {
|
||||
if err != nil {
|
||||
httputils.CloseTransport(uploadTransport)
|
||||
httputils.CloseTransport(downloadTransport)
|
||||
reqCancel()
|
||||
return nil, err
|
||||
}
|
||||
if err = downloadCfg.FillDownloadRequest(downloadReq, sessionID); err != nil {
|
||||
httputils.CloseTransport(uploadTransport)
|
||||
httputils.CloseTransport(downloadTransport)
|
||||
reqCancel()
|
||||
return nil, err
|
||||
}
|
||||
downloadReq.Host = downloadCfg.Host
|
||||
@@ -615,6 +640,7 @@ func (c *Client) DialPacketUp() (net.Conn, error) {
|
||||
if !<-gotConn {
|
||||
httputils.CloseTransport(uploadTransport)
|
||||
httputils.CloseTransport(downloadTransport)
|
||||
reqCancel()
|
||||
var buf [0]byte
|
||||
_, err = wrc.Read(buf[:])
|
||||
return nil, err
|
||||
@@ -624,6 +650,7 @@ func (c *Client) DialPacketUp() (net.Conn, error) {
|
||||
conn.onClose = func() {
|
||||
// uploadTransport already closed by writer
|
||||
httputils.CloseTransport(downloadTransport)
|
||||
reqCancel()
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
|
||||
Reference in New Issue
Block a user