Files
Archive/mihomo/transport/xhttp/client.go
T
2026-04-17 21:15:10 +02:00

699 lines
17 KiB
Go

package xhttp
import (
"bytes"
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"net/url"
"strconv"
"sync"
"time"
"github.com/metacubex/mihomo/common/httputils"
"github.com/metacubex/http"
"github.com/metacubex/http/httptrace"
"github.com/metacubex/quic-go"
"github.com/metacubex/quic-go/http3"
"github.com/metacubex/tls"
"golang.org/x/sync/semaphore"
)
// ConnIdleTimeout defines the maximum time an idle TCP session can survive in the tunnel,
// so it should be consistent across HTTP versions and with other transports.
const ConnIdleTimeout = 300 * time.Second
// QuicgoH3KeepAlivePeriod consistent with quic-go
const QuicgoH3KeepAlivePeriod = 10 * time.Second
// ChromeH2KeepAlivePeriod consistent with chrome
const ChromeH2KeepAlivePeriod = 45 * time.Second
type DialRawFunc func(ctx context.Context) (net.Conn, error)
type WrapTLSFunc func(ctx context.Context, conn net.Conn, isH2 bool) (net.Conn, error)
type DialQUICFunc func(ctx context.Context, cfg *quic.Config) (*quic.Conn, error)
type TransportMaker func() http.RoundTripper
type PacketUpWriter struct {
ctx context.Context
cancel context.CancelFunc
cfg *Config
scMaxEachPostBytes int
scMinPostsIntervalMs Range
sessionID string
transport http.RoundTripper
writeMu sync.Mutex
writeCond sync.Cond
seq uint64
buf []byte
timer *time.Timer
flushErr error
}
func (c *PacketUpWriter) Write(b []byte) (int, error) {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if err := c.flushErr; err != nil {
return 0, err
}
data := bytes.NewBuffer(b)
for data.Len() > 0 {
if c.timer == nil { // start a timer to flush the buffer
c.timer = time.AfterFunc(time.Duration(c.scMinPostsIntervalMs.Rand())*time.Millisecond, c.flush)
}
c.buf = append(c.buf, data.Next(c.scMaxEachPostBytes-len(c.buf))...) // let buffer fill up to scMaxEachPostBytes
if len(c.buf) >= c.scMaxEachPostBytes { // too much data in buffer, wait the flush complete
c.writeCond.Wait()
if err := c.flushErr; err != nil {
return 0, err
}
}
}
return len(b), nil
}
func (c *PacketUpWriter) flush() {
c.writeMu.Lock()
defer c.writeMu.Unlock()
defer c.writeCond.Broadcast() // wake up the waited Write() call
if c.timer != nil {
c.timer.Stop()
c.timer = nil
}
if c.flushErr != nil {
return
}
if len(c.buf) == 0 {
return
}
_, err := c.write(c.buf)
c.buf = c.buf[:0] // reset buffer
if err != nil {
c.flushErr = err
return
}
}
func (c *PacketUpWriter) write(b []byte) (int, error) {
u := url.URL{
Scheme: "https",
Host: c.cfg.Host,
Path: c.cfg.NormalizedPath(),
}
req, err := http.NewRequestWithContext(c.ctx, http.MethodPost, u.String(), nil)
if err != nil {
return 0, err
}
seqStr := strconv.FormatUint(c.seq, 10)
c.seq++
if err := c.cfg.FillPacketRequest(req, c.sessionID, seqStr, b); err != nil {
return 0, err
}
req.Host = c.cfg.Host
resp, err := c.transport.RoundTrip(req)
if err != nil {
return 0, err
}
defer resp.Body.Close()
_, _ = io.Copy(io.Discard, resp.Body)
if resp.StatusCode != http.StatusOK {
return 0, fmt.Errorf("xhttp packet-up bad status: %s", resp.Status)
}
return len(b), nil
}
func (c *PacketUpWriter) Close() error {
ch := make(chan struct{})
go func() { // flush in the background
defer close(ch)
c.flush()
}()
select {
case <-ch:
case <-time.After(time.Second):
}
c.cancel()
httputils.CloseTransport(c.transport)
return nil
}
func NewTransport(dialRaw DialRawFunc, wrapTLS WrapTLSFunc, dialQUIC DialQUICFunc, alpn []string, keepAlivePeriod time.Duration) http.RoundTripper {
if len(alpn) == 1 && alpn[0] == "h3" { // `alpn: [h3]` means using h3 mode
if keepAlivePeriod == 0 {
keepAlivePeriod = QuicgoH3KeepAlivePeriod
}
if keepAlivePeriod < 0 {
keepAlivePeriod = 0
}
return &http3.Transport{
QUICConfig: &quic.Config{
MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
KeepAlivePeriod: keepAlivePeriod,
MaxIdleTimeout: ConnIdleTimeout,
},
Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
return dialQUIC(ctx, cfg)
},
}
}
if len(alpn) == 1 && alpn[0] == "http/1.1" { // `alpn: [http/1.1]` means using http/1.1 mode
w := semaphore.NewWeighted(20) // limit concurrent dialing to avoid WSAECONNREFUSED on Windows
dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) {
if err := w.Acquire(ctx, 1); err != nil {
return nil, err
}
defer w.Release(1)
raw, err := dialRaw(ctx)
if err != nil {
return nil, err
}
wrapped, err := wrapTLS(ctx, raw, false)
if err != nil {
_ = raw.Close()
return nil, err
}
return wrapped, nil
}
return &http.Transport{
DialContext: dialContext,
DialTLSContext: dialContext,
IdleConnTimeout: ConnIdleTimeout,
ForceAttemptHTTP2: false, // only http/1.1
}
}
if keepAlivePeriod == 0 {
keepAlivePeriod = ChromeH2KeepAlivePeriod
}
if keepAlivePeriod < 0 {
keepAlivePeriod = 0
}
return &http.Http2Transport{
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
raw, err := dialRaw(ctx)
if err != nil {
return nil, err
}
wrapped, err := wrapTLS(ctx, raw, true)
if err != nil {
_ = raw.Close()
return nil, err
}
return wrapped, nil
},
IdleConnTimeout: ConnIdleTimeout,
ReadIdleTimeout: keepAlivePeriod,
}
}
type Client struct {
ctx context.Context
cancel context.CancelFunc
mode string
cfg *Config
scMaxEachPostBytes Range
scMinPostsIntervalMs Range
makeTransport TransportMaker
makeDownloadTransport TransportMaker
uploadManager *ReuseManager
downloadManager *ReuseManager
}
func NewClient(cfg *Config, makeTransport TransportMaker, makeDownloadTransport TransportMaker, hasReality bool) (*Client, error) {
mode := cfg.EffectiveMode(hasReality)
switch mode {
case "stream-one", "stream-up", "packet-up":
default:
return nil, fmt.Errorf("xhttp mode %s is not implemented yet", mode)
}
scMaxEachPostBytes, err := cfg.GetNormalizedScMaxEachPostBytes()
if err != nil {
return nil, err
}
scMinPostsIntervalMs, err := cfg.GetNormalizedScMinPostsIntervalMs()
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
client := &Client{
mode: mode,
cfg: cfg,
scMaxEachPostBytes: scMaxEachPostBytes,
scMinPostsIntervalMs: scMinPostsIntervalMs,
makeTransport: makeTransport,
makeDownloadTransport: makeDownloadTransport,
ctx: ctx,
cancel: cancel,
}
if cfg.ReuseConfig != nil {
client.uploadManager, err = NewReuseManager(cfg.ReuseConfig, makeTransport)
if err != nil {
return nil, err
}
client.makeTransport = client.uploadManager.GetTransport
if cfg.DownloadConfig != nil {
if makeDownloadTransport == nil {
return nil, fmt.Errorf("xhttp: download manager requires download transport maker")
}
client.downloadManager, err = NewReuseManager(cfg.DownloadConfig.ReuseConfig, makeDownloadTransport)
if err != nil {
return nil, err
}
client.makeDownloadTransport = client.downloadManager.GetTransport
}
}
return client, nil
}
func (c *Client) Close() error {
c.cancel()
var errs []error
if c.uploadManager != nil {
err := c.uploadManager.Close()
if err != nil {
errs = append(errs, err)
}
}
if c.downloadManager != nil {
err := c.downloadManager.Close()
if err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}
func (c *Client) Dial() (net.Conn, error) {
switch c.mode {
case "stream-one":
return c.DialStreamOne()
case "stream-up":
return c.DialStreamUp()
case "packet-up":
return c.DialPacketUp()
default:
return nil, fmt.Errorf("xhttp mode %s is not implemented yet", c.mode)
}
}
// onlyRoundTripper is a wrapper that prevents the underlying transport from being closed.
type onlyRoundTripper struct {
http.RoundTripper
}
func (c *Client) getTransport() (uploadTransport http.RoundTripper, downloadTransport http.RoundTripper, err error) {
uploadTransport = c.makeTransport()
downloadTransport = onlyRoundTripper{uploadTransport}
if c.makeDownloadTransport != nil {
downloadTransport = c.makeDownloadTransport()
}
return
}
func (c *Client) DialStreamOne() (net.Conn, error) {
transport, _, err := c.getTransport()
if err != nil {
return nil, err
}
requestURL := url.URL{
Scheme: "https",
Host: c.cfg.Host,
Path: c.cfg.NormalizedPath(),
}
pr, pw := io.Pipe()
conn := &Conn{writer: pw}
// Use gotConn to detect when TCP connection is established, so we can
// return the conn immediately without waiting for the HTTP response.
// This breaks the deadlock where CDN buffers response headers until the
// server sends body data, but the server waits for our request body,
// which can't be sent because we haven't returned the conn yet.
gotConn := make(chan bool, 1)
addrCtx := httputils.NewAddrContext(&conn.NetAddr, c.ctx)
ctx := httptrace.WithClientTrace(addrCtx, &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
gotConn <- true
},
})
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), pr)
if err != nil {
_ = pr.Close()
_ = pw.Close()
httputils.CloseTransport(transport)
return nil, err
}
req.Host = c.cfg.Host
if err = c.cfg.FillStreamRequest(req, ""); err != nil {
_ = pr.Close()
_ = pw.Close()
httputils.CloseTransport(transport)
return nil, err
}
wrc := NewWaitReadCloser()
go func() {
resp, err := transport.RoundTrip(req)
if err != nil {
wrc.CloseWithError(err)
close(gotConn)
return
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
_ = resp.Body.Close()
wrc.CloseWithError(fmt.Errorf("xhttp stream-one bad status: %s", resp.Status))
return
}
wrc.Set(resp.Body)
}()
if !<-gotConn {
// RoundTrip failed before TCP connected (e.g. DNS failure)
_ = pr.Close()
_ = pw.Close()
httputils.CloseTransport(transport)
var buf [0]byte
_, err = wrc.Read(buf[:])
return nil, err
}
conn.reader = wrc
conn.onClose = func() {
_ = pr.Close()
httputils.CloseTransport(transport)
}
return conn, nil
}
func (c *Client) DialStreamUp() (net.Conn, error) {
uploadTransport, downloadTransport, err := c.getTransport()
if err != nil {
return nil, err
}
downloadCfg := c.cfg
if ds := c.cfg.DownloadConfig; ds != nil {
downloadCfg = ds
}
streamURL := url.URL{
Scheme: "https",
Host: c.cfg.Host,
Path: c.cfg.NormalizedPath(),
}
downloadURL := url.URL{
Scheme: "https",
Host: downloadCfg.Host,
Path: downloadCfg.NormalizedPath(),
}
pr, pw := io.Pipe()
conn := &Conn{writer: pw}
sessionID := newSessionID()
// Async download: avoid blocking on CDN response header buffering
gotConn := make(chan bool, 1)
addrCtx := httputils.NewAddrContext(&conn.NetAddr, c.ctx)
downloadCtx := httptrace.WithClientTrace(addrCtx, &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
gotConn <- true
},
})
downloadReq, err := http.NewRequestWithContext(
downloadCtx,
http.MethodGet,
downloadURL.String(),
nil,
)
if err != nil {
httputils.CloseTransport(uploadTransport)
httputils.CloseTransport(downloadTransport)
return nil, err
}
if err := downloadCfg.FillDownloadRequest(downloadReq, sessionID); err != nil {
httputils.CloseTransport(uploadTransport)
httputils.CloseTransport(downloadTransport)
return nil, err
}
downloadReq.Host = downloadCfg.Host
uploadReq, err := http.NewRequestWithContext(
c.ctx,
http.MethodPost,
streamURL.String(),
pr,
)
if err != nil {
httputils.CloseTransport(uploadTransport)
httputils.CloseTransport(downloadTransport)
return nil, err
}
if err = c.cfg.FillStreamRequest(uploadReq, sessionID); err != nil {
httputils.CloseTransport(uploadTransport)
httputils.CloseTransport(downloadTransport)
return nil, err
}
uploadReq.Host = c.cfg.Host
wrc := NewWaitReadCloser()
go func() {
resp, err := downloadTransport.RoundTrip(downloadReq)
if err != nil {
wrc.CloseWithError(err)
close(gotConn)
return
}
if resp.StatusCode != http.StatusOK {
_ = resp.Body.Close()
wrc.CloseWithError(fmt.Errorf("xhttp stream-up download bad status: %s", resp.Status))
return
}
wrc.Set(resp.Body)
}()
if !<-gotConn {
_ = pr.Close()
_ = pw.Close()
httputils.CloseTransport(uploadTransport)
httputils.CloseTransport(downloadTransport)
var buf [0]byte
_, err = wrc.Read(buf[:])
return nil, err
}
// Start upload after download TCP is connected, so the server has likely
// already processed the GET and created the session. This preserves the
// original ordering (download before upload) while still being async.
go func() {
resp, err := uploadTransport.RoundTrip(uploadReq)
if err != nil {
_ = pw.CloseWithError(err)
return
}
defer resp.Body.Close()
_, _ = io.Copy(io.Discard, resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
_ = pw.CloseWithError(fmt.Errorf("xhttp stream-up upload bad status: %s", resp.Status))
}
}()
conn.reader = wrc
conn.onClose = func() {
_ = pr.Close()
httputils.CloseTransport(uploadTransport)
httputils.CloseTransport(downloadTransport)
}
return conn, nil
}
func (c *Client) DialPacketUp() (net.Conn, error) {
uploadTransport, downloadTransport, err := c.getTransport()
if err != nil {
return nil, err
}
downloadCfg := c.cfg
if ds := c.cfg.DownloadConfig; ds != nil {
downloadCfg = ds
}
sessionID := newSessionID()
downloadURL := url.URL{
Scheme: "https",
Host: downloadCfg.Host,
Path: downloadCfg.NormalizedPath(),
}
writerCtx, writerCancel := context.WithCancel(c.ctx)
writer := &PacketUpWriter{
ctx: writerCtx,
cancel: writerCancel,
cfg: c.cfg,
scMaxEachPostBytes: c.scMaxEachPostBytes.Rand(),
scMinPostsIntervalMs: c.scMinPostsIntervalMs,
sessionID: sessionID,
transport: uploadTransport,
seq: 0,
}
writer.writeCond = sync.Cond{L: &writer.writeMu}
conn := &Conn{writer: writer}
// Async download: avoid blocking on CDN response header buffering
gotConn := make(chan bool, 1)
addrCtx := httputils.NewAddrContext(&conn.NetAddr, c.ctx)
downloadCtx := httptrace.WithClientTrace(addrCtx, &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
gotConn <- true
},
})
downloadReq, err := http.NewRequestWithContext(
downloadCtx,
http.MethodGet,
downloadURL.String(),
nil,
)
if err != nil {
httputils.CloseTransport(uploadTransport)
httputils.CloseTransport(downloadTransport)
return nil, err
}
if err = downloadCfg.FillDownloadRequest(downloadReq, sessionID); err != nil {
httputils.CloseTransport(uploadTransport)
httputils.CloseTransport(downloadTransport)
return nil, err
}
downloadReq.Host = downloadCfg.Host
wrc := NewWaitReadCloser()
go func() {
resp, err := downloadTransport.RoundTrip(downloadReq)
if err != nil {
wrc.CloseWithError(err)
close(gotConn)
return
}
if resp.StatusCode != http.StatusOK {
_ = resp.Body.Close()
wrc.CloseWithError(fmt.Errorf("xhttp packet-up download bad status: %s", resp.Status))
return
}
wrc.Set(resp.Body)
}()
if !<-gotConn {
httputils.CloseTransport(uploadTransport)
httputils.CloseTransport(downloadTransport)
var buf [0]byte
_, err = wrc.Read(buf[:])
return nil, err
}
conn.reader = wrc
conn.onClose = func() {
// uploadTransport already closed by writer
httputils.CloseTransport(downloadTransport)
}
return conn, nil
}
func newSessionID() string {
var b [16]byte
_, _ = rand.Read(b[:])
return hex.EncodeToString(b[:])
}
// WaitReadCloser is an io.ReadCloser that blocks on Read() until the underlying
// ReadCloser is provided via Set(). This enables returning a reader immediately
// while the actual HTTP response body is obtained asynchronously in a goroutine,
// breaking the synchronous RoundTrip deadlock with CDN header buffering.
type WaitReadCloser struct {
wait chan struct{}
once sync.Once
rc io.ReadCloser
err error
}
func NewWaitReadCloser() *WaitReadCloser {
return &WaitReadCloser{wait: make(chan struct{})}
}
// Set provides the underlying ReadCloser and unblocks any pending Read calls.
// Must be called at most once. If Close was already called, rc is closed to
// prevent leaks.
func (w *WaitReadCloser) Set(rc io.ReadCloser) {
w.setup(rc, nil)
}
// CloseWithError records an error and unblocks any pending Read calls.
func (w *WaitReadCloser) CloseWithError(err error) {
w.setup(nil, err)
}
// setup sets the underlying ReadCloser and error.
func (w *WaitReadCloser) setup(rc io.ReadCloser, err error) {
w.once.Do(func() {
w.rc = rc
w.err = err
close(w.wait)
})
if w.err != nil && rc != nil {
_ = rc.Close()
}
}
func (w *WaitReadCloser) Read(b []byte) (int, error) {
<-w.wait
if w.rc == nil {
return 0, w.err
}
return w.rc.Read(b)
}
func (w *WaitReadCloser) Close() error {
w.setup(nil, net.ErrClosed)
<-w.wait
if w.rc != nil {
return w.rc.Close()
}
return nil
}