diff --git a/stack_gvisor_lazy.go b/stack_gvisor_lazy.go index 59fcb35..f5e2e6e 100644 --- a/stack_gvisor_lazy.go +++ b/stack_gvisor_lazy.go @@ -4,6 +4,7 @@ package tun import ( "context" + "errors" "net" "os" "sync" @@ -79,7 +80,7 @@ func (c *gLazyConn) HandshakeFailure(err error) error { if c.handshakeDone { return os.ErrInvalid } - c.request.Complete(err != ErrDrop) + c.request.Complete(!errors.Is(err, ErrDrop)) c.handshakeDone = true c.handshakeErr = err return nil diff --git a/stack_gvisor_tcp.go b/stack_gvisor_tcp.go index 33cf40e..0a12933 100644 --- a/stack_gvisor_tcp.go +++ b/stack_gvisor_tcp.go @@ -4,6 +4,7 @@ package tun import ( "context" + "errors" "github.com/sagernet/gvisor/pkg/tcpip/stack" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" @@ -37,7 +38,7 @@ func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) { destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination) if pErr != nil { - r.Complete(pErr != ErrDrop) + r.Complete(!errors.Is(pErr, ErrDrop)) return } conn := &gLazyConn{ diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index 3027798..473eec4 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -4,6 +4,7 @@ package tun import ( "context" + "errors" "math" "net/netip" "os" @@ -59,7 +60,7 @@ func rangeIterate(r stack.Range, fn func(*buffer.View)) func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) { pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination) if pErr != nil { - if pErr != ErrDrop { + if !errors.Is(pErr, ErrDrop) { gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer)) } return false, nil, nil, nil diff --git a/stack_system.go b/stack_system.go index 5a301d5..eaf8314 100644 --- a/stack_system.go +++ b/stack_system.go @@ -2,6 +2,7 @@ package tun import ( "context" + "errors" "net" "net/netip" "syscall" @@ -354,7 +355,7 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err } else { natPort, err := s.tcpNat.Lookup(source, destination, s.handler) if err != nil { - if err == ErrDrop { + if errors.Is(err, ErrDrop) { return false, nil } else { return false, s.resetIPv4TCP(ipHdr, tcpHdr) @@ -441,7 +442,7 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err } else { natPort, err := s.tcpNat.Lookup(source, destination, s.handler) if err != nil { - if err == ErrDrop { + if errors.Is(err, ErrDrop) { return false, nil } else { return false, s.resetIPv6TCP(ipHdr, tcpHdr) @@ -536,7 +537,7 @@ func (s *System) processIPv6UDP(ipHdr header.IPv6, udpHdr header.UDP) error { func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) { pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination) if pErr != nil { - if pErr != ErrDrop { + if !errors.Is(pErr, ErrDrop) { if source.IsIPv4() { ipHdr := userData.(header.IPv4) s.rejectIPv4WithICMP(ipHdr, header.ICMPv4PortUnreachable)