diff --git a/p2p/protocol/autonatv2/client.go b/p2p/protocol/autonatv2/client.go index 5cbe698dd..203f822d6 100644 --- a/p2p/protocol/autonatv2/client.go +++ b/p2p/protocol/autonatv2/client.go @@ -14,8 +14,6 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" - libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" - libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" "github.com/libp2p/go-msgio/pbio" ma "github.com/multiformats/go-multiaddr" ) @@ -328,22 +326,22 @@ func (ac *client) handleDialBack(s network.Stream) { } // normalizeMultiaddr returns a multiaddr suitable for equality checks. -// If the multiaddr is a webtransport component, it removes the certhashes. +// it removes trailing certhashes. func normalizeMultiaddr(addr ma.Multiaddr) ma.Multiaddr { - ok, n := libp2pwebtransport.IsWebtransportMultiaddr(addr) - if !ok { - ok, n = libp2pwebrtc.IsWebRTCDirectMultiaddr(addr) - } - if ok && n > 0 { - out := addr - for i := 0; i < n; i++ { - out, _ = ma.SplitLast(out) - } - return out - } + addr = removeTrailing(addr, ma.P_P2P) + addr = removeTrailing(addr, ma.P_CERTHASH) return addr } +func removeTrailing(addr ma.Multiaddr, protocolCode int) ma.Multiaddr { + for i := len(addr) - 1; i >= 0; i-- { + if addr[i].Code() != protocolCode { + return addr[0 : i+1] + } + } + return nil +} + func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) bool { if len(connLocalAddr) == 0 || len(dialedAddr) == 0 { return false