Fix race and disconnection after candidate replace (#913)

This PR fixes two race conditions introduced by peer-reflexive candidate
replacement and addresses a connectivity regression where an agent could
transition from Connected to Disconnected/Failed shortly after
replacement.

## Problems

1. Data race in candidate pair access:

- `replaceRedundantPeerReflexiveCandidates` updated `pair.Remote` in
place.
- At the same time, hot-path reads (for example `CandidatePair.Write`)
could read `pair.Remote` without synchronization.

2. Data race in remote candidate cache map:

- `candidateBase.handleInboundPacket` (non-STUN path) read/wrote
`remoteCandidateCaches` from recv loop goroutines.
- Concurrent writes also happened from agent-loop code
(`replaceRemoteCandidateCacheValues`).

3. Connectivity regression after replacement:

- When replacing a selected pair's remote candidate (prflx -> signaled),
the new candidate could have zero `LastReceived`/`LastSent`.
- `validateSelectedPair` uses `selectedPair.Remote.LastReceived()`,
which could cause quick `Connected -> Disconnected -> Failed`
transitions.

## Root Cause

Replacement logic preserved pair priority but mutated shared objects in
place and did not preserve candidate activity timestamps across
candidate object replacement.

## Fixes

### 1) Replace candidate pair objects instead of mutating `pair.Remote`

- Added helper `replacePairRemote(pair, remote)` to clone pair state
into a new `CandidatePair`.
- Preserved pair fields and runtime stats:
  - `id`, role, state, nomination flags, binding counters.
  - RTT fields, packet/byte counters, request/response counters.
  - timestamp `atomic.Value` fields.
- Replaced references in:
  - `a.checklist[i]`
  - `a.pairsByID[pair.id]`
- If old pair was selected, published replacement via
`a.setSelectedPair(replacement)`.

### 2) Serialize remote cache access on agent loop

- Updated non-STUN path in `candidateBase.handleInboundPacket`:
- Cache validation, remote candidate lookup, `seen(false)`, and cache
insert now run inside `agent.loop.Run(...)`.
- This removes concurrent map read/write between recv loop and
agent-loop updates.

### 3) Preserve candidate activity on replacement

- Added `copyCandidateActivity(dst, src)` to transfer:
  - `LastReceived`
  - `LastSent`
- Applied before replacing references so selected-pair liveness checks
remain stable.
This commit is contained in:
sirzooro
2026-04-21 06:27:53 +02:00
committed by GitHub
parent e6483a6066
commit ceccc15191
8 changed files with 207 additions and 67 deletions
+122 -28
View File
@@ -1118,6 +1118,118 @@ func remoteDialIPForLocalInterface(remoteIP, localIP netip.Addr) netip.Addr {
return remoteIP
}
func copyAtomicValue(dst, src *atomic.Value) {
if value := src.Load(); value != nil {
dst.Store(value)
}
}
type candidateActivitySetter interface {
setLastReceived(time.Time)
setLastSent(time.Time)
}
func copyCandidateActivity(dst, src Candidate) {
setter, ok := dst.(candidateActivitySetter)
if !ok {
return
}
if lastReceived := src.LastReceived(); !lastReceived.IsZero() && dst.LastReceived().IsZero() {
setter.setLastReceived(lastReceived)
}
if lastSent := src.LastSent(); !lastSent.IsZero() && dst.LastSent().IsZero() {
setter.setLastSent(lastSent)
}
}
func replacePairRemote(pair *CandidatePair, remote Candidate) *CandidatePair {
replacement := newCandidatePair(pair.Local, remote, pair.iceRoleControlling)
replacement.id = pair.id
replacement.bindingRequestCount = pair.bindingRequestCount
replacement.state = pair.state
replacement.nominated = pair.nominated
replacement.nominateOnBindingSuccess = pair.nominateOnBindingSuccess
atomic.StoreInt64(&replacement.currentRoundTripTime, atomic.LoadInt64(&pair.currentRoundTripTime))
atomic.StoreInt64(&replacement.totalRoundTripTime, atomic.LoadInt64(&pair.totalRoundTripTime))
atomic.StoreUint32(&replacement.packetsSent, atomic.LoadUint32(&pair.packetsSent))
atomic.StoreUint32(&replacement.packetsReceived, atomic.LoadUint32(&pair.packetsReceived))
atomic.StoreUint64(&replacement.bytesSent, atomic.LoadUint64(&pair.bytesSent))
atomic.StoreUint64(&replacement.bytesReceived, atomic.LoadUint64(&pair.bytesReceived))
atomic.StoreUint64(&replacement.requestsReceived, atomic.LoadUint64(&pair.requestsReceived))
atomic.StoreUint64(&replacement.requestsSent, atomic.LoadUint64(&pair.requestsSent))
atomic.StoreUint64(&replacement.responsesReceived, atomic.LoadUint64(&pair.responsesReceived))
atomic.StoreUint64(&replacement.responsesSent, atomic.LoadUint64(&pair.responsesSent))
copyAtomicValue(&replacement.lastPacketSentAt, &pair.lastPacketSentAt)
copyAtomicValue(&replacement.lastPacketReceivedAt, &pair.lastPacketReceivedAt)
copyAtomicValue(&replacement.firstRequestSentAt, &pair.firstRequestSentAt)
copyAtomicValue(&replacement.lastRequestSentAt, &pair.lastRequestSentAt)
copyAtomicValue(&replacement.firstResponseReceivedAt, &pair.firstResponseReceivedAt)
copyAtomicValue(&replacement.lastResponseReceivedAt, &pair.lastResponseReceivedAt)
copyAtomicValue(&replacement.firstRequestReceivedAt, &pair.firstRequestReceivedAt)
copyAtomicValue(&replacement.lastRequestReceivedAt, &pair.lastRequestReceivedAt)
return replacement
}
func (a *Agent) retargetKnownPairHolders(oldPair, newPair *CandidatePair) {
selector := a.getSelector()
switch s := selector.(type) {
case *controllingSelector:
if s.nominatedPair == oldPair {
s.nominatedPair = newPair
}
case *liteSelector:
if cs, ok := s.pairCandidateSelector.(*controllingSelector); ok && cs.nominatedPair == oldPair {
cs.nominatedPair = newPair
}
}
}
func removeRedundantPrflxFromSet(set []Candidate, cand Candidate) ([]Candidate, []Candidate) {
var replacedPrflx []Candidate
for i := 0; i < len(set); i++ {
existing := set[i]
if existing.Type() == CandidateTypePeerReflexive && existing.transportAddressEqual(cand) {
replacedPrflx = append(replacedPrflx, existing)
set = append(set[:i], set[i+1:]...)
i--
}
}
return set, replacedPrflx
}
func (a *Agent) replaceRemoteInPairs(oldRemote, newRemote Candidate) {
for i, pair := range a.checklist {
if pair.Remote == oldRemote {
oldPriority := pair.priority()
replacement := replacePairRemote(pair, newRemote)
replacement.setPriorityOverride(oldPriority)
a.checklist[i] = replacement
a.pairsByID[replacement.id] = replacement
a.retargetKnownPairHolders(pair, replacement)
if a.getSelectedPair() == pair {
a.setSelectedPair(replacement)
}
}
}
}
func (a *Agent) replaceRemoteInLocalCaches(oldRemote, newRemote Candidate) {
for _, locals := range a.localCandidates {
for _, local := range locals {
local.replaceRemoteCandidateCacheValues(oldRemote, newRemote)
}
}
}
// replaceRedundantPeerReflexiveCandidates removes any peer-reflexive candidates
// from the given set that have the same transport address as cand.
// It also updates any candidate pairs and local candidate caches that
@@ -1125,36 +1237,18 @@ func remoteDialIPForLocalInterface(remoteIP, localIP netip.Addr) netip.Addr {
// It is implemented according to RFC 8838 §11.4.
// It returns the updated set of candidates.
func (a *Agent) replaceRedundantPeerReflexiveCandidates(set []Candidate, cand Candidate) []Candidate {
if cand.Type() != CandidateTypePeerReflexive {
var replacedPrflx []Candidate
for i := 0; i < len(set); i++ {
existing := set[i]
if existing.Type() == CandidateTypePeerReflexive && existing.transportAddressEqual(cand) {
replacedPrflx = append(replacedPrflx, existing)
set = append(set[:i], set[i+1:]...)
i--
}
}
for _, oldRemote := range replacedPrflx {
for _, pair := range a.checklist {
if pair.Remote == oldRemote {
oldPriority := pair.priority()
pair.Remote = cand
pair.setPriorityOverride(oldPriority)
}
}
for _, locals := range a.localCandidates {
for _, local := range locals {
local.replaceRemoteCandidateCacheValues(oldRemote, cand)
}
}
}
if cand.Type() == CandidateTypePeerReflexive {
return set
}
return set
updatedSet, replacedPrflx := removeRedundantPrflxFromSet(set, cand)
for _, oldRemote := range replacedPrflx {
copyCandidateActivity(cand, oldRemote)
a.replaceRemoteInPairs(oldRemote, cand)
a.replaceRemoteInLocalCaches(oldRemote, cand)
}
return updatedSet
}
// addRemoteCandidate assumes you are holding the lock (must be execute using a.run).
+57 -16
View File
@@ -65,7 +65,8 @@ func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop,maintidx
}()
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.selector = &controllingSelector{agent: agent, log: agent.log}
sel := &controllingSelector{agent: agent, log: agent.log}
agent.selector = sel
hostConfig := CandidateHostConfig{
Network: "udp",
@@ -115,7 +116,8 @@ func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop,maintidx
}()
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.selector = &controllingSelector{agent: agent, log: agent.log}
sel := &controllingSelector{agent: agent, log: agent.log}
agent.selector = sel
local, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
@@ -159,7 +161,8 @@ func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop,maintidx
}()
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.selector = &controllingSelector{agent: agent, log: agent.log}
sel := &controllingSelector{agent: agent, log: agent.log}
agent.selector = sel
local, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
@@ -194,7 +197,13 @@ func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop,maintidx
require.Equal(t, prflx, pair.Remote)
local.addRemoteCandidateCache(prflx, remote)
prflx.seen(false)
prflx.seen(true)
oldLastReceived := prflx.LastReceived()
oldLastSent := prflx.LastSent()
oldPriority := pair.priority()
agent.setSelectedPair(pair)
sel.nominatedPair = pair
host, err := NewCandidateHost(&CandidateHostConfig{
Network: "udp",
@@ -209,9 +218,22 @@ func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop,maintidx
require.Len(t, set, 1)
require.Equal(t, CandidateTypeHost, set[0].Type())
require.Equal(t, host, set[0])
require.Equal(t, host, pair.Remote)
require.Equal(t, oldPriority, pair.priority())
require.Equal(t, host, local.remoteCandidateCaches[toAddrPort(remote)])
updatedPair := agent.findPair(local, host)
require.NotNil(t, updatedPair)
require.NotSame(t, pair, updatedPair)
require.Equal(t, host, updatedPair.Remote)
require.Equal(t, oldPriority, updatedPair.priority())
require.False(t, updatedPair.Remote.LastReceived().IsZero())
require.WithinDuration(t, oldLastReceived, updatedPair.Remote.LastReceived(), 10*time.Millisecond)
require.False(t, updatedPair.Remote.LastSent().IsZero())
require.WithinDuration(t, oldLastSent, updatedPair.Remote.LastSent(), 10*time.Millisecond)
require.Equal(t, prflx, pair.Remote)
require.Same(t, updatedPair, agent.getSelectedPair())
require.Same(t, updatedPair, sel.nominatedPair)
cached, ok := local.remoteCandidateCaches.Load(toAddrPort(remote))
require.True(t, ok)
require.Equal(t, host, cached)
}))
})
@@ -275,9 +297,16 @@ func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop,maintidx
require.Len(t, set, 1)
require.Equal(t, CandidateTypeServerReflexive, set[0].Type())
require.Equal(t, srflx, set[0])
require.Equal(t, srflx, pair.Remote)
require.Equal(t, oldPriority, pair.priority())
require.Equal(t, srflx, local.remoteCandidateCaches[toAddrPort(remote)])
updatedPair := agent.findPair(local, srflx)
require.NotNil(t, updatedPair)
require.NotSame(t, pair, updatedPair)
require.Equal(t, srflx, updatedPair.Remote)
require.Equal(t, oldPriority, updatedPair.priority())
require.Equal(t, prflx, pair.Remote)
cached, ok := local.remoteCandidateCaches.Load(toAddrPort(remote))
require.True(t, ok)
require.Equal(t, srflx, cached)
}))
})
@@ -341,9 +370,16 @@ func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop,maintidx
require.Len(t, set, 1)
require.Equal(t, CandidateTypeRelay, set[0].Type())
require.Equal(t, relay, set[0])
require.Equal(t, relay, pair.Remote)
require.Equal(t, oldPriority, pair.priority())
require.Equal(t, relay, local.remoteCandidateCaches[toAddrPort(remote)])
updatedPair := agent.findPair(local, relay)
require.NotNil(t, updatedPair)
require.NotSame(t, pair, updatedPair)
require.Equal(t, relay, updatedPair.Remote)
require.Equal(t, oldPriority, updatedPair.priority())
require.Equal(t, prflx, pair.Remote)
cached, ok := local.remoteCandidateCaches.Load(toAddrPort(remote))
require.True(t, ok)
require.Equal(t, relay, cached)
}))
})
@@ -403,16 +439,21 @@ func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop,maintidx
})
require.NoError(t, err)
agent.addRemoteCandidate(signaled) // nolint:contextcheck
require.Equal(t, signaled, pair.Remote)
require.Equal(t, CandidateTypeHost, pair.Remote.Type())
updatedPair := agent.findPair(local, signaled)
require.NotNil(t, updatedPair)
require.NotSame(t, pair, updatedPair)
require.Equal(t, CandidateTypePeerReflexive, pair.Remote.Type())
require.Equal(t, signaled, updatedPair.Remote)
require.Equal(t, CandidateTypeHost, updatedPair.Remote.Type())
// Now the (updated) pair should be nominatable and become nominated.
sel.ContactCandidates()
require.NotNil(t, sel.nominatedPair)
require.Same(t, pair, sel.nominatedPair)
require.Same(t, updatedPair, sel.nominatedPair)
require.Equal(t, CandidateTypeHost, sel.nominatedPair.Local.Type())
require.Equal(t, CandidateTypeHost, sel.nominatedPair.Remote.Type())
require.True(t, pair.nominated)
require.True(t, updatedPair.nominated)
}))
})
+16 -8
View File
@@ -45,7 +45,7 @@ type candidateBase struct {
relayLocalPreference uint16
remoteCandidateCaches map[AddrPort]Candidate
remoteCandidateCaches sync.Map // map[AddrPort]Candidate
isLocationTracked bool
extensions []CandidateExtension
}
@@ -287,8 +287,13 @@ func (c *candidateBase) recvLoop(initializedCh <-chan struct{}) {
}
func (c *candidateBase) validateSTUNTrafficCache(addr net.Addr) bool {
if candidate, ok := c.remoteCandidateCaches[toAddrPort(addr)]; ok {
candidate.seen(false)
if candidate, ok := c.remoteCandidateCaches.Load(toAddrPort(addr)); ok {
remoteCandidate, ok := candidate.(Candidate)
if !ok {
return false
}
remoteCandidate.seen(false)
return true
}
@@ -300,15 +305,18 @@ func (c *candidateBase) addRemoteCandidateCache(candidate Candidate, srcAddr net
if c.validateSTUNTrafficCache(srcAddr) {
return
}
c.remoteCandidateCaches[toAddrPort(srcAddr)] = candidate
c.remoteCandidateCaches.Store(toAddrPort(srcAddr), candidate)
}
func (c *candidateBase) replaceRemoteCandidateCacheValues(oldRemote, newRemote Candidate) {
for k, v := range c.remoteCandidateCaches {
if v == oldRemote {
c.remoteCandidateCaches[k] = newRemote
c.remoteCandidateCaches.Range(func(key, value any) bool {
candidate, ok := value.(Candidate)
if ok && candidate == oldRemote {
c.remoteCandidateCaches.Store(key, newRemote)
}
}
return true
})
}
func (c *candidateBase) handleInboundPacket(buf []byte, srcAddr net.Addr) {
+9 -10
View File
@@ -38,16 +38,15 @@ func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) {
candidateHost := &CandidateHost{
candidateBase: candidateBase{
id: candidateID,
address: config.Address,
candidateType: CandidateTypeHost,
component: config.Component,
port: config.Port,
tcpType: config.TCPType,
foundationOverride: config.Foundation,
priorityOverride: config.Priority,
remoteCandidateCaches: map[AddrPort]Candidate{},
isLocationTracked: config.IsLocationTracked,
id: candidateID,
address: config.Address,
candidateType: CandidateTypeHost,
component: config.Component,
port: config.Port,
tcpType: config.TCPType,
foundationOverride: config.Foundation,
priorityOverride: config.Priority,
isLocationTracked: config.IsLocationTracked,
},
network: config.Network,
}
-1
View File
@@ -60,7 +60,6 @@ func NewCandidatePeerReflexive(config *CandidatePeerReflexiveConfig) (*Candidate
Address: config.RelAddr,
Port: config.RelPort,
},
remoteCandidateCaches: map[AddrPort]Candidate{},
},
}, nil
}
+1 -2
View File
@@ -78,8 +78,7 @@ func NewCandidateRelay(config *CandidateRelayConfig) (*CandidateRelay, error) {
Address: config.RelAddr,
Port: config.RelPort,
},
relayLocalPreference: relayProtocolPreference(config.RelayProtocol),
remoteCandidateCaches: map[AddrPort]Candidate{},
relayLocalPreference: relayProtocolPreference(config.RelayProtocol),
},
relayProtocol: config.RelayProtocol,
onClose: config.OnClose,
-1
View File
@@ -62,7 +62,6 @@ func NewCandidateServerReflexive(config *CandidateServerReflexiveConfig) (*Candi
Address: config.RelAddr,
Port: config.RelPort,
},
remoteCandidateCaches: map[AddrPort]Candidate{},
},
}, nil
}
+2 -1
View File
@@ -13,6 +13,7 @@ import (
"io"
"net"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@@ -124,7 +125,7 @@ func TestBindingRequestHandler(t *testing.T) {
for _, c := range controlledAgent.localCandidates[NetworkTypeUDP4] {
cast, ok := c.(*CandidateHost)
require.True(t, ok)
cast.remoteCandidateCaches = map[AddrPort]Candidate{}
cast.remoteCandidateCaches = sync.Map{}
}
controlledAgent.setSelectedPair(nil)