diff --git a/adapter/outbound/masque.go b/adapter/outbound/masque.go index 79124cd8..5f6d8b56 100644 --- a/adapter/outbound/masque.go +++ b/adapter/outbound/masque.go @@ -17,7 +17,6 @@ import ( "github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/common/contextutils" - "github.com/metacubex/mihomo/common/httputils" "github.com/metacubex/mihomo/common/pool" "github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/resolver" @@ -171,6 +170,7 @@ func NewMasque(option MasqueOption) (*Masque, error) { } return tls.Client(c, tlsConfig), nil }, + ReadIdleTimeout: 30 * time.Second, } } @@ -248,11 +248,11 @@ func (w *Masque) run(ctx context.Context) error { } var pc net.PacketConn - var tr io.Closer + var closer io.Closer var ipConn masque.IpConn var err error if w.h2Transport != nil { - ipConn, err = masque.ConnectTunnelH2(ctx, &http.Client{Transport: w.h2Transport}, w.uri) + closer, ipConn, err = masque.ConnectTunnelH2(ctx, w.h2Transport, w.uri) if err != nil { return err } @@ -276,7 +276,7 @@ func (w *Masque) run(ctx context.Context) error { common.SetCongestionController(quicConn, w.option.CongestionController, w.option.CWND) - tr, ipConn, err = masque.ConnectTunnel(ctx, quicConn, w.uri) + closer, ipConn, err = masque.ConnectTunnel(ctx, quicConn, w.uri) if err != nil { _ = pc.Close() return err @@ -289,9 +289,7 @@ func (w *Masque) run(ctx context.Context) error { contextutils.AfterFunc(runCtx, func() { w.running.Store(false) _ = ipConn.Close() - if tr != nil { - _ = tr.Close() - } + _ = closer.Close() if pc != nil { _ = pc.Close() } @@ -357,9 +355,6 @@ func (w *Masque) Close() error { if w.tunDevice != nil { w.tunDevice.Close() } - if w.h2Transport != nil { - httputils.CloseTransport(w.h2Transport) - } return nil } diff --git a/transport/masque/client_h2.go b/transport/masque/client_h2.go index f0e9d1f2..150cce84 100644 --- a/transport/masque/client_h2.go +++ b/transport/masque/client_h2.go @@ -28,7 +28,7 @@ const ( ipv6HeaderLen = 40 ) -func ConnectTunnelH2(ctx context.Context, h2Client *http.Client, connectUri string) (IpConn, error) { +func ConnectTunnelH2(ctx context.Context, h2Transport *http.Http2Transport, connectUri string) (*http.Http2ClientConn, IpConn, error) { additionalHeaders := http.Header{ "User-Agent": []string{""}, } @@ -39,26 +39,43 @@ func ConnectTunnelH2(ctx context.Context, h2Client *http.Client, connectUri stri // TODO: support PQC h2Headers.Set("pq-enabled", "false") - ipConn, rsp, err := dialH2(ctx, h2Client, template, h2Headers) + conn, err := h2Transport.DialTLSContext(ctx, "tcp", ":0", nil) if err != nil { + return nil, nil, fmt.Errorf("connect-ip: failed to dial: %w", err) + } + + cc, err := h2Transport.NewClientConn(conn) + if err != nil { + return nil, nil, fmt.Errorf("connect-ip: failed to create client connection: %w", err) + } + + if !cc.ReserveNewRequest() { + _ = cc.Close() + return nil, nil, fmt.Errorf("connect-ip: failed to reserve client connection: %w", err) + } + + ipConn, rsp, err := dialH2(ctx, cc, template, h2Headers) + if err != nil { + _ = cc.Close() if strings.Contains(err.Error(), "tls: access denied") { - return nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service") + return nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service") } - return nil, fmt.Errorf("failed to dial connect-ip over HTTP/2: %w", err) + return nil, nil, fmt.Errorf("failed to dial connect-ip over HTTP/2: %w", err) } if rsp.StatusCode != http.StatusOK { _ = ipConn.Close() - return nil, fmt.Errorf("failed to dial connect-ip: %v", rsp.Status) + _ = cc.Close() + return nil, nil, fmt.Errorf("failed to dial connect-ip: %v", rsp.Status) } - return ipConn, nil + return cc, ipConn, nil } // dialH2 dials a proxied connection over HTTP/2 CONNECT-IP. // // This transport carries proxied packets inside HTTP capsule DATAGRAM frames. -func dialH2(ctx context.Context, client *http.Client, template *uritemplate.Template, additionalHeaders http.Header) (*h2IpConn, *http.Response, error) { +func dialH2(ctx context.Context, rt http.RoundTripper, template *uritemplate.Template, additionalHeaders http.Header) (*h2IpConn, *http.Response, error) { if len(template.Varnames()) > 0 { return nil, nil, errors.New("connect-ip: IP flow forwarding not supported") } @@ -86,7 +103,7 @@ func dialH2(ctx context.Context, client *http.Client, template *uritemplate.Temp } stop := contextutils.AfterFunc(ctx, cancel) // temporarily connect ctx with reqCtx when client.Do - rsp, err := client.Do(req) + rsp, err := rt.RoundTrip(req) stop() // disconnect ctx with reqCtx after client.Do if err != nil { cancel()