mirror of
https://github.com/xjasonlyu/tun2socks.git
synced 2026-04-22 15:57:28 +08:00
Refactor(dialer): with socket options (#525)
This commit is contained in:
+67
-51
@@ -3,81 +3,97 @@ package dialer
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// DefaultDialer is the default Dialer and is used by DialContext and ListenPacket.
|
||||
var DefaultDialer = &Dialer{
|
||||
InterfaceName: atomic.NewString(""),
|
||||
InterfaceIndex: atomic.NewInt32(0),
|
||||
RoutingMark: atomic.NewInt32(0),
|
||||
// DefaultDialer is the package-level default Dialer.
|
||||
// It is used by DialContext and ListenPacket.
|
||||
var DefaultDialer = &Dialer{}
|
||||
|
||||
// RegisterSockOpt registers a socket option on the DefaultDialer.
|
||||
func RegisterSockOpt(opt SocketOption) {
|
||||
DefaultDialer.RegisterSockOpt(opt)
|
||||
}
|
||||
|
||||
type Dialer struct {
|
||||
InterfaceName *atomic.String
|
||||
InterfaceIndex *atomic.Int32
|
||||
RoutingMark *atomic.Int32
|
||||
// Reset removes all registered socket options from the DefaultDialer.
|
||||
func Reset() {
|
||||
DefaultDialer.Reset()
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
// InterfaceName is the name of interface/device to bind.
|
||||
// If a socket is bound to an interface, only packets received
|
||||
// from that particular interface are processed by the socket.
|
||||
InterfaceName string
|
||||
|
||||
// InterfaceIndex is the index of interface/device to bind.
|
||||
// It is almost the same as InterfaceName except it uses the
|
||||
// index of the interface instead of the name.
|
||||
InterfaceIndex int
|
||||
|
||||
// RoutingMark is the mark for each packet sent through this
|
||||
// socket. Changing the mark can be used for mark-based routing
|
||||
// without netfilter or for packet filtering.
|
||||
RoutingMark int
|
||||
}
|
||||
|
||||
// DialContext is a wrapper around DefaultDialer.DialContext.
|
||||
// DialContext dials using the DefaultDialer.
|
||||
func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return DefaultDialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// ListenPacket is a wrapper around DefaultDialer.ListenPacket.
|
||||
// ListenPacket listens using the DefaultDialer.
|
||||
func ListenPacket(network, address string) (net.PacketConn, error) {
|
||||
return DefaultDialer.ListenPacket(network, address)
|
||||
}
|
||||
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return d.DialContextWithOptions(ctx, network, address, &Options{
|
||||
InterfaceName: d.InterfaceName.Load(),
|
||||
InterfaceIndex: int(d.InterfaceIndex.Load()),
|
||||
RoutingMark: int(d.RoutingMark.Load()),
|
||||
})
|
||||
// Dialer applies registered SocketOptions to all dials/listens.
|
||||
type Dialer struct {
|
||||
optsMu sync.Mutex
|
||||
atomicOpts atomic.Value
|
||||
}
|
||||
|
||||
func (*Dialer) DialContextWithOptions(ctx context.Context, network, address string, opts *Options) (net.Conn, error) {
|
||||
d := &net.Dialer{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
return setSocketOptions(network, address, c, opts)
|
||||
},
|
||||
// New creates a new Dialer with the given initial socket options.
|
||||
func New(opts ...SocketOption) *Dialer {
|
||||
d := &Dialer{}
|
||||
for _, opt := range opts {
|
||||
d.RegisterSockOpt(opt)
|
||||
}
|
||||
return d.DialContext(ctx, network, address)
|
||||
return d
|
||||
}
|
||||
|
||||
// RegisterSockOpt registers a socket option on the Dialer.
|
||||
func (d *Dialer) RegisterSockOpt(opt SocketOption) {
|
||||
d.optsMu.Lock()
|
||||
opts, _ := d.atomicOpts.Load().([]SocketOption)
|
||||
d.atomicOpts.Store(append(opts, opt))
|
||||
d.optsMu.Unlock()
|
||||
}
|
||||
|
||||
// Reset removes all registered socket options from the Dialer.
|
||||
func (d *Dialer) Reset() {
|
||||
d.optsMu.Lock()
|
||||
d.atomicOpts.Store([]SocketOption(nil))
|
||||
d.optsMu.Unlock()
|
||||
}
|
||||
|
||||
func (d *Dialer) applySockOpts(network string, address string, c syscall.RawConn) error {
|
||||
opts, _ := d.atomicOpts.Load().([]SocketOption)
|
||||
if len(opts) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Skip non-global-unicast IPs (e.g. loopback, link-local).
|
||||
if host, _, err := net.SplitHostPort(address); err == nil {
|
||||
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if err := opt.Apply(network, address, c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DialContext behaves like net.Dialer.DialContext, applying registered SocketOptions.
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
nd := &net.Dialer{
|
||||
Control: d.applySockOpts,
|
||||
}
|
||||
return nd.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// ListenPacket behaves like net.ListenConfig.ListenPacket, applying registered SocketOptions.
|
||||
func (d *Dialer) ListenPacket(network, address string) (net.PacketConn, error) {
|
||||
return d.ListenPacketWithOptions(network, address, &Options{
|
||||
InterfaceName: d.InterfaceName.Load(),
|
||||
InterfaceIndex: int(d.InterfaceIndex.Load()),
|
||||
RoutingMark: int(d.RoutingMark.Load()),
|
||||
})
|
||||
}
|
||||
|
||||
func (*Dialer) ListenPacketWithOptions(network, address string, opts *Options) (net.PacketConn, error) {
|
||||
lc := &net.ListenConfig{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
return setSocketOptions(network, address, c, opts)
|
||||
},
|
||||
Control: d.applySockOpts,
|
||||
}
|
||||
return lc.ListenPacket(context.Background(), network, address)
|
||||
}
|
||||
|
||||
+34
-23
@@ -1,30 +1,41 @@
|
||||
package dialer
|
||||
|
||||
func isTCPSocket(network string) bool {
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
import (
|
||||
"errors"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
var _ SocketOption = SocketOptionFunc(nil)
|
||||
|
||||
// SocketOption applies a socket-level configuration to a network connection
|
||||
// during dialing or listening, via syscall.RawConn.
|
||||
type SocketOption interface {
|
||||
Apply(network, address string, c syscall.RawConn) error
|
||||
}
|
||||
|
||||
func isUDPSocket(network string) bool {
|
||||
switch network {
|
||||
case "udp", "udp4", "udp6":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
// SocketOptionFunc adapts a function to a SocketOption.
|
||||
type SocketOptionFunc func(network, address string, c syscall.RawConn) error
|
||||
|
||||
func (f SocketOptionFunc) Apply(network, address string, c syscall.RawConn) error {
|
||||
return f(network, address, c)
|
||||
}
|
||||
|
||||
func isICMPSocket(network string) bool {
|
||||
switch network {
|
||||
case "ip:icmp", "ip4:icmp", "ip6:ipv6-icmp":
|
||||
return true
|
||||
case "ip4", "ip6":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
// UnsupportedSocketOption is a sentinel SocketOption that always reports
|
||||
// ErrUnsupported when applied.
|
||||
var UnsupportedSocketOption = SocketOptionFunc(unsupportedSocketOpt)
|
||||
|
||||
func unsupportedSocketOpt(_, _ string, _ syscall.RawConn) error {
|
||||
return errors.ErrUnsupported
|
||||
}
|
||||
|
||||
// rawConnControl runs f with the file descriptor obtained via RawConn.Control
|
||||
// and correctly propagates errors returned from f.
|
||||
func rawConnControl(c syscall.RawConn, f func(uintptr) error) error {
|
||||
var innerErr error
|
||||
if err := c.Control(func(fd uintptr) {
|
||||
innerErr = f(fd)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return innerErr
|
||||
}
|
||||
|
||||
+12
-32
@@ -7,39 +7,19 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
|
||||
if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) {
|
||||
return err
|
||||
}
|
||||
|
||||
var innerErr error
|
||||
err = c.Control(func(fd uintptr) {
|
||||
host, _, _ := net.SplitHostPort(address)
|
||||
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
|
||||
return
|
||||
}
|
||||
|
||||
if opts.InterfaceIndex == 0 && opts.InterfaceName != "" {
|
||||
if iface, err := net.InterfaceByName(opts.InterfaceName); err == nil {
|
||||
opts.InterfaceIndex = iface.Index
|
||||
}
|
||||
}
|
||||
|
||||
if opts.InterfaceIndex != 0 {
|
||||
func WithBindToInterface(iface *net.Interface) SocketOption {
|
||||
index := iface.Index
|
||||
return SocketOptionFunc(func(network, _ string, c syscall.RawConn) error {
|
||||
return rawConnControl(c, func(fd uintptr) error {
|
||||
switch network {
|
||||
case "tcp4", "udp4":
|
||||
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, opts.InterfaceIndex)
|
||||
case "tcp6", "udp6":
|
||||
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, opts.InterfaceIndex)
|
||||
case "ip4", "tcp4", "udp4":
|
||||
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
|
||||
case "ip6", "tcp6", "udp6":
|
||||
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, index)
|
||||
}
|
||||
if innerErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func WithRoutingMark(_ int) SocketOption { return UnsupportedSocketOption }
|
||||
|
||||
@@ -7,27 +7,12 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
|
||||
if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) {
|
||||
return err
|
||||
}
|
||||
func WithBindToInterface(_ *net.Interface) SocketOption { return UnsupportedSocketOption }
|
||||
|
||||
var innerErr error
|
||||
err = c.Control(func(fd uintptr) {
|
||||
host, _, _ := net.SplitHostPort(address)
|
||||
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
|
||||
return
|
||||
}
|
||||
|
||||
if opts.RoutingMark != 0 {
|
||||
if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_USER_COOKIE, opts.RoutingMark); innerErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
func WithRoutingMark(mark int) SocketOption {
|
||||
return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error {
|
||||
return rawConnControl(c, func(fd uintptr) error {
|
||||
return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_USER_COOKIE, mark)
|
||||
})
|
||||
})
|
||||
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
+14
-33
@@ -7,38 +7,19 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
|
||||
if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) {
|
||||
return err
|
||||
}
|
||||
|
||||
var innerErr error
|
||||
err = c.Control(func(fd uintptr) {
|
||||
host, _, _ := net.SplitHostPort(address)
|
||||
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
|
||||
return
|
||||
}
|
||||
|
||||
if opts.InterfaceName == "" && opts.InterfaceIndex != 0 {
|
||||
if iface, err := net.InterfaceByIndex(opts.InterfaceIndex); err == nil {
|
||||
opts.InterfaceName = iface.Name
|
||||
}
|
||||
}
|
||||
|
||||
if opts.InterfaceName != "" {
|
||||
if innerErr = unix.BindToDevice(int(fd), opts.InterfaceName); innerErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if opts.RoutingMark != 0 {
|
||||
if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, opts.RoutingMark); innerErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
func WithBindToInterface(iface *net.Interface) SocketOption {
|
||||
device := iface.Name
|
||||
return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error {
|
||||
return rawConnControl(c, func(fd uintptr) error {
|
||||
return unix.BindToDevice(int(fd), device)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func WithRoutingMark(mark int) SocketOption {
|
||||
return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error {
|
||||
return rawConnControl(c, func(fd uintptr) error {
|
||||
return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, mark)
|
||||
})
|
||||
})
|
||||
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -7,27 +7,12 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
|
||||
if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) {
|
||||
return err
|
||||
}
|
||||
func WithBindToInterface(_ *net.Interface) SocketOption { return UnsupportedSocketOption }
|
||||
|
||||
var innerErr error
|
||||
err = c.Control(func(fd uintptr) {
|
||||
host, _, _ := net.SplitHostPort(address)
|
||||
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
|
||||
return
|
||||
}
|
||||
|
||||
if opts.RoutingMark != 0 {
|
||||
if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RTABLE, opts.RoutingMark); innerErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
func WithRoutingMark(mark int) SocketOption {
|
||||
return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error {
|
||||
return rawConnControl(c, func(fd uintptr) error {
|
||||
return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RTABLE, mark)
|
||||
})
|
||||
})
|
||||
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
package dialer
|
||||
|
||||
import "syscall"
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) error {
|
||||
return nil
|
||||
}
|
||||
func WithBindToInterface(_ *net.Interface) SocketOption { return UnsupportedSocketOption }
|
||||
|
||||
func WithRoutingMark(_ int) SocketOption { return UnsupportedSocketOption }
|
||||
|
||||
+23
-41
@@ -1,68 +1,50 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/bits"
|
||||
"net"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
IP_UNICAST_IF = 31
|
||||
IPV6_UNICAST_IF = 31
|
||||
IP_UNICAST_IF = 0x1f
|
||||
IPV6_UNICAST_IF = 0x1f
|
||||
)
|
||||
|
||||
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
|
||||
if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) {
|
||||
return err
|
||||
}
|
||||
|
||||
var innerErr error
|
||||
err = c.Control(func(fd uintptr) {
|
||||
host, _, _ := net.SplitHostPort(address)
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil && !ip.IsGlobalUnicast() {
|
||||
return
|
||||
}
|
||||
|
||||
if opts.InterfaceIndex == 0 && opts.InterfaceName != "" {
|
||||
if iface, err := net.InterfaceByName(opts.InterfaceName); err == nil {
|
||||
opts.InterfaceIndex = iface.Index
|
||||
}
|
||||
}
|
||||
|
||||
if opts.InterfaceIndex != 0 {
|
||||
func WithBindToInterface(iface *net.Interface) SocketOption {
|
||||
index := uint32(iface.Index)
|
||||
return SocketOptionFunc(func(network, address string, c syscall.RawConn) error {
|
||||
return rawConnControl(c, func(fd uintptr) (err error) {
|
||||
switch network {
|
||||
case "tcp4", "udp4":
|
||||
innerErr = bindSocketToInterface4(windows.Handle(fd), uint32(opts.InterfaceIndex))
|
||||
case "tcp6", "udp6":
|
||||
innerErr = bindSocketToInterface6(windows.Handle(fd), uint32(opts.InterfaceIndex))
|
||||
if network == "udp6" && ip == nil {
|
||||
// The underlying IP net maybe IPv4 even if the `network` param is `udp6`,
|
||||
// so we should bind socket to interface4 at the same time.
|
||||
innerErr = bindSocketToInterface4(windows.Handle(fd), uint32(opts.InterfaceIndex))
|
||||
case "ip4", "tcp4", "udp4":
|
||||
err = bindSocketToInterface4(windows.Handle(fd), index)
|
||||
case "ip6", "tcp6", "udp6":
|
||||
err = bindSocketToInterface6(windows.Handle(fd), index)
|
||||
// UDPv6 may still use an IPv4 underlying socket if the destination
|
||||
// address is unspecified (e.g. ":0").
|
||||
if network == "udp6" {
|
||||
host, _, _ := net.SplitHostPort(address)
|
||||
if ip := net.ParseIP(host); ip == nil || ip.IsUnspecified() {
|
||||
_ = bindSocketToInterface4(windows.Handle(fd), index)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
})
|
||||
})
|
||||
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func bindSocketToInterface4(handle windows.Handle, index uint32) error {
|
||||
// For IPv4, this parameter must be an interface index in network byte order.
|
||||
// Ref: https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options
|
||||
var bytes [4]byte
|
||||
binary.BigEndian.PutUint32(bytes[:], index)
|
||||
index = *(*uint32)(unsafe.Pointer(&bytes[0]))
|
||||
index = bits.ReverseBytes32(index)
|
||||
return windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(index))
|
||||
}
|
||||
|
||||
func bindSocketToInterface6(handle windows.Handle, index uint32) error {
|
||||
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(index))
|
||||
}
|
||||
|
||||
func WithRoutingMark(_ int) SocketOption { return UnsupportedSocketOption }
|
||||
|
||||
+1
-1
@@ -10,5 +10,5 @@ func init() {
|
||||
// We must use this DialContext to query DNS
|
||||
// when using net default resolver.
|
||||
net.DefaultResolver.PreferGo = true
|
||||
net.DefaultResolver.Dial = dialer.DefaultDialer.DialContext
|
||||
net.DefaultResolver.Dial = dialer.DialContext
|
||||
}
|
||||
|
||||
+5
-3
@@ -122,18 +122,20 @@ func general(k *Key) error {
|
||||
}
|
||||
log.SetLogger(log.Must(log.NewLeveled(level)))
|
||||
|
||||
// Reset default dialer before registering options.
|
||||
dialer.Reset()
|
||||
|
||||
if k.Interface != "" {
|
||||
iface, err := net.InterfaceByName(k.Interface)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dialer.DefaultDialer.InterfaceName.Store(iface.Name)
|
||||
dialer.DefaultDialer.InterfaceIndex.Store(int32(iface.Index))
|
||||
dialer.RegisterSockOpt(dialer.WithBindToInterface(iface))
|
||||
log.Infof("[DIALER] bind to interface: %s", k.Interface)
|
||||
}
|
||||
|
||||
if k.Mark != 0 {
|
||||
dialer.DefaultDialer.RoutingMark.Store(int32(k.Mark))
|
||||
dialer.RegisterSockOpt(dialer.WithRoutingMark(k.Mark))
|
||||
log.Infof("[DIALER] set fwmark: %#x", k.Mark)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user