fix: swarm: refactor address resolution (#2990)

* Remove unused resolver in basic host

* Refactor Swarm.resolveAddrs

Refactors how DNS Address resolution works.

* lint

* Move MultiaddrDNSResolver interface to core

* Reserve output space for addresses left to resolve

* feat: core/transport: Add SkipResolver interface (#2989)

* Rebase on top of resolveAddrs refactor

* Add comments

* Sanitize address inputs when returning a reservation message (#3006)
This commit is contained in:
Marco Munizaga
2024-10-16 12:20:21 -07:00
committed by GitHub
parent c79cf3653d
commit e8b6685edb
19 changed files with 653 additions and 137 deletions
+1 -3
View File
@@ -40,7 +40,6 @@ import (
"github.com/prometheus/client_golang/prometheus"
ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/quic-go/quic-go"
"go.uber.org/fx"
@@ -114,7 +113,7 @@ type Config struct {
Peerstore peerstore.Peerstore
Reporter metrics.Reporter
MultiaddrResolver *madns.Resolver
MultiaddrResolver network.MultiaddrDNSResolver
DisablePing bool
@@ -286,7 +285,6 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }),
fx.Provide(func() pnet.PSK { return cfg.PSK }),
fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }),
fx.Provide(func() *madns.Resolver { return cfg.MultiaddrResolver }),
fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn {
hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool {
quicAddrPorts := map[string]struct{}{}
+8
View File
@@ -161,6 +161,14 @@ type Network interface {
ResourceManager() ResourceManager
}
type MultiaddrDNSResolver interface {
// ResolveDNSAddr resolves the first /dnsaddr component in a multiaddr.
// Recurisvely resolves DNSADDRs up to the recursion limit
ResolveDNSAddr(ctx context.Context, expectedPeerID peer.ID, maddr ma.Multiaddr, recursionLimit, outputLimit int) ([]ma.Multiaddr, error)
// ResolveDNSComponent resolves the first /{dns,dns4,dns6} component in a multiaddr.
ResolveDNSComponent(ctx context.Context, maddr ma.Multiaddr, outputLimit int) ([]ma.Multiaddr, error)
}
// Dialer represents a service that can dial out to peers
// (this is usually just a Network, but other services may not need the whole
// stack, and thus it becomes easier to mock)
+18
View File
@@ -61,6 +61,24 @@ func SplitAddr(m ma.Multiaddr) (transport ma.Multiaddr, id ID) {
return transport, id
}
// IDFromP2PAddr extracts the peer ID from a p2p Multiaddr
func IDFromP2PAddr(m ma.Multiaddr) (ID, error) {
if m == nil {
return "", ErrInvalidAddr
}
var lastComponent ma.Component
ma.ForEach(m, func(c ma.Component) bool {
lastComponent = c
return true
})
if lastComponent.Protocol().Code != ma.P_P2P {
return "", ErrInvalidAddr
}
id := ID(lastComponent.RawValue()) // already validated by the multiaddr library.
return id, nil
}
// AddrInfoFromString builds an AddrInfo from the string representation of a Multiaddr
func AddrInfoFromString(s string) (*AddrInfo, error) {
a, err := ma.NewMultiaddr(s)
+14
View File
@@ -4,6 +4,7 @@ import (
"testing"
. "github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/require"
ma "github.com/multiformats/go-multiaddr"
)
@@ -50,6 +51,19 @@ func TestSplitAddr(t *testing.T) {
}
}
func TestIDFromP2PAddr(t *testing.T) {
id, err := IDFromP2PAddr(maddrFull)
require.NoError(t, err)
require.Equal(t, testID, id)
id, err = IDFromP2PAddr(maddrPeer)
require.NoError(t, err)
require.Equal(t, testID, id)
_, err = IDFromP2PAddr(maddrTpt)
require.ErrorIs(t, err, ErrInvalidAddr)
}
func TestAddrInfoFromP2pAddr(t *testing.T) {
ai, err := AddrInfoFromP2pAddr(maddrFull)
if err != nil {
+12
View File
@@ -50,6 +50,10 @@ type CapableConn interface {
// shutdown. NOTE: `Dial` and `Listen` may be called after or concurrently with
// `Close`.
//
// In addition to the Transport interface, transports may implement
// Resolver or SkipResolver interface. When wrapping/embedding a transport, you should
// ensure that the Resolver/SkipResolver interface is handled correctly.
//
// For a conceptual overview, see https://docs.libp2p.io/concepts/transport/
type Transport interface {
// Dial dials a remote peer. It should try to reuse local listener
@@ -85,6 +89,14 @@ type Resolver interface {
Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error)
}
// SkipResolver can be optionally implemented by transports that don't want to
// resolve or transform the multiaddr. Useful for transports that indirectly
// wrap other transports (e.g. p2p-circuit). This lets the inner transport
// specify how a multiaddr is resolved later.
type SkipResolver interface {
SkipResolve(ctx context.Context, maddr ma.Multiaddr) bool
}
// Listener is an interface closely resembling the net.Listener interface. The
// only real difference is that Accept() returns Conn's of the type in this
// package, and also exposes a Multiaddr method as opposed to a regular Addr
-10
View File
@@ -21,7 +21,6 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
)
// DefaultSecurity is the default security option.
@@ -128,11 +127,6 @@ var DefaultConnectionManager = func(cfg *Config) error {
return cfg.Apply(ConnectionManager(mgr))
}
// DefaultMultiaddrResolver creates a default connection manager
var DefaultMultiaddrResolver = func(cfg *Config) error {
return cfg.Apply(MultiaddrResolver(madns.DefaultResolver))
}
// DefaultPrometheusRegisterer configures libp2p to use the default registerer
var DefaultPrometheusRegisterer = func(cfg *Config) error {
return cfg.Apply(PrometheusRegisterer(prometheus.DefaultRegisterer))
@@ -198,10 +192,6 @@ var defaults = []struct {
fallback: func(cfg *Config) bool { return cfg.ConnManager == nil },
opt: DefaultConnectionManager,
},
{
fallback: func(cfg *Config) bool { return cfg.MultiaddrResolver == nil },
opt: DefaultMultiaddrResolver,
},
{
fallback: func(cfg *Config) bool { return !cfg.DisableMetrics && cfg.PrometheusRegisterer == nil },
opt: DefaultPrometheusRegisterer,
+98 -2
View File
@@ -2,10 +2,16 @@ package libp2p
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"fmt"
"io"
"math/big"
"net"
"net/netip"
"regexp"
@@ -26,11 +32,12 @@ import (
"github.com/libp2p/go-libp2p/p2p/net/swarm"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
"github.com/libp2p/go-libp2p/p2p/security/noise"
tls "github.com/libp2p/go-libp2p/p2p/security/tls"
sectls "github.com/libp2p/go-libp2p/p2p/security/tls"
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
"github.com/libp2p/go-libp2p/p2p/transport/websocket"
webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
"go.uber.org/goleak"
@@ -256,7 +263,7 @@ func TestSecurityConstructor(t *testing.T) {
h, err := New(
Transport(tcp.NewTCPTransport),
Security("/noisy", noise.New),
Security("/tls", tls.New),
Security("/tls", sectls.New),
DefaultListenAddrs,
DisableRelay(),
)
@@ -655,3 +662,92 @@ func TestUseCorrectTransportForDialOut(t *testing.T) {
}
}
}
func TestCircuitBehindWSS(t *testing.T) {
relayTLSConf := getTLSConf(t, net.IPv4(127, 0, 0, 1), time.Now(), time.Now().Add(time.Hour))
serverNameChan := make(chan string, 2) // Channel that returns what server names the client hello specified
relayTLSConf.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
serverNameChan <- chi.ServerName
return relayTLSConf, nil
}
relay, err := New(
EnableRelayService(),
ForceReachabilityPublic(),
Transport(websocket.New, websocket.WithTLSConfig(relayTLSConf)),
ListenAddrStrings("/ip4/127.0.0.1/tcp/0/wss"),
)
require.NoError(t, err)
defer relay.Close()
relayAddrPort, _ := relay.Addrs()[0].ValueForProtocol(ma.P_TCP)
relayAddrWithSNIString := fmt.Sprintf(
"/dns4/localhost/tcp/%s/wss", relayAddrPort,
)
relayAddrWithSNI := []ma.Multiaddr{ma.StringCast(relayAddrWithSNIString)}
h, err := New(
NoListenAddrs,
EnableRelay(),
Transport(websocket.New, websocket.WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})),
ForceReachabilityPrivate())
require.NoError(t, err)
defer h.Close()
peerBehindRelay, err := New(
NoListenAddrs,
Transport(websocket.New, websocket.WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})),
EnableRelay(),
EnableAutoRelayWithStaticRelays([]peer.AddrInfo{{ID: relay.ID(), Addrs: relayAddrWithSNI}}),
ForceReachabilityPrivate())
require.NoError(t, err)
defer peerBehindRelay.Close()
require.Equal(t,
"localhost",
<-serverNameChan, // The server connects to the relay
)
// Connect to the peer behind the relay
h.Connect(context.Background(), peer.AddrInfo{
ID: peerBehindRelay.ID(),
Addrs: []ma.Multiaddr{ma.StringCast(
fmt.Sprintf("%s/p2p/%s/p2p-circuit", relayAddrWithSNIString, relay.ID()),
)},
})
require.NoError(t, err)
require.Equal(t,
"localhost",
<-serverNameChan, // The client connects to the relay and sends the SNI
)
}
// getTLSConf is a helper to generate a self-signed TLS config
func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config {
t.Helper()
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(1234),
Subject: pkix.Name{Organization: []string{"websocket"}},
NotBefore: start,
NotAfter: end,
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
IPAddresses: []net.IP{ip},
}
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &priv.PublicKey, priv)
require.NoError(t, err)
cert, err := x509.ParseCertificate(caBytes)
require.NoError(t, err)
return &tls.Config{
Certificates: []tls.Certificate{{
Certificate: [][]byte{cert.Raw},
PrivateKey: priv,
Leaf: cert,
}},
}
}
+1 -2
View File
@@ -31,7 +31,6 @@ import (
"github.com/prometheus/client_golang/prometheus"
ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
"go.uber.org/fx"
)
@@ -495,7 +494,7 @@ func UserAgent(userAgent string) Option {
}
// MultiaddrResolver sets the libp2p dns resolver
func MultiaddrResolver(rslv *madns.Resolver) Option {
func MultiaddrResolver(rslv network.MultiaddrDNSResolver) Option {
return func(cfg *Config) error {
cfg.MultiaddrResolver = rslv
return nil
-11
View File
@@ -38,7 +38,6 @@ import (
logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
manet "github.com/multiformats/go-multiaddr/net"
msmux "github.com/multiformats/go-multistream"
)
@@ -82,7 +81,6 @@ type BasicHost struct {
hps *holepunch.Service
pings *ping.PingService
natmgr NATManager
maResolver *madns.Resolver
cmgr connmgr.ConnManager
eventbus event.Bus
relayManager *relaysvc.RelayManager
@@ -133,10 +131,6 @@ type HostOpts struct {
// If omitted, there's no override or filtering, and the results of Addrs and AllAddrs are the same.
AddrsFactory AddrsFactory
// MultiaddrResolves holds the go-multiaddr-dns.Resolver used for resolving
// /dns4, /dns6, and /dnsaddr addresses before trying to connect to a peer.
MultiaddrResolver *madns.Resolver
// NATManager takes care of setting NAT port mappings, and discovering external addresses.
// If omitted, this will simply be disabled.
NATManager func(network.Network) NATManager
@@ -197,7 +191,6 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
mux: msmux.NewMultistreamMuxer[protocol.ID](),
negtimeout: DefaultNegotiationTimeout,
AddrsFactory: DefaultAddrsFactory,
maResolver: madns.DefaultResolver,
eventbus: opts.EventBus,
addrChangeChan: make(chan struct{}, 1),
ctx: hostCtx,
@@ -306,10 +299,6 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
h.natmgr = opts.NATManager(n)
}
if opts.MultiaddrResolver != nil {
h.maResolver = opts.MultiaddrResolver
}
if opts.ConnManager == nil {
h.cmgr = &connmgr.NullConnMgr{}
} else {
+1 -1
View File
@@ -52,7 +52,7 @@ func TestBasicDialPeerWithResolver(t *testing.T) {
resolver, err := madns.NewResolver(madns.WithDomainResolver("example.com", &mockResolver))
require.NoError(t, err)
swarms := makeSwarms(t, 2, swarmt.WithSwarmOpts(swarm.WithMultiaddrResolver(resolver)))
swarms := makeSwarms(t, 2, swarmt.WithSwarmOpts(swarm.WithMultiaddrResolver(swarm.ResolverFromMaDNS{resolver})))
defer closeSwarms(swarms)
s1 := swarms[0]
s2 := swarms[1]
+120
View File
@@ -0,0 +1,120 @@
package swarm
import (
"context"
"net"
"strconv"
"testing"
"github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
"github.com/stretchr/testify/require"
)
func TestSwarmResolver(t *testing.T) {
mockResolver := madns.MockResolver{IP: make(map[string][]net.IPAddr)}
ipaddr, err := net.ResolveIPAddr("ip4", "127.0.0.1")
require.NoError(t, err)
mockResolver.IP["example.com"] = []net.IPAddr{*ipaddr}
mockResolver.TXT = map[string][]string{
"_dnsaddr.example.com": {"dnsaddr=/ip4/127.0.0.1"},
}
madnsResolver, err := madns.NewResolver(madns.WithDomainResolver("example.com", &mockResolver))
require.NoError(t, err)
swarmResolver := ResolverFromMaDNS{madnsResolver}
ctx := context.Background()
res, err := swarmResolver.ResolveDNSComponent(ctx, multiaddr.StringCast("/dns/example.com"), 10)
require.NoError(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "/ip4/127.0.0.1", res[0].String())
res, err = swarmResolver.ResolveDNSAddr(ctx, "", multiaddr.StringCast("/dnsaddr/example.com"), 1, 10)
require.NoError(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "/ip4/127.0.0.1", res[0].String())
t.Run("Test Limits", func(t *testing.T) {
var ipaddrs []net.IPAddr
var manyDNSAddrs []string
for i := 0; i < 255; i++ {
ip := "1.2.3." + strconv.Itoa(i)
ipaddrs = append(ipaddrs, net.IPAddr{IP: net.ParseIP(ip)})
manyDNSAddrs = append(manyDNSAddrs, "dnsaddr=/ip4/"+ip)
}
mockResolver.IP = map[string][]net.IPAddr{
"example.com": ipaddrs,
}
mockResolver.TXT = map[string][]string{
"_dnsaddr.example.com": manyDNSAddrs,
}
res, err := swarmResolver.ResolveDNSComponent(ctx, multiaddr.StringCast("/dns/example.com"), 10)
require.NoError(t, err)
require.Equal(t, 10, len(res))
for i := 0; i < 10; i++ {
require.Equal(t, "/ip4/1.2.3."+strconv.Itoa(i), res[i].String())
}
res, err = swarmResolver.ResolveDNSAddr(ctx, "", multiaddr.StringCast("/dnsaddr/example.com"), 1, 10)
require.NoError(t, err)
require.Equal(t, 10, len(res))
for i := 0; i < 10; i++ {
require.Equal(t, "/ip4/1.2.3."+strconv.Itoa(i), res[i].String())
}
})
t.Run("Test Recursive Limits", func(t *testing.T) {
recursiveDNSAddr := make(map[string][]string)
for i := 0; i < 255; i++ {
recursiveDNSAddr["_dnsaddr."+strconv.Itoa(i)+".example.com"] = []string{"dnsaddr=/dnsaddr/" + strconv.Itoa(i+1) + ".example.com"}
}
recursiveDNSAddr["_dnsaddr.255.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
mockResolver.TXT = recursiveDNSAddr
res, err = swarmResolver.ResolveDNSAddr(ctx, "", multiaddr.StringCast("/dnsaddr/0.example.com"), 256, 10)
require.NoError(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "/ip4/127.0.0.1", res[0].String())
res, err = swarmResolver.ResolveDNSAddr(ctx, "", multiaddr.StringCast("/dnsaddr/0.example.com"), 255, 10)
require.NoError(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "/dnsaddr/255.example.com", res[0].String())
})
t.Run("Test Resolve at output limit", func(t *testing.T) {
recursiveDNSAddr := make(map[string][]string)
recursiveDNSAddr["_dnsaddr.example.com"] = []string{
"dnsaddr=/dnsaddr/0.example.com",
"dnsaddr=/dnsaddr/1.example.com",
"dnsaddr=/dnsaddr/2.example.com",
"dnsaddr=/dnsaddr/3.example.com",
"dnsaddr=/dnsaddr/4.example.com",
"dnsaddr=/dnsaddr/5.example.com",
"dnsaddr=/dnsaddr/6.example.com",
"dnsaddr=/dnsaddr/7.example.com",
"dnsaddr=/dnsaddr/8.example.com",
"dnsaddr=/dnsaddr/9.example.com",
}
recursiveDNSAddr["_dnsaddr.0.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.1.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.2.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.3.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.4.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.5.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.6.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.7.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.8.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.9.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
mockResolver.TXT = recursiveDNSAddr
res, err = swarmResolver.ResolveDNSAddr(ctx, "", multiaddr.StringCast("/dnsaddr/example.com"), 256, 10)
require.NoError(t, err)
require.Equal(t, 10, len(res))
for _, r := range res {
require.Equal(t, "/ip4/127.0.0.1", r.String())
}
})
}
+116 -15
View File
@@ -60,9 +60,9 @@ func WithConnectionGater(gater connmgr.ConnectionGater) Option {
}
// WithMultiaddrResolver sets a custom multiaddress resolver
func WithMultiaddrResolver(maResolver *madns.Resolver) Option {
func WithMultiaddrResolver(resolver network.MultiaddrDNSResolver) Option {
return func(s *Swarm) error {
s.maResolver = maResolver
s.multiaddrResolver = resolver
return nil
}
}
@@ -196,7 +196,7 @@ type Swarm struct {
m map[int]transport.Transport
}
maResolver *madns.Resolver
multiaddrResolver network.MultiaddrDNSResolver
// stream handlers
streamh atomic.Pointer[network.StreamHandler]
@@ -231,15 +231,15 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts
}
ctx, cancel := context.WithCancel(context.Background())
s := &Swarm{
local: local,
peers: peers,
emitter: emitter,
ctx: ctx,
ctxCancel: cancel,
dialTimeout: defaultDialTimeout,
dialTimeoutLocal: defaultDialTimeoutLocal,
maResolver: madns.DefaultResolver,
dialRanker: DefaultDialRanker,
local: local,
peers: peers,
emitter: emitter,
ctx: ctx,
ctxCancel: cancel,
dialTimeout: defaultDialTimeout,
dialTimeoutLocal: defaultDialTimeoutLocal,
multiaddrResolver: ResolverFromMaDNS{madns.DefaultResolver},
dialRanker: DefaultDialRanker,
// A black hole is a binary property. On a network if UDP dials are blocked or there is
// no IPv6 connectivity, all dials will fail. So a low success rate of 5 out 100 dials
@@ -624,7 +624,6 @@ func isBetterConn(a, b *Conn) bool {
// bestConnToPeer returns the best connection to peer.
func (s *Swarm) bestConnToPeer(p peer.ID) *Conn {
// TODO: Prefer some transports over others.
// For now, prefers direct connections over Relayed connections.
// For tie-breaking, select the newest non-closed connection with the most streams.
@@ -813,8 +812,10 @@ func (s *Swarm) ResourceManager() network.ResourceManager {
}
// Swarm is a Network.
var _ network.Network = (*Swarm)(nil)
var _ transport.TransportNetwork = (*Swarm)(nil)
var (
_ network.Network = (*Swarm)(nil)
_ transport.TransportNetwork = (*Swarm)(nil)
)
type connWithMetrics struct {
transport.CapableConn
@@ -846,3 +847,103 @@ func (c connWithMetrics) Stat() network.ConnStats {
}
var _ network.ConnStat = connWithMetrics{}
type ResolverFromMaDNS struct {
*madns.Resolver
}
var _ network.MultiaddrDNSResolver = ResolverFromMaDNS{}
func startsWithDNSADDR(m ma.Multiaddr) bool {
if m == nil {
return false
}
startsWithDNSADDR := false
// Using ForEach to avoid allocating
ma.ForEach(m, func(c ma.Component) bool {
startsWithDNSADDR = c.Protocol().Code == ma.P_DNSADDR
return false
})
return startsWithDNSADDR
}
// ResolveDNSAddr implements MultiaddrDNSResolver
func (r ResolverFromMaDNS) ResolveDNSAddr(ctx context.Context, expectedPeerID peer.ID, maddr ma.Multiaddr, recursionLimit int, outputLimit int) ([]ma.Multiaddr, error) {
if outputLimit <= 0 {
return nil, nil
}
if recursionLimit <= 0 {
return []ma.Multiaddr{maddr}, nil
}
var resolved, toResolve []ma.Multiaddr
addrs, err := r.Resolve(ctx, maddr)
if err != nil {
return nil, err
}
if len(addrs) > outputLimit {
addrs = addrs[:outputLimit]
}
for _, addr := range addrs {
if startsWithDNSADDR(addr) {
toResolve = append(toResolve, addr)
} else {
resolved = append(resolved, addr)
}
}
for i, addr := range toResolve {
// Set the nextOutputLimit to:
// outputLimit
// - len(resolved) // What we already have resolved
// - (len(toResolve) - i) // How many addresses we have left to resolve
// + 1 // The current address we are resolving
// This assumes that each DNSADDR address will resolve to at least one multiaddr.
// This assumption lets us bound the space we reserve for resolving.
nextOutputLimit := outputLimit - len(resolved) - (len(toResolve) - i) + 1
resolvedAddrs, err := r.ResolveDNSAddr(ctx, expectedPeerID, addr, recursionLimit-1, nextOutputLimit)
if err != nil {
log.Warnf("failed to resolve dnsaddr %v %s: ", addr, err)
// Dropping this address
continue
}
resolved = append(resolved, resolvedAddrs...)
}
if len(resolved) > outputLimit {
resolved = resolved[:outputLimit]
}
// If the address contains a peer id, make sure it matches our expectedPeerID
if expectedPeerID != "" {
removeMismatchPeerID := func(a ma.Multiaddr) bool {
id, err := peer.IDFromP2PAddr(a)
if err == peer.ErrInvalidAddr {
// This multiaddr didn't contain a peer id, assume it's for this peer.
// Handshake will fail later if it's not.
return false
} else if err != nil {
// This multiaddr is invalid, drop it.
return true
}
return id != expectedPeerID
}
resolved = slices.DeleteFunc(resolved, removeMismatchPeerID)
}
return resolved, nil
}
// ResolveDNSComponent implements MultiaddrDNSResolver
func (r ResolverFromMaDNS) ResolveDNSComponent(ctx context.Context, maddr ma.Multiaddr, outputLimit int) ([]ma.Multiaddr, error) {
addrs, err := r.Resolve(ctx, maddr)
if err != nil {
return nil, err
}
if len(addrs) > outputLimit {
addrs = addrs[:outputLimit]
}
return addrs, nil
}
+135 -79
View File
@@ -16,14 +16,15 @@ import (
"github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
mafmt "github.com/multiformats/go-multiaddr-fmt"
manet "github.com/multiformats/go-multiaddr/net"
)
// The maximum number of address resolution steps we'll perform for a single
// peer (for all addresses).
const maxAddressResolution = 32
// The maximum number of addresses we'll return when resolving all of a peer's
// address
const maximumResolvedAddresses = 100
const maximumDNSADDRRecursion = 4
// Diagram of dial sync:
//
@@ -302,10 +303,7 @@ func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) (goodAddrs []ma.Mul
}
// Resolve dns or dnsaddrs
resolved, err := s.resolveAddrs(ctx, peer.AddrInfo{ID: p, Addrs: peerAddrs})
if err != nil {
return nil, nil, err
}
resolved := s.resolveAddrs(ctx, peer.AddrInfo{ID: p, Addrs: peerAddrs})
goodAddrs = ma.Unique(resolved)
goodAddrs, addrErrs = s.filterKnownUndialables(p, goodAddrs)
@@ -322,84 +320,142 @@ func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) (goodAddrs []ma.Mul
return goodAddrs, addrErrs, nil
}
func (s *Swarm) resolveAddrs(ctx context.Context, pi peer.AddrInfo) ([]ma.Multiaddr, error) {
p2paddr, err := ma.NewMultiaddr("/" + ma.ProtocolWithCode(ma.P_P2P).Name + "/" + pi.ID.String())
if err != nil {
return nil, err
func startsWithDNSComponent(m ma.Multiaddr) bool {
if m == nil {
return false
}
var resolveSteps int
// Recursively resolve all addrs.
//
// While the toResolve list is non-empty:
// * Pop an address off.
// * If the address is fully resolved, add it to the resolved list.
// * Otherwise, resolve it and add the results to the "to resolve" list.
toResolve := append([]ma.Multiaddr{}, pi.Addrs...)
resolved := make([]ma.Multiaddr, 0, len(pi.Addrs))
for len(toResolve) > 0 {
// pop the last addr off.
addr := toResolve[len(toResolve)-1]
toResolve = toResolve[:len(toResolve)-1]
// if it's resolved, add it to the resolved list.
if !madns.Matches(addr) {
resolved = append(resolved, addr)
continue
startsWithDNS := false
// Using ForEach to avoid allocating
ma.ForEach(m, func(c ma.Component) bool {
switch c.Protocol().Code {
case ma.P_DNS, ma.P_DNS4, ma.P_DNS6:
startsWithDNS = true
}
resolveSteps++
return false
})
return startsWithDNS
}
// We've resolved too many addresses. We can keep all the fully
// resolved addresses but we'll need to skip the rest.
if resolveSteps >= maxAddressResolution {
log.Warnf(
"peer %s asked us to resolve too many addresses: %s/%s",
pi.ID,
resolveSteps,
maxAddressResolution,
)
continue
}
tpt := s.TransportForDialing(addr)
resolver, ok := tpt.(transport.Resolver)
if ok {
resolvedAddrs, err := resolver.Resolve(ctx, addr)
if err != nil {
log.Warnf("Failed to resolve multiaddr %s by transport %v: %v", addr, tpt, err)
continue
}
var added bool
for _, a := range resolvedAddrs {
if !addr.Equal(a) {
toResolve = append(toResolve, a)
added = true
}
}
if added {
continue
}
}
// otherwise, resolve it
reqaddr := addr.Encapsulate(p2paddr)
resaddrs, err := s.maResolver.Resolve(ctx, reqaddr)
if err != nil {
log.Infof("error resolving %s: %s", reqaddr, err)
}
// add the results to the toResolve list.
for _, res := range resaddrs {
pi, err := peer.AddrInfoFromP2pAddr(res)
if err != nil {
log.Infof("error parsing %s: %s", res, err)
}
toResolve = append(toResolve, pi.Addrs...)
func stripP2PComponent(addrs []ma.Multiaddr) []ma.Multiaddr {
for i, addr := range addrs {
if id, _ := peer.IDFromP2PAddr(addr); id != "" {
addrs[i], _ = ma.SplitLast(addr)
}
}
return addrs
}
return resolved, nil
type resolver struct {
canResolve func(ma.Multiaddr) bool
resolve func(ctx context.Context, maddr ma.Multiaddr, outputLimit int) ([]ma.Multiaddr, error)
}
type resolveErr struct {
addr ma.Multiaddr
err error
}
func chainResolvers(ctx context.Context, addrs []ma.Multiaddr, outputLimit int, resolvers []resolver) ([]ma.Multiaddr, []resolveErr) {
nextAddrs := make([]ma.Multiaddr, 0, len(addrs))
errs := make([]resolveErr, 0)
for _, r := range resolvers {
for _, a := range addrs {
if !r.canResolve(a) {
nextAddrs = append(nextAddrs, a)
continue
}
if len(nextAddrs) >= outputLimit {
nextAddrs = nextAddrs[:outputLimit]
break
}
next, err := r.resolve(ctx, a, outputLimit-len(nextAddrs))
if err != nil {
errs = append(errs, resolveErr{addr: a, err: err})
continue
}
nextAddrs = append(nextAddrs, next...)
}
addrs, nextAddrs = nextAddrs, addrs
nextAddrs = nextAddrs[:0]
}
return addrs, errs
}
// resolveAddrs resolves DNS/DNSADDR components in the given peer's addresses.
// We want to resolve the DNS components to IP addresses becase we want the
// swarm to manage ranking and dialing multiple connections, and a single DNS
// address can resolve to multiple IP addresses.
func (s *Swarm) resolveAddrs(ctx context.Context, pi peer.AddrInfo) []ma.Multiaddr {
dnsAddrResolver := resolver{
canResolve: startsWithDNSADDR,
resolve: func(ctx context.Context, maddr ma.Multiaddr, outputLimit int) ([]ma.Multiaddr, error) {
return s.multiaddrResolver.ResolveDNSAddr(ctx, pi.ID, maddr, maximumDNSADDRRecursion, outputLimit)
},
}
var skipped []ma.Multiaddr
skipResolver := resolver{
canResolve: func(addr ma.Multiaddr) bool {
tpt := s.TransportForDialing(addr)
if tpt == nil {
return false
}
_, ok := tpt.(transport.SkipResolver)
return ok
},
resolve: func(ctx context.Context, addr ma.Multiaddr, outputLimit int) ([]ma.Multiaddr, error) {
tpt := s.TransportForDialing(addr)
resolver, ok := tpt.(transport.SkipResolver)
if !ok {
return []ma.Multiaddr{addr}, nil
}
if resolver.SkipResolve(ctx, addr) {
skipped = append(skipped, addr)
return nil, nil
}
return []ma.Multiaddr{addr}, nil
},
}
tptResolver := resolver{
canResolve: func(addr ma.Multiaddr) bool {
tpt := s.TransportForDialing(addr)
if tpt == nil {
return false
}
_, ok := tpt.(transport.Resolver)
return ok
},
resolve: func(ctx context.Context, addr ma.Multiaddr, outputLimit int) ([]ma.Multiaddr, error) {
tpt := s.TransportForDialing(addr)
resolver, ok := tpt.(transport.Resolver)
if !ok {
return []ma.Multiaddr{addr}, nil
}
addrs, err := resolver.Resolve(ctx, addr)
if err != nil {
return nil, err
}
if len(addrs) > outputLimit {
addrs = addrs[:outputLimit]
}
return addrs, nil
},
}
dnsResolver := resolver{
canResolve: startsWithDNSComponent,
resolve: s.multiaddrResolver.ResolveDNSComponent,
}
addrs, errs := chainResolvers(ctx, pi.Addrs, maximumResolvedAddresses, []resolver{dnsAddrResolver, skipResolver, tptResolver, dnsResolver})
for _, err := range errs {
log.Warnf("Failed to resolve addr %s: %v", err.addr, err.err)
}
// Add skipped addresses back to the resolved addresses
addrs = append(addrs, skipped...)
return stripP2PComponent(addrs)
}
func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, resch chan transport.DialUpdate) error {
+18 -3
View File
@@ -55,7 +55,7 @@ func TestAddrsForDial(t *testing.T) {
tpt, err := websocket.New(nil, &network.NullResourceManager{})
require.NoError(t, err)
s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(resolver))
s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(ResolverFromMaDNS{resolver}))
require.NoError(t, err)
defer s.Close()
err = s.AddTransport(tpt)
@@ -96,7 +96,7 @@ func TestDedupAddrsForDial(t *testing.T) {
ps.AddPrivKey(id, priv)
t.Cleanup(func() { ps.Close() })
s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(resolver))
s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(ResolverFromMaDNS{resolver}))
require.NoError(t, err)
defer s.Close()
@@ -127,7 +127,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm {
ps.AddPubKey(id, priv.GetPublic())
ps.AddPrivKey(id, priv)
t.Cleanup(func() { ps.Close() })
s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(resolver))
s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(ResolverFromMaDNS{resolver}))
require.NoError(t, err)
t.Cleanup(func() {
s.Close()
@@ -398,3 +398,18 @@ func TestBlackHoledAddrBlocked(t *testing.T) {
}
require.ErrorIs(t, err, ErrDialRefusedBlackHole)
}
func TestSkipDialingManyDNS(t *testing.T) {
resolver, err := madns.NewResolver()
if err != nil {
t.Fatal(err)
}
s := newTestSwarmWithResolver(t, resolver)
defer s.Close()
id := test.RandPeerIDFatal(t)
addr := ma.StringCast("/dns/example.com/udp/1234/p2p-circuit/dns/example.com/p2p-circuit/dns/example.com")
resolved := s.resolveAddrs(context.Background(), peer.AddrInfo{ID: id, Addrs: []ma.Multiaddr{addr}})
require.NoError(t, err)
require.Less(t, len(resolved), 3)
}
+5
View File
@@ -4,6 +4,7 @@ import (
"fmt"
"strings"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
@@ -30,6 +31,10 @@ func (s *Swarm) TransportForDialing(a ma.Multiaddr) transport.Transport {
if isRelayAddr(a) {
return s.transports.m[ma.P_CIRCUIT]
}
if id, _ := peer.IDFromP2PAddr(a); id != "" {
// This addr has a p2p component. Drop it so we can check transport.
a, _ = ma.SplitLast(a)
}
for _, t := range s.transports.m {
if t.CanDial(a) {
return t
@@ -46,8 +46,20 @@ func AddTransport(h host.Host, upgrader transport.Upgrader) error {
// Transport interface
var _ transport.Transport = (*Client)(nil)
// p2p-circuit implements the SkipResolver interface so that the underlying
// transport can do the address resolution later. If you wrap this transport,
// make sure you also implement SkipResolver as well.
var _ transport.SkipResolver = (*Client)(nil)
var _ io.Closer = (*Client)(nil)
// SkipResolve returns true since we always defer to the inner transport for
// the actual connection. By skipping resolution here, we let the inner
// transport decide how to resolve the multiaddr
func (c *Client) SkipResolve(ctx context.Context, maddr ma.Multiaddr) bool {
return true
}
func (c *Client) Dial(ctx context.Context, a ma.Multiaddr, p peer.ID) (transport.CapableConn, error) {
connScope, err := c.host.Network().ResourceManager().OpenConnection(network.DirOutbound, false, a)
+40 -10
View File
@@ -9,6 +9,7 @@ import (
"sync/atomic"
"time"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
@@ -224,7 +225,13 @@ func (r *Relay) handleReserve(s network.Stream) pbv2.Status {
// Delivery of the reservation might fail for a number of reasons.
// For example, the stream might be reset or the connection might be closed before the reservation is received.
// In that case, the reservation will just be garbage collected later.
if err := r.writeResponse(s, pbv2.Status_OK, r.makeReservationMsg(p, expire), r.makeLimitMsg(p)); err != nil {
rsvp := makeReservationMsg(
r.host.Peerstore().PrivKey(r.host.ID()),
r.host.ID(),
r.host.Addrs(),
p,
expire)
if err := r.writeResponse(s, pbv2.Status_OK, rsvp, r.makeLimitMsg(p)); err != nil {
log.Debugf("error writing reservation response; retracting reservation for %s", p)
s.Reset()
return pbv2.Status_CONNECTION_FAILED
@@ -567,31 +574,54 @@ func (r *Relay) writeResponse(s network.Stream, status pbv2.Status, rsvp *pbv2.R
return wr.WriteMsg(&msg)
}
func (r *Relay) makeReservationMsg(p peer.ID, expire time.Time) *pbv2.Reservation {
func makeReservationMsg(
signingKey crypto.PrivKey,
selfID peer.ID,
selfAddrs []ma.Multiaddr,
p peer.ID,
expire time.Time,
) *pbv2.Reservation {
expireUnix := uint64(expire.Unix())
rsvp := &pbv2.Reservation{Expire: &expireUnix}
selfP2PAddr, err := ma.NewComponent("p2p", selfID.String())
if err != nil {
log.Errorf("error creating p2p component: %s", err)
return rsvp
}
var addrBytes [][]byte
for _, addr := range r.host.Addrs() {
for _, addr := range selfAddrs {
if !manet.IsPublicAddr(addr) {
continue
}
addr = addr.Encapsulate(r.selfAddr)
id, _ := peer.IDFromP2PAddr(addr)
switch {
case id == "":
// No ID, we'll add one to the address
addr = addr.Encapsulate(selfP2PAddr)
case id == selfID:
// This address already has our ID in it.
// Do nothing
case id != selfID:
// This address has a different ID in it. Skip it.
log.Warnf("skipping address %s: contains an unexpected ID", addr)
continue
}
addrBytes = append(addrBytes, addr.Bytes())
}
rsvp := &pbv2.Reservation{
Expire: &expireUnix,
Addrs: addrBytes,
}
rsvp.Addrs = addrBytes
voucher := &proto.ReservationVoucher{
Relay: r.host.ID(),
Relay: selfID,
Peer: p,
Expiration: expire,
}
envelope, err := record.Seal(voucher, r.host.Peerstore().PrivKey(r.host.ID()))
envelope, err := record.Seal(voucher, signingKey)
if err != nil {
log.Errorf("error sealing voucher for %s: %s", p, err)
return rsvp
@@ -0,0 +1,53 @@
package relay
import (
"crypto/rand"
"testing"
"time"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/require"
ma "github.com/multiformats/go-multiaddr"
)
func genKeyAndID(t *testing.T) (crypto.PrivKey, peer.ID) {
t.Helper()
key, _, err := crypto.GenerateEd25519Key(rand.Reader)
require.NoError(t, err)
id, err := peer.IDFromPrivateKey(key)
require.NoError(t, err)
return key, id
}
// TestMakeReservationWithP2PAddrs ensures that our reservation message builder
// sanitizes the input addresses
func TestMakeReservationWithP2PAddrs(t *testing.T) {
selfKey, selfID := genKeyAndID(t)
_, otherID := genKeyAndID(t)
_, reserverID := genKeyAndID(t)
addrs := []ma.Multiaddr{
ma.StringCast("/ip4/1.2.3.4/tcp/1234"), // No p2p part
ma.StringCast("/ip4/1.2.3.4/tcp/1235/p2p/" + selfID.String()), // Already has p2p part
ma.StringCast("/ip4/1.2.3.4/tcp/1236/p2p/" + otherID.String()), // Some other peer (?? Not expected, but we could get anything in this func)
}
rsvp := makeReservationMsg(selfKey, selfID, addrs, reserverID, time.Now().Add(time.Minute))
require.NotNil(t, rsvp)
expectedAddrs := []string{
"/ip4/1.2.3.4/tcp/1234/p2p/" + selfID.String(),
"/ip4/1.2.3.4/tcp/1235/p2p/" + selfID.String(),
}
var addrsFromRsvp []string
for _, addr := range rsvp.GetAddrs() {
a, err := ma.NewMultiaddrBytes(addr)
require.NoError(t, err)
addrsFromRsvp = append(addrsFromRsvp, a.String())
}
require.Equal(t, expectedAddrs, addrsFromRsvp)
}
+1 -1
View File
@@ -133,7 +133,7 @@ func (t *WebsocketTransport) Resolve(_ context.Context, maddr ma.Multiaddr) ([]m
if parsed.sni == nil {
var err error
// We don't have an sni component, we'll use dns/dnsaddr
// We don't have an sni component, we'll use dns
ma.ForEach(parsed.restMultiaddr, func(c ma.Component) bool {
switch c.Protocol().Code {
case ma.P_DNS, ma.P_DNS4, ma.P_DNS6: