Refactor WebRTC handling to use per-peer broadcasters for video and audio tracks

This commit is contained in:
Cédric Verstraeten
2026-03-09 15:12:01 +00:00
parent ca0e426382
commit 1bf8006055
3 changed files with 237 additions and 52 deletions
+9 -7
View File
@@ -800,17 +800,19 @@ func HandleLiveStreamHD(livestreamCursor *packets.QueueCursor, configuration *mo
// Check if we need to enable the live stream
if config.Capture.Liveview != "false" {
// Should create a track here.
// Create per-peer broadcasters instead of shared tracks.
// Each viewer gets its own track with independent, non-blocking writes
// so a slow/congested peer cannot stall the others.
streams, _ := rtspClient.GetStreams()
videoTrack := webrtc.NewVideoTrack(streams)
audioTrack := webrtc.NewAudioTrack(streams)
videoBroadcaster := webrtc.NewVideoBroadcaster(streams)
audioBroadcaster := webrtc.NewAudioBroadcaster(streams)
if videoTrack == nil && audioTrack == nil {
log.Log.Error("cloud.HandleLiveStreamHD(): failed to create both video and audio tracks")
if videoBroadcaster == nil && audioBroadcaster == nil {
log.Log.Error("cloud.HandleLiveStreamHD(): failed to create both video and audio broadcasters")
return
}
go webrtc.WriteToTrack(livestreamCursor, configuration, communication, mqttClient, videoTrack, audioTrack, rtspClient)
go webrtc.WriteToTrack(livestreamCursor, configuration, communication, mqttClient, videoBroadcaster, audioBroadcaster, rtspClient)
if config.Capture.ForwardWebRTC == "true" {
@@ -818,7 +820,7 @@ func HandleLiveStreamHD(livestreamCursor *packets.QueueCursor, configuration *mo
log.Log.Info("cloud.HandleLiveStreamHD(): Waiting for peer connections.")
for handshake := range communication.HandleLiveHDHandshake {
log.Log.Info("cloud.HandleLiveStreamHD(): setting up a peer connection.")
go webrtc.InitializeWebRTCConnection(configuration, communication, mqttClient, videoTrack, audioTrack, handshake)
go webrtc.InitializeWebRTCConnection(configuration, communication, mqttClient, videoBroadcaster, audioBroadcaster, handshake)
}
}
+137
View File
@@ -0,0 +1,137 @@
package webrtc
import (
"io"
"sync"
"github.com/kerberos-io/agent/machinery/src/log"
pionWebRTC "github.com/pion/webrtc/v4"
pionMedia "github.com/pion/webrtc/v4/pkg/media"
)
const (
// peerSampleBuffer controls how many samples can be buffered per peer before
// dropping. Keeps slow peers from blocking the broadcaster.
peerSampleBuffer = 60
)
// peerTrack is a per-peer track with its own non-blocking sample channel.
type peerTrack struct {
track *pionWebRTC.TrackLocalStaticSample
samples chan pionMedia.Sample
done chan struct{}
}
// TrackBroadcaster fans out media samples to multiple peer-specific tracks
// without blocking. Each peer gets its own TrackLocalStaticSample and a
// goroutine that drains samples independently, so a slow/congested peer
// cannot stall the others.
type TrackBroadcaster struct {
mu sync.RWMutex
peers map[string]*peerTrack
mimeType string
id string
streamID string
}
// NewTrackBroadcaster creates a new broadcaster for either video or audio.
func NewTrackBroadcaster(mimeType string, id string, streamID string) *TrackBroadcaster {
return &TrackBroadcaster{
peers: make(map[string]*peerTrack),
mimeType: mimeType,
id: id,
streamID: streamID,
}
}
// AddPeer creates a new per-peer track and starts a writer goroutine.
// Returns the track to be added to the PeerConnection via AddTrack().
func (b *TrackBroadcaster) AddPeer(sessionKey string) (*pionWebRTC.TrackLocalStaticSample, error) {
track, err := pionWebRTC.NewTrackLocalStaticSample(
pionWebRTC.RTPCodecCapability{MimeType: b.mimeType},
b.id,
b.streamID,
)
if err != nil {
return nil, err
}
pt := &peerTrack{
track: track,
samples: make(chan pionMedia.Sample, peerSampleBuffer),
done: make(chan struct{}),
}
b.mu.Lock()
b.peers[sessionKey] = pt
b.mu.Unlock()
// Per-peer writer goroutine — drains samples independently.
go func() {
defer close(pt.done)
for sample := range pt.samples {
if err := pt.track.WriteSample(sample); err != nil {
if err == io.ErrClosedPipe {
return
}
log.Log.Error("webrtc.broadcaster.peerWriter(): error writing sample for " + sessionKey + ": " + err.Error())
}
}
}()
log.Log.Info("webrtc.broadcaster.AddPeer(): added peer track for " + sessionKey)
return track, nil
}
// RemovePeer stops the writer goroutine and removes the peer.
func (b *TrackBroadcaster) RemovePeer(sessionKey string) {
b.mu.Lock()
pt, exists := b.peers[sessionKey]
if exists {
delete(b.peers, sessionKey)
}
b.mu.Unlock()
if exists {
close(pt.samples)
<-pt.done // wait for writer goroutine to finish
log.Log.Info("webrtc.broadcaster.RemovePeer(): removed peer track for " + sessionKey)
}
}
// WriteSample fans out a sample to all connected peers without blocking.
// If a peer's buffer is full (slow consumer), the sample is dropped for
// that peer only — other peers are unaffected.
func (b *TrackBroadcaster) WriteSample(sample pionMedia.Sample) {
b.mu.RLock()
defer b.mu.RUnlock()
for sessionKey, pt := range b.peers {
select {
case pt.samples <- sample:
default:
log.Log.Warning("webrtc.broadcaster.WriteSample(): dropping sample for slow peer " + sessionKey)
}
}
}
// PeerCount returns the current number of connected peers.
func (b *TrackBroadcaster) PeerCount() int {
b.mu.RLock()
defer b.mu.RUnlock()
return len(b.peers)
}
// Close removes all peers and stops all writer goroutines.
func (b *TrackBroadcaster) Close() {
b.mu.Lock()
keys := make([]string, 0, len(b.peers))
for k := range b.peers {
keys = append(keys, k)
}
b.mu.Unlock()
for _, key := range keys {
b.RemovePeer(key)
}
}
+91 -45
View File
@@ -48,13 +48,16 @@ type ConnectionManager struct {
// peerConnectionWrapper wraps a peer connection with additional metadata
type peerConnectionWrapper struct {
conn *pionWebRTC.PeerConnection
cancelCtx context.CancelFunc
done chan struct{}
closeOnce sync.Once
connected atomic.Bool
disconnectMu sync.Mutex
disconnectTimer *time.Timer
conn *pionWebRTC.PeerConnection
cancelCtx context.CancelFunc
done chan struct{}
closeOnce sync.Once
connected atomic.Bool
disconnectMu sync.Mutex
disconnectTimer *time.Timer
sessionKey string
videoBroadcaster *TrackBroadcaster
audioBroadcaster *TrackBroadcaster
}
var globalConnectionManager = NewConnectionManager()
@@ -153,6 +156,15 @@ func cleanupPeerConnection(sessionKey string, wrapper *peerConnectionWrapper) {
log.Log.Info("webrtc.main.cleanupPeerConnection(): Peer disconnected. Active peers: " + strconv.FormatInt(count, 10))
}
// Remove per-peer tracks from broadcasters so the fan-out stops
// writing to this peer immediately.
if wrapper.videoBroadcaster != nil {
wrapper.videoBroadcaster.RemovePeer(sessionKey)
}
if wrapper.audioBroadcaster != nil {
wrapper.audioBroadcaster.RemovePeer(sessionKey)
}
globalConnectionManager.CloseCandidateChannel(sessionKey)
if wrapper.conn != nil {
@@ -239,7 +251,7 @@ func RegisterDefaultInterceptors(mediaEngine *pionWebRTC.MediaEngine, intercepto
return nil
}
func InitializeWebRTCConnection(configuration *models.Configuration, communication *models.Communication, mqttClient mqtt.Client, videoTrack *pionWebRTC.TrackLocalStaticSample, audioTrack *pionWebRTC.TrackLocalStaticSample, handshake models.RequestHDStreamPayload) {
func InitializeWebRTCConnection(configuration *models.Configuration, communication *models.Communication, mqttClient mqtt.Client, videoBroadcaster *TrackBroadcaster, audioBroadcaster *TrackBroadcaster, handshake models.RequestHDStreamPayload) {
config := configuration.Config
deviceKey := config.Key
@@ -319,14 +331,25 @@ func InitializeWebRTCConnection(configuration *models.Configuration, communicati
// Create context for this connection
ctx, cancel := context.WithCancel(context.Background())
wrapper := &peerConnectionWrapper{
conn: peerConnection,
cancelCtx: cancel,
done: make(chan struct{}),
conn: peerConnection,
cancelCtx: cancel,
done: make(chan struct{}),
sessionKey: sessionKey,
videoBroadcaster: videoBroadcaster,
audioBroadcaster: audioBroadcaster,
}
// Create a per-peer video track from the broadcaster so writes
// to this peer are independent and non-blocking.
var videoSender *pionWebRTC.RTPSender = nil
if videoTrack != nil {
if videoSender, err = peerConnection.AddTrack(videoTrack); err != nil {
if videoBroadcaster != nil {
peerVideoTrack, trackErr := videoBroadcaster.AddPeer(sessionKey)
if trackErr != nil {
log.Log.Error("webrtc.main.InitializeWebRTCConnection(): error creating per-peer video track: " + trackErr.Error())
cleanupPeerConnection(sessionKey, wrapper)
return
}
if videoSender, err = peerConnection.AddTrack(peerVideoTrack); err != nil {
log.Log.Error("webrtc.main.InitializeWebRTCConnection(): error adding video track: " + err.Error())
cleanupPeerConnection(sessionKey, wrapper)
return
@@ -357,9 +380,16 @@ func InitializeWebRTCConnection(configuration *models.Configuration, communicati
}()
}
// Create a per-peer audio track from the broadcaster.
var audioSender *pionWebRTC.RTPSender = nil
if audioTrack != nil {
if audioSender, err = peerConnection.AddTrack(audioTrack); err != nil {
if audioBroadcaster != nil {
peerAudioTrack, trackErr := audioBroadcaster.AddPeer(sessionKey)
if trackErr != nil {
log.Log.Error("webrtc.main.InitializeWebRTCConnection(): error creating per-peer audio track: " + trackErr.Error())
cleanupPeerConnection(sessionKey, wrapper)
return
}
if audioSender, err = peerConnection.AddTrack(peerAudioTrack); err != nil {
log.Log.Error("webrtc.main.InitializeWebRTCConnection(): error adding audio track: " + err.Error())
cleanupPeerConnection(sessionKey, wrapper)
return
@@ -598,6 +628,32 @@ func InitializeWebRTCConnection(configuration *models.Configuration, communicati
}
}
func NewVideoBroadcaster(streams []packets.Stream) *TrackBroadcaster {
// Verify H264 is available (same check as NewVideoTrack)
for _, s := range streams {
if s.Name == "H264" {
return NewTrackBroadcaster(pionWebRTC.MimeTypeH264, "video", trackStreamID)
}
}
log.Log.Error("webrtc.main.NewVideoBroadcaster(): no H264 stream found")
return nil
}
func NewAudioBroadcaster(streams []packets.Stream) *TrackBroadcaster {
for _, s := range streams {
switch s.Name {
case "OPUS":
return NewTrackBroadcaster(pionWebRTC.MimeTypeOpus, "audio", trackStreamID)
case "PCM_MULAW":
return NewTrackBroadcaster(pionWebRTC.MimeTypePCMU, "audio", trackStreamID)
case "PCM_ALAW":
return NewTrackBroadcaster(pionWebRTC.MimeTypePCMA, "audio", trackStreamID)
}
}
log.Log.Error("webrtc.main.NewAudioBroadcaster(): no supported audio codec found")
return nil
}
func NewVideoTrack(streams []packets.Stream) *pionWebRTC.TrackLocalStaticSample {
mimeType := pionWebRTC.MimeTypeH264
outboundVideoTrack, err := pionWebRTC.NewTrackLocalStaticSample(pionWebRTC.RTPCodecCapability{MimeType: mimeType}, "video", trackStreamID)
@@ -711,17 +767,13 @@ func updateStreamState(communication *models.Communication, state *streamState)
}
// writeFinalSamples writes any remaining buffered samples
func writeFinalSamples(state *streamState, videoTrack, audioTrack *pionWebRTC.TrackLocalStaticSample) {
if state.lastVideoSample != nil && videoTrack != nil {
if err := videoTrack.WriteSample(*state.lastVideoSample); err != nil && err != io.ErrClosedPipe {
log.Log.Error("webrtc.main.writeFinalSamples(): error writing final video sample: " + err.Error())
}
func writeFinalSamples(state *streamState, videoBroadcaster, audioBroadcaster *TrackBroadcaster) {
if state.lastVideoSample != nil && videoBroadcaster != nil {
videoBroadcaster.WriteSample(*state.lastVideoSample)
}
if state.lastAudioSample != nil && audioTrack != nil {
if err := audioTrack.WriteSample(*state.lastAudioSample); err != nil && err != io.ErrClosedPipe {
log.Log.Error("webrtc.main.writeFinalSamples(): error writing final audio sample: " + err.Error())
}
if state.lastAudioSample != nil && audioBroadcaster != nil {
audioBroadcaster.WriteSample(*state.lastAudioSample)
}
}
@@ -760,9 +812,9 @@ func sampleDuration(current packets.Packet, previousTimestamp uint32, fallback t
return fallback
}
// processVideoPacket processes a video packet and writes samples to the track
func processVideoPacket(pkt packets.Packet, state *streamState, videoTrack *pionWebRTC.TrackLocalStaticSample, config models.Config) {
if videoTrack == nil {
// processVideoPacket processes a video packet and writes samples to the broadcaster
func processVideoPacket(pkt packets.Packet, state *streamState, videoBroadcaster *TrackBroadcaster, config models.Config) {
if videoBroadcaster == nil {
return
}
@@ -785,18 +837,15 @@ func processVideoPacket(pkt packets.Packet, state *streamState, videoTrack *pion
if state.lastVideoSample != nil {
state.lastVideoSample.Duration = sampleDuration(pkt, state.lastVideoSample.PacketTimestamp, 33*time.Millisecond)
if err := videoTrack.WriteSample(*state.lastVideoSample); err != nil && err != io.ErrClosedPipe {
log.Log.Error("webrtc.main.processVideoPacket(): error writing video sample: " + err.Error())
}
videoBroadcaster.WriteSample(*state.lastVideoSample)
}
state.lastVideoSample = &sample
}
// processAudioPacket processes an audio packet and writes samples to the track
func processAudioPacket(pkt packets.Packet, state *streamState, audioTrack *pionWebRTC.TrackLocalStaticSample, hasAAC bool) {
if audioTrack == nil {
// processAudioPacket processes an audio packet and writes samples to the broadcaster
func processAudioPacket(pkt packets.Packet, state *streamState, audioBroadcaster *TrackBroadcaster, hasAAC bool) {
if audioBroadcaster == nil {
return
}
@@ -810,10 +859,7 @@ func processAudioPacket(pkt packets.Packet, state *streamState, audioTrack *pion
if state.lastAudioSample != nil {
state.lastAudioSample.Duration = sampleDuration(pkt, state.lastAudioSample.PacketTimestamp, 20*time.Millisecond)
if err := audioTrack.WriteSample(*state.lastAudioSample); err != nil && err != io.ErrClosedPipe {
log.Log.Error("webrtc.main.processAudioPacket(): error writing audio sample: " + err.Error())
}
audioBroadcaster.WriteSample(*state.lastAudioSample)
}
state.lastAudioSample = &sample
@@ -828,13 +874,13 @@ func shouldDropPacketForLatency(pkt packets.Packet) bool {
return age > maxLivePacketAge
}
func WriteToTrack(livestreamCursor *packets.QueueCursor, configuration *models.Configuration, communication *models.Communication, mqttClient mqtt.Client, videoTrack *pionWebRTC.TrackLocalStaticSample, audioTrack *pionWebRTC.TrackLocalStaticSample, rtspClient capture.RTSPClient) {
func WriteToTrack(livestreamCursor *packets.QueueCursor, configuration *models.Configuration, communication *models.Communication, mqttClient mqtt.Client, videoBroadcaster *TrackBroadcaster, audioBroadcaster *TrackBroadcaster, rtspClient capture.RTSPClient) {
config := configuration.Config
// Check if at least one track is available
if videoTrack == nil && audioTrack == nil {
log.Log.Error("webrtc.main.WriteToTrack(): both video and audio tracks are nil, cannot proceed")
// Check if at least one broadcaster is available
if videoBroadcaster == nil && audioBroadcaster == nil {
log.Log.Error("webrtc.main.WriteToTrack(): both video and audio broadcasters are nil, cannot proceed")
return
}
@@ -857,7 +903,7 @@ func WriteToTrack(livestreamCursor *packets.QueueCursor, configuration *models.C
}
defer func() {
writeFinalSamples(state, videoTrack, audioTrack)
writeFinalSamples(state, videoBroadcaster, audioBroadcaster)
log.Log.Info("webrtc.main.WriteToTrack(): stopped writing to track")
}()
@@ -923,9 +969,9 @@ func WriteToTrack(livestreamCursor *packets.QueueCursor, configuration *models.C
// Process video or audio packets
if pkt.IsVideo {
processVideoPacket(pkt, state, videoTrack, config)
processVideoPacket(pkt, state, videoBroadcaster, config)
} else if pkt.IsAudio {
processAudioPacket(pkt, state, audioTrack, codecs.hasAAC)
processAudioPacket(pkt, state, audioBroadcaster, codecs.hasAAC)
}
}
}