diff --git a/go.mod b/go.mod index ff61d7c..86d0da1 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,10 @@ module github.com/sagernet/sing-tun go 1.24.7 require ( + github.com/florianl/go-nfqueue/v2 v2.0.2 github.com/go-ole/go-ole v1.3.0 github.com/google/btree v1.1.3 + github.com/mdlayher/netlink v1.7.2 github.com/sagernet/fswatch v0.1.1 github.com/sagernet/gvisor v0.0.0-20250811.0-sing-box-mod.1 github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a @@ -22,7 +24,6 @@ require ( github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/josharian/native v1.1.0 // indirect - github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.4.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect diff --git a/go.sum b/go.sum index e3c486b..d9664e2 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/florianl/go-nfqueue/v2 v2.0.2 h1:FL5lQTeetgpCvac1TRwSfgaXUn0YSO7WzGvWNIp3JPE= +github.com/florianl/go-nfqueue/v2 v2.0.2/go.mod h1:VA09+iPOT43OMoCKNfXHyzujQUty2xmzyCRkBOlmabc= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= diff --git a/nfqueue_linux.go b/nfqueue_linux.go new file mode 100644 index 0000000..10f253f --- /dev/null +++ b/nfqueue_linux.go @@ -0,0 +1,244 @@ +//go:build linux + +package tun + +import ( + "context" + "errors" + "sync" + "sync/atomic" + + "github.com/sagernet/sing-tun/internal/gtcpip/header" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/florianl/go-nfqueue/v2" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +const nfqueueMaxPacketLen = 512 + +type nfqueueHandler struct { + ctx context.Context + cancel context.CancelFunc + handler Handler + logger logger.Logger + nfq *nfqueue.Nfqueue + queue uint16 + outputMark uint32 + resetMark uint32 + wg sync.WaitGroup + closed atomic.Bool +} + +type nfqueueOptions struct { + Context context.Context + Handler Handler + Logger logger.Logger + Queue uint16 + OutputMark uint32 + ResetMark uint32 +} + +func newNFQueueHandler(options nfqueueOptions) (*nfqueueHandler, error) { + ctx, cancel := context.WithCancel(options.Context) + return &nfqueueHandler{ + ctx: ctx, + cancel: cancel, + handler: options.Handler, + logger: options.Logger, + queue: options.Queue, + outputMark: options.OutputMark, + resetMark: options.ResetMark, + }, nil +} + +func (h *nfqueueHandler) setVerdict(packetID uint32, verdict int, mark uint32) { + var err error + if mark != 0 { + err = h.nfq.SetVerdictWithOption(packetID, verdict, nfqueue.WithMark(mark)) + } else { + err = h.nfq.SetVerdict(packetID, verdict) + } + if err != nil && !h.closed.Load() && h.ctx.Err() == nil { + h.logger.Trace(E.Cause(err, "set verdict")) + } +} + +func (h *nfqueueHandler) Start() error { + config := nfqueue.Config{ + NfQueue: h.queue, + MaxPacketLen: nfqueueMaxPacketLen, + MaxQueueLen: 4096, + Copymode: nfqueue.NfQnlCopyPacket, + AfFamily: unix.AF_UNSPEC, + Flags: nfqueue.NfQaCfgFlagFailOpen, + } + + nfq, err := nfqueue.Open(&config) + if err != nil { + return E.Cause(err, "open nfqueue") + } + h.nfq = nfq + + if err = nfq.SetOption(netlink.NoENOBUFS, true); err != nil { + h.nfq.Close() + return E.Cause(err, "set nfqueue option") + } + + h.wg.Add(1) + go func() { + defer h.wg.Done() + err := nfq.RegisterWithErrorFunc(h.ctx, h.handlePacket, func(e error) int { + if h.ctx.Err() != nil { + return 1 + } + h.logger.Error("nfqueue error: ", e) + return 0 + }) + if err != nil && h.ctx.Err() == nil { + h.logger.Error("nfqueue register error: ", err) + } + }() + + return nil +} + +func parseIPv6TransportHeader(payload []byte) (transportProto uint8, transportOffset int, ok bool) { + if len(payload) < header.IPv6MinimumSize { + return 0, 0, false + } + + ipv6 := header.IPv6(payload) + nextHeader := ipv6.NextHeader() + offset := header.IPv6MinimumSize + + for { + switch nextHeader { + case unix.IPPROTO_HOPOPTS, + unix.IPPROTO_ROUTING, + unix.IPPROTO_DSTOPTS: + if len(payload) < offset+2 { + return 0, 0, false + } + nextHeader = payload[offset] + extLen := int(payload[offset+1]+1) * 8 + if len(payload) < offset+extLen { + return 0, 0, false + } + offset += extLen + + case unix.IPPROTO_FRAGMENT: + if len(payload) < offset+8 { + return 0, 0, false + } + nextHeader = payload[offset] + offset += 8 + + case unix.IPPROTO_AH: + if len(payload) < offset+2 { + return 0, 0, false + } + nextHeader = payload[offset] + extLen := int(payload[offset+1]+2) * 4 + if len(payload) < offset+extLen { + return 0, 0, false + } + offset += extLen + + case unix.IPPROTO_NONE: + return 0, 0, false + + default: + return nextHeader, offset, true + } + } +} + +func (h *nfqueueHandler) handlePacket(attr nfqueue.Attribute) int { + if h.closed.Load() { + return 0 + } + if attr.PacketID == nil || attr.Payload == nil { + return 0 + } + + packetID := *attr.PacketID + payload := *attr.Payload + + if len(payload) < header.IPv4MinimumSize { + h.setVerdict(packetID, nfqueue.NfAccept, 0) + return 0 + } + + var srcAddr, dstAddr M.Socksaddr + var tcpOffset int + + version := payload[0] >> 4 + if version == 4 { + ipv4 := header.IPv4(payload) + if !ipv4.IsValid(len(payload)) || ipv4.Protocol() != uint8(unix.IPPROTO_TCP) { + h.setVerdict(packetID, nfqueue.NfAccept, 0) + return 0 + } + srcAddr = M.SocksaddrFrom(ipv4.SourceAddr(), 0) + dstAddr = M.SocksaddrFrom(ipv4.DestinationAddr(), 0) + tcpOffset = int(ipv4.HeaderLength()) + } else if version == 6 { + transportProto, transportOffset, ok := parseIPv6TransportHeader(payload) + if !ok || transportProto != unix.IPPROTO_TCP { + h.setVerdict(packetID, nfqueue.NfAccept, 0) + return 0 + } + ipv6 := header.IPv6(payload) + srcAddr = M.SocksaddrFrom(ipv6.SourceAddr(), 0) + dstAddr = M.SocksaddrFrom(ipv6.DestinationAddr(), 0) + tcpOffset = transportOffset + } else { + h.setVerdict(packetID, nfqueue.NfAccept, 0) + return 0 + } + + if len(payload) < tcpOffset+header.TCPMinimumSize { + h.setVerdict(packetID, nfqueue.NfAccept, 0) + return 0 + } + + tcp := header.TCP(payload[tcpOffset:]) + srcAddr = M.SocksaddrFrom(srcAddr.Addr, tcp.SourcePort()) + dstAddr = M.SocksaddrFrom(dstAddr.Addr, tcp.DestinationPort()) + + flags := tcp.Flags() + if !flags.Contains(header.TCPFlagSyn) || flags.Contains(header.TCPFlagAck) { + h.setVerdict(packetID, nfqueue.NfAccept, 0) + return 0 + } + + _, pErr := h.handler.PrepareConnection(N.NetworkTCP, srcAddr, dstAddr, nil, 0) + + switch { + case errors.Is(pErr, ErrBypass): + h.setVerdict(packetID, nfqueue.NfAccept, h.outputMark) + case errors.Is(pErr, ErrReset): + h.setVerdict(packetID, nfqueue.NfAccept, h.resetMark) + case errors.Is(pErr, ErrDrop): + h.setVerdict(packetID, nfqueue.NfDrop, 0) + default: + h.setVerdict(packetID, nfqueue.NfAccept, 0) + } + + return 0 +} + +func (h *nfqueueHandler) Close() error { + h.closed.Store(true) + h.cancel() + if h.nfq != nil { + h.nfq.Close() + } + h.wg.Wait() + return nil +} diff --git a/redirect.go b/redirect.go index 0569eb3..dcf3e72 100644 --- a/redirect.go +++ b/redirect.go @@ -5,7 +5,6 @@ import ( "github.com/sagernet/sing/common/control" "github.com/sagernet/sing/common/logger" - N "github.com/sagernet/sing/common/network" "go4.org/netipx" ) @@ -13,6 +12,8 @@ import ( const ( DefaultAutoRedirectInputMark = 0x2023 DefaultAutoRedirectOutputMark = 0x2024 + DefaultAutoRedirectResetMark = 0x2025 + DefaultAutoRedirectNFQueue = 100 ) type AutoRedirect interface { @@ -24,7 +25,7 @@ type AutoRedirect interface { type AutoRedirectOptions struct { TunOptions *Options Context context.Context - Handler N.TCPConnectionHandlerEx + Handler Handler Logger logger.Logger NetworkMonitor NetworkUpdateMonitor InterfaceFinder control.InterfaceFinder diff --git a/redirect_linux.go b/redirect_linux.go index 8c97d07..2a638ee 100644 --- a/redirect_linux.go +++ b/redirect_linux.go @@ -13,7 +13,6 @@ import ( E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" "go4.org/netipx" @@ -22,7 +21,7 @@ import ( type autoRedirect struct { tunOptions *Options ctx context.Context - handler N.TCPConnectionHandlerEx + handler Handler logger logger.Logger tableName string networkMonitor NetworkUpdateMonitor @@ -41,6 +40,8 @@ type autoRedirect struct { suPath string routeAddressSet *[]*netipx.IPSet routeExcludeAddressSet *[]*netipx.IPSet + nfqueueHandler *nfqueueHandler + nfqueueEnabled bool } func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) { @@ -125,13 +126,30 @@ func (r *autoRedirect) Start() error { listenAddr = netip.IPv4Unspecified() } server := newRedirectServer(r.ctx, r.handler, r.logger, listenAddr) - err := server.Start() + err = server.Start() if err != nil { return E.Cause(err, "start redirect server") } r.redirectServer = server } if r.useNFTables { + var handler *nfqueueHandler + handler, err = newNFQueueHandler(nfqueueOptions{ + Context: r.ctx, + Handler: r.handler, + Logger: r.logger, + Queue: r.effectiveNFQueue(), + OutputMark: r.effectiveOutputMark(), + ResetMark: r.effectiveResetMark(), + }) + if err != nil { + r.logger.Warn("nfqueue not available, pre-match disabled: ", err) + } else if err = handler.Start(); err != nil { + r.logger.Warn("nfqueue start failed, pre-match disabled: ", err) + } else { + r.nfqueueHandler = handler + r.nfqueueEnabled = true + } r.cleanupNFTables() err = r.setupNFTables() } else { @@ -142,6 +160,9 @@ func (r *autoRedirect) Start() error { } func (r *autoRedirect) Close() error { + if r.nfqueueHandler != nil { + r.nfqueueHandler.Close() + } if r.useNFTables { r.cleanupNFTables() } else { @@ -181,3 +202,28 @@ func (r *autoRedirect) redirectPort() uint16 { } return M.AddrPortFromNet(r.redirectServer.listener.Addr()).Port() } + +func (r *autoRedirect) effectiveOutputMark() uint32 { + if r.tunOptions.AutoRedirectOutputMark != 0 { + return r.tunOptions.AutoRedirectOutputMark + } + return DefaultAutoRedirectOutputMark +} + +func (r *autoRedirect) effectiveResetMark() uint32 { + if r.tunOptions.AutoRedirectResetMark != 0 { + return r.tunOptions.AutoRedirectResetMark + } + return DefaultAutoRedirectResetMark +} + +func (r *autoRedirect) effectiveNFQueue() uint16 { + if r.tunOptions.AutoRedirectNFQueue != 0 { + return r.tunOptions.AutoRedirectNFQueue + } + return DefaultAutoRedirectNFQueue +} + +func (r *autoRedirect) shouldSkipOutputChain() bool { + return len(r.tunOptions.IncludeInterface) > 0 && !common.Contains(r.tunOptions.IncludeInterface, "lo") || common.Contains(r.tunOptions.ExcludeInterface, "lo") +} diff --git a/redirect_nftables.go b/redirect_nftables.go index 0123953..4dea1c1 100644 --- a/redirect_nftables.go +++ b/redirect_nftables.go @@ -51,13 +51,23 @@ func (r *autoRedirect) setupNFTables() error { return err } - skipOutput := len(r.tunOptions.IncludeInterface) > 0 && !common.Contains(r.tunOptions.IncludeInterface, "lo") || common.Contains(r.tunOptions.ExcludeInterface, "lo") - if !skipOutput { + if r.nfqueueEnabled { + err = r.nftablesCreatePreMatchChains(nft, table) + if err != nil { + return err + } + } + + if !r.shouldSkipOutputChain() { + outputNATPriority := nftables.ChainPriorityMangle + if r.nfqueueEnabled { + outputNATPriority = nftables.ChainPriorityRef(*nftables.ChainPriorityMangle + 1) + } chainOutput := nft.AddChain(&nftables.Chain{ Name: "output", Table: table, Hooknum: nftables.ChainHookOutput, - Priority: nftables.ChainPriorityMangle, + Priority: outputNATPriority, Type: nftables.ChainTypeNAT, }) if r.tunOptions.AutoRedirectMarkMode { @@ -267,7 +277,7 @@ func (r *autoRedirect) setupNFTables() error { return nil } -// TODO; test is this works +// TODO: test if this works func (r *autoRedirect) nftablesUpdateLocalAddressSet() error { newLocalAddresses := common.FlatMap(r.interfaceFinder.Interfaces(), func(it control.Interface) []netip.Prefix { return common.Filter(it.Addresses, func(prefix netip.Prefix) bool { @@ -327,3 +337,122 @@ func (r *autoRedirect) cleanupNFTables() { _ = nft.Flush() _ = nft.CloseLasting() } + +func (r *autoRedirect) nftablesCreatePreMatchChains(nft *nftables.Conn, table *nftables.Table) error { + chainPreroutingPreMatch := nft.AddChain(&nftables.Chain{ + Name: "prerouting_prematch", + Table: table, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest - 1), + Type: nftables.ChainTypeFilter, + }) + r.nftablesAddPreMatchRules(nft, table, chainPreroutingPreMatch, true) + + if !r.shouldSkipOutputChain() { + chainOutputPreMatch := nft.AddChain(&nftables.Chain{ + Name: "output_prematch", + Table: table, + Hooknum: nftables.ChainHookOutput, + Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityMangle - 1), + Type: nftables.ChainTypeFilter, + }) + r.nftablesAddPreMatchRules(nft, table, chainOutputPreMatch, false) + } + + return nil +} + +func (r *autoRedirect) nftablesAddPreMatchRules(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain, isPrerouting bool) { + ifnameKey := expr.MetaKeyOIFNAME + if isPrerouting { + ifnameKey = expr.MetaKeyIIFNAME + } + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: ifnameKey, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: nftablesIfname(r.tunOptions.Name)}, + &expr.Verdict{Kind: expr.VerdictReturn}, + }, + }) + + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{Op: expr.CmpOpNeq, Register: 1, Data: []byte{unix.IPPROTO_TCP}}, + &expr.Verdict{Kind: expr.VerdictReturn}, + }, + }) + + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark())}, + &expr.Verdict{Kind: expr.VerdictReturn}, + }, + }) + + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Ct{Key: expr.CtKeyMARK, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark())}, + &expr.Verdict{Kind: expr.VerdictReturn}, + }, + }) + + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: 13, + Len: 1, + }, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 1, + Mask: []byte{0x12}, + Xor: []byte{0x00}, + }, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{0x02}}, + &expr.Counter{}, + &expr.Queue{ + Num: r.effectiveNFQueue(), + Flag: expr.QueueFlagBypass, + }, + }, + }) + + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(r.effectiveResetMark())}, + &expr.Counter{}, + &expr.Reject{Type: unix.NFT_REJECT_TCP_RST}, + }, + }) + + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark())}, + &expr.Ct{Key: expr.CtKeyMARK, Register: 1, SourceRegister: true}, + &expr.Counter{}, + }, + }) +} diff --git a/redirect_nftables_rules.go b/redirect_nftables_rules.go index 2d71fff..ea03d40 100644 --- a/redirect_nftables_rules.go +++ b/redirect_nftables_rules.go @@ -213,6 +213,48 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft }) } } + if r.nfqueueEnabled && chain.Hooknum == nftables.ChainHookPrerouting && chain.Type == nftables.ChainTypeNAT { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Ct{ + Key: expr.CtKeyMARK, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark()), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + } + if r.nfqueueEnabled && chain.Hooknum == nftables.ChainHookOutput && chain.Type == nftables.ChainTypeNAT { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Ct{ + Key: expr.CtKeyMARK, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark()), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + } if chain.Hooknum == nftables.ChainHookPrerouting { nft.AddRule(&nftables.Rule{ Table: table, diff --git a/stack.go b/stack.go index 7014c6d..7c34c79 100644 --- a/stack.go +++ b/stack.go @@ -13,8 +13,9 @@ import ( ) var ( - ErrDrop = E.New("drop by rule") - ErrReset = E.New("reset by rule") + ErrDrop = E.New("drop by rule") + ErrReset = E.New("reset by rule") + ErrBypass = E.New("bypass by rule") ) type Stack interface { diff --git a/tun.go b/tun.go index c32cd8a..f4b60c1 100644 --- a/tun.go +++ b/tun.go @@ -83,6 +83,8 @@ type Options struct { AutoRedirectMarkMode bool AutoRedirectInputMark uint32 AutoRedirectOutputMark uint32 + AutoRedirectResetMark uint32 + AutoRedirectNFQueue uint16 ExcludeMPTCP bool Inet4LoopbackAddress []netip.Addr Inet6LoopbackAddress []netip.Addr