Honor interface/IP filters for srflx and relay STUN requests (#899)

## Summary

This PR fixes a bug where server-reflexive (srflx) STUN requests and
TURN relay UDP sockets could be sent from interfaces that were filtered
out by `InterfaceFilter` / `IPFilter`.

Previously:

- srflx gathering used wildcard local binds (`0.0.0.0` / `::`)
- relay TURN/UDP gathering used wildcard local bind (`0.0.0.0:0`)

That allowed kernel routing to choose source interfaces outside
configured filters, which could expose public IPs from non-allowed
adapters.

With this change, when interface/IP filters are configured:

- srflx gathering binds STUN sockets to filtered local interface
addresses only
- relay (TURN/UDP) gathering binds local sockets to filtered local
interface addresses only

## Root cause

- Wildcard local binds do not constrain source interface selection.
- `localInterfaces(...)` filtering was applied to host gathering but not
consistently to srflx and relay local bind paths.

## What changed

- `gather.go`
  - Refactored `gatherCandidatesSrflx` to support two modes:
    - **Filtered mode** (`interfaceFilter` or `ipFilter` set):
      - Resolve filtered local interfaces via `localInterfaces(...)`.
      - Bind srflx sockets per matching local address/family.
    - **Default mode** (no filters):
      - Preserve existing wildcard-bind behavior.
- Updated `gatherCandidatesRelay` (TURN/UDP path) with the same two-mode
behavior:
- **Filtered mode** binds relay UDP sockets per filtered local IPv4
address.
    - **Default mode** preserves legacy `0.0.0.0:0` bind behavior.

- `gather_test.go`
- Added `srflxListenCaptureNet` and
`TestGatherCandidatesSrflxRespectsInterfaceFilter`.
- Added `relayListenCaptureNet` and
`TestGatherCandidatesRelayRespectsInterfaceFilter`.
- Kept existing relay behavior checks via
`TestGatherCandidatesRelayCallsAddRelayCandidates`.

## Backward compatibility

- No behavior change for users who do not configure `InterfaceFilter` /
`IPFilter`.
- When filters are configured, srflx/relay local bind behavior now
consistently honors those filters.

## Reference issue
Fixes #727
This commit is contained in:
sirzooro
2026-03-15 07:04:56 +01:00
committed by GitHub
parent 398ac7c807
commit 8f6a882190
3 changed files with 654 additions and 259 deletions
+3 -3
View File
@@ -513,7 +513,7 @@ func TestMultipleConfigOptions(t *testing.T) {
func TestWithInterfaceFilter(t *testing.T) {
t.Run("sets interface filter", func(t *testing.T) {
filter := func(interfaceName string) bool {
return interfaceName == "eth0"
return interfaceName == "eth0" // nolint:goconst
}
agent, err := NewAgentWithOptions(WithInterfaceFilter(filter))
@@ -521,7 +521,7 @@ func TestWithInterfaceFilter(t *testing.T) {
defer agent.Close() //nolint:errcheck
assert.NotNil(t, agent.interfaceFilter)
assert.True(t, agent.interfaceFilter("eth0"))
assert.True(t, agent.interfaceFilter("eth0")) // nolint:goconst
assert.False(t, agent.interfaceFilter("wlan0"))
})
@@ -549,7 +549,7 @@ func TestWithInterfaceFilter(t *testing.T) {
assert.NotNil(t, agent.interfaceFilter)
assert.True(t, agent.interfaceFilter("lo"))
assert.False(t, agent.interfaceFilter("eth0"))
assert.False(t, agent.interfaceFilter("eth0")) // nolint:goconst
})
}
+314 -256
View File
@@ -736,89 +736,123 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
var wg sync.WaitGroup
defer wg.Wait()
useFilteredLocalAddrs := a.interfaceFilter != nil || a.ipFilter != nil
localAddrs := []ifaceAddr{}
if useFilteredLocalAddrs {
_, addrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback)
if err != nil {
a.log.Warnf("Failed to iterate local interfaces, srflx candidates will not be gathered %s", err)
return
}
localAddrs = addrs
}
gatherForURL := func(url stun.URI, network string, listenAddr *net.UDPAddr) {
defer wg.Done()
hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port)
serverAddr, err := a.net.ResolveUDPAddr(network, hostPort)
if err != nil {
a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err)
return
}
if shouldFilterLocationTracked(serverAddr.IP) {
a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort)
return
}
conn, err := listenUDPInPortRange(
a.net,
a.log,
int(a.portMax),
int(a.portMin),
network,
listenAddr,
)
if err != nil {
closeConnAndLog(conn, a.log, "failed to listen for %s: %v", serverAddr.String(), err)
return
}
// If the agent closes midway through the connection
// we end it early to prevent close delay.
cancelCtx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
go func() {
select {
case <-cancelCtx.Done():
return
case <-a.loop.Done():
_ = conn.Close()
}
}()
xorAddr, err := stunx.GetXORMappedAddr(conn, serverAddr, a.stunGatherTimeout)
if err != nil {
closeConnAndLog(conn, a.log, "failed to get server reflexive address %s %s: %v", network, url, err)
return
}
ip := xorAddr.IP
port := xorAddr.Port
lAddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert
srflxConfig := CandidateServerReflexiveConfig{
Network: network,
Address: ip.String(),
Port: port,
Component: ComponentRTP,
RelAddr: lAddr.IP.String(),
RelPort: lAddr.Port,
}
c, err := NewCandidateServerReflexive(&srflxConfig)
if err != nil {
closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err)
return
}
if err := a.addCandidate(ctx, c, conn); err != nil {
if closeErr := c.close(); closeErr != nil {
a.log.Warnf("Failed to close candidate: %v", closeErr)
}
a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err)
}
}
for _, networkType := range networkTypes {
if networkType.IsTCP() {
continue
}
for i := range urls {
wg.Add(1)
go func(url stun.URI, network string) {
defer wg.Done()
if !useFilteredLocalAddrs {
wg.Add(1)
go gatherForURL(*urls[i], networkType.String(), &net.UDPAddr{IP: nil, Port: 0})
hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port)
serverAddr, err := a.net.ResolveUDPAddr(network, hostPort)
if err != nil {
a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err)
continue
}
return
for j := range localAddrs {
if networkType.IsIPv4() && localAddrs[j].addr.Is6() {
continue
}
if networkType.IsIPv6() && !localAddrs[j].addr.Is6() {
continue
}
if shouldFilterLocationTracked(serverAddr.IP) {
a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort)
return
}
conn, err := listenUDPInPortRange(
a.net,
a.log,
int(a.portMax),
int(a.portMin),
network,
&net.UDPAddr{IP: nil, Port: 0},
wg.Add(1)
go gatherForURL(
*urls[i],
networkType.String(),
&net.UDPAddr{IP: localAddrs[j].addr.AsSlice(), Zone: localAddrs[j].addr.Zone(), Port: 0},
)
if err != nil {
closeConnAndLog(conn, a.log, "failed to listen for %s: %v", serverAddr.String(), err)
return
}
// If the agent closes midway through the connection
// we end it early to prevent close delay.
cancelCtx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
go func() {
select {
case <-cancelCtx.Done():
return
case <-a.loop.Done():
_ = conn.Close()
}
}()
xorAddr, err := stunx.GetXORMappedAddr(conn, serverAddr, a.stunGatherTimeout)
if err != nil {
closeConnAndLog(conn, a.log, "failed to get server reflexive address %s %s: %v", network, url, err)
return
}
ip := xorAddr.IP
port := xorAddr.Port
lAddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert
srflxConfig := CandidateServerReflexiveConfig{
Network: network,
Address: ip.String(),
Port: port,
Component: ComponentRTP,
RelAddr: lAddr.IP.String(),
RelPort: lAddr.Port,
}
c, err := NewCandidateServerReflexive(&srflxConfig)
if err != nil {
closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err)
return
}
if err := a.addCandidate(ctx, c, conn); err != nil {
if closeErr := c.close(); closeErr != nil {
a.log.Warnf("Failed to close candidate: %v", closeErr)
}
a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err)
}
}(*urls[i], networkType.String())
}
}
}
}
@@ -831,6 +865,18 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) {
network := NetworkTypeUDP4.String()
_, ifaces, _ := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, a.networkTypes, a.includeLoopback)
useFilteredLocalAddrs := a.interfaceFilter != nil || a.ipFilter != nil
localAddrs := []ifaceAddr{}
if useFilteredLocalAddrs {
for i := range ifaces {
if ifaces[i].addr.Is6() {
continue
}
localAddrs = append(localAddrs, ifaces[i])
}
}
for _, url := range urls {
switch {
case url.Scheme != stun.SchemeTypeTURN && url.Scheme != stun.SchemeTypeTURNS:
@@ -845,204 +891,216 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) {
return
}
wg.Add(1)
go func(url stun.URI) {
defer wg.Done()
turnServerAddr := fmt.Sprintf("%s:%d", url.Host, url.Port)
var (
locConn net.PacketConn
err error
relAddr string
relPort int
relayProtocol string
)
switch {
case url.Proto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURN:
if locConn, err = a.net.ListenPacket(network, "0.0.0.0:0"); err != nil {
a.log.Warnf("Failed to listen %s: %v", network, err)
return
}
relAddr = locConn.LocalAddr().(*net.UDPAddr).IP.String() //nolint:forcetypeassert
relPort = locConn.LocalAddr().(*net.UDPAddr).Port //nolint:forcetypeassert
relayProtocol = udp
case a.proxyDialer != nil && url.Proto == stun.ProtoTypeTCP &&
(url.Scheme == stun.SchemeTypeTURN || url.Scheme == stun.SchemeTypeTURNS):
conn, connectErr := a.proxyDialer.Dial(NetworkTypeTCP4.String(), turnServerAddr)
if connectErr != nil {
a.log.Warnf("Failed to dial TCP address %s via proxy dialer: %v", turnServerAddr, connectErr)
return
}
relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert
relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert
if url.Scheme == stun.SchemeTypeTURN {
relayProtocol = tcp
} else if url.Scheme == stun.SchemeTypeTURNS {
relayProtocol = "tls"
}
locConn = turn.NewSTUNConn(conn)
case url.Proto == stun.ProtoTypeTCP && url.Scheme == stun.SchemeTypeTURN:
tcpAddr, connectErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr)
if connectErr != nil {
a.log.Warnf("Failed to resolve TCP address %s: %v", turnServerAddr, connectErr)
return
}
conn, connectErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr)
if connectErr != nil {
a.log.Warnf("Failed to dial TCP address %s: %v", turnServerAddr, connectErr)
return
}
relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert
relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert
relayProtocol = tcp
locConn = turn.NewSTUNConn(conn)
case url.Proto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURNS:
udpAddr, connectErr := a.net.ResolveUDPAddr(network, turnServerAddr)
if connectErr != nil {
a.log.Warnf("Failed to resolve UDP address %s: %v", turnServerAddr, connectErr)
return
}
udpConn, dialErr := a.net.DialUDP("udp", nil, udpAddr)
if dialErr != nil {
a.log.Warnf("Failed to dial DTLS address %s: %v", turnServerAddr, dialErr)
return
}
conn, connectErr := dtls.ClientWithOptions(&fakenet.PacketConn{Conn: udpConn}, udpConn.RemoteAddr(),
dtls.WithServerName(url.Host),
dtls.WithInsecureSkipVerify(a.insecureSkipVerify), //nolint:gosec
dtls.WithLoggerFactory(a.loggerFactory),
gatherForURL := func(url stun.URI, localBindAddr string) {
wg.Add(1)
go func(url stun.URI, localBindAddr string) {
defer wg.Done()
turnServerAddr := fmt.Sprintf("%s:%d", url.Host, url.Port)
var (
locConn net.PacketConn
err error
relAddr string
relPort int
relayProtocol string
)
if connectErr != nil {
a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr)
switch {
case url.Proto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURN:
if locConn, err = a.net.ListenPacket(network, localBindAddr); err != nil {
a.log.Warnf("Failed to listen %s: %v", network, err)
return
}
relAddr = locConn.LocalAddr().(*net.UDPAddr).IP.String() //nolint:forcetypeassert
relPort = locConn.LocalAddr().(*net.UDPAddr).Port //nolint:forcetypeassert
relayProtocol = udp
case a.proxyDialer != nil && url.Proto == stun.ProtoTypeTCP &&
(url.Scheme == stun.SchemeTypeTURN || url.Scheme == stun.SchemeTypeTURNS):
conn, connectErr := a.proxyDialer.Dial(NetworkTypeTCP4.String(), turnServerAddr)
if connectErr != nil {
a.log.Warnf("Failed to dial TCP address %s via proxy dialer: %v", turnServerAddr, connectErr)
return
}
relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert
relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert
if url.Scheme == stun.SchemeTypeTURN {
relayProtocol = tcp
} else if url.Scheme == stun.SchemeTypeTURNS {
relayProtocol = "tls"
}
locConn = turn.NewSTUNConn(conn)
case url.Proto == stun.ProtoTypeTCP && url.Scheme == stun.SchemeTypeTURN:
tcpAddr, connectErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr)
if connectErr != nil {
a.log.Warnf("Failed to resolve TCP address %s: %v", turnServerAddr, connectErr)
return
}
conn, connectErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr)
if connectErr != nil {
a.log.Warnf("Failed to dial TCP address %s: %v", turnServerAddr, connectErr)
return
}
relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert
relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert
relayProtocol = tcp
locConn = turn.NewSTUNConn(conn)
case url.Proto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURNS:
udpAddr, connectErr := a.net.ResolveUDPAddr(network, turnServerAddr)
if connectErr != nil {
a.log.Warnf("Failed to resolve UDP address %s: %v", turnServerAddr, connectErr)
return
}
udpConn, dialErr := a.net.DialUDP("udp", nil, udpAddr)
if dialErr != nil {
a.log.Warnf("Failed to dial DTLS address %s: %v", turnServerAddr, dialErr)
return
}
conn, connectErr := dtls.ClientWithOptions(&fakenet.PacketConn{Conn: udpConn}, udpConn.RemoteAddr(),
dtls.WithServerName(url.Host),
dtls.WithInsecureSkipVerify(a.insecureSkipVerify), //nolint:gosec
dtls.WithLoggerFactory(a.loggerFactory),
)
if connectErr != nil {
a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr)
return
}
if connectErr = conn.HandshakeContext(ctx); connectErr != nil {
a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr)
return
}
relAddr = conn.LocalAddr().(*net.UDPAddr).IP.String() //nolint:forcetypeassert
relPort = conn.LocalAddr().(*net.UDPAddr).Port //nolint:forcetypeassert
relayProtocol = relayProtocolDTLS
locConn = &fakenet.PacketConn{Conn: conn}
case url.Proto == stun.ProtoTypeTCP && url.Scheme == stun.SchemeTypeTURNS:
tcpAddr, resolvErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr)
if resolvErr != nil {
a.log.Warnf("Failed to resolve relay address %s: %v", turnServerAddr, resolvErr)
return
}
tcpConn, dialErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr)
if dialErr != nil {
a.log.Warnf("Failed to connect to relay: %v", dialErr)
return
}
conn := tls.Client(tcpConn, &tls.Config{
ServerName: url.Host,
InsecureSkipVerify: a.insecureSkipVerify, //nolint:gosec
})
if hsErr := conn.HandshakeContext(ctx); hsErr != nil {
if closeErr := tcpConn.Close(); closeErr != nil {
a.log.Errorf("Failed to close relay connection: %v", closeErr)
}
a.log.Warnf("Failed to connect to relay: %v", hsErr)
return
}
relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert
relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert
relayProtocol = relayProtocolTLS
locConn = turn.NewSTUNConn(conn)
default:
a.log.Warnf("Unable to handle URL in gatherCandidatesRelay %v", url)
return
}
if connectErr = conn.HandshakeContext(ctx); connectErr != nil {
a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr)
return
factory := a.turnClientFactory
if factory == nil {
factory = defaultTurnClient
}
relAddr = conn.LocalAddr().(*net.UDPAddr).IP.String() //nolint:forcetypeassert
relPort = conn.LocalAddr().(*net.UDPAddr).Port //nolint:forcetypeassert
relayProtocol = relayProtocolDTLS
locConn = &fakenet.PacketConn{Conn: conn}
case url.Proto == stun.ProtoTypeTCP && url.Scheme == stun.SchemeTypeTURNS:
tcpAddr, resolvErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr)
if resolvErr != nil {
a.log.Warnf("Failed to resolve relay address %s: %v", turnServerAddr, resolvErr)
return
}
tcpConn, dialErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr)
if dialErr != nil {
a.log.Warnf("Failed to connect to relay: %v", dialErr)
return
}
conn := tls.Client(tcpConn, &tls.Config{
ServerName: url.Host,
InsecureSkipVerify: a.insecureSkipVerify, //nolint:gosec
client, err := factory(&turn.ClientConfig{
TURNServerAddr: turnServerAddr,
Conn: locConn,
Username: url.Username,
Password: url.Password,
LoggerFactory: a.loggerFactory,
Net: a.net,
})
if hsErr := conn.HandshakeContext(ctx); hsErr != nil {
if closeErr := tcpConn.Close(); closeErr != nil {
a.log.Errorf("Failed to close relay connection: %v", closeErr)
}
a.log.Warnf("Failed to connect to relay: %v", hsErr)
if err != nil {
closeConnAndLog(locConn, a.log, "failed to create new TURN client %s %s", turnServerAddr, err)
return
}
relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert
relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert
relayProtocol = relayProtocolTLS
locConn = turn.NewSTUNConn(conn)
default:
a.log.Warnf("Unable to handle URL in gatherCandidatesRelay %v", url)
return
}
factory := a.turnClientFactory
if factory == nil {
factory = defaultTurnClient
}
client, err := factory(&turn.ClientConfig{
TURNServerAddr: turnServerAddr,
Conn: locConn,
Username: url.Username,
Password: url.Password,
LoggerFactory: a.loggerFactory,
Net: a.net,
})
if err != nil {
closeConnAndLog(locConn, a.log, "failed to create new TURN client %s %s", turnServerAddr, err)
return
}
if err = client.Listen(); err != nil {
client.Close()
closeConnAndLog(locConn, a.log, "failed to listen on TURN client %s %s", turnServerAddr, err)
return
}
relayConn, err := client.Allocate()
if err != nil {
client.Close()
closeConnAndLog(locConn, a.log, "failed to allocate on TURN client %s %s", turnServerAddr, err)
return
}
rAddr := relayConn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert
if shouldFilterLocationTracked(rAddr.IP) {
a.log.Warnf("TURN address %s is somehow filtered for location tracking reasons", rAddr.IP)
return
}
a.addRelayCandidates(ctx, relayEndpoint{
network: network,
address: rAddr.IP,
port: rAddr.Port,
relAddr: relAddr,
relPort: relPort,
iface: findIfaceForIP(ifaces, net.ParseIP(relAddr)),
protocol: relayProtocol,
conn: relayConn,
onClose: func() error {
if err = client.Listen(); err != nil {
client.Close()
closeConnAndLog(locConn, a.log, "failed to listen on TURN client %s %s", turnServerAddr, err)
return locConn.Close()
},
closeConn: func() {
if relayConErr := relayConn.Close(); relayConErr != nil {
a.log.Warnf("Failed to close relay %v", relayConErr)
}
},
})
}(*url)
return
}
relayConn, err := client.Allocate()
if err != nil {
client.Close()
closeConnAndLog(locConn, a.log, "failed to allocate on TURN client %s %s", turnServerAddr, err)
return
}
rAddr := relayConn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert
if shouldFilterLocationTracked(rAddr.IP) {
a.log.Warnf("TURN address %s is somehow filtered for location tracking reasons", rAddr.IP)
return
}
a.addRelayCandidates(ctx, relayEndpoint{
network: network,
address: rAddr.IP,
port: rAddr.Port,
relAddr: relAddr,
relPort: relPort,
iface: findIfaceForIP(ifaces, net.ParseIP(relAddr)),
protocol: relayProtocol,
conn: relayConn,
onClose: func() error {
client.Close()
return locConn.Close()
},
closeConn: func() {
if relayConErr := relayConn.Close(); relayConErr != nil {
a.log.Warnf("Failed to close relay %v", relayConErr)
}
},
})
}(url, localBindAddr)
}
if !useFilteredLocalAddrs {
gatherForURL(*url, "0.0.0.0:0")
continue
}
for i := range localAddrs {
gatherForURL(*url, net.JoinHostPort(localAddrs[i].addr.String(), "0"))
}
}
}
+337
View File
@@ -773,6 +773,128 @@ func (n *relayGatherNet) CreateListenConfig(*net.ListenConfig) transport.ListenC
return nil
}
type relayListenCaptureNet struct {
listenPacketAddresses []string
mu sync.Mutex
}
func newRelayListenCaptureNet() *relayListenCaptureNet {
return &relayListenCaptureNet{}
}
func (n *relayListenCaptureNet) ListenPacket(_ string, address string) (net.PacketConn, error) {
n.mu.Lock()
n.listenPacketAddresses = append(n.listenPacketAddresses, address)
n.mu.Unlock()
udpAddr, err := net.ResolveUDPAddr("udp4", address)
if err != nil {
return nil, err
}
return newStubPacketConn(udpAddr), nil
}
func (n *relayListenCaptureNet) ListenUDP(string, *net.UDPAddr) (transport.UDPConn, error) {
return nil, transport.ErrNotSupported
}
func (n *relayListenCaptureNet) ListenTCP(string, *net.TCPAddr) (transport.TCPListener, error) {
return nil, transport.ErrNotSupported
}
func (n *relayListenCaptureNet) Dial(string, string) (net.Conn, error) {
return nil, transport.ErrNotSupported
}
func (n *relayListenCaptureNet) DialUDP(string, *net.UDPAddr, *net.UDPAddr) (transport.UDPConn, error) {
return nil, transport.ErrNotSupported
}
func (n *relayListenCaptureNet) DialTCP(string, *net.TCPAddr, *net.TCPAddr) (transport.TCPConn, error) {
return nil, transport.ErrNotSupported
}
func (n *relayListenCaptureNet) ResolveIPAddr(network, address string) (*net.IPAddr, error) {
return net.ResolveIPAddr(network, address)
}
func (n *relayListenCaptureNet) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
return net.ResolveUDPAddr(network, address)
}
func (n *relayListenCaptureNet) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) {
return net.ResolveTCPAddr(network, address)
}
func (n *relayListenCaptureNet) Interfaces() ([]*transport.Interface, error) {
iface0 := transport.NewInterface(net.Interface{
Index: 1,
MTU: 1500,
Name: "eth0",
Flags: net.FlagUp,
})
iface0.AddAddress(&net.IPNet{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(8, 32)})
iface1 := transport.NewInterface(net.Interface{
Index: 2,
MTU: 1500,
Name: "wlan0",
Flags: net.FlagUp,
})
iface1.AddAddress(&net.IPNet{IP: net.IPv4(127, 0, 0, 2), Mask: net.CIDRMask(8, 32)})
return []*transport.Interface{iface0, iface1}, nil
}
func (n *relayListenCaptureNet) InterfaceByIndex(index int) (*transport.Interface, error) {
ifaces, err := n.Interfaces()
if err != nil {
return nil, err
}
for _, iface := range ifaces {
if iface.Index == index {
return iface, nil
}
}
return nil, transport.ErrInterfaceNotFound
}
func (n *relayListenCaptureNet) InterfaceByName(name string) (*transport.Interface, error) {
ifaces, err := n.Interfaces()
if err != nil {
return nil, err
}
for _, iface := range ifaces {
if iface.Name == name {
return iface, nil
}
}
return nil, transport.ErrInterfaceNotFound
}
func (n *relayListenCaptureNet) CreateDialer(*net.Dialer) transport.Dialer {
return nil
}
func (n *relayListenCaptureNet) CreateListenConfig(*net.ListenConfig) transport.ListenConfig {
return nil
}
func (n *relayListenCaptureNet) listenAddresses() []string {
n.mu.Lock()
defer n.mu.Unlock()
addrs := make([]string, len(n.listenPacketAddresses))
copy(addrs, n.listenPacketAddresses)
return addrs
}
type hostGatherNet struct {
addr *net.UDPAddr
}
@@ -875,6 +997,138 @@ func (n *hostGatherNet) CreateListenConfig(*net.ListenConfig) transport.ListenCo
return nil
}
type srflxListenCaptureNet struct {
listenCalls []*net.UDPAddr
mu sync.Mutex
}
func newSrflxListenCaptureNet() *srflxListenCaptureNet {
return &srflxListenCaptureNet{}
}
func (n *srflxListenCaptureNet) ListenPacket(string, string) (net.PacketConn, error) {
return newStubPacketConn(&net.UDPAddr{IP: net.IPv4zero, Port: 0}), nil
}
func (n *srflxListenCaptureNet) ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
n.mu.Lock()
if laddr != nil {
n.listenCalls = append(n.listenCalls, &net.UDPAddr{
IP: append(net.IP{}, laddr.IP...),
Port: laddr.Port,
Zone: laddr.Zone,
})
} else {
n.listenCalls = append(n.listenCalls, nil)
}
n.mu.Unlock()
return net.ListenUDP(network, &net.UDPAddr{IP: net.IPv4zero, Port: 0}) //nolint:wrapcheck
}
func (n *srflxListenCaptureNet) ListenTCP(string, *net.TCPAddr) (transport.TCPListener, error) {
return nil, transport.ErrNotSupported
}
func (n *srflxListenCaptureNet) Dial(string, string) (net.Conn, error) {
return nil, transport.ErrNotSupported
}
func (n *srflxListenCaptureNet) DialUDP(string, *net.UDPAddr, *net.UDPAddr) (transport.UDPConn, error) {
return nil, transport.ErrNotSupported
}
func (n *srflxListenCaptureNet) DialTCP(string, *net.TCPAddr, *net.TCPAddr) (transport.TCPConn, error) {
return nil, transport.ErrNotSupported
}
func (n *srflxListenCaptureNet) ResolveIPAddr(network, address string) (*net.IPAddr, error) {
return net.ResolveIPAddr(network, address)
}
func (n *srflxListenCaptureNet) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
return net.ResolveUDPAddr(network, address)
}
func (n *srflxListenCaptureNet) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) {
return net.ResolveTCPAddr(network, address)
}
func (n *srflxListenCaptureNet) Interfaces() ([]*transport.Interface, error) {
iface0 := transport.NewInterface(net.Interface{
Index: 1,
MTU: 1500,
Name: "eth0",
Flags: net.FlagUp,
})
iface0.AddAddress(&net.IPNet{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(8, 32)})
iface1 := transport.NewInterface(net.Interface{
Index: 2,
MTU: 1500,
Name: "wlan0",
Flags: net.FlagUp,
})
iface1.AddAddress(&net.IPNet{IP: net.IPv4(127, 0, 0, 2), Mask: net.CIDRMask(8, 32)})
return []*transport.Interface{iface0, iface1}, nil
}
func (n *srflxListenCaptureNet) InterfaceByIndex(index int) (*transport.Interface, error) {
ifaces, err := n.Interfaces()
if err != nil {
return nil, err
}
for _, iface := range ifaces {
if iface.Index == index {
return iface, nil
}
}
return nil, transport.ErrInterfaceNotFound
}
func (n *srflxListenCaptureNet) InterfaceByName(name string) (*transport.Interface, error) {
ifaces, err := n.Interfaces()
if err != nil {
return nil, err
}
for _, iface := range ifaces {
if iface.Name == name {
return iface, nil
}
}
return nil, transport.ErrInterfaceNotFound
}
func (n *srflxListenCaptureNet) CreateDialer(*net.Dialer) transport.Dialer {
return nil
}
func (n *srflxListenCaptureNet) CreateListenConfig(*net.ListenConfig) transport.ListenConfig {
return nil
}
func (n *srflxListenCaptureNet) listenCallIPs() []string {
n.mu.Lock()
defer n.mu.Unlock()
ipStrings := make([]string, 0, len(n.listenCalls))
for _, addr := range n.listenCalls {
if addr == nil || addr.IP == nil {
ipStrings = append(ipStrings, "")
continue
}
ipStrings = append(ipStrings, addr.IP.String())
}
return ipStrings
}
type errorPacketConn struct {
addr net.Addr
closed bool
@@ -1164,6 +1418,54 @@ func TestGatherCandidatesRelayUsesTurnNet(t *testing.T) {
}
}
func TestGatherCandidatesRelayRespectsInterfaceFilter(t *testing.T) {
defer test.CheckRoutines(t)()
netCapture := newRelayListenCaptureNet()
stubClient := &stubTurnClient{}
agent, err := NewAgentWithOptions(
WithNet(netCapture),
WithNetworkTypes([]NetworkType{NetworkTypeUDP4}),
WithCandidateTypes([]CandidateType{CandidateTypeRelay}),
WithMulticastDNSMode(MulticastDNSModeDisabled),
WithUrls([]*stun.URI{
{
Scheme: stun.SchemeTypeTURN,
Host: "example.com",
Port: 3478,
Username: "username",
Password: "password",
Proto: stun.ProtoTypeUDP,
},
}),
WithInterfaceFilter(func(iface string) bool {
return iface == "eth0"
}),
WithIncludeLoopback(),
)
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
agent.turnClientFactory = func(cfg *turn.ClientConfig) (turnClient, error) {
stubClient.cfgConn = cfg.Conn
return stubClient, nil
}
require.NoError(t, agent.OnCandidate(func(Candidate) {}))
agent.gatherCandidatesRelay(context.Background(), agent.urls)
listenAddrs := netCapture.listenAddresses()
require.NotEmpty(t, listenAddrs)
for _, addr := range listenAddrs {
require.Equal(t, "127.0.0.1:0", addr)
}
}
func TestGatherCandidatesRelayDefaultClientError(t *testing.T) {
defer test.CheckRoutines(t)()
@@ -1427,6 +1729,41 @@ func TestGatherCandidatesSrflxUDPMux(t *testing.T) {
require.Equal(t, 1, udpMuxSrflx.connCount(), "expected mux to be asked for one connection")
}
func TestGatherCandidatesSrflxRespectsInterfaceFilter(t *testing.T) {
netCapture := newSrflxListenCaptureNet()
agent, err := NewAgentWithOptions(
WithNet(netCapture),
WithNetworkTypes([]NetworkType{NetworkTypeUDP4}),
WithCandidateTypes([]CandidateType{CandidateTypeServerReflexive}),
WithMulticastDNSMode(MulticastDNSModeDisabled),
WithUrls([]*stun.URI{{
Scheme: stun.SchemeTypeSTUN,
Host: localhostIPStr,
Port: 9,
}}),
WithInterfaceFilter(func(iface string) bool {
return iface == "eth0"
}),
WithIncludeLoopback(),
WithSTUNGatherTimeout(5*time.Millisecond),
)
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
require.NoError(t, agent.OnCandidate(func(Candidate) {}))
agent.gatherCandidatesSrflx(context.Background(), agent.urls, []NetworkType{NetworkTypeUDP4})
listenIPs := netCapture.listenCallIPs()
require.NotEmpty(t, listenIPs)
for _, ip := range listenIPs {
require.Equal(t, "127.0.0.1", ip)
}
}
// TestUDPMuxDefaultWithNAT1To1IPsUsage requires that candidates
// are given and connections are valid when using UDPMuxDefault and NAT1To1IPs.
func TestUDPMuxDefaultWithNAT1To1IPsUsage(t *testing.T) {