Align mdns and TCPMux for active TCP connectivity

This commit is contained in:
Jo Turk
2026-02-19 16:52:55 +02:00
parent da652e14c2
commit 74a026ba4c
3 changed files with 155 additions and 1 deletions
+79 -1
View File
@@ -487,6 +487,8 @@ func newAgentWithConfig(agent *Agent, opts ...AgentOption) (*Agent, error) {
return nil, fmt.Errorf("error getting local interfaces: %w", err)
}
mDNSLocalAddress := mDNSLocalAddressFromTCPMux(agent.tcpMux, agent.networkTypes)
// Opportunistic mDNS: If we can't open the connection, that's ok: we
// can continue without it.
if agent.mDNSConn, agent.mDNSMode, err = createMulticastDNS(
@@ -494,6 +496,7 @@ func newAgentWithConfig(agent *Agent, opts ...AgentOption) (*Agent, error) {
agent.networkTypes,
localIfcs,
agent.includeLoopback,
mDNSLocalAddress,
agent.mDNSMode,
agent.mDNSName,
agent.log,
@@ -558,6 +561,67 @@ func newAgentWithConfig(agent *Agent, opts ...AgentOption) (*Agent, error) {
return agent, nil
}
func mDNSLocalAddressFromTCPMux(tcpMux TCPMux, networkTypes []NetworkType) net.IP {
if tcpMux == nil || !allNetworkTypesTCP(networkTypes) {
return nil
}
tcpAddr, ok := localTCPAddrFromMux(tcpMux)
if !ok {
return nil
}
localAddr, ok := mDNSLocalAddressFromIP(tcpAddr.IP)
if !ok {
return nil
}
return localAddr
}
func allNetworkTypesTCP(networkTypes []NetworkType) bool {
if len(networkTypes) == 0 {
return false
}
for _, networkType := range networkTypes {
if !networkType.IsTCP() {
return false
}
}
return true
}
func localTCPAddrFromMux(tcpMux TCPMux) (*net.TCPAddr, bool) {
addrProvider, ok := tcpMux.(interface{ LocalAddr() net.Addr })
if !ok {
return nil, false
}
tcpAddr, ok := addrProvider.LocalAddr().(*net.TCPAddr)
if !ok || tcpAddr.IP == nil || tcpAddr.IP.IsUnspecified() {
return nil, false
}
return tcpAddr, true
}
func mDNSLocalAddressFromIP(ip net.IP) (net.IP, bool) {
parsed, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, false
}
parsed = parsed.Unmap()
if parsed.Is6() && (parsed.IsLinkLocalUnicast() || parsed.IsLinkLocalMulticast()) {
// mdns.Config.LocalAddress has no zone support for link-local IPv6.
return nil, false
}
return parsed.AsSlice(), true
}
func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remotePwd string) error {
a.muHaveStarted.Lock()
defer a.muHaveStarted.Unlock()
@@ -987,10 +1051,12 @@ func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) {
continue
}
dialIP := remoteDialIPForLocalInterface(ip, localIPs[i].addr)
conn := newActiveTCPConn(
a.loop,
net.JoinHostPort(localIPs[i].addr.String(), "0"),
netip.AddrPortFrom(ip, uint16(remoteCandidate.Port())), //nolint:gosec // G115, no overflow, a port
netip.AddrPortFrom(dialIP, uint16(remoteCandidate.Port())), //nolint:gosec // G115, no overflow, a port
a.log,
)
@@ -1025,6 +1091,18 @@ func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) {
}
}
func remoteDialIPForLocalInterface(remoteIP, localIP netip.Addr) netip.Addr {
if remoteIP.Is6() &&
remoteIP.Zone() == "" &&
(remoteIP.IsLinkLocalUnicast() || remoteIP.IsLinkLocalMulticast()) {
if zone := localIP.Zone(); zone != "" {
return remoteIP.WithZone(zone)
}
}
return remoteIP
}
// addRemoteCandidate assumes you are holding the lock (must be execute using a.run).
func (a *Agent) addRemoteCandidate(cand Candidate) { //nolint:cyclop
set := a.remoteCandidates[cand.NetworkType()]
+73
View File
@@ -9,6 +9,7 @@ package ice
import (
"context"
"net"
"net/netip"
"strconv"
"sync"
"testing"
@@ -2594,6 +2595,40 @@ func TestAgentUpdateOptions(t *testing.T) {
})
}
func TestRemoteDialIPForLocalInterface(t *testing.T) {
t.Run("adds local zone for zone-less link-local IPv6", func(t *testing.T) {
remote := netip.MustParseAddr("fe80::1234")
local := netip.MustParseAddr("fe80::1%eth0")
got := remoteDialIPForLocalInterface(remote, local)
require.Equal(t, netip.MustParseAddr("fe80::1234%eth0"), got)
})
t.Run("keeps existing remote zone", func(t *testing.T) {
remote := netip.MustParseAddr("fe80::1234%eth9")
local := netip.MustParseAddr("fe80::1%eth0")
got := remoteDialIPForLocalInterface(remote, local)
require.Equal(t, remote, got)
})
t.Run("does not modify global IPv6", func(t *testing.T) {
remote := netip.MustParseAddr("2001:db8::1234")
local := netip.MustParseAddr("fe80::1%eth0")
got := remoteDialIPForLocalInterface(remote, local)
require.Equal(t, remote, got)
})
t.Run("does not modify zone-less link-local when local has no zone", func(t *testing.T) {
remote := netip.MustParseAddr("fe80::1234")
local := netip.MustParseAddr("2001:db8::1")
got := remoteDialIPForLocalInterface(remote, local)
require.Equal(t, remote, got)
})
}
func TestMDNSQueryTimeout(t *testing.T) {
t.Run("falls back to default when unset", func(t *testing.T) {
agent := &Agent{}
@@ -2605,3 +2640,41 @@ func TestMDNSQueryTimeout(t *testing.T) {
require.Equal(t, 3*time.Second, agent.mDNSQueryTimeout())
})
}
type localAddrTCPMux struct {
addr net.Addr
}
func (m *localAddrTCPMux) Close() error { return nil }
func (m *localAddrTCPMux) GetConnByUfrag(string, bool, net.IP) (net.PacketConn, error) {
return nil, nil //nolint:nilnil
}
func (m *localAddrTCPMux) RemoveConnByUfrag(string) {}
func (m *localAddrTCPMux) LocalAddr() net.Addr { return m.addr }
func TestMDNSLocalAddressFromTCPMux(t *testing.T) {
t.Run("nil for mixed network types", func(t *testing.T) {
mux := &localAddrTCPMux{addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::1")}}
require.Nil(t, mDNSLocalAddressFromTCPMux(mux, []NetworkType{NetworkTypeTCP6, NetworkTypeUDP6}))
})
t.Run("nil for link-local IPv6 listener", func(t *testing.T) {
mux := &localAddrTCPMux{addr: &net.TCPAddr{IP: net.ParseIP("fe80::1"), Zone: "eth0"}}
require.Nil(t, mDNSLocalAddressFromTCPMux(mux, []NetworkType{NetworkTypeTCP6}))
})
t.Run("uses listener IP for TCP-only global IPv6", func(t *testing.T) {
want := net.ParseIP("2001:db8::1")
mux := &localAddrTCPMux{addr: &net.TCPAddr{IP: want, Zone: "wg0"}}
got := mDNSLocalAddressFromTCPMux(mux, []NetworkType{NetworkTypeTCP6})
require.Equal(t, want.To16(), got)
})
t.Run("nil for TCP mux without LocalAddr", func(t *testing.T) {
require.Nil(t, mDNSLocalAddressFromTCPMux(&stubTCPMux{}, []NetworkType{NetworkTypeTCP4}))
})
}
+3
View File
@@ -44,6 +44,7 @@ func createMulticastDNS(
networkTypes []NetworkType,
interfaces []*transport.Interface,
includeLoopback bool,
localAddress net.IP,
mDNSMode MulticastDNSMode,
mDNSName string,
log logging.LeveledLogger,
@@ -125,6 +126,7 @@ func createMulticastDNS(
conn, err := mdns.Server(pktConnV4, pktConnV6, &mdns.Config{
Interfaces: ifcs,
IncludeLoopback: includeLoopback,
LocalAddress: localAddress,
LoggerFactory: loggerFactory,
})
@@ -133,6 +135,7 @@ func createMulticastDNS(
conn, err := mdns.Server(pktConnV4, pktConnV6, &mdns.Config{
Interfaces: ifcs,
IncludeLoopback: includeLoopback,
LocalAddress: localAddress,
LocalNames: []string{mDNSName},
LoggerFactory: loggerFactory,
})