chore: better lifecycle management

This commit is contained in:
wwqgtxx
2026-04-01 09:33:52 +08:00
parent 73465feabc
commit 018a2918fa
2 changed files with 99 additions and 56 deletions
+17 -32
View File
@@ -2,6 +2,7 @@ package outbound
import (
"context"
"errors"
"fmt"
"net"
"strconv"
@@ -37,7 +38,7 @@ type Vless struct {
// for gun mux
gunTransport *gun.Transport
// for xhttp
dialXHTTPConn func() (net.Conn, error)
xhttpClient *xhttp.Client
realityConfig *tlsC.RealityConfig
echConfig *ech.Config
@@ -188,7 +189,7 @@ func (v *Vless) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
case "grpc":
break // already handle in gun transport
case "xhttp":
break // already handle in dialXHTTPConn
break // already handle in xhttp client
default:
// default tcp network
// handle TLS
@@ -272,7 +273,7 @@ func (v *Vless) dialContext(ctx context.Context) (c net.Conn, err error) {
case "grpc": // gun transport
return v.gunTransport.Dial()
case "xhttp":
return v.dialXHTTPConn()
return v.xhttpClient.Dial()
default:
}
return v.dialer.DialContext(ctx, "tcp", v.addr)
@@ -347,10 +348,18 @@ func (v *Vless) ProxyInfo() C.ProxyInfo {
// Close implements C.ProxyAdapter
func (v *Vless) Close() error {
var errs []error
if v.gunTransport != nil {
return v.gunTransport.Close()
if err := v.gunTransport.Close(); err != nil {
errs = append(errs, err)
}
}
return nil
if v.xhttpClient != nil {
if err := v.xhttpClient.Close(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}
func parseVlessAddr(metadata *C.Metadata, xudp bool) *vless.DstAddr {
@@ -607,33 +616,9 @@ func NewVless(option VlessOption) (*Vless, error) {
}
}
mode := cfg.EffectiveMode(v.realityConfig != nil)
switch mode {
case "stream-one":
v.dialXHTTPConn = func() (net.Conn, error) {
transport := makeTransport()
return xhttp.DialStreamOne(cfg, transport)
}
case "stream-up":
v.dialXHTTPConn = func() (net.Conn, error) {
transport := makeTransport()
downloadTransport := transport
if makeDownloadTransport != nil {
downloadTransport = makeDownloadTransport()
}
return xhttp.DialStreamUp(cfg, transport, downloadTransport)
}
case "packet-up":
v.dialXHTTPConn = func() (net.Conn, error) {
transport := makeTransport()
downloadTransport := transport
if makeDownloadTransport != nil {
downloadTransport = makeDownloadTransport()
}
return xhttp.DialPacketUp(cfg, transport, downloadTransport)
}
default:
return nil, fmt.Errorf("xhttp mode %s is not implemented yet", mode)
v.xhttpClient, err = xhttp.NewClient(cfg, makeTransport, makeDownloadTransport, v.realityConfig != nil)
if err != nil {
return nil, err
}
}
+82 -24
View File
@@ -20,6 +20,8 @@ import (
type DialRawFunc func(ctx context.Context) (net.Conn, error)
type WrapTLSFunc func(ctx context.Context, conn net.Conn, isH2 bool) (net.Conn, error)
type TransportMaker func() http.RoundTripper
type PacketUpWriter struct {
ctx context.Context
cfg *Config
@@ -88,26 +90,72 @@ func NewTransport(dialRaw DialRawFunc, wrapTLS WrapTLSFunc) http.RoundTripper {
}
}
func DialStreamOne(cfg *Config, transport http.RoundTripper) (net.Conn, error) {
type Client struct {
mode string
cfg *Config
makeTransport TransportMaker
makeDownloadTransport TransportMaker
ctx context.Context
cancel context.CancelFunc
}
func NewClient(cfg *Config, makeTransport TransportMaker, makeDownloadTransport TransportMaker, hasReality bool) (*Client, error) {
mode := cfg.EffectiveMode(hasReality)
switch mode {
case "stream-one", "stream-up", "packet-up":
default:
return nil, fmt.Errorf("xhttp mode %s is not implemented yet", mode)
}
ctx, cancel := context.WithCancel(context.Background())
return &Client{
mode: mode,
cfg: cfg,
makeTransport: makeTransport,
makeDownloadTransport: makeDownloadTransport,
ctx: ctx,
cancel: cancel,
}, nil
}
func (c *Client) Dial() (net.Conn, error) {
switch c.mode {
case "stream-one":
return c.DialStreamOne()
case "stream-up":
return c.DialStreamUp()
case "packet-up":
return c.DialPacketUp()
default:
return nil, fmt.Errorf("xhttp mode %s is not implemented yet", c.mode)
}
}
func (c *Client) Close() error {
c.cancel()
return nil
}
func (c *Client) DialStreamOne() (net.Conn, error) {
transport := c.makeTransport()
requestURL := url.URL{
Scheme: "https",
Host: cfg.Host,
Path: cfg.NormalizedPath(),
Host: c.cfg.Host,
Path: c.cfg.NormalizedPath(),
}
pr, pw := io.Pipe()
ctx := context.Background()
conn := &Conn{writer: pw}
req, err := http.NewRequestWithContext(httputils.NewAddrContext(&conn.NetAddr, ctx), http.MethodPost, requestURL.String(), pr)
req, err := http.NewRequestWithContext(httputils.NewAddrContext(&conn.NetAddr, c.ctx), http.MethodPost, requestURL.String(), pr)
if err != nil {
_ = pr.Close()
_ = pw.Close()
return nil, err
}
req.Host = cfg.Host
req.Host = c.cfg.Host
if err := cfg.FillStreamRequest(req, ""); err != nil {
if err := c.cfg.FillStreamRequest(req, ""); err != nil {
_ = pr.Close()
_ = pw.Close()
return nil, err
@@ -136,16 +184,22 @@ func DialStreamOne(cfg *Config, transport http.RoundTripper) (net.Conn, error) {
return conn, nil
}
func DialStreamUp(cfg *Config, uploadTransport http.RoundTripper, downloadTransport http.RoundTripper) (net.Conn, error) {
downloadCfg := cfg
if ds := cfg.DownloadConfig; ds != nil {
func (c *Client) DialStreamUp() (net.Conn, error) {
uploadTransport := c.makeTransport()
downloadTransport := uploadTransport
if c.makeDownloadTransport != nil {
downloadTransport = c.makeDownloadTransport()
}
downloadCfg := c.cfg
if ds := c.cfg.DownloadConfig; ds != nil {
downloadCfg = ds
}
streamURL := url.URL{
Scheme: "https",
Host: cfg.Host,
Path: cfg.NormalizedPath(),
Host: c.cfg.Host,
Path: c.cfg.NormalizedPath(),
}
downloadURL := url.URL{
@@ -155,13 +209,12 @@ func DialStreamUp(cfg *Config, uploadTransport http.RoundTripper, downloadTransp
}
pr, pw := io.Pipe()
ctx := context.Background()
conn := &Conn{writer: pw}
sessionID := newSessionID()
downloadReq, err := http.NewRequestWithContext(
httputils.NewAddrContext(&conn.NetAddr, ctx),
httputils.NewAddrContext(&conn.NetAddr, c.ctx),
http.MethodGet,
downloadURL.String(),
nil,
@@ -201,7 +254,7 @@ func DialStreamUp(cfg *Config, uploadTransport http.RoundTripper, downloadTransp
}
uploadReq, err := http.NewRequestWithContext(
ctx,
c.ctx,
http.MethodPost,
streamURL.String(),
pr,
@@ -217,7 +270,7 @@ func DialStreamUp(cfg *Config, uploadTransport http.RoundTripper, downloadTransp
return nil, err
}
if err := cfg.FillStreamRequest(uploadReq, sessionID); err != nil {
if err := c.cfg.FillStreamRequest(uploadReq, sessionID); err != nil {
_ = downloadResp.Body.Close()
_ = pr.Close()
_ = pw.Close()
@@ -227,7 +280,7 @@ func DialStreamUp(cfg *Config, uploadTransport http.RoundTripper, downloadTransp
}
return nil, err
}
uploadReq.Host = cfg.Host
uploadReq.Host = c.cfg.Host
go func() {
resp, err := uploadTransport.RoundTrip(uploadReq)
@@ -255,9 +308,15 @@ func DialStreamUp(cfg *Config, uploadTransport http.RoundTripper, downloadTransp
return conn, nil
}
func DialPacketUp(cfg *Config, uploadTransport http.RoundTripper, downloadTransport http.RoundTripper) (net.Conn, error) {
downloadCfg := cfg
if ds := cfg.DownloadConfig; ds != nil {
func (c *Client) DialPacketUp() (net.Conn, error) {
uploadTransport := c.makeTransport()
downloadTransport := uploadTransport
if c.makeDownloadTransport != nil {
downloadTransport = c.makeDownloadTransport()
}
downloadCfg := c.cfg
if ds := c.cfg.DownloadConfig; ds != nil {
downloadCfg = ds
}
sessionID := newSessionID()
@@ -268,17 +327,16 @@ func DialPacketUp(cfg *Config, uploadTransport http.RoundTripper, downloadTransp
Path: downloadCfg.NormalizedPath(),
}
ctx := context.Background()
writer := &PacketUpWriter{
ctx: ctx,
cfg: cfg,
ctx: c.ctx,
cfg: c.cfg,
sessionID: sessionID,
transport: uploadTransport,
seq: 0,
}
conn := &Conn{writer: writer}
downloadReq, err := http.NewRequestWithContext(httputils.NewAddrContext(&conn.NetAddr, ctx), http.MethodGet, downloadURL.String(), nil)
downloadReq, err := http.NewRequestWithContext(httputils.NewAddrContext(&conn.NetAddr, c.ctx), http.MethodGet, downloadURL.String(), nil)
if err != nil {
return nil, err
}