Files
ice/agent_test.go
sirzooro ceccc15191 Fix race and disconnection after candidate replace (#913)
This PR fixes two race conditions introduced by peer-reflexive candidate
replacement and addresses a connectivity regression where an agent could
transition from Connected to Disconnected/Failed shortly after
replacement.

## Problems

1. Data race in candidate pair access:

- `replaceRedundantPeerReflexiveCandidates` updated `pair.Remote` in
place.
- At the same time, hot-path reads (for example `CandidatePair.Write`)
could read `pair.Remote` without synchronization.

2. Data race in remote candidate cache map:

- `candidateBase.handleInboundPacket` (non-STUN path) read/wrote
`remoteCandidateCaches` from recv loop goroutines.
- Concurrent writes also happened from agent-loop code
(`replaceRemoteCandidateCacheValues`).

3. Connectivity regression after replacement:

- When replacing a selected pair's remote candidate (prflx -> signaled),
the new candidate could have zero `LastReceived`/`LastSent`.
- `validateSelectedPair` uses `selectedPair.Remote.LastReceived()`,
which could cause quick `Connected -> Disconnected -> Failed`
transitions.

## Root Cause

Replacement logic preserved pair priority but mutated shared objects in
place and did not preserve candidate activity timestamps across
candidate object replacement.

## Fixes

### 1) Replace candidate pair objects instead of mutating `pair.Remote`

- Added helper `replacePairRemote(pair, remote)` to clone pair state
into a new `CandidatePair`.
- Preserved pair fields and runtime stats:
  - `id`, role, state, nomination flags, binding counters.
  - RTT fields, packet/byte counters, request/response counters.
  - timestamp `atomic.Value` fields.
- Replaced references in:
  - `a.checklist[i]`
  - `a.pairsByID[pair.id]`
- If old pair was selected, published replacement via
`a.setSelectedPair(replacement)`.

### 2) Serialize remote cache access on agent loop

- Updated non-STUN path in `candidateBase.handleInboundPacket`:
- Cache validation, remote candidate lookup, `seen(false)`, and cache
insert now run inside `agent.loop.Run(...)`.
- This removes concurrent map read/write between recv loop and
agent-loop updates.

### 3) Preserve candidate activity on replacement

- Added `copyCandidateActivity(dst, src)` to transfer:
  - `LastReceived`
  - `LastSent`
- Applied before replacing references so selected-pair liveness checks
remain stable.
2026-04-21 06:27:53 +02:00

3389 lines
96 KiB
Go

// SPDX-FileCopyrightText: 2026 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
//go:build !js
package ice
import (
"context"
"net"
"net/netip"
"strconv"
"sync"
"testing"
"time"
"github.com/pion/ice/v4/internal/fakenet"
"github.com/pion/logging"
"github.com/pion/stun/v3"
"github.com/pion/transport/v4/test"
"github.com/pion/transport/v4/vnet"
"github.com/stretchr/testify/require"
)
type BadAddr struct{}
func (ba *BadAddr) Network() string {
return "xxx"
}
func (ba *BadAddr) String() string {
return "yyy"
}
type recordingSelector struct {
handledBindingRequest bool
handledSuccess bool
}
func (r *recordingSelector) Start() {}
func (r *recordingSelector) ContactCandidates() {}
func (r *recordingSelector) PingCandidate(Candidate, Candidate) {}
func (r *recordingSelector) HandleSuccessResponse(*stun.Message, Candidate, Candidate, net.Addr) {
r.handledSuccess = true
}
func (r *recordingSelector) HandleBindingRequest(*stun.Message, Candidate, Candidate) {
r.handledBindingRequest = true
}
func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop,maintidx
defer test.CheckRoutines(t)()
// Limit runtime in case of deadlocks
defer test.TimeOut(time.Second * 2).Stop()
t.Run("UDP prflx candidate from handleInbound()", func(t *testing.T) {
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
sel := &controllingSelector{agent: agent, log: agent.log}
agent.selector = sel
hostConfig := CandidateHostConfig{
Network: "udp",
Address: "192.168.0.2",
Port: 777,
Component: 1,
}
local, err := NewCandidateHost(&hostConfig)
local.conn = &fakenet.MockPacketConn{}
require.NoError(t, err)
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
UseCandidate(),
AttrControlling(agent.tieBreaker),
PriorityAttr(local.Priority()),
stun.NewShortTermIntegrity(agent.localPwd),
stun.Fingerprint,
)
require.NoError(t, err)
// nolint: contextcheck
agent.handleInbound(msg, local, remote)
// Length of remote candidate list must be one now
require.Len(t, agent.remoteCandidates, 1)
// Length of remote candidate list for a network type must be 1
set := agent.remoteCandidates[local.NetworkType()]
require.Len(t, set, 1)
c := set[0]
require.Equal(t, CandidateTypePeerReflexive, c.Type())
require.Equal(t, "172.17.0.3", c.Address())
require.Equal(t, 999, c.Port())
}))
})
t.Run("prflx candidate priority comes from inbound PRIORITY", func(t *testing.T) {
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
sel := &controllingSelector{agent: agent, log: agent.log}
agent.selector = sel
local, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "192.168.0.2",
Port: 777,
Component: 1,
})
require.NoError(t, err)
local.conn = &fakenet.MockPacketConn{}
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
remotePriority := uint32(123456)
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
UseCandidate(),
AttrControlling(agent.tieBreaker),
PriorityAttr(remotePriority),
stun.NewShortTermIntegrity(agent.localPwd),
stun.Fingerprint,
)
require.NoError(t, err)
// nolint: contextcheck
agent.handleInbound(msg, local, remote)
set := agent.remoteCandidates[local.NetworkType()]
require.Len(t, set, 1)
c := set[0]
require.Equal(t, CandidateTypePeerReflexive, c.Type())
require.Equal(t, remotePriority, c.Priority())
}))
})
t.Run("Signaled host candidate replaces existing remote prflx candidate", func(t *testing.T) {
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
sel := &controllingSelector{agent: agent, log: agent.log}
agent.selector = sel
local, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "192.168.0.2",
Port: 777,
Component: 1,
})
require.NoError(t, err)
local.conn = &fakenet.MockPacketConn{}
agent.localCandidates[local.NetworkType()] = []Candidate{local}
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
UseCandidate(),
AttrControlling(agent.tieBreaker),
PriorityAttr(uint32(99999)),
stun.NewShortTermIntegrity(agent.localPwd),
stun.Fingerprint,
)
require.NoError(t, err)
// nolint: contextcheck
agent.handleInbound(msg, local, remote)
set := agent.remoteCandidates[local.NetworkType()]
require.Len(t, set, 1)
prflx := set[0]
require.Equal(t, CandidateTypePeerReflexive, prflx.Type())
require.Len(t, agent.checklist, 1)
pair := agent.checklist[0]
require.Equal(t, prflx, pair.Remote)
local.addRemoteCandidateCache(prflx, remote)
prflx.seen(false)
prflx.seen(true)
oldLastReceived := prflx.LastReceived()
oldLastSent := prflx.LastSent()
oldPriority := pair.priority()
agent.setSelectedPair(pair)
sel.nominatedPair = pair
host, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "172.17.0.3",
Port: 999,
Component: 1,
})
require.NoError(t, err)
agent.addRemoteCandidate(host) // nolint:contextcheck
set = agent.remoteCandidates[local.NetworkType()]
require.Len(t, set, 1)
require.Equal(t, CandidateTypeHost, set[0].Type())
require.Equal(t, host, set[0])
updatedPair := agent.findPair(local, host)
require.NotNil(t, updatedPair)
require.NotSame(t, pair, updatedPair)
require.Equal(t, host, updatedPair.Remote)
require.Equal(t, oldPriority, updatedPair.priority())
require.False(t, updatedPair.Remote.LastReceived().IsZero())
require.WithinDuration(t, oldLastReceived, updatedPair.Remote.LastReceived(), 10*time.Millisecond)
require.False(t, updatedPair.Remote.LastSent().IsZero())
require.WithinDuration(t, oldLastSent, updatedPair.Remote.LastSent(), 10*time.Millisecond)
require.Equal(t, prflx, pair.Remote)
require.Same(t, updatedPair, agent.getSelectedPair())
require.Same(t, updatedPair, sel.nominatedPair)
cached, ok := local.remoteCandidateCaches.Load(toAddrPort(remote))
require.True(t, ok)
require.Equal(t, host, cached)
}))
})
t.Run("Signaled srflx candidate replaces existing remote prflx candidate", func(t *testing.T) { // nolint:dupl
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.selector = &controllingSelector{agent: agent, log: agent.log}
local, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "192.168.0.2",
Port: 777,
Component: 1,
})
require.NoError(t, err)
local.conn = &fakenet.MockPacketConn{}
agent.localCandidates[local.NetworkType()] = []Candidate{local}
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
UseCandidate(),
AttrControlling(agent.tieBreaker),
PriorityAttr(uint32(99999)),
stun.NewShortTermIntegrity(agent.localPwd),
stun.Fingerprint,
)
require.NoError(t, err)
// nolint: contextcheck
agent.handleInbound(msg, local, remote)
set := agent.remoteCandidates[local.NetworkType()]
require.Len(t, set, 1)
prflx := set[0]
require.Equal(t, CandidateTypePeerReflexive, prflx.Type())
require.Len(t, agent.checklist, 1)
pair := agent.checklist[0]
require.Equal(t, prflx, pair.Remote)
local.addRemoteCandidateCache(prflx, remote)
oldPriority := pair.priority()
srflx, err := NewCandidateServerReflexive(&CandidateServerReflexiveConfig{
Network: "udp",
Address: "172.17.0.3",
Port: 999,
Component: 1,
RelAddr: "0.0.0.0",
RelPort: 0,
})
require.NoError(t, err)
agent.addRemoteCandidate(srflx) // nolint:contextcheck
set = agent.remoteCandidates[local.NetworkType()]
require.Len(t, set, 1)
require.Equal(t, CandidateTypeServerReflexive, set[0].Type())
require.Equal(t, srflx, set[0])
updatedPair := agent.findPair(local, srflx)
require.NotNil(t, updatedPair)
require.NotSame(t, pair, updatedPair)
require.Equal(t, srflx, updatedPair.Remote)
require.Equal(t, oldPriority, updatedPair.priority())
require.Equal(t, prflx, pair.Remote)
cached, ok := local.remoteCandidateCaches.Load(toAddrPort(remote))
require.True(t, ok)
require.Equal(t, srflx, cached)
}))
})
t.Run("Signaled relay candidate replaces existing remote prflx candidate", func(t *testing.T) { // nolint:dupl
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.selector = &controllingSelector{agent: agent, log: agent.log}
local, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "192.168.0.2",
Port: 777,
Component: 1,
})
require.NoError(t, err)
local.conn = &fakenet.MockPacketConn{}
agent.localCandidates[local.NetworkType()] = []Candidate{local}
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
UseCandidate(),
AttrControlling(agent.tieBreaker),
PriorityAttr(uint32(99999)),
stun.NewShortTermIntegrity(agent.localPwd),
stun.Fingerprint,
)
require.NoError(t, err)
// nolint: contextcheck
agent.handleInbound(msg, local, remote)
set := agent.remoteCandidates[local.NetworkType()]
require.Len(t, set, 1)
prflx := set[0]
require.Equal(t, CandidateTypePeerReflexive, prflx.Type())
require.Len(t, agent.checklist, 1)
pair := agent.checklist[0]
require.Equal(t, prflx, pair.Remote)
local.addRemoteCandidateCache(prflx, remote)
oldPriority := pair.priority()
relay, err := NewCandidateRelay(&CandidateRelayConfig{
Network: "udp",
Address: "172.17.0.3",
Port: 999,
Component: 1,
RelAddr: "0.0.0.0",
RelPort: 0,
})
require.NoError(t, err)
agent.addRemoteCandidate(relay) // nolint:contextcheck
set = agent.remoteCandidates[local.NetworkType()]
require.Len(t, set, 1)
require.Equal(t, CandidateTypeRelay, set[0].Type())
require.Equal(t, relay, set[0])
updatedPair := agent.findPair(local, relay)
require.NotNil(t, updatedPair)
require.NotSame(t, pair, updatedPair)
require.Equal(t, relay, updatedPair.Remote)
require.Equal(t, oldPriority, updatedPair.priority())
require.Equal(t, prflx, pair.Remote)
cached, ok := local.remoteCandidateCaches.Load(toAddrPort(remote))
require.True(t, ok)
require.Equal(t, relay, cached)
}))
})
t.Run("AcceptanceMinWait: prflx not accepted until replaced", func(t *testing.T) {
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.isControlling.Store(true)
agent.remoteUfrag = "remoteUfrag"
agent.remotePwd = "remotePwd"
agent.hostAcceptanceMinWait = 0
agent.prflxAcceptanceMinWait = time.Hour
local, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "192.168.0.2",
Port: 777,
Component: 1,
})
require.NoError(t, err)
local.conn = &fakenet.MockPacketConn{}
agent.localCandidates[local.NetworkType()] = []Candidate{local}
prflx, err := NewCandidatePeerReflexive(&CandidatePeerReflexiveConfig{
Network: "udp",
Address: "1.2.3.4",
Port: 999,
Component: 1,
})
require.NoError(t, err)
agent.addRemoteCandidate(prflx) // nolint:contextcheck
pair := agent.findPair(local, prflx)
require.NotNil(t, pair)
require.Equal(t, CandidateTypeHost, pair.Local.Type())
require.Equal(t, CandidateTypePeerReflexive, pair.Remote.Type())
pair.state = CandidatePairStateSucceeded
sel := &controllingSelector{agent: agent, log: agent.log}
sel.Start()
// With prflxAcceptanceMinWait set high, remote prflx candidate should not be nominatable.
sel.ContactCandidates()
require.Nil(t, sel.nominatedPair)
require.False(t, pair.nominated)
// Trickle the signaled candidate for the same transport address.
signaled, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "1.2.3.4",
Port: 999,
Component: 1,
})
require.NoError(t, err)
agent.addRemoteCandidate(signaled) // nolint:contextcheck
updatedPair := agent.findPair(local, signaled)
require.NotNil(t, updatedPair)
require.NotSame(t, pair, updatedPair)
require.Equal(t, CandidateTypePeerReflexive, pair.Remote.Type())
require.Equal(t, signaled, updatedPair.Remote)
require.Equal(t, CandidateTypeHost, updatedPair.Remote.Type())
// Now the (updated) pair should be nominatable and become nominated.
sel.ContactCandidates()
require.NotNil(t, sel.nominatedPair)
require.Same(t, updatedPair, sel.nominatedPair)
require.Equal(t, CandidateTypeHost, sel.nominatedPair.Local.Type())
require.Equal(t, CandidateTypeHost, sel.nominatedPair.Remote.Type())
require.True(t, updatedPair.nominated)
}))
})
t.Run("Bad network type with handleInbound()", func(t *testing.T) {
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.selector = &controllingSelector{agent: agent, log: agent.log}
hostConfig := CandidateHostConfig{
Network: "tcp",
Address: "192.168.0.2",
Port: 777,
Component: 1,
}
local, err := NewCandidateHost(&hostConfig)
require.NoError(t, err)
remote := &BadAddr{}
// nolint: contextcheck
agent.handleInbound(nil, local, remote)
require.Len(t, agent.remoteCandidates, 0)
}))
})
t.Run("prflx candidate is stored even when network type is disabled", func(t *testing.T) {
agent, err := NewAgentWithOptions(
WithNetworkTypes([]NetworkType{NetworkTypeTCP4}),
)
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.selector = &recordingSelector{}
hostConfig := CandidateHostConfig{
Network: "udp",
Address: "192.168.0.2",
Port: 777,
Component: 1,
}
local, err := NewCandidateHost(&hostConfig)
local.conn = &fakenet.MockPacketConn{}
require.NoError(t, err)
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
UseCandidate(),
AttrControlling(agent.tieBreaker),
PriorityAttr(local.Priority()),
stun.NewShortTermIntegrity(agent.localPwd),
stun.Fingerprint,
)
require.NoError(t, err)
// nolint: contextcheck
agent.handleInbound(msg, local, remote)
require.Len(t, agent.remoteCandidates, 1)
require.Len(t, agent.remoteCandidates[NetworkTypeUDP4], 1)
}))
})
t.Run("Success from unknown remote, prflx candidate MUST only be created via Binding Request", func(t *testing.T) {
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.selector = &controllingSelector{agent: agent, log: agent.log}
tID := [stun.TransactionIDSize]byte{}
copy(tID[:], "ABC")
agent.pendingBindingRequests = []bindingRequest{
{time.Now(), tID, &net.UDPAddr{}, false, nil},
}
hostConfig := CandidateHostConfig{
Network: "udp",
Address: "192.168.0.2",
Port: 777,
Component: 1,
}
local, err := NewCandidateHost(&hostConfig)
local.conn = &fakenet.MockPacketConn{}
require.NoError(t, err)
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
msg, err := stun.Build(stun.BindingSuccess, stun.NewTransactionIDSetter(tID),
stun.NewShortTermIntegrity(agent.remotePwd),
stun.Fingerprint,
)
require.NoError(t, err)
// nolint: contextcheck
agent.handleInbound(msg, local, remote)
require.Len(t, agent.remoteCandidates, 0)
}))
})
}
func TestAddRemoteCandidateStoresCandidatesIndependentlyOfNetworkTypes(t *testing.T) {
defer test.CheckRoutines(t)()
agent, err := NewAgentWithOptions(
WithNetworkTypes([]NetworkType{NetworkTypeUDP4}),
)
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
tcpCandidate, err := UnmarshalCandidate("1052353102 1 tcp 1675624447 192.0.2.1 8080 typ host tcptype passive")
require.NoError(t, err)
require.NoError(t, agent.AddRemoteCandidate(tcpCandidate))
udpCandidate, err := UnmarshalCandidate("1052353102 1 udp 1675624447 192.0.2.2 8080 typ host")
require.NoError(t, err)
require.NoError(t, agent.AddRemoteCandidate(udpCandidate))
require.Eventually(t, func() bool {
actual, err := agent.GetRemoteCandidates()
if err != nil {
return false
}
if len(actual) != 2 {
return false
}
var hasUDP, hasTCP bool
for _, c := range actual {
if c.Address() == udpCandidate.Address() {
hasUDP = true
}
if c.Address() == tcpCandidate.Address() {
hasTCP = true
}
}
return hasUDP && hasTCP
}, time.Second, 10*time.Millisecond)
}
// Assert that Agent on startup sends message, and doesn't wait for connectivityTicker to fire
// https://github.com/pion/ice/issues/15
func TestConnectivityOnStartup(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop()
// Create a network with two interfaces
wan, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "0.0.0.0/0",
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
require.NoError(t, err)
net0, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.1"},
})
require.NoError(t, err)
require.NoError(t, wan.AddNet(net0))
net1, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.2"},
})
require.NoError(t, err)
require.NoError(t, wan.AddNet(net1))
require.NoError(t, wan.Start())
aNotifier, aConnected := onConnected()
bNotifier, bConnected := onConnected()
KeepaliveInterval := time.Hour
cfg0 := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled,
Net: net0,
KeepaliveInterval: &KeepaliveInterval,
CheckInterval: &KeepaliveInterval,
}
aAgent, err := NewAgent(cfg0)
require.NoError(t, err)
defer func() {
require.NoError(t, aAgent.Close())
}()
require.NoError(t, aAgent.OnConnectionStateChange(aNotifier))
cfg1 := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled,
Net: net1,
KeepaliveInterval: &KeepaliveInterval,
CheckInterval: &KeepaliveInterval,
}
bAgent, err := NewAgent(cfg1)
require.NoError(t, err)
defer func() {
require.NoError(t, bAgent.Close())
}()
require.NoError(t, bAgent.OnConnectionStateChange(bNotifier))
func(aAgent, bAgent *Agent) (*Conn, *Conn) {
// Manual signaling
aUfrag, aPwd, err := aAgent.GetLocalUserCredentials()
require.NoError(t, err)
bUfrag, bPwd, err := bAgent.GetLocalUserCredentials()
require.NoError(t, err)
gatherAndExchangeCandidates(t, aAgent, bAgent)
accepted := make(chan struct{})
accepting := make(chan struct{})
var aConn *Conn
origHdlr := aAgent.onConnectionStateChangeHdlr.Load()
if origHdlr != nil {
defer require.NoError(t, aAgent.OnConnectionStateChange(origHdlr.(func(ConnectionState)))) //nolint:forcetypeassert
}
require.NoError(t, aAgent.OnConnectionStateChange(func(s ConnectionState) {
if s == ConnectionStateChecking {
close(accepting)
}
if origHdlr != nil {
origHdlr.(func(ConnectionState))(s) //nolint:forcetypeassert
}
}))
go func() {
var acceptErr error
aConn, acceptErr = aAgent.Accept(context.TODO(), bUfrag, bPwd)
require.NoError(t, acceptErr)
close(accepted)
}()
<-accepting
bConn, err := bAgent.Dial(context.TODO(), aUfrag, aPwd)
require.NoError(t, err)
// Ensure accepted
<-accepted
return aConn, bConn
}(aAgent, bAgent)
// Ensure pair selected
// Note: this assumes ConnectionStateConnected is thrown after selecting the final pair
<-aConnected
<-bConnected
require.NoError(t, wan.Stop())
}
func TestConnectivityLite(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop()
stunServerURL := &stun.URI{
Scheme: SchemeTypeSTUN,
Host: "1.2.3.4",
Port: 3478,
Proto: stun.ProtoTypeUDP,
}
natType := &vnet.NATType{
MappingBehavior: vnet.EndpointIndependent,
FilteringBehavior: vnet.EndpointIndependent,
}
vent, err := buildVNet(natType, natType)
require.NoError(t, err, "should succeed")
defer vent.close()
aNotifier, aConnected := onConnected()
bNotifier, bConnected := onConnected()
cfg0 := &AgentConfig{
Urls: []*stun.URI{stunServerURL},
NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled,
Net: vent.net0,
}
aAgent, err := NewAgent(cfg0)
require.NoError(t, err)
defer func() {
require.NoError(t, aAgent.Close())
}()
require.NoError(t, aAgent.OnConnectionStateChange(aNotifier))
cfg1 := &AgentConfig{
Urls: []*stun.URI{},
Lite: true,
CandidateTypes: []CandidateType{CandidateTypeHost},
NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled,
Net: vent.net1,
}
bAgent, err := NewAgent(cfg1)
require.NoError(t, err)
defer func() {
require.NoError(t, bAgent.Close())
}()
require.NoError(t, bAgent.OnConnectionStateChange(bNotifier))
connectWithVNet(t, aAgent, bAgent)
// Ensure pair selected
// Note: this assumes ConnectionStateConnected is thrown after selecting the final pair
<-aConnected
<-bConnected
}
func TestInboundValidity(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
buildMsg := func(class stun.MessageClass, username, key string) *stun.Message {
msg, err := stun.Build(stun.NewType(stun.MethodBinding, class), stun.TransactionID,
stun.NewUsername(username),
stun.NewShortTermIntegrity(key),
stun.Fingerprint,
)
require.NoError(t, err)
return msg
}
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
hostConfig := CandidateHostConfig{
Network: "udp",
Address: "192.168.0.2",
Port: 777,
Component: 1,
}
local, err := NewCandidateHost(&hostConfig)
local.conn = &fakenet.MockPacketConn{}
require.NoError(t, err)
t.Run("Invalid Binding requests should be discarded", func(t *testing.T) {
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
agent.handleInbound(buildMsg(stun.ClassRequest, "invalid", agent.localPwd), local, remote)
require.Len(t, agent.remoteCandidates, 0)
agent.handleInbound(buildMsg(stun.ClassRequest, agent.localUfrag+":"+agent.remoteUfrag, "Invalid"), local, remote)
require.Len(t, agent.remoteCandidates, 0)
})
t.Run("Invalid Binding success responses should be discarded", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
}()
a.handleInbound(buildMsg(stun.ClassSuccessResponse, a.localUfrag+":"+a.remoteUfrag, "Invalid"), local, remote)
require.Len(t, a.remoteCandidates, 0)
})
t.Run("Discard non-binding messages", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
}()
a.handleInbound(buildMsg(stun.ClassErrorResponse, a.localUfrag+":"+a.remoteUfrag, "Invalid"), local, remote)
require.Len(t, a.remoteCandidates, 0)
})
t.Run("Valid bind request", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
}()
err = a.loop.Run(a.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log}
// nolint: contextcheck
a.handleInbound(buildMsg(stun.ClassRequest, a.localUfrag+":"+a.remoteUfrag, a.localPwd), local, remote)
require.Len(t, a.remoteCandidates, 1)
})
require.NoError(t, err)
})
t.Run("Valid bind without fingerprint", func(t *testing.T) {
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.selector = &controllingSelector{agent: agent, log: agent.log}
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
stun.NewShortTermIntegrity(agent.localPwd),
)
require.NoError(t, err)
// nolint: contextcheck
agent.handleInbound(msg, local, remote)
require.Len(t, agent.remoteCandidates, 1)
}))
})
t.Run("Success with invalid TransactionID", func(t *testing.T) {
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
hostConfig := CandidateHostConfig{
Network: "udp",
Address: "192.168.0.2",
Port: 777,
Component: 1,
}
local, err := NewCandidateHost(&hostConfig)
local.conn = &fakenet.MockPacketConn{}
require.NoError(t, err)
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
tID := [stun.TransactionIDSize]byte{}
copy(tID[:], "ABC")
msg, err := stun.Build(stun.BindingSuccess, stun.NewTransactionIDSetter(tID),
stun.NewShortTermIntegrity(agent.remotePwd),
stun.Fingerprint,
)
require.NoError(t, err)
agent.handleInbound(msg, local, remote)
require.Len(t, agent.remoteCandidates, 0)
})
}
func TestHandleInboundAdditionalCases(t *testing.T) {
defer test.CheckRoutines(t)()
newTestAgent := func(t *testing.T) *Agent {
t.Helper()
agent, err := NewAgentWithOptions(
WithNet(newStubNet(t)),
WithMulticastDNSMode(MulticastDNSModeDisabled),
WithNetworkTypes([]NetworkType{NetworkTypeUDP4}),
)
require.NoError(t, err)
return agent
}
t.Run("Binding indication updates last received", func(t *testing.T) {
agent := newTestAgent(t)
defer func() {
require.NoError(t, agent.Close())
}()
local := newHostLocal(t)
remoteConfig := &CandidateHostConfig{
Network: "udp",
Address: "192.0.2.1",
Port: 4242,
Component: 1,
}
remoteCandidate, err := NewCandidateHost(remoteConfig)
require.NoError(t, err)
remoteAddr := &net.UDPAddr{IP: net.ParseIP(remoteCandidate.Address()), Port: remoteCandidate.Port()}
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.addRemoteCandidate(remoteCandidate) //nolint:contextcheck
}))
msg, err := stun.Build(stun.NewType(stun.MethodBinding, stun.ClassIndication), stun.TransactionID, stun.Fingerprint)
require.NoError(t, err)
lastReceived := remoteCandidate.LastReceived()
agent.handleInbound(msg, local, remoteAddr)
require.False(t, remoteCandidate.LastReceived().IsZero())
require.NotEqual(t, lastReceived, remoteCandidate.LastReceived())
})
t.Run("Role conflict prevents binding handling", func(t *testing.T) {
agent := newTestAgent(t)
defer func() {
require.NoError(t, agent.Close())
}()
local := newHostLocal(t)
local.conn = &fakenet.MockPacketConn{}
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
selector := &recordingSelector{}
agent.selector = selector
agent.isControlling.Store(true)
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
AttrControlling(agent.tieBreaker),
stun.NewShortTermIntegrity(agent.localPwd),
stun.Fingerprint,
)
require.NoError(t, err)
agent.handleInbound(msg, local, remote)
require.False(t, selector.handledBindingRequest)
require.True(t, agent.isControlling.Load())
})
t.Run("Invalid remote address is discarded", func(t *testing.T) {
agent := newTestAgent(t)
defer func() {
require.NoError(t, agent.Close())
}()
local := newHostLocal(t)
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
stun.NewShortTermIntegrity(agent.localPwd),
stun.Fingerprint,
)
require.NoError(t, err)
agent.handleInbound(msg, local, &BadAddr{})
require.Len(t, agent.remoteCandidates, 0)
})
t.Run("Nil local candidate is ignored", func(t *testing.T) {
agent := newTestAgent(t)
defer func() {
require.NoError(t, agent.Close())
}()
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
stun.NewShortTermIntegrity(agent.localPwd),
stun.Fingerprint,
)
require.NoError(t, err)
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
agent.handleInbound(msg, nil, remote)
require.Len(t, agent.remoteCandidates, 0)
})
t.Run("Success response for known transaction marks pair succeeded", func(t *testing.T) {
agent := newTestAgent(t)
defer func() {
require.NoError(t, agent.Close())
}()
local := newHostLocal(t)
remoteConfig := &CandidateHostConfig{
Network: "udp",
Address: "192.0.2.2",
Port: 5555,
Component: 1,
}
remoteCandidate, err := NewCandidateHost(remoteConfig)
require.NoError(t, err)
remoteAddr := &net.UDPAddr{IP: net.ParseIP(remoteCandidate.Address()), Port: remoteCandidate.Port()}
transactionID := stun.NewTransactionID()
remotePwd := "remotekey"
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.selector = &controllingSelector{agent: agent, log: agent.log}
agent.selector.Start()
agent.localCandidates[local.NetworkType()] = append(agent.localCandidates[local.NetworkType()], local)
agent.addRemoteCandidate(remoteCandidate) //nolint:contextcheck
agent.pendingBindingRequests = []bindingRequest{{
timestamp: time.Now(),
transactionID: transactionID,
destination: remoteAddr,
}}
agent.remotePwd = remotePwd
}))
msg, err := stun.Build(stun.BindingSuccess, stun.NewTransactionIDSetter(transactionID),
stun.NewShortTermIntegrity(remotePwd),
stun.Fingerprint,
)
require.NoError(t, err)
agent.handleInbound(msg, local, remoteAddr)
pair := agent.findPair(local, remoteCandidate)
require.NotNil(t, pair)
require.Equal(t, CandidatePairStateSucceeded, pair.state)
require.Empty(t, agent.pendingBindingRequests)
})
t.Run("Binding request from blocked prflx address is discarded", func(t *testing.T) {
agent, err := NewAgentWithOptions(
WithNet(newStubNet(t)),
WithMulticastDNSMode(MulticastDNSModeDisabled),
WithNetworkTypes([]NetworkType{NetworkTypeUDP4}),
WithRemoteIPFilter(func(ip net.IP) bool {
return !ip.Equal(net.IPv4(172, 17, 0, 44))
}),
)
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
local := newHostLocal(t)
selector := &recordingSelector{}
agent.selector = selector
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
stun.NewShortTermIntegrity(agent.localPwd),
stun.Fingerprint,
)
require.NoError(t, err)
remote := &net.UDPAddr{IP: net.IPv4(172, 17, 0, 44), Port: 9999}
agent.handleInbound(msg, local, remote)
require.False(t, selector.handledBindingRequest)
require.Len(t, agent.remoteCandidates, 0)
})
t.Run("Binding request prflx filter is called once", func(t *testing.T) {
filterCalls := 0
agent, err := NewAgentWithOptions(
WithNet(newStubNet(t)),
WithMulticastDNSMode(MulticastDNSModeDisabled),
WithNetworkTypes([]NetworkType{NetworkTypeUDP4}),
WithRemoteIPFilter(func(net.IP) bool {
filterCalls++
return filterCalls > 1
}),
)
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
local := newHostLocal(t)
selector := &recordingSelector{}
agent.selector = selector
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
stun.NewShortTermIntegrity(agent.localPwd),
stun.Fingerprint,
)
require.NoError(t, err)
remote := &net.UDPAddr{IP: net.IPv4(172, 17, 0, 45), Port: 9999}
agent.handleInbound(msg, local, remote)
require.Equal(t, 1, filterCalls)
require.False(t, selector.handledBindingRequest)
require.Len(t, agent.remoteCandidates, 0)
})
}
func TestInvalidAgentStarts(t *testing.T) {
defer test.CheckRoutines(t)()
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
_, err = agent.Dial(ctx, "", "bar")
require.ErrorIs(t, ErrRemoteUfragEmpty, err)
_, err = agent.Dial(ctx, "foo", "")
require.ErrorIs(t, ErrRemotePwdEmpty, err)
_, err = agent.Dial(ctx, "foo", "bar")
require.ErrorIs(t, ErrCanceledByCaller, err)
_, err = agent.Dial(ctx, "foo", "bar")
require.ErrorIs(t, ErrMultipleStart, err)
}
// Assert that Agent emits Connecting/Connected/Disconnected/Failed/Closed messages.
func TestConnectionStateCallback(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 5).Stop()
disconnectedDuration := time.Second
failedDuration := time.Second
KeepaliveInterval := time.Duration(0)
cfg := &AgentConfig{
Urls: []*stun.URI{},
NetworkTypes: supportedNetworkTypes(),
DisconnectedTimeout: &disconnectedDuration,
FailedTimeout: &failedDuration,
KeepaliveInterval: &KeepaliveInterval,
InterfaceFilter: problematicNetworkInterfaces,
}
isClosed := make(chan any)
aAgent, err := NewAgent(cfg)
require.NoError(t, err)
defer func() {
select {
case <-isClosed:
return
default:
}
require.NoError(t, aAgent.Close())
}()
bAgent, err := NewAgent(cfg)
require.NoError(t, err)
defer func() {
select {
case <-isClosed:
return
default:
}
require.NoError(t, bAgent.Close())
}()
isChecking := make(chan any)
isConnected := make(chan any)
isDisconnected := make(chan any)
isFailed := make(chan any)
err = aAgent.OnConnectionStateChange(func(c ConnectionState) {
switch c {
case ConnectionStateChecking:
close(isChecking)
case ConnectionStateConnected:
close(isConnected)
case ConnectionStateDisconnected:
close(isDisconnected)
case ConnectionStateFailed:
close(isFailed)
case ConnectionStateClosed:
close(isClosed)
default:
}
})
require.NoError(t, err)
connect(t, aAgent, bAgent)
<-isChecking
<-isConnected
<-isDisconnected
<-isFailed
require.NoError(t, aAgent.Close())
require.NoError(t, bAgent.Close())
<-isClosed
}
func TestInvalidGather(t *testing.T) {
t.Run("Gather with no OnCandidate should error", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
}()
err = a.GatherCandidates()
require.ErrorIs(t, ErrNoOnCandidateHandler, err)
})
}
func TestCandidatePairsStats(t *testing.T) { //nolint:cyclop,gocyclo
defer test.CheckRoutines(t)()
// Avoid deadlocks?
defer test.TimeOut(1 * time.Second).Stop()
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
hostConfig := &CandidateHostConfig{
Network: "udp",
Address: "192.168.1.1",
Port: 19216,
Component: 1,
}
hostLocal, err := NewCandidateHost(hostConfig)
require.NoError(t, err)
relayConfig := &CandidateRelayConfig{
Network: "udp",
Address: "1.2.3.4",
Port: 2340,
Component: 1,
RelAddr: "4.3.2.1",
RelPort: 43210,
}
relayRemote, err := NewCandidateRelay(relayConfig)
require.NoError(t, err)
srflxConfig := &CandidateServerReflexiveConfig{
Network: "udp",
Address: "10.10.10.2",
Port: 19218,
Component: 1,
RelAddr: "4.3.2.1",
RelPort: 43212,
}
srflxRemote, err := NewCandidateServerReflexive(srflxConfig)
require.NoError(t, err)
prflxConfig := &CandidatePeerReflexiveConfig{
Network: "udp",
Address: "10.10.10.2",
Port: 19217,
Component: 1,
RelAddr: "4.3.2.1",
RelPort: 43211,
}
prflxRemote, err := NewCandidatePeerReflexive(prflxConfig)
require.NoError(t, err)
hostConfig = &CandidateHostConfig{
Network: "udp",
Address: "1.2.3.5",
Port: 12350,
Component: 1,
}
hostRemote, err := NewCandidateHost(hostConfig)
require.NoError(t, err)
for _, remote := range []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote} {
p := agent.findPair(hostLocal, remote)
if p == nil {
p = agent.addPair(hostLocal, remote)
}
p.UpdateRequestReceived()
p.UpdateRequestSent()
p.UpdateResponseSent()
p.UpdateRoundTripTime(time.Second)
p.UpdatePacketSent(100)
p.UpdatePacketReceived(200)
}
p := agent.findPair(hostLocal, prflxRemote)
p.state = CandidatePairStateFailed
for i := 1; i < 10; i++ {
p.UpdateRoundTripTime(time.Duration(i+1) * time.Second)
}
stats := agent.GetCandidatePairsStats()
require.Len(t, stats, 4)
var relayPairStat, srflxPairStat, prflxPairStat, hostPairStat CandidatePairStats
for _, cps := range stats {
require.Equal(t, cps.LocalCandidateID, hostLocal.ID())
switch cps.RemoteCandidateID {
case relayRemote.ID():
relayPairStat = cps
case srflxRemote.ID():
srflxPairStat = cps
case prflxRemote.ID():
prflxPairStat = cps
case hostRemote.ID():
hostPairStat = cps
default:
t.Fatal("invalid remote candidate ID") //nolint
}
require.False(t, cps.FirstRequestTimestamp.IsZero())
require.False(t, cps.LastRequestTimestamp.IsZero())
require.False(t, cps.FirstResponseTimestamp.IsZero())
require.False(t, cps.LastResponseTimestamp.IsZero())
require.False(t, cps.FirstRequestReceivedTimestamp.IsZero())
require.False(t, cps.LastRequestReceivedTimestamp.IsZero())
require.Equal(t, uint32(1), cps.PacketsSent)
require.Equal(t, uint32(1), cps.PacketsReceived)
require.Equal(t, uint64(100), cps.BytesSent)
require.Equal(t, uint64(200), cps.BytesReceived)
require.False(t, cps.LastPacketSentTimestamp.IsZero())
require.False(t, cps.LastPacketReceivedTimestamp.IsZero())
}
require.Equal(t, relayPairStat.RemoteCandidateID, relayRemote.ID())
require.Equal(t, srflxPairStat.RemoteCandidateID, srflxRemote.ID())
require.Equal(t, prflxPairStat.RemoteCandidateID, prflxRemote.ID())
require.Equal(t, hostPairStat.RemoteCandidateID, hostRemote.ID())
require.Equal(t, prflxPairStat.State, CandidatePairStateFailed)
require.Equal(t, float64(10), prflxPairStat.CurrentRoundTripTime)
require.Equal(t, float64(55), prflxPairStat.TotalRoundTripTime)
require.Equal(t, uint64(10), prflxPairStat.ResponsesReceived)
}
func TestSelectedCandidatePairStats(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
// Avoid deadlocks?
defer test.TimeOut(1 * time.Second).Stop()
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
hostConfig := &CandidateHostConfig{
Network: "udp",
Address: "192.168.1.1",
Port: 19216,
Component: 1,
}
hostLocal, err := NewCandidateHost(hostConfig)
require.NoError(t, err)
srflxConfig := &CandidateServerReflexiveConfig{
Network: "udp",
Address: "10.10.10.2",
Port: 19218,
Component: 1,
RelAddr: "4.3.2.1",
RelPort: 43212,
}
srflxRemote, err := NewCandidateServerReflexive(srflxConfig)
require.NoError(t, err)
// no selected pair, should return not available
_, ok := agent.GetSelectedCandidatePairStats()
require.False(t, ok)
// add pair and populate some RTT stats
candidatePair := agent.findPair(hostLocal, srflxRemote)
if candidatePair == nil {
agent.addPair(hostLocal, srflxRemote)
candidatePair = agent.findPair(hostLocal, srflxRemote)
}
for i := range 10 {
candidatePair.UpdateRoundTripTime(time.Duration(i+1) * time.Second)
}
candidatePair.UpdatePacketSent(150)
candidatePair.UpdatePacketReceived(250)
// set the pair as selected
agent.setSelectedPair(candidatePair)
stats, ok := agent.GetSelectedCandidatePairStats()
require.True(t, ok)
require.Equal(t, stats.LocalCandidateID, hostLocal.ID())
require.Equal(t, stats.RemoteCandidateID, srflxRemote.ID())
require.Equal(t, float64(10), stats.CurrentRoundTripTime)
require.Equal(t, float64(55), stats.TotalRoundTripTime)
require.Equal(t, uint64(10), stats.ResponsesReceived)
require.Equal(t, uint32(1), stats.PacketsSent)
require.Equal(t, uint32(1), stats.PacketsReceived)
require.Equal(t, uint64(150), stats.BytesSent)
require.Equal(t, uint64(250), stats.BytesReceived)
require.False(t, stats.LastPacketSentTimestamp.IsZero())
require.False(t, stats.LastPacketReceivedTimestamp.IsZero())
}
func TestLocalCandidateStats(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
// Avoid deadlocks?
defer test.TimeOut(1 * time.Second).Stop()
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
hostConfig := &CandidateHostConfig{
Network: "udp",
Address: "192.168.1.1",
Port: 19216,
Component: 1,
}
hostLocal, err := NewCandidateHost(hostConfig)
require.NoError(t, err)
srflxConfig := &CandidateServerReflexiveConfig{
Network: "udp",
Address: "192.168.1.1",
Port: 19217,
Component: 1,
RelAddr: "4.3.2.1",
RelPort: 43212,
}
srflxLocal, err := NewCandidateServerReflexive(srflxConfig)
require.NoError(t, err)
agent.localCandidates[NetworkTypeUDP4] = []Candidate{hostLocal, srflxLocal}
localStats := agent.GetLocalCandidatesStats()
require.Len(t, localStats, 2)
var hostLocalStat, srflxLocalStat CandidateStats
for _, stats := range localStats {
var candidate Candidate
switch stats.ID {
case hostLocal.ID():
hostLocalStat = stats
candidate = hostLocal
case srflxLocal.ID():
srflxLocalStat = stats
candidate = srflxLocal
default:
t.Fatal("invalid local candidate ID") // nolint
}
require.Equal(t, stats.CandidateType, candidate.Type())
require.Equal(t, stats.Priority, candidate.Priority())
require.Equal(t, stats.IP, candidate.Address())
}
require.Equal(t, hostLocalStat.ID, hostLocal.ID())
require.Equal(t, srflxLocalStat.ID, srflxLocal.ID())
}
func TestRemoteCandidateStats(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
// Avoid deadlocks?
defer test.TimeOut(1 * time.Second).Stop()
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
relayConfig := &CandidateRelayConfig{
Network: "udp",
Address: "1.2.3.4",
Port: 12340,
Component: 1,
RelAddr: "4.3.2.1",
RelPort: 43210,
}
relayRemote, err := NewCandidateRelay(relayConfig)
require.NoError(t, err)
srflxConfig := &CandidateServerReflexiveConfig{
Network: "udp",
Address: "10.10.10.2",
Port: 19218,
Component: 1,
RelAddr: "4.3.2.1",
RelPort: 43212,
}
srflxRemote, err := NewCandidateServerReflexive(srflxConfig)
require.NoError(t, err)
prflxConfig := &CandidatePeerReflexiveConfig{
Network: "udp",
Address: "10.10.10.2",
Port: 19217,
Component: 1,
RelAddr: "4.3.2.1",
RelPort: 43211,
}
prflxRemote, err := NewCandidatePeerReflexive(prflxConfig)
require.NoError(t, err)
hostConfig := &CandidateHostConfig{
Network: "udp",
Address: "1.2.3.5",
Port: 12350,
Component: 1,
}
hostRemote, err := NewCandidateHost(hostConfig)
require.NoError(t, err)
agent.remoteCandidates[NetworkTypeUDP4] = []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote}
remoteStats := agent.GetRemoteCandidatesStats()
require.Len(t, remoteStats, 4)
var relayRemoteStat, srflxRemoteStat, prflxRemoteStat, hostRemoteStat CandidateStats
for _, stats := range remoteStats {
var candidate Candidate
switch stats.ID {
case relayRemote.ID():
relayRemoteStat = stats
candidate = relayRemote
case srflxRemote.ID():
srflxRemoteStat = stats
candidate = srflxRemote
case prflxRemote.ID():
prflxRemoteStat = stats
candidate = prflxRemote
case hostRemote.ID():
hostRemoteStat = stats
candidate = hostRemote
default:
t.Fatal("invalid remote candidate ID") // nolint
}
require.Equal(t, stats.CandidateType, candidate.Type())
require.Equal(t, stats.Priority, candidate.Priority())
require.Equal(t, stats.IP, candidate.Address())
}
require.Equal(t, relayRemoteStat.ID, relayRemote.ID())
require.Equal(t, srflxRemoteStat.ID, srflxRemote.ID())
require.Equal(t, prflxRemoteStat.ID, prflxRemote.ID())
require.Equal(t, hostRemoteStat.ID, hostRemote.ID())
}
func TestInitExtIPMapping(t *testing.T) {
defer test.CheckRoutines(t)()
// agent.addressRewriteMapper should be nil by default
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
require.Nil(t, agent.addressRewriteMapper)
require.NoError(t, agent.Close())
// a.addressRewriteMapper should be nil when NAT1To1IPs is a non-nil empty array
agent, err = NewAgent(&AgentConfig{
NAT1To1IPs: []string{},
NAT1To1IPCandidateType: CandidateTypeHost,
})
require.NoError(t, err)
require.Nil(t, agent.addressRewriteMapper)
require.NoError(t, agent.Close())
// NewAgent should return an error when 1:1 NAT for host candidate is enabled
// but the candidate type does not appear in the CandidateTypes.
_, err = NewAgent(&AgentConfig{
NAT1To1IPs: []string{"1.2.3.4"},
NAT1To1IPCandidateType: CandidateTypeHost,
CandidateTypes: []CandidateType{CandidateTypeRelay},
})
require.ErrorIs(t, ErrIneffectiveNAT1To1IPMappingHost, err)
// NewAgent should return an error when 1:1 NAT for srflx candidate is enabled
// but the candidate type does not appear in the CandidateTypes.
_, err = NewAgent(&AgentConfig{
NAT1To1IPs: []string{"1.2.3.4"},
NAT1To1IPCandidateType: CandidateTypeServerReflexive,
CandidateTypes: []CandidateType{CandidateTypeRelay},
})
require.ErrorIs(t, ErrIneffectiveNAT1To1IPMappingSrflx, err)
// NewAgent should return an error when 1:1 NAT for host candidate is enabled
// along with mDNS with MulticastDNSModeQueryAndGather
_, err = NewAgent(&AgentConfig{
NAT1To1IPs: []string{"1.2.3.4"},
NAT1To1IPCandidateType: CandidateTypeHost,
MulticastDNSMode: MulticastDNSModeQueryAndGather,
})
require.ErrorIs(t, ErrMulticastDNSWithNAT1To1IPMapping, err)
// NewAgent should return if newAddressRewriteMapper() returns an error.
_, err = NewAgent(&AgentConfig{
NAT1To1IPs: []string{"bad.2.3.4"}, // Bad IP
NAT1To1IPCandidateType: CandidateTypeHost,
})
require.ErrorIs(t, ErrInvalidNAT1To1IPMapping, err)
}
func TestBindingRequestTimeout(t *testing.T) {
defer test.CheckRoutines(t)()
const expectedRemovalCount = 2
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
now := time.Now()
agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
timestamp: now, // Valid
})
agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
timestamp: now.Add(-3900 * time.Millisecond), // Valid
})
agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
timestamp: now.Add(-4100 * time.Millisecond), // Invalid
})
agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
timestamp: now.Add(-75 * time.Hour), // Invalid
})
agent.invalidatePendingBindingRequests(now)
require.Equal(
t,
expectedRemovalCount,
len(agent.pendingBindingRequests),
"Binding invalidation due to timeout did not remove the correct number of binding requests",
)
}
// TestAgentCredentials checks if local username fragments and passwords (if set) meet RFC standard
// and ensure it's backwards compatible with previous versions of the pion/ice.
func TestAgentCredentials(t *testing.T) {
defer test.CheckRoutines(t)()
// Make sure to pass Travis check by disabling the logs
log := logging.NewDefaultLoggerFactory()
log.DefaultLogLevel = logging.LogLevelDisabled
// Agent should not require any of the usernames and password to be set
// If set, they should follow the default 16/128 bits random number generator strategy
agent, err := NewAgent(&AgentConfig{LoggerFactory: log})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
require.GreaterOrEqual(t, len([]rune(agent.localUfrag))*8, 24)
require.GreaterOrEqual(t, len([]rune(agent.localPwd))*8, 128)
// Should honor RFC standards
// Local values MUST be unguessable, with at least 128 bits of
// random number generator output used to generate the password, and
// at least 24 bits of output to generate the username fragment.
_, err = NewAgent(&AgentConfig{LocalUfrag: "xx", LoggerFactory: log})
require.EqualError(t, err, ErrLocalUfragInsufficientBits.Error())
_, err = NewAgent(&AgentConfig{LocalPwd: "xxxxxx", LoggerFactory: log})
require.EqualError(t, err, ErrLocalPwdInsufficientBits.Error())
}
// Assert that Agent on Failure deletes all existing candidates
// User can then do an ICE Restart to bring agent back.
func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 5).Stop()
oneSecond := time.Second
KeepaliveInterval := time.Duration(0)
cfg := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
DisconnectedTimeout: &oneSecond,
FailedTimeout: &oneSecond,
KeepaliveInterval: &KeepaliveInterval,
}
aAgent, err := NewAgent(cfg)
require.NoError(t, err)
defer func() {
require.NoError(t, aAgent.Close())
}()
bAgent, err := NewAgent(cfg)
require.NoError(t, err)
defer func() {
require.NoError(t, bAgent.Close())
}()
isFailed := make(chan any)
require.NoError(t, aAgent.OnConnectionStateChange(func(c ConnectionState) {
if c == ConnectionStateFailed {
close(isFailed)
}
}))
connect(t, aAgent, bAgent)
<-isFailed
done := make(chan struct{})
require.NoError(t, aAgent.loop.Run(context.Background(), func(context.Context) {
require.Equal(t, len(aAgent.remoteCandidates), 0)
require.Equal(t, len(aAgent.localCandidates), 0)
close(done)
}))
<-done
}
// Assert that the ICE Agent can go directly from Connecting -> Failed on both sides.
func TestConnectionStateConnectingToFailed(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 5).Stop()
oneSecond := time.Second
KeepaliveInterval := time.Duration(0)
cfg := &AgentConfig{
DisconnectedTimeout: &oneSecond,
FailedTimeout: &oneSecond,
KeepaliveInterval: &KeepaliveInterval,
}
aAgent, err := NewAgent(cfg)
require.NoError(t, err)
defer func() {
require.NoError(t, aAgent.Close())
}()
bAgent, err := NewAgent(cfg)
require.NoError(t, err)
defer func() {
require.NoError(t, bAgent.Close())
}()
var isFailed sync.WaitGroup
var isChecking sync.WaitGroup
isFailed.Add(2)
isChecking.Add(2)
connectionStateCheck := func(c ConnectionState) {
switch c {
case ConnectionStateFailed:
isFailed.Done()
case ConnectionStateChecking:
isChecking.Done()
case ConnectionStateCompleted:
t.Errorf("Unexpected ConnectionState: %v", c) //nolint
default:
}
}
require.NoError(t, aAgent.OnConnectionStateChange(connectionStateCheck))
require.NoError(t, bAgent.OnConnectionStateChange(connectionStateCheck))
go func() {
_, err := aAgent.Accept(context.TODO(), "InvalidFrag", "InvalidPwd")
require.Error(t, err)
}()
go func() {
_, err := bAgent.Dial(context.TODO(), "InvalidFrag", "InvalidPwd")
require.Error(t, err)
}()
isChecking.Wait()
isFailed.Wait()
}
func TestAgentRestart(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop()
oneSecond := time.Second
t.Run("Restart During Gather", func(t *testing.T) {
connA, connB := pipe(t, &AgentConfig{
DisconnectedTimeout: &oneSecond,
FailedTimeout: &oneSecond,
})
defer closePipe(t, connA, connB)
ctx, cancel := context.WithCancel(context.Background())
require.NoError(t, connB.agent.OnConnectionStateChange(func(c ConnectionState) {
if c == ConnectionStateFailed || c == ConnectionStateDisconnected {
cancel()
}
}))
connA.agent.gatheringState = GatheringStateGathering
require.NoError(t, connA.agent.Restart("", ""))
<-ctx.Done()
})
t.Run("Restart When Closed", func(t *testing.T) {
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
require.NoError(t, agent.Close())
require.Equal(t, ErrClosed, agent.Restart("", ""))
})
t.Run("Restart One Side", func(t *testing.T) {
connA, connB := pipe(t, &AgentConfig{
DisconnectedTimeout: &oneSecond,
FailedTimeout: &oneSecond,
})
defer closePipe(t, connA, connB)
ctx, cancel := context.WithCancel(context.Background())
require.NoError(t, connB.agent.OnConnectionStateChange(func(c ConnectionState) {
if c == ConnectionStateFailed || c == ConnectionStateDisconnected {
cancel()
}
}))
require.NoError(t, connA.agent.Restart("", ""))
<-ctx.Done()
})
t.Run("Restart Both Sides", func(t *testing.T) {
// Get all addresses of candidates concatenated
generateCandidateAddressStrings := func(candidates []Candidate, err error) (out string) {
require.NoError(t, err)
for _, c := range candidates {
out += c.Address() + ":"
out += strconv.Itoa(c.Port())
}
return
}
// Store the original candidates, confirm that after we reconnect we have new pairs
connA, connB := pipe(t, &AgentConfig{
DisconnectedTimeout: &oneSecond,
FailedTimeout: &oneSecond,
})
defer closePipe(t, connA, connB)
connAFirstCandidates := generateCandidateAddressStrings(connA.agent.GetLocalCandidates())
connBFirstCandidates := generateCandidateAddressStrings(connB.agent.GetLocalCandidates())
aNotifier, aConnected := onConnected()
require.NoError(t, connA.agent.OnConnectionStateChange(aNotifier))
bNotifier, bConnected := onConnected()
require.NoError(t, connB.agent.OnConnectionStateChange(bNotifier))
// Restart and Re-Signal
require.NoError(t, connA.agent.Restart("", ""))
require.NoError(t, connB.agent.Restart("", ""))
// Exchange Candidates and Credentials
ufrag, pwd, err := connB.agent.GetLocalUserCredentials()
require.NoError(t, err)
require.NoError(t, connA.agent.SetRemoteCredentials(ufrag, pwd))
ufrag, pwd, err = connA.agent.GetLocalUserCredentials()
require.NoError(t, err)
require.NoError(t, connB.agent.SetRemoteCredentials(ufrag, pwd))
gatherAndExchangeCandidates(t, connA.agent, connB.agent)
// Wait until both have gone back to connected
<-aConnected
<-bConnected
// Assert that we have new candidates each time
require.NotEqual(t, connAFirstCandidates, generateCandidateAddressStrings(connA.agent.GetLocalCandidates()))
require.NotEqual(t, connBFirstCandidates, generateCandidateAddressStrings(connB.agent.GetLocalCandidates()))
})
}
func TestGetRemoteCredentials(t *testing.T) {
var config AgentConfig
agent, err := NewAgent(&config)
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
agent.remoteUfrag = "remoteUfrag"
agent.remotePwd = "remotePwd"
actualUfrag, actualPwd, err := agent.GetRemoteUserCredentials()
require.NoError(t, err)
require.Equal(t, actualUfrag, agent.remoteUfrag)
require.Equal(t, actualPwd, agent.remotePwd)
}
func TestGetRemoteCandidates(t *testing.T) {
var config AgentConfig
agent, err := NewAgent(&config)
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
expectedCandidates := []Candidate{}
for i := range 5 {
cfg := CandidateHostConfig{
Network: "udp",
Address: "192.168.0.2",
Port: 1000 + i,
Component: 1,
}
cand, errCand := NewCandidateHost(&cfg)
require.NoError(t, errCand)
expectedCandidates = append(expectedCandidates, cand)
agent.addRemoteCandidate(cand)
}
actualCandidates, err := agent.GetRemoteCandidates()
require.NoError(t, err)
require.ElementsMatch(t, expectedCandidates, actualCandidates)
}
func TestRemoteIPFilterInAddRemoteCandidate(t *testing.T) {
agent, err := NewAgent(&AgentConfig{
RemoteIPFilter: func(ip net.IP) (keep bool) {
return !ip.Equal(net.IPv4(203, 0, 113, 9))
},
})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
blocked, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "203.0.113.9",
Port: 40000,
Component: 1,
})
require.NoError(t, err)
allowed, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "198.51.100.10",
Port: 40001,
Component: 1,
})
require.NoError(t, err)
agent.addRemoteCandidate(blocked)
agent.addRemoteCandidate(allowed)
actual, err := agent.GetRemoteCandidates()
require.NoError(t, err)
require.Len(t, actual, 1)
require.Equal(t, allowed.Address(), actual[0].Address())
}
func TestAddRemoteCandidateHonorsRemoteIPFilter(t *testing.T) {
agent, err := NewAgent(&AgentConfig{
RemoteIPFilter: func(ip net.IP) (keep bool) {
return !ip.Equal(net.IPv4(203, 0, 113, 11))
},
})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
blocked, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "203.0.113.11",
Port: 41000,
Component: 1,
})
require.NoError(t, err)
allowed, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "198.51.100.12",
Port: 41001,
Component: 1,
})
require.NoError(t, err)
require.NoError(t, agent.AddRemoteCandidate(blocked))
require.NoError(t, agent.AddRemoteCandidate(allowed))
require.Eventually(t, func() bool {
actual, err := agent.GetRemoteCandidates()
if err != nil {
return false
}
return len(actual) == 1 && actual[0].Address() == allowed.Address()
}, time.Second, 10*time.Millisecond)
}
func TestGetLocalCandidates(t *testing.T) {
var config AgentConfig
agent, err := NewAgent(&config)
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
dummyConn := &net.UDPConn{}
expectedCandidates := []Candidate{}
for i := range 5 {
cfg := CandidateHostConfig{
Network: "udp",
Address: "192.168.0.2",
Port: 1000 + i,
Component: 1,
}
cand, errCand := NewCandidateHost(&cfg)
require.NoError(t, errCand)
expectedCandidates = append(expectedCandidates, cand)
err = agent.addCandidate(context.Background(), cand, dummyConn)
require.NoError(t, err)
}
actualCandidates, err := agent.GetLocalCandidates()
require.NoError(t, err)
require.ElementsMatch(t, expectedCandidates, actualCandidates)
}
func TestCloseInConnectionStateCallback(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 5).Stop()
disconnectedDuration := time.Second
failedDuration := time.Second
KeepaliveInterval := time.Duration(0)
CheckInterval := 500 * time.Millisecond
cfg := &AgentConfig{
Urls: []*stun.URI{},
NetworkTypes: supportedNetworkTypes(),
DisconnectedTimeout: &disconnectedDuration,
FailedTimeout: &failedDuration,
KeepaliveInterval: &KeepaliveInterval,
CheckInterval: &CheckInterval,
}
aAgent, err := NewAgent(cfg)
require.NoError(t, err)
var aAgentClosed bool
defer func() {
if aAgentClosed {
return
}
require.NoError(t, aAgent.Close())
}()
bAgent, err := NewAgent(cfg)
require.NoError(t, err)
defer func() {
require.NoError(t, bAgent.Close())
}()
isClosed := make(chan any)
isConnected := make(chan any)
err = aAgent.OnConnectionStateChange(func(c ConnectionState) {
switch c {
case ConnectionStateConnected:
<-isConnected
require.NoError(t, aAgent.Close())
aAgentClosed = true
case ConnectionStateClosed:
close(isClosed)
default:
}
})
require.NoError(t, err)
connect(t, aAgent, bAgent)
close(isConnected)
<-isClosed
}
func TestRunTaskInConnectionStateCallback(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 5).Stop()
oneSecond := time.Second
KeepaliveInterval := time.Duration(0)
CheckInterval := 50 * time.Millisecond
cfg := &AgentConfig{
Urls: []*stun.URI{},
NetworkTypes: supportedNetworkTypes(),
DisconnectedTimeout: &oneSecond,
FailedTimeout: &oneSecond,
KeepaliveInterval: &KeepaliveInterval,
CheckInterval: &CheckInterval,
}
aAgent, err := NewAgent(cfg)
require.NoError(t, err)
defer func() {
require.NoError(t, aAgent.Close())
}()
bAgent, err := NewAgent(cfg)
require.NoError(t, err)
defer func() {
require.NoError(t, bAgent.Close())
}()
isComplete := make(chan any)
err = aAgent.OnConnectionStateChange(func(c ConnectionState) {
if c == ConnectionStateConnected {
_, _, errCred := aAgent.GetLocalUserCredentials()
require.NoError(t, errCred)
require.NoError(t, aAgent.Restart("", ""))
close(isComplete)
}
})
require.NoError(t, err)
connect(t, aAgent, bAgent)
<-isComplete
}
func TestRunTaskInSelectedCandidatePairChangeCallback(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 5).Stop()
oneSecond := time.Second
KeepaliveInterval := time.Duration(0)
CheckInterval := 50 * time.Millisecond
cfg := &AgentConfig{
Urls: []*stun.URI{},
NetworkTypes: supportedNetworkTypes(),
DisconnectedTimeout: &oneSecond,
FailedTimeout: &oneSecond,
KeepaliveInterval: &KeepaliveInterval,
CheckInterval: &CheckInterval,
}
aAgent, err := NewAgent(cfg)
require.NoError(t, err)
defer func() {
require.NoError(t, aAgent.Close())
}()
bAgent, err := NewAgent(cfg)
require.NoError(t, err)
defer func() {
require.NoError(t, bAgent.Close())
}()
isComplete := make(chan any)
isTested := make(chan any)
err = aAgent.OnSelectedCandidatePairChange(func(Candidate, Candidate) {
go func() {
_, _, errCred := aAgent.GetLocalUserCredentials()
require.NoError(t, errCred)
close(isTested)
}()
})
require.NoError(t, err)
err = aAgent.OnConnectionStateChange(func(c ConnectionState) {
if c == ConnectionStateConnected {
close(isComplete)
}
})
require.NoError(t, err)
connect(t, aAgent, bAgent)
<-isComplete
<-isTested
}
// Assert that a Lite agent goes to disconnected and failed.
func TestLiteLifecycle(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop()
aNotifier, aConnected := onConnected()
aAgent, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled,
})
require.NoError(t, err)
var aClosed bool
defer func() {
if aClosed {
return
}
require.NoError(t, aAgent.Close())
}()
require.NoError(t, aAgent.OnConnectionStateChange(aNotifier))
disconnectedDuration := time.Second
failedDuration := time.Second
KeepaliveInterval := time.Duration(0)
CheckInterval := 500 * time.Millisecond
bAgent, err := NewAgent(&AgentConfig{
Lite: true,
CandidateTypes: []CandidateType{CandidateTypeHost},
NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled,
DisconnectedTimeout: &disconnectedDuration,
FailedTimeout: &failedDuration,
KeepaliveInterval: &KeepaliveInterval,
CheckInterval: &CheckInterval,
})
require.NoError(t, err)
var bClosed bool
defer func() {
if bClosed {
return
}
require.NoError(t, bAgent.Close())
}()
bConnected := make(chan any)
bDisconnected := make(chan any)
bFailed := make(chan any)
require.NoError(t, bAgent.OnConnectionStateChange(func(c ConnectionState) {
switch c {
case ConnectionStateConnected:
close(bConnected)
case ConnectionStateDisconnected:
close(bDisconnected)
case ConnectionStateFailed:
close(bFailed)
default:
}
}))
connectWithVNet(t, bAgent, aAgent)
<-aConnected
<-bConnected
require.NoError(t, aAgent.Close())
aClosed = true
<-bDisconnected
<-bFailed
require.NoError(t, bAgent.Close())
bClosed = true
}
func TestValidateSelectedPairTransitions(t *testing.T) {
agent := &Agent{
disconnectedTimeout: time.Second,
failedTimeout: time.Second,
connectionState: ConnectionStateConnected,
connectionStateNotifier: &handlerNotifier{
connectionStateFunc: func(ConnectionState) {},
done: make(chan struct{}),
},
log: logging.NewDefaultLoggerFactory().NewLogger("test"),
}
local, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "1.1.1.1",
Port: 1000,
Component: ComponentRTP,
})
require.NoError(t, err)
remote, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "2.2.2.2",
Port: 2000,
Component: ComponentRTP,
})
require.NoError(t, err)
remote.setLastReceived(time.Now().Add(-3 * time.Second))
agent.selectedPair.Store(newCandidatePair(local, remote, true))
require.True(t, agent.validateSelectedPair())
require.Equal(t, ConnectionStateDisconnected, agent.connectionState)
require.True(t, agent.validateSelectedPair())
require.Equal(t, ConnectionStateFailed, agent.connectionState)
}
func TestNilCandidate(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
require.NoError(t, a.AddRemoteCandidate(nil))
require.NoError(t, a.Close())
}
func TestNilCandidatePair(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
}()
a.setSelectedPair(nil)
}
func TestGetSelectedCandidatePair(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop()
wan, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "0.0.0.0/0",
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
require.NoError(t, err)
net, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.1"},
})
require.NoError(t, err)
require.NoError(t, wan.AddNet(net))
require.NoError(t, wan.Start())
cfg := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
Net: net,
}
aAgent, err := NewAgent(cfg)
require.NoError(t, err)
defer func() {
require.NoError(t, aAgent.Close())
}()
bAgent, err := NewAgent(cfg)
require.NoError(t, err)
defer func() {
require.NoError(t, bAgent.Close())
}()
aAgentPair, err := aAgent.GetSelectedCandidatePair()
require.NoError(t, err)
require.Nil(t, aAgentPair)
bAgentPair, err := bAgent.GetSelectedCandidatePair()
require.NoError(t, err)
require.Nil(t, bAgentPair)
connect(t, aAgent, bAgent)
aAgentPair, err = aAgent.GetSelectedCandidatePair()
require.NoError(t, err)
require.NotNil(t, aAgentPair)
bAgentPair, err = bAgent.GetSelectedCandidatePair()
require.NoError(t, err)
require.NotNil(t, bAgentPair)
require.True(t, bAgentPair.Local.Equal(aAgentPair.Remote))
require.True(t, bAgentPair.Remote.Equal(aAgentPair.Local))
require.NoError(t, wan.Stop())
}
func TestAcceptAggressiveNomination(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop()
// Create a network with two interfaces
wan, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "0.0.0.0/0",
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
require.NoError(t, err)
net0, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.1"},
})
require.NoError(t, err)
require.NoError(t, wan.AddNet(net0))
net1, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.2", "192.168.0.3", "192.168.0.4"},
})
require.NoError(t, err)
require.NoError(t, wan.AddNet(net1))
require.NoError(t, wan.Start())
testCases := []struct {
name string
isLite bool
enableUseCandidateCheckPriority bool
useHigherPriority bool
isExpectedToSwitch bool
}{
{"should accept higher priority - full agent", false, false, true, true},
{"should not accept lower priority - full agent", false, false, false, false},
{"should accept higher priority - no use-candidate priority check - lite agent", true, false, true, true},
{"should accept lower priority - no use-candidate priority check - lite agent", true, false, false, true},
{"should accept higher priority - use-candidate priority check - lite agent", true, true, true, true},
{"should not accept lower priority - use-candidate priority check - lite agent", true, true, false, false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
aNotifier, aConnected := onConnected()
bNotifier, bConnected := onConnected()
KeepaliveInterval := time.Hour
cfg0 := &AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
MulticastDNSMode: MulticastDNSModeDisabled,
Net: net0,
KeepaliveInterval: &KeepaliveInterval,
CheckInterval: &KeepaliveInterval,
Lite: tc.isLite,
EnableUseCandidateCheckPriority: tc.enableUseCandidateCheckPriority,
}
if tc.isLite {
cfg0.CandidateTypes = []CandidateType{CandidateTypeHost}
}
var aAgent, bAgent *Agent
aAgent, err = NewAgent(cfg0)
require.NoError(t, err)
defer func() {
require.NoError(t, aAgent.Close())
}()
require.NoError(t, aAgent.OnConnectionStateChange(aNotifier))
cfg1 := &AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
MulticastDNSMode: MulticastDNSModeDisabled,
Net: net1,
KeepaliveInterval: &KeepaliveInterval,
CheckInterval: &KeepaliveInterval,
}
bAgent, err = NewAgent(cfg1)
require.NoError(t, err)
defer func() {
require.NoError(t, bAgent.Close())
}()
require.NoError(t, bAgent.OnConnectionStateChange(bNotifier))
connect(t, aAgent, bAgent)
// Ensure pair selected
// Note: this assumes ConnectionStateConnected is thrown after selecting the final pair
<-aConnected
<-bConnected
// Send new USE-CANDIDATE message with priority to update the selected pair
buildMsg := func(class stun.MessageClass, username, key string, priority uint32) *stun.Message {
msg, err1 := stun.Build(stun.NewType(stun.MethodBinding, class), stun.TransactionID,
stun.NewUsername(username),
stun.NewShortTermIntegrity(key),
UseCandidate(),
PriorityAttr(priority),
stun.Fingerprint,
)
require.NoError(t, err1)
return msg
}
selectedCh := make(chan Candidate, 1)
var expectNewSelectedCandidate Candidate
err = aAgent.OnSelectedCandidatePairChange(func(_, remote Candidate) {
selectedCh <- remote
})
require.NoError(t, err)
var bcandidates []Candidate
bcandidates, err = bAgent.GetLocalCandidates()
require.NoError(t, err)
for _, cand := range bcandidates {
if cand != bAgent.getSelectedPair().Local { //nolint:nestif
if expectNewSelectedCandidate == nil {
expected_change_priority:
for _, candidates := range aAgent.remoteCandidates {
for _, candidate := range candidates {
if candidate.Equal(cand) {
if tc.useHigherPriority {
candidate.(*CandidateHost).priorityOverride += 1000 //nolint:forcetypeassert
} else {
candidate.(*CandidateHost).priorityOverride -= 1000 //nolint:forcetypeassert
}
break expected_change_priority
}
}
}
if tc.isExpectedToSwitch {
expectNewSelectedCandidate = cand
} else {
expectNewSelectedCandidate = aAgent.getSelectedPair().Remote
}
} else {
// a smaller change for other candidates other the new expected one
change_priority:
for _, candidates := range aAgent.remoteCandidates {
for _, candidate := range candidates {
if candidate.Equal(cand) {
if tc.useHigherPriority {
candidate.(*CandidateHost).priorityOverride += 500 //nolint:forcetypeassert
} else {
candidate.(*CandidateHost).priorityOverride -= 500 //nolint:forcetypeassert
}
break change_priority
}
}
}
}
_, err = cand.writeTo(
buildMsg(
stun.ClassRequest,
aAgent.localUfrag+":"+aAgent.remoteUfrag,
aAgent.localPwd,
cand.Priority(),
).Raw,
bAgent.getSelectedPair().Remote,
)
require.NoError(t, err)
}
}
// Wait until either we observe the expected switch or the timeout elapses,
// Ugly but makes the tests less flaky, especially on Windows.
timeout := 3 * time.Second
deadline := time.Now().Add(timeout)
observedExpected := false
waitLoop:
for time.Now().Before(deadline) {
select {
case selected := <-selectedCh:
if tc.isExpectedToSwitch {
if selected.Equal(expectNewSelectedCandidate) {
observedExpected = true
break waitLoop
}
}
default:
time.Sleep(10 * time.Millisecond)
}
}
if tc.isExpectedToSwitch {
if !observedExpected {
// Verify the agent's final selected pair if we didn't observe the event directly.
require.True(t, aAgent.getSelectedPair().Remote.Equal(expectNewSelectedCandidate))
}
} else {
// Ensure no switch happened by checking the agent's final selected pair.
require.True(t, aAgent.getSelectedPair().Remote.Equal(expectNewSelectedCandidate))
}
})
}
require.NoError(t, wan.Stop())
}
// Close can deadlock but GracefulClose must not.
func TestAgentGracefulCloseDeadlock(t *testing.T) {
defer test.CheckRoutinesStrict(t)()
defer test.TimeOut(time.Second * 5).Stop()
config := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
}
aAgent, err := NewAgent(config)
require.NoError(t, err)
var aAgentClosed bool
defer func() {
if aAgentClosed {
return
}
require.NoError(t, aAgent.Close())
}()
bAgent, err := NewAgent(config)
require.NoError(t, err)
var bAgentClosed bool
defer func() {
if bAgentClosed {
return
}
require.NoError(t, bAgent.Close())
}()
var connected, closeNow, closed sync.WaitGroup
connected.Add(2)
closeNow.Add(1)
closed.Add(2)
closeHdlr := func(agent *Agent, agentClosed *bool) {
require.NoError(t, agent.OnConnectionStateChange(func(cs ConnectionState) {
if cs == ConnectionStateConnected {
connected.Done()
closeNow.Wait()
go func() {
require.NoError(t, agent.GracefulClose())
*agentClosed = true
closed.Done()
}()
}
}))
}
closeHdlr(aAgent, &aAgentClosed)
closeHdlr(bAgent, &bAgentClosed)
t.Log("connecting agents")
_, _ = connect(t, aAgent, bAgent)
t.Log("waiting for them to confirm connection in callback")
connected.Wait()
t.Log("tell them to close themselves in the same callback and wait")
closeNow.Done()
closed.Wait()
}
func TestSetCandidatesUfrag(t *testing.T) {
var config AgentConfig
agent, err := NewAgent(&config)
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
dummyConn := &net.UDPConn{}
for i := range 5 {
cfg := CandidateHostConfig{
Network: "udp",
Address: "192.168.0.2",
Port: 1000 + i,
Component: 1,
}
cand, errCand := NewCandidateHost(&cfg)
require.NoError(t, errCand)
err = agent.addCandidate(context.Background(), cand, dummyConn)
require.NoError(t, err)
}
actualCandidates, err := agent.GetLocalCandidates()
require.NoError(t, err)
for _, candidate := range actualCandidates {
ext, ok := candidate.GetExtension("ufrag")
require.True(t, ok)
require.Equal(t, agent.localUfrag, ext.Value)
}
}
func TestAlwaysSentKeepAlive(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
// Avoid deadlocks?
defer test.TimeOut(1 * time.Second).Stop()
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
log := logging.NewDefaultLoggerFactory().NewLogger("agent")
agent.selector = &controllingSelector{agent: agent, log: log}
pair := makeCandidatePair(t)
s, ok := pair.Local.(*CandidateHost)
require.True(t, ok)
s.conn = &fakenet.MockPacketConn{}
agent.setSelectedPair(pair)
pair.Remote.seen(false)
lastSent := pair.Local.LastSent()
agent.checkKeepalive()
newLastSent := pair.Local.LastSent()
require.NotEqual(t, lastSent, newLastSent)
lastSent = newLastSent
// Wait for enough time to pass so there is difference in sent time of local candidate.
require.Eventually(t, func() bool {
agent.checkKeepalive()
newLastSent = pair.Local.LastSent()
return !lastSent.Equal(newLastSent)
}, 1*time.Second, 50*time.Millisecond)
}
func TestRoleConflict(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop()
runTest := func(t *testing.T, doDial bool) {
t.Helper()
cfg := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled,
InterfaceFilter: problematicNetworkInterfaces,
}
aAgent, err := NewAgent(cfg)
require.NoError(t, err)
bAgent, err := NewAgent(cfg)
require.NoError(t, err)
isConnected := make(chan any)
err = aAgent.OnConnectionStateChange(func(c ConnectionState) {
if c == ConnectionStateConnected {
close(isConnected)
}
})
require.NoError(t, err)
gatherAndExchangeCandidates(t, aAgent, bAgent)
go func() {
ufrag, pwd, routineErr := bAgent.GetLocalUserCredentials()
require.NoError(t, routineErr)
if doDial {
_, routineErr = aAgent.Dial(context.TODO(), ufrag, pwd)
} else {
_, routineErr = aAgent.Accept(context.TODO(), ufrag, pwd)
}
require.NoError(t, routineErr)
}()
ufrag, pwd, err := aAgent.GetLocalUserCredentials()
require.NoError(t, err)
if doDial {
_, err = bAgent.Dial(context.TODO(), ufrag, pwd)
} else {
_, err = bAgent.Accept(context.TODO(), ufrag, pwd)
}
require.NoError(t, err)
<-isConnected
require.NoError(t, aAgent.Close())
require.NoError(t, bAgent.Close())
}
t.Run("Controlling", func(t *testing.T) {
runTest(t, true)
})
t.Run("Controlled", func(t *testing.T) {
runTest(t, false)
})
}
func TestDefaultCandidateTypes(t *testing.T) {
expected := []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive, CandidateTypeRelay}
first := defaultCandidateTypes()
require.Equal(t, expected, first)
first[0] = CandidateTypeRelay
second := defaultCandidateTypes()
require.Equal(t, expected, second)
}
func TestDefaultRelayAcceptanceMinWaitFor(t *testing.T) {
t.Run("relay only defaults to zero wait", func(t *testing.T) {
wait := defaultRelayAcceptanceMinWaitFor([]CandidateType{CandidateTypeRelay})
require.Equal(t, defaultRelayOnlyAcceptanceMinWait, wait)
})
t.Run("empty candidate types uses general relay wait", func(t *testing.T) {
wait := defaultRelayAcceptanceMinWaitFor(nil)
require.Equal(t, defaultRelayAcceptanceMinWait, wait)
})
t.Run("mixed candidate types uses general relay wait", func(t *testing.T) {
wait := defaultRelayAcceptanceMinWaitFor([]CandidateType{CandidateTypeHost, CandidateTypeRelay})
require.Equal(t, defaultRelayAcceptanceMinWait, wait)
})
}
func TestAgentConfig_initWithDefaults_UsesProvidedValues(t *testing.T) {
valMaxBindingReq := uint16(0)
valSrflxWait := 111 * time.Millisecond
valPrflxWait := 222 * time.Millisecond
valRelayWait := 3 * time.Second
valStunTimeout := 4 * time.Second
cfg := &AgentConfig{
MaxBindingRequests: &valMaxBindingReq,
SrflxAcceptanceMinWait: &valSrflxWait,
PrflxAcceptanceMinWait: &valPrflxWait,
RelayAcceptanceMinWait: &valRelayWait,
STUNGatherTimeout: &valStunTimeout,
}
var a Agent
cfg.initWithDefaults(&a)
require.Equal(t, valMaxBindingReq, a.maxBindingRequests, "expected override for MaxBindingRequests")
require.Equal(t, valSrflxWait, a.srflxAcceptanceMinWait, "expected override for SrflxAcceptanceMinWait")
require.Equal(t, valPrflxWait, a.prflxAcceptanceMinWait, "expected override for PrflxAcceptanceMinWait")
require.Equal(t, valRelayWait, a.relayAcceptanceMinWait, "expected override for RelayAcceptanceMinWait")
require.Equal(t, valStunTimeout, a.stunGatherTimeout, "expected override for STUNGatherTimeout")
}
// TestAutomaticRenominationWithVNet tests automatic renomination with simple vnet setup.
// This is a simplified test that verifies the renomination mechanism triggers correctly.
func TestAutomaticRenominationWithVNet(t *testing.T) {
defer test.CheckRoutines(t)()
loggerFactory := logging.NewDefaultLoggerFactory()
// Create simple vnet with two agents on same network (no NAT)
wan, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "0.0.0.0/0",
LoggerFactory: loggerFactory,
})
require.NoError(t, err)
net0, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.1"},
})
require.NoError(t, err)
require.NoError(t, wan.AddNet(net0))
net1, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.2"},
})
require.NoError(t, err)
require.NoError(t, wan.AddNet(net1))
require.NoError(t, wan.Start())
defer wan.Stop() //nolint:errcheck
// Create agents with automatic renomination
keepaliveInterval := 100 * time.Millisecond
checkInterval := 50 * time.Millisecond
renominationInterval := 200 * time.Millisecond
agent1, err := newAgentFromConfig(&AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4},
MulticastDNSMode: MulticastDNSModeDisabled,
Net: net0,
KeepaliveInterval: &keepaliveInterval,
CheckInterval: &checkInterval,
},
WithRenomination(DefaultNominationValueGenerator()),
WithAutomaticRenomination(renominationInterval),
)
require.NoError(t, err)
defer agent1.Close() //nolint:errcheck
agent2, err := NewAgent(&AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4},
MulticastDNSMode: MulticastDNSModeDisabled,
Net: net1,
KeepaliveInterval: &keepaliveInterval,
CheckInterval: &checkInterval,
})
require.NoError(t, err)
defer agent2.Close() //nolint:errcheck
agent2.enableRenomination = true
agent2.nominationValueGenerator = DefaultNominationValueGenerator()
// Connect the agents using the existing helper
conn1, conn2 := connectWithVNet(t, agent1, agent2)
// Verify connection works
testData := []byte("test data")
_, err = conn1.Write(testData)
require.NoError(t, err)
buf := make([]byte, len(testData))
_, err = conn2.Read(buf)
require.NoError(t, err)
require.Equal(t, testData, buf)
}
// TestAutomaticRenominationRTTImprovement tests that automatic renomination
// triggers when RTT significantly improves.
func TestAutomaticRenominationRTTImprovement(t *testing.T) {
defer test.CheckRoutines(t)()
// This test verifies the RTT-based renomination logic
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer agent.Close() //nolint:errcheck
// Create two pairs with different RTTs
localHost1, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "192.168.1.1",
Port: 10000,
Component: 1,
})
require.NoError(t, err)
localHost2, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "192.168.1.3", // Different address
Port: 10001,
Component: 1,
})
require.NoError(t, err)
remoteHost, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "192.168.1.2",
Port: 20000,
Component: 1,
})
require.NoError(t, err)
// Current pair with high RTT
currentPair := newCandidatePair(localHost1, remoteHost, true)
currentPair.state = CandidatePairStateSucceeded
currentPair.UpdateRoundTripTime(100 * time.Millisecond)
// Candidate pair with significantly better RTT (>10ms improvement)
betterPair := newCandidatePair(localHost2, remoteHost, true)
betterPair.state = CandidatePairStateSucceeded
betterPair.UpdateRoundTripTime(50 * time.Millisecond) // 50ms improvement
// Should trigger renomination due to RTT improvement
shouldRenominate := agent.shouldRenominate(currentPair, betterPair)
require.True(t, shouldRenominate, "Should renominate for >10ms RTT improvement")
// Test with small RTT improvement (<10ms)
slightlyBetterPair := newCandidatePair(localHost2, remoteHost, true)
slightlyBetterPair.state = CandidatePairStateSucceeded
slightlyBetterPair.UpdateRoundTripTime(95 * time.Millisecond) // Only 5ms improvement
shouldRenominate = agent.shouldRenominate(currentPair, slightlyBetterPair)
require.False(t, shouldRenominate, "Should not renominate for <10ms RTT improvement")
}
// TestAutomaticRenominationRelayToDirect tests that automatic renomination
// always prefers direct connections over relay connections.
func TestAutomaticRenominationRelayToDirect(t *testing.T) {
defer test.CheckRoutines(t)()
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer agent.Close() //nolint:errcheck
// Create relay pair
localRelay, err := NewCandidateRelay(&CandidateRelayConfig{
Network: "udp",
Address: "10.0.0.1",
Port: 30000,
Component: 1,
RelAddr: "192.168.1.1",
RelPort: 10000,
})
require.NoError(t, err)
remoteRelay, err := NewCandidateRelay(&CandidateRelayConfig{
Network: "udp",
Address: "10.0.0.2",
Port: 40000,
Component: 1,
RelAddr: "192.168.1.2",
RelPort: 20000,
})
require.NoError(t, err)
relayPair := newCandidatePair(localRelay, remoteRelay, true)
relayPair.state = CandidatePairStateSucceeded
relayPair.UpdateRoundTripTime(50 * time.Millisecond)
// Create host pair with similar RTT
localHost, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "192.168.1.1",
Port: 10000,
Component: 1,
})
require.NoError(t, err)
remoteHost, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
Address: "192.168.1.2",
Port: 20000,
Component: 1,
})
require.NoError(t, err)
hostPair := newCandidatePair(localHost, remoteHost, true)
hostPair.state = CandidatePairStateSucceeded
hostPair.UpdateRoundTripTime(45 * time.Millisecond) // Similar RTT
// Should always prefer direct over relay
shouldRenominate := agent.shouldRenominate(relayPair, hostPair)
require.True(t, shouldRenominate, "Should always renominate from relay to direct connection")
}
func TestAgentUpdateOptions(t *testing.T) {
defer test.CheckRoutines(t)()
t.Run("URLs can be updated on a running agent", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
}()
newURLs := []*stun.URI{
{Scheme: SchemeTypeSTUN, Host: "1.2.3.4", Port: 3478, Proto: stun.ProtoTypeUDP},
}
require.NoError(t, a.UpdateOptions(WithUrls(newURLs)))
})
t.Run("UpdateOptions on closed agent fails", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
require.NoError(t, a.Close())
require.Equal(t, ErrClosed, a.UpdateOptions(WithUrls([]*stun.URI{})))
})
t.Run("Non-updatable options are rejected", func(t *testing.T) { //nolint:varnamelen
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
// All options except WithUrls should be rejected on a running agent.
// When adding new options, add them here if they are not runtime-updatable.
nonUpdatableOptions := map[string]AgentOption{
"WithAddressRewriteRules": WithAddressRewriteRules(AddressRewriteRule{External: []string{"1.2.3.4"}}),
"WithICELite": WithICELite(true),
"WithPortRange": WithPortRange(5000, 6000),
"WithDisconnectedTimeout": WithDisconnectedTimeout(time.Second),
"WithFailedTimeout": WithFailedTimeout(time.Second),
"WithKeepaliveInterval": WithKeepaliveInterval(time.Second),
"WithHostAcceptanceMinWait": WithHostAcceptanceMinWait(time.Second),
"WithSrflxAcceptanceMinWait": WithSrflxAcceptanceMinWait(time.Second),
"WithPrflxAcceptanceMinWait": WithPrflxAcceptanceMinWait(time.Second),
"WithRelayAcceptanceMinWait": WithRelayAcceptanceMinWait(time.Second),
"WithSTUNGatherTimeout": WithSTUNGatherTimeout(time.Second),
"WithIPFilter": WithIPFilter(func(net.IP) bool { return true }),
"WithRemoteIPFilter": WithRemoteIPFilter(func(net.IP) bool { return true }),
"WithNet": WithNet(nil),
"WithMulticastDNSMode": WithMulticastDNSMode(MulticastDNSModeDisabled),
"WithMulticastDNSHostName": WithMulticastDNSHostName("test.local"),
"WithLocalCredentials": WithLocalCredentials("", ""),
"WithTCPMux": WithTCPMux(nil),
"WithUDPMux": WithUDPMux(nil),
"WithUDPMuxSrflx": WithUDPMuxSrflx(nil),
"WithProxyDialer": WithProxyDialer(nil),
"WithMaxBindingRequests": WithMaxBindingRequests(10),
"WithCheckInterval": WithCheckInterval(time.Second),
"WithRenomination": WithRenomination(DefaultNominationValueGenerator()),
"WithNominationAttribute": WithNominationAttribute(0x0030),
"WithIncludeLoopback": WithIncludeLoopback(),
"WithTCPPriorityOffset": WithTCPPriorityOffset(10),
"WithDisableActiveTCP": WithDisableActiveTCP(),
"WithBindingRequestHandler": WithBindingRequestHandler(nil),
"WithEnableUseCandidateCheckPriority": WithEnableUseCandidateCheckPriority(),
"WithContinualGatheringPolicy": WithContinualGatheringPolicy(GatherOnce),
"WithNetworkMonitorInterval": WithNetworkMonitorInterval(time.Second),
"WithNetworkTypes": WithNetworkTypes([]NetworkType{NetworkTypeUDP4}),
"WithTURNTransportProtocols": WithTURNTransportProtocols([]NetworkType{NetworkTypeTCP4}),
"WithCandidateTypes": WithCandidateTypes([]CandidateType{CandidateTypeHost}),
"WithAutomaticRenomination": WithAutomaticRenomination(time.Second),
"WithInterfaceFilter": WithInterfaceFilter(func(string) bool { return true }),
"WithLoggerFactory": WithLoggerFactory(nil),
}
for name, opt := range nonUpdatableOptions {
err := agent.UpdateOptions(opt)
require.ErrorIs(t, err, ErrAgentOptionNotUpdatable, "option %s should not be updatable", name)
}
})
}
func TestRemoteDialIPForLocalInterface(t *testing.T) {
t.Run("adds local zone for zone-less link-local IPv6", func(t *testing.T) {
remote := netip.MustParseAddr("fe80::1234")
local := netip.MustParseAddr("fe80::1%eth0")
got := remoteDialIPForLocalInterface(remote, local)
require.Equal(t, netip.MustParseAddr("fe80::1234%eth0"), got)
})
t.Run("keeps existing remote zone", func(t *testing.T) {
remote := netip.MustParseAddr("fe80::1234%eth9")
local := netip.MustParseAddr("fe80::1%eth0")
got := remoteDialIPForLocalInterface(remote, local)
require.Equal(t, remote, got)
})
t.Run("does not modify global IPv6", func(t *testing.T) {
remote := netip.MustParseAddr("2001:db8::1234")
local := netip.MustParseAddr("fe80::1%eth0")
got := remoteDialIPForLocalInterface(remote, local)
require.Equal(t, remote, got)
})
t.Run("does not modify zone-less link-local when local has no zone", func(t *testing.T) {
remote := netip.MustParseAddr("fe80::1234")
local := netip.MustParseAddr("2001:db8::1")
got := remoteDialIPForLocalInterface(remote, local)
require.Equal(t, remote, got)
})
}
func TestMDNSQueryTimeout(t *testing.T) {
t.Run("falls back to default when unset", func(t *testing.T) {
agent := &Agent{}
require.Equal(t, defaultSTUNGatherTimeout, agent.mDNSQueryTimeout())
})
t.Run("uses configured stun gather timeout when set", func(t *testing.T) {
agent := &Agent{stunGatherTimeout: 3 * time.Second}
require.Equal(t, 3*time.Second, agent.mDNSQueryTimeout())
})
}
func TestAddRemoteCandidateIndependentFromTURNTransportSelection(t *testing.T) {
t.Run("accepts UDP relay candidate with tcp-only configured network types and TURN/TCP URL", func(t *testing.T) {
agent, err := NewAgent(&AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeTCP4},
CandidateTypes: []CandidateType{CandidateTypeRelay},
Urls: []*stun.URI{{
Scheme: stun.SchemeTypeTURN,
Proto: stun.ProtoTypeTCP,
Host: "turn.example.com",
Port: 3478,
Username: "user",
Password: "pass",
}},
})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
cand, err := NewCandidateRelay(&CandidateRelayConfig{
Network: udp,
Address: "198.51.100.2",
Port: 5000,
Component: ComponentRTP,
RelAddr: "192.0.2.10",
RelPort: 4000,
})
require.NoError(t, err)
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
accepted := agent.addRemoteCandidate(cand) // nolint:contextcheck
require.True(t, accepted)
require.Len(t, agent.remoteCandidates[NetworkTypeUDP4], 1)
}))
})
// nolint:dupl
t.Run("accepts UDP host candidate with tcp-only configured network types and TURN/TCP URL", func(t *testing.T) {
agent, err := NewAgent(&AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeTCP4},
CandidateTypes: []CandidateType{CandidateTypeRelay},
Urls: []*stun.URI{{
Scheme: stun.SchemeTypeTURN,
Proto: stun.ProtoTypeTCP,
Host: "turn.example.com",
Port: 3478,
Username: "user",
Password: "pass",
}},
})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
cand, err := UnmarshalCandidate("1052353102 1 udp 1675624447 198.51.100.20 5002 typ host")
require.NoError(t, err)
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
accepted := agent.addRemoteCandidate(cand) // nolint:contextcheck
require.True(t, accepted)
require.Len(t, agent.remoteCandidates[NetworkTypeUDP4], 1)
}))
})
// nolint:dupl
t.Run("accepts UDP srflx candidate with tcp-only configured network types and TURN/TCP URL", func(t *testing.T) {
agent, err := NewAgent(&AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeTCP4},
CandidateTypes: []CandidateType{CandidateTypeRelay},
Urls: []*stun.URI{{
Scheme: stun.SchemeTypeTURN,
Proto: stun.ProtoTypeTCP,
Host: "turn.example.com",
Port: 3478,
Username: "user",
Password: "pass",
}},
})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
cand, err := UnmarshalCandidate(
"1052353102 1 udp 1675624447 198.51.100.21 5003 typ srflx raddr 192.0.2.21 rport 4003")
require.NoError(t, err)
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
accepted := agent.addRemoteCandidate(cand) // nolint:contextcheck
require.True(t, accepted)
require.Len(t, agent.remoteCandidates[NetworkTypeUDP4], 1)
}))
})
t.Run("stores UDP relay candidate regardless of TURN URL transport", func(t *testing.T) {
agent, err := NewAgent(&AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeTCP4},
CandidateTypes: []CandidateType{CandidateTypeRelay},
Urls: []*stun.URI{{
Scheme: stun.SchemeTypeTURN,
Proto: stun.ProtoTypeUDP,
Host: "turn.example.com",
Port: 3478,
}},
})
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
cand, err := NewCandidateRelay(&CandidateRelayConfig{
Network: udp,
Address: "198.51.100.3",
Port: 5001,
Component: ComponentRTP,
RelAddr: "192.0.2.11",
RelPort: 4001,
})
require.NoError(t, err)
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
accepted := agent.addRemoteCandidate(cand) // nolint:contextcheck
require.True(t, accepted)
require.Len(t, agent.remoteCandidates[NetworkTypeUDP4], 1)
}))
})
}
type localAddrTCPMux struct {
addr net.Addr
}
func (m *localAddrTCPMux) Close() error { return nil }
func (m *localAddrTCPMux) GetConnByUfrag(string, bool, net.IP) (net.PacketConn, error) {
return nil, nil //nolint:nilnil
}
func (m *localAddrTCPMux) RemoveConnByUfrag(string) {}
func (m *localAddrTCPMux) LocalAddr() net.Addr { return m.addr }
func TestMDNSLocalAddressFromTCPMux(t *testing.T) {
t.Run("nil for mixed network types", func(t *testing.T) {
mux := &localAddrTCPMux{addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::1")}}
require.Nil(t, mDNSLocalAddressFromTCPMux(mux, []NetworkType{NetworkTypeTCP6, NetworkTypeUDP6}))
})
t.Run("nil for link-local IPv6 listener", func(t *testing.T) {
mux := &localAddrTCPMux{addr: &net.TCPAddr{IP: net.ParseIP("fe80::1"), Zone: "eth0"}}
require.Nil(t, mDNSLocalAddressFromTCPMux(mux, []NetworkType{NetworkTypeTCP6}))
})
t.Run("uses listener IP for TCP-only global IPv6", func(t *testing.T) {
want := net.ParseIP("2001:db8::1")
mux := &localAddrTCPMux{addr: &net.TCPAddr{IP: want, Zone: "wg0"}}
got := mDNSLocalAddressFromTCPMux(mux, []NetworkType{NetworkTypeTCP6})
require.Equal(t, want.To16(), got)
})
t.Run("nil for TCP mux without LocalAddr", func(t *testing.T) {
require.Nil(t, mDNSLocalAddressFromTCPMux(&stubTCPMux{}, []NetworkType{NetworkTypeTCP4}))
})
}