Add pre-matching support for auto redirect

This commit is contained in:
世界
2025-12-26 05:56:29 +08:00
parent 6516c2d8f1
commit a850c4f8a1
9 changed files with 480 additions and 12 deletions
+2 -1
View File
@@ -3,8 +3,10 @@ module github.com/sagernet/sing-tun
go 1.24.7
require (
github.com/florianl/go-nfqueue/v2 v2.0.2
github.com/go-ole/go-ole v1.3.0
github.com/google/btree v1.1.3
github.com/mdlayher/netlink v1.7.2
github.com/sagernet/fswatch v0.1.1
github.com/sagernet/gvisor v0.0.0-20250811.0-sing-box-mod.1
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
@@ -22,7 +24,6 @@ require (
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.4.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
+2
View File
@@ -1,5 +1,7 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/florianl/go-nfqueue/v2 v2.0.2 h1:FL5lQTeetgpCvac1TRwSfgaXUn0YSO7WzGvWNIp3JPE=
github.com/florianl/go-nfqueue/v2 v2.0.2/go.mod h1:VA09+iPOT43OMoCKNfXHyzujQUty2xmzyCRkBOlmabc=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
+244
View File
@@ -0,0 +1,244 @@
//go:build linux
package tun
import (
"context"
"errors"
"sync"
"sync/atomic"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/florianl/go-nfqueue/v2"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
)
const nfqueueMaxPacketLen = 512
type nfqueueHandler struct {
ctx context.Context
cancel context.CancelFunc
handler Handler
logger logger.Logger
nfq *nfqueue.Nfqueue
queue uint16
outputMark uint32
resetMark uint32
wg sync.WaitGroup
closed atomic.Bool
}
type nfqueueOptions struct {
Context context.Context
Handler Handler
Logger logger.Logger
Queue uint16
OutputMark uint32
ResetMark uint32
}
func newNFQueueHandler(options nfqueueOptions) (*nfqueueHandler, error) {
ctx, cancel := context.WithCancel(options.Context)
return &nfqueueHandler{
ctx: ctx,
cancel: cancel,
handler: options.Handler,
logger: options.Logger,
queue: options.Queue,
outputMark: options.OutputMark,
resetMark: options.ResetMark,
}, nil
}
func (h *nfqueueHandler) setVerdict(packetID uint32, verdict int, mark uint32) {
var err error
if mark != 0 {
err = h.nfq.SetVerdictWithOption(packetID, verdict, nfqueue.WithMark(mark))
} else {
err = h.nfq.SetVerdict(packetID, verdict)
}
if err != nil && !h.closed.Load() && h.ctx.Err() == nil {
h.logger.Trace(E.Cause(err, "set verdict"))
}
}
func (h *nfqueueHandler) Start() error {
config := nfqueue.Config{
NfQueue: h.queue,
MaxPacketLen: nfqueueMaxPacketLen,
MaxQueueLen: 4096,
Copymode: nfqueue.NfQnlCopyPacket,
AfFamily: unix.AF_UNSPEC,
Flags: nfqueue.NfQaCfgFlagFailOpen,
}
nfq, err := nfqueue.Open(&config)
if err != nil {
return E.Cause(err, "open nfqueue")
}
h.nfq = nfq
if err = nfq.SetOption(netlink.NoENOBUFS, true); err != nil {
h.nfq.Close()
return E.Cause(err, "set nfqueue option")
}
h.wg.Add(1)
go func() {
defer h.wg.Done()
err := nfq.RegisterWithErrorFunc(h.ctx, h.handlePacket, func(e error) int {
if h.ctx.Err() != nil {
return 1
}
h.logger.Error("nfqueue error: ", e)
return 0
})
if err != nil && h.ctx.Err() == nil {
h.logger.Error("nfqueue register error: ", err)
}
}()
return nil
}
func parseIPv6TransportHeader(payload []byte) (transportProto uint8, transportOffset int, ok bool) {
if len(payload) < header.IPv6MinimumSize {
return 0, 0, false
}
ipv6 := header.IPv6(payload)
nextHeader := ipv6.NextHeader()
offset := header.IPv6MinimumSize
for {
switch nextHeader {
case unix.IPPROTO_HOPOPTS,
unix.IPPROTO_ROUTING,
unix.IPPROTO_DSTOPTS:
if len(payload) < offset+2 {
return 0, 0, false
}
nextHeader = payload[offset]
extLen := int(payload[offset+1]+1) * 8
if len(payload) < offset+extLen {
return 0, 0, false
}
offset += extLen
case unix.IPPROTO_FRAGMENT:
if len(payload) < offset+8 {
return 0, 0, false
}
nextHeader = payload[offset]
offset += 8
case unix.IPPROTO_AH:
if len(payload) < offset+2 {
return 0, 0, false
}
nextHeader = payload[offset]
extLen := int(payload[offset+1]+2) * 4
if len(payload) < offset+extLen {
return 0, 0, false
}
offset += extLen
case unix.IPPROTO_NONE:
return 0, 0, false
default:
return nextHeader, offset, true
}
}
}
func (h *nfqueueHandler) handlePacket(attr nfqueue.Attribute) int {
if h.closed.Load() {
return 0
}
if attr.PacketID == nil || attr.Payload == nil {
return 0
}
packetID := *attr.PacketID
payload := *attr.Payload
if len(payload) < header.IPv4MinimumSize {
h.setVerdict(packetID, nfqueue.NfAccept, 0)
return 0
}
var srcAddr, dstAddr M.Socksaddr
var tcpOffset int
version := payload[0] >> 4
if version == 4 {
ipv4 := header.IPv4(payload)
if !ipv4.IsValid(len(payload)) || ipv4.Protocol() != uint8(unix.IPPROTO_TCP) {
h.setVerdict(packetID, nfqueue.NfAccept, 0)
return 0
}
srcAddr = M.SocksaddrFrom(ipv4.SourceAddr(), 0)
dstAddr = M.SocksaddrFrom(ipv4.DestinationAddr(), 0)
tcpOffset = int(ipv4.HeaderLength())
} else if version == 6 {
transportProto, transportOffset, ok := parseIPv6TransportHeader(payload)
if !ok || transportProto != unix.IPPROTO_TCP {
h.setVerdict(packetID, nfqueue.NfAccept, 0)
return 0
}
ipv6 := header.IPv6(payload)
srcAddr = M.SocksaddrFrom(ipv6.SourceAddr(), 0)
dstAddr = M.SocksaddrFrom(ipv6.DestinationAddr(), 0)
tcpOffset = transportOffset
} else {
h.setVerdict(packetID, nfqueue.NfAccept, 0)
return 0
}
if len(payload) < tcpOffset+header.TCPMinimumSize {
h.setVerdict(packetID, nfqueue.NfAccept, 0)
return 0
}
tcp := header.TCP(payload[tcpOffset:])
srcAddr = M.SocksaddrFrom(srcAddr.Addr, tcp.SourcePort())
dstAddr = M.SocksaddrFrom(dstAddr.Addr, tcp.DestinationPort())
flags := tcp.Flags()
if !flags.Contains(header.TCPFlagSyn) || flags.Contains(header.TCPFlagAck) {
h.setVerdict(packetID, nfqueue.NfAccept, 0)
return 0
}
_, pErr := h.handler.PrepareConnection(N.NetworkTCP, srcAddr, dstAddr, nil, 0)
switch {
case errors.Is(pErr, ErrBypass):
h.setVerdict(packetID, nfqueue.NfAccept, h.outputMark)
case errors.Is(pErr, ErrReset):
h.setVerdict(packetID, nfqueue.NfAccept, h.resetMark)
case errors.Is(pErr, ErrDrop):
h.setVerdict(packetID, nfqueue.NfDrop, 0)
default:
h.setVerdict(packetID, nfqueue.NfAccept, 0)
}
return 0
}
func (h *nfqueueHandler) Close() error {
h.closed.Store(true)
h.cancel()
if h.nfq != nil {
h.nfq.Close()
}
h.wg.Wait()
return nil
}
+3 -2
View File
@@ -5,7 +5,6 @@ import (
"github.com/sagernet/sing/common/control"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
"go4.org/netipx"
)
@@ -13,6 +12,8 @@ import (
const (
DefaultAutoRedirectInputMark = 0x2023
DefaultAutoRedirectOutputMark = 0x2024
DefaultAutoRedirectResetMark = 0x2025
DefaultAutoRedirectNFQueue = 100
)
type AutoRedirect interface {
@@ -24,7 +25,7 @@ type AutoRedirect interface {
type AutoRedirectOptions struct {
TunOptions *Options
Context context.Context
Handler N.TCPConnectionHandlerEx
Handler Handler
Logger logger.Logger
NetworkMonitor NetworkUpdateMonitor
InterfaceFinder control.InterfaceFinder
+49 -3
View File
@@ -13,7 +13,6 @@ import (
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
"go4.org/netipx"
@@ -22,7 +21,7 @@ import (
type autoRedirect struct {
tunOptions *Options
ctx context.Context
handler N.TCPConnectionHandlerEx
handler Handler
logger logger.Logger
tableName string
networkMonitor NetworkUpdateMonitor
@@ -41,6 +40,8 @@ type autoRedirect struct {
suPath string
routeAddressSet *[]*netipx.IPSet
routeExcludeAddressSet *[]*netipx.IPSet
nfqueueHandler *nfqueueHandler
nfqueueEnabled bool
}
func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
@@ -125,13 +126,30 @@ func (r *autoRedirect) Start() error {
listenAddr = netip.IPv4Unspecified()
}
server := newRedirectServer(r.ctx, r.handler, r.logger, listenAddr)
err := server.Start()
err = server.Start()
if err != nil {
return E.Cause(err, "start redirect server")
}
r.redirectServer = server
}
if r.useNFTables {
var handler *nfqueueHandler
handler, err = newNFQueueHandler(nfqueueOptions{
Context: r.ctx,
Handler: r.handler,
Logger: r.logger,
Queue: r.effectiveNFQueue(),
OutputMark: r.effectiveOutputMark(),
ResetMark: r.effectiveResetMark(),
})
if err != nil {
r.logger.Warn("nfqueue not available, pre-match disabled: ", err)
} else if err = handler.Start(); err != nil {
r.logger.Warn("nfqueue start failed, pre-match disabled: ", err)
} else {
r.nfqueueHandler = handler
r.nfqueueEnabled = true
}
r.cleanupNFTables()
err = r.setupNFTables()
} else {
@@ -142,6 +160,9 @@ func (r *autoRedirect) Start() error {
}
func (r *autoRedirect) Close() error {
if r.nfqueueHandler != nil {
r.nfqueueHandler.Close()
}
if r.useNFTables {
r.cleanupNFTables()
} else {
@@ -181,3 +202,28 @@ func (r *autoRedirect) redirectPort() uint16 {
}
return M.AddrPortFromNet(r.redirectServer.listener.Addr()).Port()
}
func (r *autoRedirect) effectiveOutputMark() uint32 {
if r.tunOptions.AutoRedirectOutputMark != 0 {
return r.tunOptions.AutoRedirectOutputMark
}
return DefaultAutoRedirectOutputMark
}
func (r *autoRedirect) effectiveResetMark() uint32 {
if r.tunOptions.AutoRedirectResetMark != 0 {
return r.tunOptions.AutoRedirectResetMark
}
return DefaultAutoRedirectResetMark
}
func (r *autoRedirect) effectiveNFQueue() uint16 {
if r.tunOptions.AutoRedirectNFQueue != 0 {
return r.tunOptions.AutoRedirectNFQueue
}
return DefaultAutoRedirectNFQueue
}
func (r *autoRedirect) shouldSkipOutputChain() bool {
return len(r.tunOptions.IncludeInterface) > 0 && !common.Contains(r.tunOptions.IncludeInterface, "lo") || common.Contains(r.tunOptions.ExcludeInterface, "lo")
}
+133 -4
View File
@@ -51,13 +51,23 @@ func (r *autoRedirect) setupNFTables() error {
return err
}
skipOutput := len(r.tunOptions.IncludeInterface) > 0 && !common.Contains(r.tunOptions.IncludeInterface, "lo") || common.Contains(r.tunOptions.ExcludeInterface, "lo")
if !skipOutput {
if r.nfqueueEnabled {
err = r.nftablesCreatePreMatchChains(nft, table)
if err != nil {
return err
}
}
if !r.shouldSkipOutputChain() {
outputNATPriority := nftables.ChainPriorityMangle
if r.nfqueueEnabled {
outputNATPriority = nftables.ChainPriorityRef(*nftables.ChainPriorityMangle + 1)
}
chainOutput := nft.AddChain(&nftables.Chain{
Name: "output",
Table: table,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityMangle,
Priority: outputNATPriority,
Type: nftables.ChainTypeNAT,
})
if r.tunOptions.AutoRedirectMarkMode {
@@ -267,7 +277,7 @@ func (r *autoRedirect) setupNFTables() error {
return nil
}
// TODO; test is this works
// TODO: test if this works
func (r *autoRedirect) nftablesUpdateLocalAddressSet() error {
newLocalAddresses := common.FlatMap(r.interfaceFinder.Interfaces(), func(it control.Interface) []netip.Prefix {
return common.Filter(it.Addresses, func(prefix netip.Prefix) bool {
@@ -327,3 +337,122 @@ func (r *autoRedirect) cleanupNFTables() {
_ = nft.Flush()
_ = nft.CloseLasting()
}
func (r *autoRedirect) nftablesCreatePreMatchChains(nft *nftables.Conn, table *nftables.Table) error {
chainPreroutingPreMatch := nft.AddChain(&nftables.Chain{
Name: "prerouting_prematch",
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest - 1),
Type: nftables.ChainTypeFilter,
})
r.nftablesAddPreMatchRules(nft, table, chainPreroutingPreMatch, true)
if !r.shouldSkipOutputChain() {
chainOutputPreMatch := nft.AddChain(&nftables.Chain{
Name: "output_prematch",
Table: table,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityMangle - 1),
Type: nftables.ChainTypeFilter,
})
r.nftablesAddPreMatchRules(nft, table, chainOutputPreMatch, false)
}
return nil
}
func (r *autoRedirect) nftablesAddPreMatchRules(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain, isPrerouting bool) {
ifnameKey := expr.MetaKeyOIFNAME
if isPrerouting {
ifnameKey = expr.MetaKeyIIFNAME
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: ifnameKey, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: nftablesIfname(r.tunOptions.Name)},
&expr.Verdict{Kind: expr.VerdictReturn},
},
})
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpNeq, Register: 1, Data: []byte{unix.IPPROTO_TCP}},
&expr.Verdict{Kind: expr.VerdictReturn},
},
})
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark())},
&expr.Verdict{Kind: expr.VerdictReturn},
},
})
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Ct{Key: expr.CtKeyMARK, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark())},
&expr.Verdict{Kind: expr.VerdictReturn},
},
})
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Payload{
OperationType: expr.PayloadLoad,
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 13,
Len: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 1,
Mask: []byte{0x12},
Xor: []byte{0x00},
},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{0x02}},
&expr.Counter{},
&expr.Queue{
Num: r.effectiveNFQueue(),
Flag: expr.QueueFlagBypass,
},
},
})
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(r.effectiveResetMark())},
&expr.Counter{},
&expr.Reject{Type: unix.NFT_REJECT_TCP_RST},
},
})
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark())},
&expr.Ct{Key: expr.CtKeyMARK, Register: 1, SourceRegister: true},
&expr.Counter{},
},
})
}
+42
View File
@@ -213,6 +213,48 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft
})
}
}
if r.nfqueueEnabled && chain.Hooknum == nftables.ChainHookPrerouting && chain.Type == nftables.ChainTypeNAT {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark()),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
}
if r.nfqueueEnabled && chain.Hooknum == nftables.ChainHookOutput && chain.Type == nftables.ChainTypeNAT {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark()),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
}
if chain.Hooknum == nftables.ChainHookPrerouting {
nft.AddRule(&nftables.Rule{
Table: table,
+3 -2
View File
@@ -13,8 +13,9 @@ import (
)
var (
ErrDrop = E.New("drop by rule")
ErrReset = E.New("reset by rule")
ErrDrop = E.New("drop by rule")
ErrReset = E.New("reset by rule")
ErrBypass = E.New("bypass by rule")
)
type Stack interface {
+2
View File
@@ -83,6 +83,8 @@ type Options struct {
AutoRedirectMarkMode bool
AutoRedirectInputMark uint32
AutoRedirectOutputMark uint32
AutoRedirectResetMark uint32
AutoRedirectNFQueue uint16
ExcludeMPTCP bool
Inet4LoopbackAddress []netip.Addr
Inet6LoopbackAddress []netip.Addr