Refactor(dialer): with socket options (#525)

This commit is contained in:
Jason Lyu
2026-02-07 21:34:58 -05:00
committed by GitHub
parent 82cca651a0
commit 9c2ce39979
10 changed files with 174 additions and 230 deletions
+67 -51
View File
@@ -3,81 +3,97 @@ package dialer
import (
"context"
"net"
"sync"
"syscall"
"go.uber.org/atomic"
)
// DefaultDialer is the default Dialer and is used by DialContext and ListenPacket.
var DefaultDialer = &Dialer{
InterfaceName: atomic.NewString(""),
InterfaceIndex: atomic.NewInt32(0),
RoutingMark: atomic.NewInt32(0),
// DefaultDialer is the package-level default Dialer.
// It is used by DialContext and ListenPacket.
var DefaultDialer = &Dialer{}
// RegisterSockOpt registers a socket option on the DefaultDialer.
func RegisterSockOpt(opt SocketOption) {
DefaultDialer.RegisterSockOpt(opt)
}
type Dialer struct {
InterfaceName *atomic.String
InterfaceIndex *atomic.Int32
RoutingMark *atomic.Int32
// Reset removes all registered socket options from the DefaultDialer.
func Reset() {
DefaultDialer.Reset()
}
type Options struct {
// InterfaceName is the name of interface/device to bind.
// If a socket is bound to an interface, only packets received
// from that particular interface are processed by the socket.
InterfaceName string
// InterfaceIndex is the index of interface/device to bind.
// It is almost the same as InterfaceName except it uses the
// index of the interface instead of the name.
InterfaceIndex int
// RoutingMark is the mark for each packet sent through this
// socket. Changing the mark can be used for mark-based routing
// without netfilter or for packet filtering.
RoutingMark int
}
// DialContext is a wrapper around DefaultDialer.DialContext.
// DialContext dials using the DefaultDialer.
func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return DefaultDialer.DialContext(ctx, network, address)
}
// ListenPacket is a wrapper around DefaultDialer.ListenPacket.
// ListenPacket listens using the DefaultDialer.
func ListenPacket(network, address string) (net.PacketConn, error) {
return DefaultDialer.ListenPacket(network, address)
}
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.DialContextWithOptions(ctx, network, address, &Options{
InterfaceName: d.InterfaceName.Load(),
InterfaceIndex: int(d.InterfaceIndex.Load()),
RoutingMark: int(d.RoutingMark.Load()),
})
// Dialer applies registered SocketOptions to all dials/listens.
type Dialer struct {
optsMu sync.Mutex
atomicOpts atomic.Value
}
func (*Dialer) DialContextWithOptions(ctx context.Context, network, address string, opts *Options) (net.Conn, error) {
d := &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
return setSocketOptions(network, address, c, opts)
},
// New creates a new Dialer with the given initial socket options.
func New(opts ...SocketOption) *Dialer {
d := &Dialer{}
for _, opt := range opts {
d.RegisterSockOpt(opt)
}
return d.DialContext(ctx, network, address)
return d
}
// RegisterSockOpt registers a socket option on the Dialer.
func (d *Dialer) RegisterSockOpt(opt SocketOption) {
d.optsMu.Lock()
opts, _ := d.atomicOpts.Load().([]SocketOption)
d.atomicOpts.Store(append(opts, opt))
d.optsMu.Unlock()
}
// Reset removes all registered socket options from the Dialer.
func (d *Dialer) Reset() {
d.optsMu.Lock()
d.atomicOpts.Store([]SocketOption(nil))
d.optsMu.Unlock()
}
func (d *Dialer) applySockOpts(network string, address string, c syscall.RawConn) error {
opts, _ := d.atomicOpts.Load().([]SocketOption)
if len(opts) == 0 {
return nil
}
// Skip non-global-unicast IPs (e.g. loopback, link-local).
if host, _, err := net.SplitHostPort(address); err == nil {
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return nil
}
}
for _, opt := range opts {
if err := opt.Apply(network, address, c); err != nil {
return err
}
}
return nil
}
// DialContext behaves like net.Dialer.DialContext, applying registered SocketOptions.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
nd := &net.Dialer{
Control: d.applySockOpts,
}
return nd.DialContext(ctx, network, address)
}
// ListenPacket behaves like net.ListenConfig.ListenPacket, applying registered SocketOptions.
func (d *Dialer) ListenPacket(network, address string) (net.PacketConn, error) {
return d.ListenPacketWithOptions(network, address, &Options{
InterfaceName: d.InterfaceName.Load(),
InterfaceIndex: int(d.InterfaceIndex.Load()),
RoutingMark: int(d.RoutingMark.Load()),
})
}
func (*Dialer) ListenPacketWithOptions(network, address string, opts *Options) (net.PacketConn, error) {
lc := &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
return setSocketOptions(network, address, c, opts)
},
Control: d.applySockOpts,
}
return lc.ListenPacket(context.Background(), network, address)
}
+34 -23
View File
@@ -1,30 +1,41 @@
package dialer
func isTCPSocket(network string) bool {
switch network {
case "tcp", "tcp4", "tcp6":
return true
default:
return false
}
import (
"errors"
"syscall"
)
var _ SocketOption = SocketOptionFunc(nil)
// SocketOption applies a socket-level configuration to a network connection
// during dialing or listening, via syscall.RawConn.
type SocketOption interface {
Apply(network, address string, c syscall.RawConn) error
}
func isUDPSocket(network string) bool {
switch network {
case "udp", "udp4", "udp6":
return true
default:
return false
}
// SocketOptionFunc adapts a function to a SocketOption.
type SocketOptionFunc func(network, address string, c syscall.RawConn) error
func (f SocketOptionFunc) Apply(network, address string, c syscall.RawConn) error {
return f(network, address, c)
}
func isICMPSocket(network string) bool {
switch network {
case "ip:icmp", "ip4:icmp", "ip6:ipv6-icmp":
return true
case "ip4", "ip6":
return true
default:
return false
}
// UnsupportedSocketOption is a sentinel SocketOption that always reports
// ErrUnsupported when applied.
var UnsupportedSocketOption = SocketOptionFunc(unsupportedSocketOpt)
func unsupportedSocketOpt(_, _ string, _ syscall.RawConn) error {
return errors.ErrUnsupported
}
// rawConnControl runs f with the file descriptor obtained via RawConn.Control
// and correctly propagates errors returned from f.
func rawConnControl(c syscall.RawConn, f func(uintptr) error) error {
var innerErr error
if err := c.Control(func(fd uintptr) {
innerErr = f(fd)
}); err != nil {
return err
}
return innerErr
}
+12 -32
View File
@@ -7,39 +7,19 @@ import (
"golang.org/x/sys/unix"
)
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) {
return err
}
var innerErr error
err = c.Control(func(fd uintptr) {
host, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return
}
if opts.InterfaceIndex == 0 && opts.InterfaceName != "" {
if iface, err := net.InterfaceByName(opts.InterfaceName); err == nil {
opts.InterfaceIndex = iface.Index
}
}
if opts.InterfaceIndex != 0 {
func WithBindToInterface(iface *net.Interface) SocketOption {
index := iface.Index
return SocketOptionFunc(func(network, _ string, c syscall.RawConn) error {
return rawConnControl(c, func(fd uintptr) error {
switch network {
case "tcp4", "udp4":
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, opts.InterfaceIndex)
case "tcp6", "udp6":
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, opts.InterfaceIndex)
case "ip4", "tcp4", "udp4":
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
case "ip6", "tcp6", "udp6":
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, index)
}
if innerErr != nil {
return
}
}
return nil
})
})
if innerErr != nil {
err = innerErr
}
return err
}
func WithRoutingMark(_ int) SocketOption { return UnsupportedSocketOption }
+6 -21
View File
@@ -7,27 +7,12 @@ import (
"golang.org/x/sys/unix"
)
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) {
return err
}
func WithBindToInterface(_ *net.Interface) SocketOption { return UnsupportedSocketOption }
var innerErr error
err = c.Control(func(fd uintptr) {
host, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return
}
if opts.RoutingMark != 0 {
if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_USER_COOKIE, opts.RoutingMark); innerErr != nil {
return
}
}
func WithRoutingMark(mark int) SocketOption {
return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error {
return rawConnControl(c, func(fd uintptr) error {
return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_USER_COOKIE, mark)
})
})
if innerErr != nil {
err = innerErr
}
return err
}
+14 -33
View File
@@ -7,38 +7,19 @@ import (
"golang.org/x/sys/unix"
)
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) {
return err
}
var innerErr error
err = c.Control(func(fd uintptr) {
host, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return
}
if opts.InterfaceName == "" && opts.InterfaceIndex != 0 {
if iface, err := net.InterfaceByIndex(opts.InterfaceIndex); err == nil {
opts.InterfaceName = iface.Name
}
}
if opts.InterfaceName != "" {
if innerErr = unix.BindToDevice(int(fd), opts.InterfaceName); innerErr != nil {
return
}
}
if opts.RoutingMark != 0 {
if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, opts.RoutingMark); innerErr != nil {
return
}
}
func WithBindToInterface(iface *net.Interface) SocketOption {
device := iface.Name
return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error {
return rawConnControl(c, func(fd uintptr) error {
return unix.BindToDevice(int(fd), device)
})
})
}
func WithRoutingMark(mark int) SocketOption {
return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error {
return rawConnControl(c, func(fd uintptr) error {
return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, mark)
})
})
if innerErr != nil {
err = innerErr
}
return err
}
+6 -21
View File
@@ -7,27 +7,12 @@ import (
"golang.org/x/sys/unix"
)
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) {
return err
}
func WithBindToInterface(_ *net.Interface) SocketOption { return UnsupportedSocketOption }
var innerErr error
err = c.Control(func(fd uintptr) {
host, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return
}
if opts.RoutingMark != 0 {
if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RTABLE, opts.RoutingMark); innerErr != nil {
return
}
}
func WithRoutingMark(mark int) SocketOption {
return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error {
return rawConnControl(c, func(fd uintptr) error {
return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RTABLE, mark)
})
})
if innerErr != nil {
err = innerErr
}
return err
}
+6 -4
View File
@@ -2,8 +2,10 @@
package dialer
import "syscall"
import (
"net"
)
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) error {
return nil
}
func WithBindToInterface(_ *net.Interface) SocketOption { return UnsupportedSocketOption }
func WithRoutingMark(_ int) SocketOption { return UnsupportedSocketOption }
+23 -41
View File
@@ -1,68 +1,50 @@
package dialer
import (
"encoding/binary"
"math/bits"
"net"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
const (
IP_UNICAST_IF = 31
IPV6_UNICAST_IF = 31
IP_UNICAST_IF = 0x1f
IPV6_UNICAST_IF = 0x1f
)
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) {
return err
}
var innerErr error
err = c.Control(func(fd uintptr) {
host, _, _ := net.SplitHostPort(address)
ip := net.ParseIP(host)
if ip != nil && !ip.IsGlobalUnicast() {
return
}
if opts.InterfaceIndex == 0 && opts.InterfaceName != "" {
if iface, err := net.InterfaceByName(opts.InterfaceName); err == nil {
opts.InterfaceIndex = iface.Index
}
}
if opts.InterfaceIndex != 0 {
func WithBindToInterface(iface *net.Interface) SocketOption {
index := uint32(iface.Index)
return SocketOptionFunc(func(network, address string, c syscall.RawConn) error {
return rawConnControl(c, func(fd uintptr) (err error) {
switch network {
case "tcp4", "udp4":
innerErr = bindSocketToInterface4(windows.Handle(fd), uint32(opts.InterfaceIndex))
case "tcp6", "udp6":
innerErr = bindSocketToInterface6(windows.Handle(fd), uint32(opts.InterfaceIndex))
if network == "udp6" && ip == nil {
// The underlying IP net maybe IPv4 even if the `network` param is `udp6`,
// so we should bind socket to interface4 at the same time.
innerErr = bindSocketToInterface4(windows.Handle(fd), uint32(opts.InterfaceIndex))
case "ip4", "tcp4", "udp4":
err = bindSocketToInterface4(windows.Handle(fd), index)
case "ip6", "tcp6", "udp6":
err = bindSocketToInterface6(windows.Handle(fd), index)
// UDPv6 may still use an IPv4 underlying socket if the destination
// address is unspecified (e.g. ":0").
if network == "udp6" {
host, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(host); ip == nil || ip.IsUnspecified() {
_ = bindSocketToInterface4(windows.Handle(fd), index)
}
}
}
}
return
})
})
if innerErr != nil {
err = innerErr
}
return err
}
func bindSocketToInterface4(handle windows.Handle, index uint32) error {
// For IPv4, this parameter must be an interface index in network byte order.
// Ref: https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options
var bytes [4]byte
binary.BigEndian.PutUint32(bytes[:], index)
index = *(*uint32)(unsafe.Pointer(&bytes[0]))
index = bits.ReverseBytes32(index)
return windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(index))
}
func bindSocketToInterface6(handle windows.Handle, index uint32) error {
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(index))
}
func WithRoutingMark(_ int) SocketOption { return UnsupportedSocketOption }
+1 -1
View File
@@ -10,5 +10,5 @@ func init() {
// We must use this DialContext to query DNS
// when using net default resolver.
net.DefaultResolver.PreferGo = true
net.DefaultResolver.Dial = dialer.DefaultDialer.DialContext
net.DefaultResolver.Dial = dialer.DialContext
}
+5 -3
View File
@@ -122,18 +122,20 @@ func general(k *Key) error {
}
log.SetLogger(log.Must(log.NewLeveled(level)))
// Reset default dialer before registering options.
dialer.Reset()
if k.Interface != "" {
iface, err := net.InterfaceByName(k.Interface)
if err != nil {
return err
}
dialer.DefaultDialer.InterfaceName.Store(iface.Name)
dialer.DefaultDialer.InterfaceIndex.Store(int32(iface.Index))
dialer.RegisterSockOpt(dialer.WithBindToInterface(iface))
log.Infof("[DIALER] bind to interface: %s", k.Interface)
}
if k.Mark != 0 {
dialer.DefaultDialer.RoutingMark.Store(int32(k.Mark))
dialer.RegisterSockOpt(dialer.WithRoutingMark(k.Mark))
log.Infof("[DIALER] set fwmark: %#x", k.Mark)
}