diff --git a/core/adapter/adapter.go b/core/adapter/adapter.go index f959c88..6c52fbe 100644 --- a/core/adapter/adapter.go +++ b/core/adapter/adapter.go @@ -11,7 +11,7 @@ import ( type TCPConn interface { net.Conn - // ID returns the transport endpoint id. + // ID returns the transport endpoint ID. ID() stack.TransportEndpointID } @@ -21,6 +21,20 @@ type UDPConn interface { net.Conn net.PacketConn - // ID returns the transport endpoint id. + // ID returns the transport endpoint ID. + ID() stack.TransportEndpointID +} + +// Packet represents a generic network packet delivered to a network +// handler. It provides access to the underlying packet buffer, the +// owning network stack, and the associated stack.TransportEndpointID. +type Packet interface { + // Buffer returns the packet buffer containing the data and headers. + Buffer() *stack.PacketBuffer + + // Stack returns the network stack responsible for handling this packet. + Stack() *stack.Stack + + // ID returns the transport endpoint ID. ID() stack.TransportEndpointID } diff --git a/core/adapter/handler.go b/core/adapter/handler.go index 22fc591..5ddadb4 100644 --- a/core/adapter/handler.go +++ b/core/adapter/handler.go @@ -6,3 +6,9 @@ type TransportHandler interface { HandleTCP(TCPConn) HandleUDP(UDPConn) } + +// NetworkHandler is a L3/network packet handler that implements +// HandlePacket method. +type NetworkHandler interface { + HandlePacket(Packet) bool +} diff --git a/core/icmp.go b/core/icmp.go index 28f7700..ddbc0bb 100644 --- a/core/icmp.go +++ b/core/icmp.go @@ -10,26 +10,38 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "github.com/xjasonlyu/tun2socks/v2/core/adapter" "github.com/xjasonlyu/tun2socks/v2/core/option" ) -func withICMPHandler() option.Option { +func withICMPHandler(h adapter.NetworkHandler) option.Option { return func(s *stack.Stack) error { - f := newICMPForwarder(s) + f := newICMPForwarder(s, func(r *icmpForwarderRequest) bool { + if h != nil { + return h.HandlePacket(r) + } + return false + }) s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.HandlePacket) return nil } } +type icmpForwarderHandler func(*icmpForwarderRequest) bool + type icmpForwarder struct { s *stack.Stack + h icmpForwarderHandler } -func newICMPForwarder(s *stack.Stack) *icmpForwarder { - return &icmpForwarder{s: s} +func newICMPForwarder(s *stack.Stack, h icmpForwarderHandler) *icmpForwarder { + return &icmpForwarder{s: s, h: h} } func (f *icmpForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + if f.h(&icmpForwarderRequest{pkt: pkt.Clone(), id: id, stack: f.s}) { + return true /* handled */ + } switch pkt.NetworkProtocolNumber { case ipv4.ProtocolNumber: return f.handlePacket4(id, pkt) @@ -95,3 +107,17 @@ func (f *icmpForwarder) handlePacket4(_ stack.TransportEndpointID, pkt *stack.Pa func (f *icmpForwarder) handlePacket6(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { return false // not implemented } + +type icmpForwarderRequest struct { + stack *stack.Stack + id stack.TransportEndpointID + pkt *stack.PacketBuffer +} + +func (r *icmpForwarderRequest) Stack() *stack.Stack { return r.stack } + +func (r *icmpForwarderRequest) ID() stack.TransportEndpointID { return r.id } + +func (r *icmpForwarderRequest) Buffer() *stack.PacketBuffer { return r.pkt } + +var _ adapter.Packet = (*icmpForwarderRequest)(nil) diff --git a/core/stack.go b/core/stack.go index e0faa03..d59b647 100644 --- a/core/stack.go +++ b/core/stack.go @@ -24,6 +24,10 @@ type Config struct { // stack to set transport handlers. TransportHandler adapter.TransportHandler + // ICMPHandler is used to customize ICMP packet handling. + // If nil, the default icmpForwarder is used. + ICMPHandler adapter.NetworkHandler + // MulticastGroups is used by internal stack to add // nic to given groups. MulticastGroups []netip.Addr @@ -61,8 +65,8 @@ func CreateStack(cfg *Config) (*stack.Stack, error) { // before creating NIC, otherwise NIC would dispatch packets // to stack and cause race condition. // Initiate transport protocol (TCP/UDP) with given handler. - withTCPHandler(cfg.TransportHandler.HandleTCP), - withUDPHandler(cfg.TransportHandler.HandleUDP), + withTCPHandler(cfg.TransportHandler), + withUDPHandler(cfg.TransportHandler), // gVisor added NetworkPacketInfo.LocalAddressTemporary to // identify packets received with temporary addresses due @@ -73,7 +77,7 @@ func CreateStack(cfg *Config) (*stack.Stack, error) { // Ref: // - https://github.com/google/gvisor/issues/8657 // - https://github.com/google/gvisor/pull/11681 - withICMPHandler(), + withICMPHandler(cfg.ICMPHandler), // Create stack NIC and then bind link endpoint to it. withCreatingNIC(nicID, cfg.LinkEndpoint), diff --git a/core/tcp.go b/core/tcp.go index b9a0ab7..92d60be 100644 --- a/core/tcp.go +++ b/core/tcp.go @@ -41,9 +41,9 @@ const ( tcpKeepaliveInterval = 30 * time.Second ) -func withTCPHandler(handle func(adapter.TCPConn)) option.Option { +func withTCPHandler(h adapter.TransportHandler) option.Option { return func(s *stack.Stack) error { - tcpForwarder := tcp.NewForwarder(s, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { + f := tcp.NewForwarder(s, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { var ( wq waiter.Queue ep tcpip.Endpoint @@ -73,9 +73,9 @@ func withTCPHandler(handle func(adapter.TCPConn)) option.Option { TCPConn: gonet.NewTCPConn(&wq, ep), id: id, } - handle(conn) + h.HandleTCP(conn) }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) + s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) return nil } } diff --git a/core/udp.go b/core/udp.go index eabca96..2e226e1 100644 --- a/core/udp.go +++ b/core/udp.go @@ -11,9 +11,9 @@ import ( "github.com/xjasonlyu/tun2socks/v2/core/option" ) -func withUDPHandler(handle func(adapter.UDPConn)) option.Option { +func withUDPHandler(h adapter.TransportHandler) option.Option { return func(s *stack.Stack) error { - udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) bool { + f := udp.NewForwarder(s, func(r *udp.ForwarderRequest) bool { var ( wq waiter.Queue id = r.ID() @@ -29,10 +29,10 @@ func withUDPHandler(handle func(adapter.UDPConn)) option.Option { UDPConn: gonet.NewUDPConn(&wq, ep), id: id, } - handle(conn) + h.HandleUDP(conn) return true }) - s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) + s.SetTransportProtocolHandler(udp.ProtocolNumber, f.HandlePacket) return nil } } diff --git a/engine/engine.go b/engine/engine.go index d14162c..57bb165 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -13,6 +13,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "github.com/xjasonlyu/tun2socks/v2/core" + "github.com/xjasonlyu/tun2socks/v2/core/adapter" "github.com/xjasonlyu/tun2socks/v2/core/device" "github.com/xjasonlyu/tun2socks/v2/core/option" "github.com/xjasonlyu/tun2socks/v2/dialer" @@ -36,6 +37,9 @@ var ( // _defaultStack holds the default stack for the engine. _defaultStack *stack.Stack + + // _icmpHandler holds the custom ICMP handler for the engine. + _icmpHandler adapter.NetworkHandler ) // Start starts the default engine up. @@ -59,6 +63,13 @@ func Insert(k *Key) { _engineMu.Unlock() } +// SetICMPHandler sets the custom ICMP handler for the default engine. +func SetICMPHandler(h adapter.NetworkHandler) { + _engineMu.Lock() + _icmpHandler = h + _engineMu.Unlock() +} + func start() error { _engineMu.Lock() defer _engineMu.Unlock() @@ -227,6 +238,7 @@ func netstack(k *Key) (err error) { if _defaultStack, err = core.CreateStack(&core.Config{ LinkEndpoint: _defaultDevice, TransportHandler: tunnel.T(), + ICMPHandler: _icmpHandler, MulticastGroups: multicastGroups, Options: opts, }); err != nil {