diff --git a/agent_options_test.go b/agent_options_test.go index ee56af3..1f32ec5 100644 --- a/agent_options_test.go +++ b/agent_options_test.go @@ -513,7 +513,7 @@ func TestMultipleConfigOptions(t *testing.T) { func TestWithInterfaceFilter(t *testing.T) { t.Run("sets interface filter", func(t *testing.T) { filter := func(interfaceName string) bool { - return interfaceName == "eth0" + return interfaceName == "eth0" // nolint:goconst } agent, err := NewAgentWithOptions(WithInterfaceFilter(filter)) @@ -521,7 +521,7 @@ func TestWithInterfaceFilter(t *testing.T) { defer agent.Close() //nolint:errcheck assert.NotNil(t, agent.interfaceFilter) - assert.True(t, agent.interfaceFilter("eth0")) + assert.True(t, agent.interfaceFilter("eth0")) // nolint:goconst assert.False(t, agent.interfaceFilter("wlan0")) }) @@ -549,7 +549,7 @@ func TestWithInterfaceFilter(t *testing.T) { assert.NotNil(t, agent.interfaceFilter) assert.True(t, agent.interfaceFilter("lo")) - assert.False(t, agent.interfaceFilter("eth0")) + assert.False(t, agent.interfaceFilter("eth0")) // nolint:goconst }) } diff --git a/gather.go b/gather.go index d025fe9..58f495c 100644 --- a/gather.go +++ b/gather.go @@ -736,89 +736,123 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net var wg sync.WaitGroup defer wg.Wait() + useFilteredLocalAddrs := a.interfaceFilter != nil || a.ipFilter != nil + localAddrs := []ifaceAddr{} + if useFilteredLocalAddrs { + _, addrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback) + if err != nil { + a.log.Warnf("Failed to iterate local interfaces, srflx candidates will not be gathered %s", err) + + return + } + localAddrs = addrs + } + + gatherForURL := func(url stun.URI, network string, listenAddr *net.UDPAddr) { + defer wg.Done() + + hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port) + serverAddr, err := a.net.ResolveUDPAddr(network, hostPort) + if err != nil { + a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err) + + return + } + + if shouldFilterLocationTracked(serverAddr.IP) { + a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort) + + return + } + + conn, err := listenUDPInPortRange( + a.net, + a.log, + int(a.portMax), + int(a.portMin), + network, + listenAddr, + ) + if err != nil { + closeConnAndLog(conn, a.log, "failed to listen for %s: %v", serverAddr.String(), err) + + return + } + // If the agent closes midway through the connection + // we end it early to prevent close delay. + cancelCtx, cancelFunc := context.WithCancel(ctx) + defer cancelFunc() + go func() { + select { + case <-cancelCtx.Done(): + return + case <-a.loop.Done(): + _ = conn.Close() + } + }() + + xorAddr, err := stunx.GetXORMappedAddr(conn, serverAddr, a.stunGatherTimeout) + if err != nil { + closeConnAndLog(conn, a.log, "failed to get server reflexive address %s %s: %v", network, url, err) + + return + } + + ip := xorAddr.IP + port := xorAddr.Port + + lAddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert + srflxConfig := CandidateServerReflexiveConfig{ + Network: network, + Address: ip.String(), + Port: port, + Component: ComponentRTP, + RelAddr: lAddr.IP.String(), + RelPort: lAddr.Port, + } + c, err := NewCandidateServerReflexive(&srflxConfig) + if err != nil { + closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err) + + return + } + + if err := a.addCandidate(ctx, c, conn); err != nil { + if closeErr := c.close(); closeErr != nil { + a.log.Warnf("Failed to close candidate: %v", closeErr) + } + a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) + } + } + for _, networkType := range networkTypes { if networkType.IsTCP() { continue } for i := range urls { - wg.Add(1) - go func(url stun.URI, network string) { - defer wg.Done() + if !useFilteredLocalAddrs { + wg.Add(1) + go gatherForURL(*urls[i], networkType.String(), &net.UDPAddr{IP: nil, Port: 0}) - hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port) - serverAddr, err := a.net.ResolveUDPAddr(network, hostPort) - if err != nil { - a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err) + continue + } - return + for j := range localAddrs { + if networkType.IsIPv4() && localAddrs[j].addr.Is6() { + continue + } + if networkType.IsIPv6() && !localAddrs[j].addr.Is6() { + continue } - if shouldFilterLocationTracked(serverAddr.IP) { - a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort) - - return - } - - conn, err := listenUDPInPortRange( - a.net, - a.log, - int(a.portMax), - int(a.portMin), - network, - &net.UDPAddr{IP: nil, Port: 0}, + wg.Add(1) + go gatherForURL( + *urls[i], + networkType.String(), + &net.UDPAddr{IP: localAddrs[j].addr.AsSlice(), Zone: localAddrs[j].addr.Zone(), Port: 0}, ) - if err != nil { - closeConnAndLog(conn, a.log, "failed to listen for %s: %v", serverAddr.String(), err) - - return - } - // If the agent closes midway through the connection - // we end it early to prevent close delay. - cancelCtx, cancelFunc := context.WithCancel(ctx) - defer cancelFunc() - go func() { - select { - case <-cancelCtx.Done(): - return - case <-a.loop.Done(): - _ = conn.Close() - } - }() - - xorAddr, err := stunx.GetXORMappedAddr(conn, serverAddr, a.stunGatherTimeout) - if err != nil { - closeConnAndLog(conn, a.log, "failed to get server reflexive address %s %s: %v", network, url, err) - - return - } - - ip := xorAddr.IP - port := xorAddr.Port - - lAddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert - srflxConfig := CandidateServerReflexiveConfig{ - Network: network, - Address: ip.String(), - Port: port, - Component: ComponentRTP, - RelAddr: lAddr.IP.String(), - RelPort: lAddr.Port, - } - c, err := NewCandidateServerReflexive(&srflxConfig) - if err != nil { - closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err) - - return - } - - if err := a.addCandidate(ctx, c, conn); err != nil { - if closeErr := c.close(); closeErr != nil { - a.log.Warnf("Failed to close candidate: %v", closeErr) - } - a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) - } - }(*urls[i], networkType.String()) + } } } } @@ -831,6 +865,18 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { network := NetworkTypeUDP4.String() _, ifaces, _ := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, a.networkTypes, a.includeLoopback) + useFilteredLocalAddrs := a.interfaceFilter != nil || a.ipFilter != nil + localAddrs := []ifaceAddr{} + if useFilteredLocalAddrs { + for i := range ifaces { + if ifaces[i].addr.Is6() { + continue + } + + localAddrs = append(localAddrs, ifaces[i]) + } + } + for _, url := range urls { switch { case url.Scheme != stun.SchemeTypeTURN && url.Scheme != stun.SchemeTypeTURNS: @@ -845,204 +891,216 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { return } - wg.Add(1) - go func(url stun.URI) { - defer wg.Done() - turnServerAddr := fmt.Sprintf("%s:%d", url.Host, url.Port) - var ( - locConn net.PacketConn - err error - relAddr string - relPort int - relayProtocol string - ) - - switch { - case url.Proto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURN: - if locConn, err = a.net.ListenPacket(network, "0.0.0.0:0"); err != nil { - a.log.Warnf("Failed to listen %s: %v", network, err) - - return - } - - relAddr = locConn.LocalAddr().(*net.UDPAddr).IP.String() //nolint:forcetypeassert - relPort = locConn.LocalAddr().(*net.UDPAddr).Port //nolint:forcetypeassert - relayProtocol = udp - case a.proxyDialer != nil && url.Proto == stun.ProtoTypeTCP && - (url.Scheme == stun.SchemeTypeTURN || url.Scheme == stun.SchemeTypeTURNS): - conn, connectErr := a.proxyDialer.Dial(NetworkTypeTCP4.String(), turnServerAddr) - if connectErr != nil { - a.log.Warnf("Failed to dial TCP address %s via proxy dialer: %v", turnServerAddr, connectErr) - - return - } - - relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert - relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert - if url.Scheme == stun.SchemeTypeTURN { - relayProtocol = tcp - } else if url.Scheme == stun.SchemeTypeTURNS { - relayProtocol = "tls" - } - locConn = turn.NewSTUNConn(conn) - - case url.Proto == stun.ProtoTypeTCP && url.Scheme == stun.SchemeTypeTURN: - tcpAddr, connectErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr) - if connectErr != nil { - a.log.Warnf("Failed to resolve TCP address %s: %v", turnServerAddr, connectErr) - - return - } - - conn, connectErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr) - if connectErr != nil { - a.log.Warnf("Failed to dial TCP address %s: %v", turnServerAddr, connectErr) - - return - } - - relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert - relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert - relayProtocol = tcp - locConn = turn.NewSTUNConn(conn) - case url.Proto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURNS: - udpAddr, connectErr := a.net.ResolveUDPAddr(network, turnServerAddr) - if connectErr != nil { - a.log.Warnf("Failed to resolve UDP address %s: %v", turnServerAddr, connectErr) - - return - } - - udpConn, dialErr := a.net.DialUDP("udp", nil, udpAddr) - if dialErr != nil { - a.log.Warnf("Failed to dial DTLS address %s: %v", turnServerAddr, dialErr) - - return - } - - conn, connectErr := dtls.ClientWithOptions(&fakenet.PacketConn{Conn: udpConn}, udpConn.RemoteAddr(), - dtls.WithServerName(url.Host), - dtls.WithInsecureSkipVerify(a.insecureSkipVerify), //nolint:gosec - dtls.WithLoggerFactory(a.loggerFactory), + gatherForURL := func(url stun.URI, localBindAddr string) { + wg.Add(1) + go func(url stun.URI, localBindAddr string) { + defer wg.Done() + turnServerAddr := fmt.Sprintf("%s:%d", url.Host, url.Port) + var ( + locConn net.PacketConn + err error + relAddr string + relPort int + relayProtocol string ) - if connectErr != nil { - a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr) + + switch { + case url.Proto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURN: + if locConn, err = a.net.ListenPacket(network, localBindAddr); err != nil { + a.log.Warnf("Failed to listen %s: %v", network, err) + + return + } + + relAddr = locConn.LocalAddr().(*net.UDPAddr).IP.String() //nolint:forcetypeassert + relPort = locConn.LocalAddr().(*net.UDPAddr).Port //nolint:forcetypeassert + relayProtocol = udp + case a.proxyDialer != nil && url.Proto == stun.ProtoTypeTCP && + (url.Scheme == stun.SchemeTypeTURN || url.Scheme == stun.SchemeTypeTURNS): + conn, connectErr := a.proxyDialer.Dial(NetworkTypeTCP4.String(), turnServerAddr) + if connectErr != nil { + a.log.Warnf("Failed to dial TCP address %s via proxy dialer: %v", turnServerAddr, connectErr) + + return + } + + relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert + relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert + if url.Scheme == stun.SchemeTypeTURN { + relayProtocol = tcp + } else if url.Scheme == stun.SchemeTypeTURNS { + relayProtocol = "tls" + } + locConn = turn.NewSTUNConn(conn) + + case url.Proto == stun.ProtoTypeTCP && url.Scheme == stun.SchemeTypeTURN: + tcpAddr, connectErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr) + if connectErr != nil { + a.log.Warnf("Failed to resolve TCP address %s: %v", turnServerAddr, connectErr) + + return + } + + conn, connectErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr) + if connectErr != nil { + a.log.Warnf("Failed to dial TCP address %s: %v", turnServerAddr, connectErr) + + return + } + + relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert + relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert + relayProtocol = tcp + locConn = turn.NewSTUNConn(conn) + case url.Proto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURNS: + udpAddr, connectErr := a.net.ResolveUDPAddr(network, turnServerAddr) + if connectErr != nil { + a.log.Warnf("Failed to resolve UDP address %s: %v", turnServerAddr, connectErr) + + return + } + + udpConn, dialErr := a.net.DialUDP("udp", nil, udpAddr) + if dialErr != nil { + a.log.Warnf("Failed to dial DTLS address %s: %v", turnServerAddr, dialErr) + + return + } + + conn, connectErr := dtls.ClientWithOptions(&fakenet.PacketConn{Conn: udpConn}, udpConn.RemoteAddr(), + dtls.WithServerName(url.Host), + dtls.WithInsecureSkipVerify(a.insecureSkipVerify), //nolint:gosec + dtls.WithLoggerFactory(a.loggerFactory), + ) + if connectErr != nil { + a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr) + + return + } + + if connectErr = conn.HandshakeContext(ctx); connectErr != nil { + a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr) + + return + } + + relAddr = conn.LocalAddr().(*net.UDPAddr).IP.String() //nolint:forcetypeassert + relPort = conn.LocalAddr().(*net.UDPAddr).Port //nolint:forcetypeassert + relayProtocol = relayProtocolDTLS + locConn = &fakenet.PacketConn{Conn: conn} + case url.Proto == stun.ProtoTypeTCP && url.Scheme == stun.SchemeTypeTURNS: + tcpAddr, resolvErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr) + if resolvErr != nil { + a.log.Warnf("Failed to resolve relay address %s: %v", turnServerAddr, resolvErr) + + return + } + + tcpConn, dialErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr) + if dialErr != nil { + a.log.Warnf("Failed to connect to relay: %v", dialErr) + + return + } + + conn := tls.Client(tcpConn, &tls.Config{ + ServerName: url.Host, + InsecureSkipVerify: a.insecureSkipVerify, //nolint:gosec + }) + + if hsErr := conn.HandshakeContext(ctx); hsErr != nil { + if closeErr := tcpConn.Close(); closeErr != nil { + a.log.Errorf("Failed to close relay connection: %v", closeErr) + } + a.log.Warnf("Failed to connect to relay: %v", hsErr) + + return + } + + relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert + relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert + relayProtocol = relayProtocolTLS + locConn = turn.NewSTUNConn(conn) + default: + a.log.Warnf("Unable to handle URL in gatherCandidatesRelay %v", url) return } - if connectErr = conn.HandshakeContext(ctx); connectErr != nil { - a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr) - - return + factory := a.turnClientFactory + if factory == nil { + factory = defaultTurnClient } - relAddr = conn.LocalAddr().(*net.UDPAddr).IP.String() //nolint:forcetypeassert - relPort = conn.LocalAddr().(*net.UDPAddr).Port //nolint:forcetypeassert - relayProtocol = relayProtocolDTLS - locConn = &fakenet.PacketConn{Conn: conn} - case url.Proto == stun.ProtoTypeTCP && url.Scheme == stun.SchemeTypeTURNS: - tcpAddr, resolvErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr) - if resolvErr != nil { - a.log.Warnf("Failed to resolve relay address %s: %v", turnServerAddr, resolvErr) - - return - } - - tcpConn, dialErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr) - if dialErr != nil { - a.log.Warnf("Failed to connect to relay: %v", dialErr) - - return - } - - conn := tls.Client(tcpConn, &tls.Config{ - ServerName: url.Host, - InsecureSkipVerify: a.insecureSkipVerify, //nolint:gosec + client, err := factory(&turn.ClientConfig{ + TURNServerAddr: turnServerAddr, + Conn: locConn, + Username: url.Username, + Password: url.Password, + LoggerFactory: a.loggerFactory, + Net: a.net, }) - - if hsErr := conn.HandshakeContext(ctx); hsErr != nil { - if closeErr := tcpConn.Close(); closeErr != nil { - a.log.Errorf("Failed to close relay connection: %v", closeErr) - } - a.log.Warnf("Failed to connect to relay: %v", hsErr) + if err != nil { + closeConnAndLog(locConn, a.log, "failed to create new TURN client %s %s", turnServerAddr, err) return } - relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert - relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert - relayProtocol = relayProtocolTLS - locConn = turn.NewSTUNConn(conn) - default: - a.log.Warnf("Unable to handle URL in gatherCandidatesRelay %v", url) - - return - } - - factory := a.turnClientFactory - if factory == nil { - factory = defaultTurnClient - } - - client, err := factory(&turn.ClientConfig{ - TURNServerAddr: turnServerAddr, - Conn: locConn, - Username: url.Username, - Password: url.Password, - LoggerFactory: a.loggerFactory, - Net: a.net, - }) - if err != nil { - closeConnAndLog(locConn, a.log, "failed to create new TURN client %s %s", turnServerAddr, err) - - return - } - - if err = client.Listen(); err != nil { - client.Close() - closeConnAndLog(locConn, a.log, "failed to listen on TURN client %s %s", turnServerAddr, err) - - return - } - - relayConn, err := client.Allocate() - if err != nil { - client.Close() - closeConnAndLog(locConn, a.log, "failed to allocate on TURN client %s %s", turnServerAddr, err) - - return - } - - rAddr := relayConn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert - - if shouldFilterLocationTracked(rAddr.IP) { - a.log.Warnf("TURN address %s is somehow filtered for location tracking reasons", rAddr.IP) - - return - } - - a.addRelayCandidates(ctx, relayEndpoint{ - network: network, - address: rAddr.IP, - port: rAddr.Port, - relAddr: relAddr, - relPort: relPort, - iface: findIfaceForIP(ifaces, net.ParseIP(relAddr)), - protocol: relayProtocol, - conn: relayConn, - onClose: func() error { + if err = client.Listen(); err != nil { client.Close() + closeConnAndLog(locConn, a.log, "failed to listen on TURN client %s %s", turnServerAddr, err) - return locConn.Close() - }, - closeConn: func() { - if relayConErr := relayConn.Close(); relayConErr != nil { - a.log.Warnf("Failed to close relay %v", relayConErr) - } - }, - }) - }(*url) + return + } + + relayConn, err := client.Allocate() + if err != nil { + client.Close() + closeConnAndLog(locConn, a.log, "failed to allocate on TURN client %s %s", turnServerAddr, err) + + return + } + + rAddr := relayConn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert + + if shouldFilterLocationTracked(rAddr.IP) { + a.log.Warnf("TURN address %s is somehow filtered for location tracking reasons", rAddr.IP) + + return + } + + a.addRelayCandidates(ctx, relayEndpoint{ + network: network, + address: rAddr.IP, + port: rAddr.Port, + relAddr: relAddr, + relPort: relPort, + iface: findIfaceForIP(ifaces, net.ParseIP(relAddr)), + protocol: relayProtocol, + conn: relayConn, + onClose: func() error { + client.Close() + + return locConn.Close() + }, + closeConn: func() { + if relayConErr := relayConn.Close(); relayConErr != nil { + a.log.Warnf("Failed to close relay %v", relayConErr) + } + }, + }) + }(url, localBindAddr) + } + + if !useFilteredLocalAddrs { + gatherForURL(*url, "0.0.0.0:0") + + continue + } + + for i := range localAddrs { + gatherForURL(*url, net.JoinHostPort(localAddrs[i].addr.String(), "0")) + } } } diff --git a/gather_test.go b/gather_test.go index c8766e7..8f6f104 100644 --- a/gather_test.go +++ b/gather_test.go @@ -773,6 +773,128 @@ func (n *relayGatherNet) CreateListenConfig(*net.ListenConfig) transport.ListenC return nil } +type relayListenCaptureNet struct { + listenPacketAddresses []string + mu sync.Mutex +} + +func newRelayListenCaptureNet() *relayListenCaptureNet { + return &relayListenCaptureNet{} +} + +func (n *relayListenCaptureNet) ListenPacket(_ string, address string) (net.PacketConn, error) { + n.mu.Lock() + n.listenPacketAddresses = append(n.listenPacketAddresses, address) + n.mu.Unlock() + + udpAddr, err := net.ResolveUDPAddr("udp4", address) + if err != nil { + return nil, err + } + + return newStubPacketConn(udpAddr), nil +} + +func (n *relayListenCaptureNet) ListenUDP(string, *net.UDPAddr) (transport.UDPConn, error) { + return nil, transport.ErrNotSupported +} + +func (n *relayListenCaptureNet) ListenTCP(string, *net.TCPAddr) (transport.TCPListener, error) { + return nil, transport.ErrNotSupported +} + +func (n *relayListenCaptureNet) Dial(string, string) (net.Conn, error) { + return nil, transport.ErrNotSupported +} + +func (n *relayListenCaptureNet) DialUDP(string, *net.UDPAddr, *net.UDPAddr) (transport.UDPConn, error) { + return nil, transport.ErrNotSupported +} + +func (n *relayListenCaptureNet) DialTCP(string, *net.TCPAddr, *net.TCPAddr) (transport.TCPConn, error) { + return nil, transport.ErrNotSupported +} + +func (n *relayListenCaptureNet) ResolveIPAddr(network, address string) (*net.IPAddr, error) { + return net.ResolveIPAddr(network, address) +} + +func (n *relayListenCaptureNet) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { + return net.ResolveUDPAddr(network, address) +} + +func (n *relayListenCaptureNet) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { + return net.ResolveTCPAddr(network, address) +} + +func (n *relayListenCaptureNet) Interfaces() ([]*transport.Interface, error) { + iface0 := transport.NewInterface(net.Interface{ + Index: 1, + MTU: 1500, + Name: "eth0", + Flags: net.FlagUp, + }) + iface0.AddAddress(&net.IPNet{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(8, 32)}) + + iface1 := transport.NewInterface(net.Interface{ + Index: 2, + MTU: 1500, + Name: "wlan0", + Flags: net.FlagUp, + }) + iface1.AddAddress(&net.IPNet{IP: net.IPv4(127, 0, 0, 2), Mask: net.CIDRMask(8, 32)}) + + return []*transport.Interface{iface0, iface1}, nil +} + +func (n *relayListenCaptureNet) InterfaceByIndex(index int) (*transport.Interface, error) { + ifaces, err := n.Interfaces() + if err != nil { + return nil, err + } + + for _, iface := range ifaces { + if iface.Index == index { + return iface, nil + } + } + + return nil, transport.ErrInterfaceNotFound +} + +func (n *relayListenCaptureNet) InterfaceByName(name string) (*transport.Interface, error) { + ifaces, err := n.Interfaces() + if err != nil { + return nil, err + } + + for _, iface := range ifaces { + if iface.Name == name { + return iface, nil + } + } + + return nil, transport.ErrInterfaceNotFound +} + +func (n *relayListenCaptureNet) CreateDialer(*net.Dialer) transport.Dialer { + return nil +} + +func (n *relayListenCaptureNet) CreateListenConfig(*net.ListenConfig) transport.ListenConfig { + return nil +} + +func (n *relayListenCaptureNet) listenAddresses() []string { + n.mu.Lock() + defer n.mu.Unlock() + + addrs := make([]string, len(n.listenPacketAddresses)) + copy(addrs, n.listenPacketAddresses) + + return addrs +} + type hostGatherNet struct { addr *net.UDPAddr } @@ -875,6 +997,138 @@ func (n *hostGatherNet) CreateListenConfig(*net.ListenConfig) transport.ListenCo return nil } +type srflxListenCaptureNet struct { + listenCalls []*net.UDPAddr + mu sync.Mutex +} + +func newSrflxListenCaptureNet() *srflxListenCaptureNet { + return &srflxListenCaptureNet{} +} + +func (n *srflxListenCaptureNet) ListenPacket(string, string) (net.PacketConn, error) { + return newStubPacketConn(&net.UDPAddr{IP: net.IPv4zero, Port: 0}), nil +} + +func (n *srflxListenCaptureNet) ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { + n.mu.Lock() + if laddr != nil { + n.listenCalls = append(n.listenCalls, &net.UDPAddr{ + IP: append(net.IP{}, laddr.IP...), + Port: laddr.Port, + Zone: laddr.Zone, + }) + } else { + n.listenCalls = append(n.listenCalls, nil) + } + n.mu.Unlock() + + return net.ListenUDP(network, &net.UDPAddr{IP: net.IPv4zero, Port: 0}) //nolint:wrapcheck +} + +func (n *srflxListenCaptureNet) ListenTCP(string, *net.TCPAddr) (transport.TCPListener, error) { + return nil, transport.ErrNotSupported +} + +func (n *srflxListenCaptureNet) Dial(string, string) (net.Conn, error) { + return nil, transport.ErrNotSupported +} + +func (n *srflxListenCaptureNet) DialUDP(string, *net.UDPAddr, *net.UDPAddr) (transport.UDPConn, error) { + return nil, transport.ErrNotSupported +} + +func (n *srflxListenCaptureNet) DialTCP(string, *net.TCPAddr, *net.TCPAddr) (transport.TCPConn, error) { + return nil, transport.ErrNotSupported +} + +func (n *srflxListenCaptureNet) ResolveIPAddr(network, address string) (*net.IPAddr, error) { + return net.ResolveIPAddr(network, address) +} + +func (n *srflxListenCaptureNet) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { + return net.ResolveUDPAddr(network, address) +} + +func (n *srflxListenCaptureNet) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { + return net.ResolveTCPAddr(network, address) +} + +func (n *srflxListenCaptureNet) Interfaces() ([]*transport.Interface, error) { + iface0 := transport.NewInterface(net.Interface{ + Index: 1, + MTU: 1500, + Name: "eth0", + Flags: net.FlagUp, + }) + iface0.AddAddress(&net.IPNet{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(8, 32)}) + + iface1 := transport.NewInterface(net.Interface{ + Index: 2, + MTU: 1500, + Name: "wlan0", + Flags: net.FlagUp, + }) + iface1.AddAddress(&net.IPNet{IP: net.IPv4(127, 0, 0, 2), Mask: net.CIDRMask(8, 32)}) + + return []*transport.Interface{iface0, iface1}, nil +} + +func (n *srflxListenCaptureNet) InterfaceByIndex(index int) (*transport.Interface, error) { + ifaces, err := n.Interfaces() + if err != nil { + return nil, err + } + + for _, iface := range ifaces { + if iface.Index == index { + return iface, nil + } + } + + return nil, transport.ErrInterfaceNotFound +} + +func (n *srflxListenCaptureNet) InterfaceByName(name string) (*transport.Interface, error) { + ifaces, err := n.Interfaces() + if err != nil { + return nil, err + } + + for _, iface := range ifaces { + if iface.Name == name { + return iface, nil + } + } + + return nil, transport.ErrInterfaceNotFound +} + +func (n *srflxListenCaptureNet) CreateDialer(*net.Dialer) transport.Dialer { + return nil +} + +func (n *srflxListenCaptureNet) CreateListenConfig(*net.ListenConfig) transport.ListenConfig { + return nil +} + +func (n *srflxListenCaptureNet) listenCallIPs() []string { + n.mu.Lock() + defer n.mu.Unlock() + + ipStrings := make([]string, 0, len(n.listenCalls)) + for _, addr := range n.listenCalls { + if addr == nil || addr.IP == nil { + ipStrings = append(ipStrings, "") + + continue + } + ipStrings = append(ipStrings, addr.IP.String()) + } + + return ipStrings +} + type errorPacketConn struct { addr net.Addr closed bool @@ -1164,6 +1418,54 @@ func TestGatherCandidatesRelayUsesTurnNet(t *testing.T) { } } +func TestGatherCandidatesRelayRespectsInterfaceFilter(t *testing.T) { + defer test.CheckRoutines(t)() + + netCapture := newRelayListenCaptureNet() + stubClient := &stubTurnClient{} + + agent, err := NewAgentWithOptions( + WithNet(netCapture), + WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), + WithCandidateTypes([]CandidateType{CandidateTypeRelay}), + WithMulticastDNSMode(MulticastDNSModeDisabled), + WithUrls([]*stun.URI{ + { + Scheme: stun.SchemeTypeTURN, + Host: "example.com", + Port: 3478, + Username: "username", + Password: "password", + Proto: stun.ProtoTypeUDP, + }, + }), + WithInterfaceFilter(func(iface string) bool { + return iface == "eth0" + }), + WithIncludeLoopback(), + ) + require.NoError(t, err) + defer func() { + require.NoError(t, agent.Close()) + }() + + agent.turnClientFactory = func(cfg *turn.ClientConfig) (turnClient, error) { + stubClient.cfgConn = cfg.Conn + + return stubClient, nil + } + + require.NoError(t, agent.OnCandidate(func(Candidate) {})) + + agent.gatherCandidatesRelay(context.Background(), agent.urls) + + listenAddrs := netCapture.listenAddresses() + require.NotEmpty(t, listenAddrs) + for _, addr := range listenAddrs { + require.Equal(t, "127.0.0.1:0", addr) + } +} + func TestGatherCandidatesRelayDefaultClientError(t *testing.T) { defer test.CheckRoutines(t)() @@ -1427,6 +1729,41 @@ func TestGatherCandidatesSrflxUDPMux(t *testing.T) { require.Equal(t, 1, udpMuxSrflx.connCount(), "expected mux to be asked for one connection") } +func TestGatherCandidatesSrflxRespectsInterfaceFilter(t *testing.T) { + netCapture := newSrflxListenCaptureNet() + + agent, err := NewAgentWithOptions( + WithNet(netCapture), + WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), + WithCandidateTypes([]CandidateType{CandidateTypeServerReflexive}), + WithMulticastDNSMode(MulticastDNSModeDisabled), + WithUrls([]*stun.URI{{ + Scheme: stun.SchemeTypeSTUN, + Host: localhostIPStr, + Port: 9, + }}), + WithInterfaceFilter(func(iface string) bool { + return iface == "eth0" + }), + WithIncludeLoopback(), + WithSTUNGatherTimeout(5*time.Millisecond), + ) + require.NoError(t, err) + defer func() { + require.NoError(t, agent.Close()) + }() + + require.NoError(t, agent.OnCandidate(func(Candidate) {})) + + agent.gatherCandidatesSrflx(context.Background(), agent.urls, []NetworkType{NetworkTypeUDP4}) + + listenIPs := netCapture.listenCallIPs() + require.NotEmpty(t, listenIPs) + for _, ip := range listenIPs { + require.Equal(t, "127.0.0.1", ip) + } +} + // TestUDPMuxDefaultWithNAT1To1IPsUsage requires that candidates // are given and connections are valid when using UDPMuxDefault and NAT1To1IPs. func TestUDPMuxDefaultWithNAT1To1IPsUsage(t *testing.T) {