From 9c2ce399799536655a1a1cb59561e270be767b87 Mon Sep 17 00:00:00 2001 From: Jason Lyu Date: Sat, 7 Feb 2026 21:34:58 -0500 Subject: [PATCH] Refactor(dialer): with socket options (#525) --- dialer/dialer.go | 118 ++++++++++++++++++++++---------------- dialer/sockopt.go | 57 ++++++++++-------- dialer/sockopt_darwin.go | 44 ++++---------- dialer/sockopt_freebsd.go | 27 ++------- dialer/sockopt_linux.go | 47 +++++---------- dialer/sockopt_openbsd.go | 27 ++------- dialer/sockopt_others.go | 10 ++-- dialer/sockopt_windows.go | 64 ++++++++------------- dns/resolver.go | 2 +- engine/engine.go | 8 ++- 10 files changed, 174 insertions(+), 230 deletions(-) diff --git a/dialer/dialer.go b/dialer/dialer.go index 4165369..f0c1170 100644 --- a/dialer/dialer.go +++ b/dialer/dialer.go @@ -3,81 +3,97 @@ package dialer import ( "context" "net" + "sync" "syscall" "go.uber.org/atomic" ) -// DefaultDialer is the default Dialer and is used by DialContext and ListenPacket. -var DefaultDialer = &Dialer{ - InterfaceName: atomic.NewString(""), - InterfaceIndex: atomic.NewInt32(0), - RoutingMark: atomic.NewInt32(0), +// DefaultDialer is the package-level default Dialer. +// It is used by DialContext and ListenPacket. +var DefaultDialer = &Dialer{} + +// RegisterSockOpt registers a socket option on the DefaultDialer. +func RegisterSockOpt(opt SocketOption) { + DefaultDialer.RegisterSockOpt(opt) } -type Dialer struct { - InterfaceName *atomic.String - InterfaceIndex *atomic.Int32 - RoutingMark *atomic.Int32 +// Reset removes all registered socket options from the DefaultDialer. +func Reset() { + DefaultDialer.Reset() } -type Options struct { - // InterfaceName is the name of interface/device to bind. - // If a socket is bound to an interface, only packets received - // from that particular interface are processed by the socket. - InterfaceName string - - // InterfaceIndex is the index of interface/device to bind. - // It is almost the same as InterfaceName except it uses the - // index of the interface instead of the name. - InterfaceIndex int - - // RoutingMark is the mark for each packet sent through this - // socket. Changing the mark can be used for mark-based routing - // without netfilter or for packet filtering. - RoutingMark int -} - -// DialContext is a wrapper around DefaultDialer.DialContext. +// DialContext dials using the DefaultDialer. func DialContext(ctx context.Context, network, address string) (net.Conn, error) { return DefaultDialer.DialContext(ctx, network, address) } -// ListenPacket is a wrapper around DefaultDialer.ListenPacket. +// ListenPacket listens using the DefaultDialer. func ListenPacket(network, address string) (net.PacketConn, error) { return DefaultDialer.ListenPacket(network, address) } -func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return d.DialContextWithOptions(ctx, network, address, &Options{ - InterfaceName: d.InterfaceName.Load(), - InterfaceIndex: int(d.InterfaceIndex.Load()), - RoutingMark: int(d.RoutingMark.Load()), - }) +// Dialer applies registered SocketOptions to all dials/listens. +type Dialer struct { + optsMu sync.Mutex + atomicOpts atomic.Value } -func (*Dialer) DialContextWithOptions(ctx context.Context, network, address string, opts *Options) (net.Conn, error) { - d := &net.Dialer{ - Control: func(network, address string, c syscall.RawConn) error { - return setSocketOptions(network, address, c, opts) - }, +// New creates a new Dialer with the given initial socket options. +func New(opts ...SocketOption) *Dialer { + d := &Dialer{} + for _, opt := range opts { + d.RegisterSockOpt(opt) } - return d.DialContext(ctx, network, address) + return d } +// RegisterSockOpt registers a socket option on the Dialer. +func (d *Dialer) RegisterSockOpt(opt SocketOption) { + d.optsMu.Lock() + opts, _ := d.atomicOpts.Load().([]SocketOption) + d.atomicOpts.Store(append(opts, opt)) + d.optsMu.Unlock() +} + +// Reset removes all registered socket options from the Dialer. +func (d *Dialer) Reset() { + d.optsMu.Lock() + d.atomicOpts.Store([]SocketOption(nil)) + d.optsMu.Unlock() +} + +func (d *Dialer) applySockOpts(network string, address string, c syscall.RawConn) error { + opts, _ := d.atomicOpts.Load().([]SocketOption) + if len(opts) == 0 { + return nil + } + // Skip non-global-unicast IPs (e.g. loopback, link-local). + if host, _, err := net.SplitHostPort(address); err == nil { + if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() { + return nil + } + } + for _, opt := range opts { + if err := opt.Apply(network, address, c); err != nil { + return err + } + } + return nil +} + +// DialContext behaves like net.Dialer.DialContext, applying registered SocketOptions. +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + nd := &net.Dialer{ + Control: d.applySockOpts, + } + return nd.DialContext(ctx, network, address) +} + +// ListenPacket behaves like net.ListenConfig.ListenPacket, applying registered SocketOptions. func (d *Dialer) ListenPacket(network, address string) (net.PacketConn, error) { - return d.ListenPacketWithOptions(network, address, &Options{ - InterfaceName: d.InterfaceName.Load(), - InterfaceIndex: int(d.InterfaceIndex.Load()), - RoutingMark: int(d.RoutingMark.Load()), - }) -} - -func (*Dialer) ListenPacketWithOptions(network, address string, opts *Options) (net.PacketConn, error) { lc := &net.ListenConfig{ - Control: func(network, address string, c syscall.RawConn) error { - return setSocketOptions(network, address, c, opts) - }, + Control: d.applySockOpts, } return lc.ListenPacket(context.Background(), network, address) } diff --git a/dialer/sockopt.go b/dialer/sockopt.go index ac28bde..d67e616 100644 --- a/dialer/sockopt.go +++ b/dialer/sockopt.go @@ -1,30 +1,41 @@ package dialer -func isTCPSocket(network string) bool { - switch network { - case "tcp", "tcp4", "tcp6": - return true - default: - return false - } +import ( + "errors" + "syscall" +) + +var _ SocketOption = SocketOptionFunc(nil) + +// SocketOption applies a socket-level configuration to a network connection +// during dialing or listening, via syscall.RawConn. +type SocketOption interface { + Apply(network, address string, c syscall.RawConn) error } -func isUDPSocket(network string) bool { - switch network { - case "udp", "udp4", "udp6": - return true - default: - return false - } +// SocketOptionFunc adapts a function to a SocketOption. +type SocketOptionFunc func(network, address string, c syscall.RawConn) error + +func (f SocketOptionFunc) Apply(network, address string, c syscall.RawConn) error { + return f(network, address, c) } -func isICMPSocket(network string) bool { - switch network { - case "ip:icmp", "ip4:icmp", "ip6:ipv6-icmp": - return true - case "ip4", "ip6": - return true - default: - return false - } +// UnsupportedSocketOption is a sentinel SocketOption that always reports +// ErrUnsupported when applied. +var UnsupportedSocketOption = SocketOptionFunc(unsupportedSocketOpt) + +func unsupportedSocketOpt(_, _ string, _ syscall.RawConn) error { + return errors.ErrUnsupported +} + +// rawConnControl runs f with the file descriptor obtained via RawConn.Control +// and correctly propagates errors returned from f. +func rawConnControl(c syscall.RawConn, f func(uintptr) error) error { + var innerErr error + if err := c.Control(func(fd uintptr) { + innerErr = f(fd) + }); err != nil { + return err + } + return innerErr } diff --git a/dialer/sockopt_darwin.go b/dialer/sockopt_darwin.go index 640227c..799c743 100644 --- a/dialer/sockopt_darwin.go +++ b/dialer/sockopt_darwin.go @@ -7,39 +7,19 @@ import ( "golang.org/x/sys/unix" ) -func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) { - if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) { - return err - } - - var innerErr error - err = c.Control(func(fd uintptr) { - host, _, _ := net.SplitHostPort(address) - if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() { - return - } - - if opts.InterfaceIndex == 0 && opts.InterfaceName != "" { - if iface, err := net.InterfaceByName(opts.InterfaceName); err == nil { - opts.InterfaceIndex = iface.Index - } - } - - if opts.InterfaceIndex != 0 { +func WithBindToInterface(iface *net.Interface) SocketOption { + index := iface.Index + return SocketOptionFunc(func(network, _ string, c syscall.RawConn) error { + return rawConnControl(c, func(fd uintptr) error { switch network { - case "tcp4", "udp4": - innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, opts.InterfaceIndex) - case "tcp6", "udp6": - innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, opts.InterfaceIndex) + case "ip4", "tcp4", "udp4": + return unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, index) + case "ip6", "tcp6", "udp6": + return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, index) } - if innerErr != nil { - return - } - } + return nil + }) }) - - if innerErr != nil { - err = innerErr - } - return err } + +func WithRoutingMark(_ int) SocketOption { return UnsupportedSocketOption } diff --git a/dialer/sockopt_freebsd.go b/dialer/sockopt_freebsd.go index 00339be..77643fd 100644 --- a/dialer/sockopt_freebsd.go +++ b/dialer/sockopt_freebsd.go @@ -7,27 +7,12 @@ import ( "golang.org/x/sys/unix" ) -func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) { - if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) { - return err - } +func WithBindToInterface(_ *net.Interface) SocketOption { return UnsupportedSocketOption } - var innerErr error - err = c.Control(func(fd uintptr) { - host, _, _ := net.SplitHostPort(address) - if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() { - return - } - - if opts.RoutingMark != 0 { - if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_USER_COOKIE, opts.RoutingMark); innerErr != nil { - return - } - } +func WithRoutingMark(mark int) SocketOption { + return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error { + return rawConnControl(c, func(fd uintptr) error { + return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_USER_COOKIE, mark) + }) }) - - if innerErr != nil { - err = innerErr - } - return err } diff --git a/dialer/sockopt_linux.go b/dialer/sockopt_linux.go index 9af1479..0e08652 100644 --- a/dialer/sockopt_linux.go +++ b/dialer/sockopt_linux.go @@ -7,38 +7,19 @@ import ( "golang.org/x/sys/unix" ) -func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) { - if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) { - return err - } - - var innerErr error - err = c.Control(func(fd uintptr) { - host, _, _ := net.SplitHostPort(address) - if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() { - return - } - - if opts.InterfaceName == "" && opts.InterfaceIndex != 0 { - if iface, err := net.InterfaceByIndex(opts.InterfaceIndex); err == nil { - opts.InterfaceName = iface.Name - } - } - - if opts.InterfaceName != "" { - if innerErr = unix.BindToDevice(int(fd), opts.InterfaceName); innerErr != nil { - return - } - } - if opts.RoutingMark != 0 { - if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, opts.RoutingMark); innerErr != nil { - return - } - } +func WithBindToInterface(iface *net.Interface) SocketOption { + device := iface.Name + return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error { + return rawConnControl(c, func(fd uintptr) error { + return unix.BindToDevice(int(fd), device) + }) + }) +} + +func WithRoutingMark(mark int) SocketOption { + return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error { + return rawConnControl(c, func(fd uintptr) error { + return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, mark) + }) }) - - if innerErr != nil { - err = innerErr - } - return err } diff --git a/dialer/sockopt_openbsd.go b/dialer/sockopt_openbsd.go index 18ea7e0..1fadaa5 100644 --- a/dialer/sockopt_openbsd.go +++ b/dialer/sockopt_openbsd.go @@ -7,27 +7,12 @@ import ( "golang.org/x/sys/unix" ) -func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) { - if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) { - return err - } +func WithBindToInterface(_ *net.Interface) SocketOption { return UnsupportedSocketOption } - var innerErr error - err = c.Control(func(fd uintptr) { - host, _, _ := net.SplitHostPort(address) - if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() { - return - } - - if opts.RoutingMark != 0 { - if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RTABLE, opts.RoutingMark); innerErr != nil { - return - } - } +func WithRoutingMark(mark int) SocketOption { + return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error { + return rawConnControl(c, func(fd uintptr) error { + return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RTABLE, mark) + }) }) - - if innerErr != nil { - err = innerErr - } - return err } diff --git a/dialer/sockopt_others.go b/dialer/sockopt_others.go index 1695529..067dd0f 100644 --- a/dialer/sockopt_others.go +++ b/dialer/sockopt_others.go @@ -2,8 +2,10 @@ package dialer -import "syscall" +import ( + "net" +) -func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) error { - return nil -} +func WithBindToInterface(_ *net.Interface) SocketOption { return UnsupportedSocketOption } + +func WithRoutingMark(_ int) SocketOption { return UnsupportedSocketOption } diff --git a/dialer/sockopt_windows.go b/dialer/sockopt_windows.go index 2836b66..fb00f5d 100644 --- a/dialer/sockopt_windows.go +++ b/dialer/sockopt_windows.go @@ -1,68 +1,50 @@ package dialer import ( - "encoding/binary" + "math/bits" "net" "syscall" - "unsafe" "golang.org/x/sys/windows" ) const ( - IP_UNICAST_IF = 31 - IPV6_UNICAST_IF = 31 + IP_UNICAST_IF = 0x1f + IPV6_UNICAST_IF = 0x1f ) -func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) { - if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) { - return err - } - - var innerErr error - err = c.Control(func(fd uintptr) { - host, _, _ := net.SplitHostPort(address) - ip := net.ParseIP(host) - if ip != nil && !ip.IsGlobalUnicast() { - return - } - - if opts.InterfaceIndex == 0 && opts.InterfaceName != "" { - if iface, err := net.InterfaceByName(opts.InterfaceName); err == nil { - opts.InterfaceIndex = iface.Index - } - } - - if opts.InterfaceIndex != 0 { +func WithBindToInterface(iface *net.Interface) SocketOption { + index := uint32(iface.Index) + return SocketOptionFunc(func(network, address string, c syscall.RawConn) error { + return rawConnControl(c, func(fd uintptr) (err error) { switch network { - case "tcp4", "udp4": - innerErr = bindSocketToInterface4(windows.Handle(fd), uint32(opts.InterfaceIndex)) - case "tcp6", "udp6": - innerErr = bindSocketToInterface6(windows.Handle(fd), uint32(opts.InterfaceIndex)) - if network == "udp6" && ip == nil { - // The underlying IP net maybe IPv4 even if the `network` param is `udp6`, - // so we should bind socket to interface4 at the same time. - innerErr = bindSocketToInterface4(windows.Handle(fd), uint32(opts.InterfaceIndex)) + case "ip4", "tcp4", "udp4": + err = bindSocketToInterface4(windows.Handle(fd), index) + case "ip6", "tcp6", "udp6": + err = bindSocketToInterface6(windows.Handle(fd), index) + // UDPv6 may still use an IPv4 underlying socket if the destination + // address is unspecified (e.g. ":0"). + if network == "udp6" { + host, _, _ := net.SplitHostPort(address) + if ip := net.ParseIP(host); ip == nil || ip.IsUnspecified() { + _ = bindSocketToInterface4(windows.Handle(fd), index) + } } } - } + return + }) }) - - if innerErr != nil { - err = innerErr - } - return err } func bindSocketToInterface4(handle windows.Handle, index uint32) error { // For IPv4, this parameter must be an interface index in network byte order. // Ref: https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options - var bytes [4]byte - binary.BigEndian.PutUint32(bytes[:], index) - index = *(*uint32)(unsafe.Pointer(&bytes[0])) + index = bits.ReverseBytes32(index) return windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(index)) } func bindSocketToInterface6(handle windows.Handle, index uint32) error { return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(index)) } + +func WithRoutingMark(_ int) SocketOption { return UnsupportedSocketOption } diff --git a/dns/resolver.go b/dns/resolver.go index 858223b..8748aa4 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -10,5 +10,5 @@ func init() { // We must use this DialContext to query DNS // when using net default resolver. net.DefaultResolver.PreferGo = true - net.DefaultResolver.Dial = dialer.DefaultDialer.DialContext + net.DefaultResolver.Dial = dialer.DialContext } diff --git a/engine/engine.go b/engine/engine.go index 57bb165..e9c6fc1 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -122,18 +122,20 @@ func general(k *Key) error { } log.SetLogger(log.Must(log.NewLeveled(level))) + // Reset default dialer before registering options. + dialer.Reset() + if k.Interface != "" { iface, err := net.InterfaceByName(k.Interface) if err != nil { return err } - dialer.DefaultDialer.InterfaceName.Store(iface.Name) - dialer.DefaultDialer.InterfaceIndex.Store(int32(iface.Index)) + dialer.RegisterSockOpt(dialer.WithBindToInterface(iface)) log.Infof("[DIALER] bind to interface: %s", k.Interface) } if k.Mark != 0 { - dialer.DefaultDialer.RoutingMark.Store(int32(k.Mark)) + dialer.RegisterSockOpt(dialer.WithRoutingMark(k.Mark)) log.Infof("[DIALER] set fwmark: %#x", k.Mark) }