mirror of
https://github.com/xjasonlyu/tun2socks.git
synced 2026-04-22 15:57:28 +08:00
Feature: support custom ICMP message handling (#521)
Co-authored-by: Jason Lyu <xjasonlyu@gmail.com>
This commit is contained in:
+16
-2
@@ -11,7 +11,7 @@ import (
|
||||
type TCPConn interface {
|
||||
net.Conn
|
||||
|
||||
// ID returns the transport endpoint id.
|
||||
// ID returns the transport endpoint ID.
|
||||
ID() stack.TransportEndpointID
|
||||
}
|
||||
|
||||
@@ -21,6 +21,20 @@ type UDPConn interface {
|
||||
net.Conn
|
||||
net.PacketConn
|
||||
|
||||
// ID returns the transport endpoint id.
|
||||
// ID returns the transport endpoint ID.
|
||||
ID() stack.TransportEndpointID
|
||||
}
|
||||
|
||||
// Packet represents a generic network packet delivered to a network
|
||||
// handler. It provides access to the underlying packet buffer, the
|
||||
// owning network stack, and the associated stack.TransportEndpointID.
|
||||
type Packet interface {
|
||||
// Buffer returns the packet buffer containing the data and headers.
|
||||
Buffer() *stack.PacketBuffer
|
||||
|
||||
// Stack returns the network stack responsible for handling this packet.
|
||||
Stack() *stack.Stack
|
||||
|
||||
// ID returns the transport endpoint ID.
|
||||
ID() stack.TransportEndpointID
|
||||
}
|
||||
|
||||
@@ -6,3 +6,9 @@ type TransportHandler interface {
|
||||
HandleTCP(TCPConn)
|
||||
HandleUDP(UDPConn)
|
||||
}
|
||||
|
||||
// NetworkHandler is a L3/network packet handler that implements
|
||||
// HandlePacket method.
|
||||
type NetworkHandler interface {
|
||||
HandlePacket(Packet) bool
|
||||
}
|
||||
|
||||
+30
-4
@@ -10,26 +10,38 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||
|
||||
"github.com/xjasonlyu/tun2socks/v2/core/adapter"
|
||||
"github.com/xjasonlyu/tun2socks/v2/core/option"
|
||||
)
|
||||
|
||||
func withICMPHandler() option.Option {
|
||||
func withICMPHandler(h adapter.NetworkHandler) option.Option {
|
||||
return func(s *stack.Stack) error {
|
||||
f := newICMPForwarder(s)
|
||||
f := newICMPForwarder(s, func(r *icmpForwarderRequest) bool {
|
||||
if h != nil {
|
||||
return h.HandlePacket(r)
|
||||
}
|
||||
return false
|
||||
})
|
||||
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.HandlePacket)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type icmpForwarderHandler func(*icmpForwarderRequest) bool
|
||||
|
||||
type icmpForwarder struct {
|
||||
s *stack.Stack
|
||||
h icmpForwarderHandler
|
||||
}
|
||||
|
||||
func newICMPForwarder(s *stack.Stack) *icmpForwarder {
|
||||
return &icmpForwarder{s: s}
|
||||
func newICMPForwarder(s *stack.Stack, h icmpForwarderHandler) *icmpForwarder {
|
||||
return &icmpForwarder{s: s, h: h}
|
||||
}
|
||||
|
||||
func (f *icmpForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
if f.h(&icmpForwarderRequest{pkt: pkt.Clone(), id: id, stack: f.s}) {
|
||||
return true /* handled */
|
||||
}
|
||||
switch pkt.NetworkProtocolNumber {
|
||||
case ipv4.ProtocolNumber:
|
||||
return f.handlePacket4(id, pkt)
|
||||
@@ -95,3 +107,17 @@ func (f *icmpForwarder) handlePacket4(_ stack.TransportEndpointID, pkt *stack.Pa
|
||||
func (f *icmpForwarder) handlePacket6(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
return false // not implemented
|
||||
}
|
||||
|
||||
type icmpForwarderRequest struct {
|
||||
stack *stack.Stack
|
||||
id stack.TransportEndpointID
|
||||
pkt *stack.PacketBuffer
|
||||
}
|
||||
|
||||
func (r *icmpForwarderRequest) Stack() *stack.Stack { return r.stack }
|
||||
|
||||
func (r *icmpForwarderRequest) ID() stack.TransportEndpointID { return r.id }
|
||||
|
||||
func (r *icmpForwarderRequest) Buffer() *stack.PacketBuffer { return r.pkt }
|
||||
|
||||
var _ adapter.Packet = (*icmpForwarderRequest)(nil)
|
||||
|
||||
+7
-3
@@ -24,6 +24,10 @@ type Config struct {
|
||||
// stack to set transport handlers.
|
||||
TransportHandler adapter.TransportHandler
|
||||
|
||||
// ICMPHandler is used to customize ICMP packet handling.
|
||||
// If nil, the default icmpForwarder is used.
|
||||
ICMPHandler adapter.NetworkHandler
|
||||
|
||||
// MulticastGroups is used by internal stack to add
|
||||
// nic to given groups.
|
||||
MulticastGroups []netip.Addr
|
||||
@@ -61,8 +65,8 @@ func CreateStack(cfg *Config) (*stack.Stack, error) {
|
||||
// before creating NIC, otherwise NIC would dispatch packets
|
||||
// to stack and cause race condition.
|
||||
// Initiate transport protocol (TCP/UDP) with given handler.
|
||||
withTCPHandler(cfg.TransportHandler.HandleTCP),
|
||||
withUDPHandler(cfg.TransportHandler.HandleUDP),
|
||||
withTCPHandler(cfg.TransportHandler),
|
||||
withUDPHandler(cfg.TransportHandler),
|
||||
|
||||
// gVisor added NetworkPacketInfo.LocalAddressTemporary to
|
||||
// identify packets received with temporary addresses due
|
||||
@@ -73,7 +77,7 @@ func CreateStack(cfg *Config) (*stack.Stack, error) {
|
||||
// Ref:
|
||||
// - https://github.com/google/gvisor/issues/8657
|
||||
// - https://github.com/google/gvisor/pull/11681
|
||||
withICMPHandler(),
|
||||
withICMPHandler(cfg.ICMPHandler),
|
||||
|
||||
// Create stack NIC and then bind link endpoint to it.
|
||||
withCreatingNIC(nicID, cfg.LinkEndpoint),
|
||||
|
||||
+4
-4
@@ -41,9 +41,9 @@ const (
|
||||
tcpKeepaliveInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
func withTCPHandler(handle func(adapter.TCPConn)) option.Option {
|
||||
func withTCPHandler(h adapter.TransportHandler) option.Option {
|
||||
return func(s *stack.Stack) error {
|
||||
tcpForwarder := tcp.NewForwarder(s, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) {
|
||||
f := tcp.NewForwarder(s, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) {
|
||||
var (
|
||||
wq waiter.Queue
|
||||
ep tcpip.Endpoint
|
||||
@@ -73,9 +73,9 @@ func withTCPHandler(handle func(adapter.TCPConn)) option.Option {
|
||||
TCPConn: gonet.NewTCPConn(&wq, ep),
|
||||
id: id,
|
||||
}
|
||||
handle(conn)
|
||||
h.HandleTCP(conn)
|
||||
})
|
||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
+4
-4
@@ -11,9 +11,9 @@ import (
|
||||
"github.com/xjasonlyu/tun2socks/v2/core/option"
|
||||
)
|
||||
|
||||
func withUDPHandler(handle func(adapter.UDPConn)) option.Option {
|
||||
func withUDPHandler(h adapter.TransportHandler) option.Option {
|
||||
return func(s *stack.Stack) error {
|
||||
udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) bool {
|
||||
f := udp.NewForwarder(s, func(r *udp.ForwarderRequest) bool {
|
||||
var (
|
||||
wq waiter.Queue
|
||||
id = r.ID()
|
||||
@@ -29,10 +29,10 @@ func withUDPHandler(handle func(adapter.UDPConn)) option.Option {
|
||||
UDPConn: gonet.NewUDPConn(&wq, ep),
|
||||
id: id,
|
||||
}
|
||||
handle(conn)
|
||||
h.HandleUDP(conn)
|
||||
return true
|
||||
})
|
||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, f.HandlePacket)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
|
||||
"github.com/xjasonlyu/tun2socks/v2/core"
|
||||
"github.com/xjasonlyu/tun2socks/v2/core/adapter"
|
||||
"github.com/xjasonlyu/tun2socks/v2/core/device"
|
||||
"github.com/xjasonlyu/tun2socks/v2/core/option"
|
||||
"github.com/xjasonlyu/tun2socks/v2/dialer"
|
||||
@@ -36,6 +37,9 @@ var (
|
||||
|
||||
// _defaultStack holds the default stack for the engine.
|
||||
_defaultStack *stack.Stack
|
||||
|
||||
// _icmpHandler holds the custom ICMP handler for the engine.
|
||||
_icmpHandler adapter.NetworkHandler
|
||||
)
|
||||
|
||||
// Start starts the default engine up.
|
||||
@@ -59,6 +63,13 @@ func Insert(k *Key) {
|
||||
_engineMu.Unlock()
|
||||
}
|
||||
|
||||
// SetICMPHandler sets the custom ICMP handler for the default engine.
|
||||
func SetICMPHandler(h adapter.NetworkHandler) {
|
||||
_engineMu.Lock()
|
||||
_icmpHandler = h
|
||||
_engineMu.Unlock()
|
||||
}
|
||||
|
||||
func start() error {
|
||||
_engineMu.Lock()
|
||||
defer _engineMu.Unlock()
|
||||
@@ -227,6 +238,7 @@ func netstack(k *Key) (err error) {
|
||||
if _defaultStack, err = core.CreateStack(&core.Config{
|
||||
LinkEndpoint: _defaultDevice,
|
||||
TransportHandler: tunnel.T(),
|
||||
ICMPHandler: _icmpHandler,
|
||||
MulticastGroups: multicastGroups,
|
||||
Options: opts,
|
||||
}); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user