added some hosts and server sinks

This commit is contained in:
harshabose
2025-06-05 15:45:39 +05:30
parent 3657f4bc09
commit ce1cf615e2
11 changed files with 2306 additions and 47 deletions
+7 -2
View File
@@ -1,4 +1,4 @@
package interfaces
package mediasink
import (
"context"
@@ -9,6 +9,11 @@ import (
type Host interface {
Connect(context.Context)
Write(*rtp.Packet) error
WriteRTP(*rtp.Packet) error
Write([]byte) error
io.Closer
}
type CanCallBackPayload interface {
SetOnPayloadCallback(f func([]byte) error)
}
+303
View File
@@ -1 +1,304 @@
package loopback
import (
"context"
"fmt"
"net"
"os/exec"
"sync"
"time"
"github.com/pion/rtp"
)
type metrics struct {
DataSent int64
DataRecvd int64
DataSentRate float32
DataRecvdRate float32
mu sync.RWMutex
lastUpdate time.Time
lastSent int64
lastRecvd int64
}
func (m *metrics) updateRates() {
m.mu.Lock()
defer m.mu.Unlock()
now := time.Now()
if !m.lastUpdate.IsZero() {
duration := now.Sub(m.lastUpdate).Seconds()
if duration > 0 {
m.DataSentRate = float32(m.DataSent-m.lastSent) / float32(duration)
m.DataRecvdRate = float32(m.DataRecvd-m.lastRecvd) / float32(duration)
}
}
m.lastUpdate = now
m.lastSent = m.DataSent
m.lastRecvd = m.DataRecvd
}
func (m *metrics) addSent(bytes int64) {
m.mu.Lock()
defer m.mu.Unlock()
m.DataSent += bytes
}
func (m *metrics) addRecvd(bytes int64) {
m.mu.Lock()
defer m.mu.Unlock()
m.DataRecvd += bytes
}
func (m *metrics) GetStats() (sent, recvd int64, sentRate, recvdRate float32) {
m.mu.RLock()
defer m.mu.RUnlock()
return m.DataSent, m.DataRecvd, m.DataSentRate, m.DataRecvdRate
}
type OnMessage = func([]byte) error
type LoopBack struct {
bindPort *net.UDPConn
remote *net.UDPAddr
f OnMessage
metrics metrics
cmd *exec.Cmd
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.RWMutex
}
// NewLoopBack creates a new LoopBack instance
func NewLoopBack(options ...Option) (*LoopBack, error) {
l := &LoopBack{
metrics: metrics{lastUpdate: time.Now()},
}
for _, option := range options {
if err := option(l); err != nil {
return nil, err
}
}
return l, nil
}
func (l *LoopBack) Connect(ctx context.Context) {
l.mu.Lock()
defer l.mu.Unlock()
if l.cancel != nil {
l.cancel()
}
l.ctx, l.cancel = context.WithCancel(ctx)
// Start reading goroutine
l.wg.Add(1)
go l.readLoop()
// Start metrics update goroutine
l.wg.Add(1)
go l.metricsLoop()
fmt.Printf("LoopBack connected on %s\n", l.bindPort.LocalAddr().String())
}
func (l *LoopBack) WriteRTP(packet *rtp.Packet) error {
if packet == nil {
return fmt.Errorf("packet cannot be nil")
}
payload, err := packet.Marshal()
if err != nil {
return fmt.Errorf("failed to marshal RTP packet: %v", err)
}
return l.Write(payload)
}
func (l *LoopBack) Write(payload []byte) error {
l.mu.RLock()
defer l.mu.RUnlock()
if l.bindPort == nil {
return fmt.Errorf("bind port not yet set. Skipping message")
}
if l.remote == nil {
return fmt.Errorf("remote port not yet discovered. Skipping message")
}
bytesWritten, err := l.bindPort.WriteToUDP(payload, l.remote)
if err != nil {
return fmt.Errorf("failed to write UDP message: %v", err)
}
if bytesWritten != len(payload) {
return fmt.Errorf("written bytes (%d) != message length (%d)", bytesWritten, len(payload))
}
// Update metrics
l.metrics.addSent(int64(bytesWritten))
return nil
}
func (l *LoopBack) Close() error {
l.mu.Lock()
defer l.mu.Unlock()
fmt.Println("Closing LoopBack...")
// Cancel context to stop goroutines
if l.cancel != nil {
l.cancel()
}
// Close UDP connection
var err error
if l.bindPort != nil {
err = l.bindPort.Close()
l.bindPort = nil
}
// Wait for goroutines to finish
l.wg.Wait()
fmt.Println("LoopBack closed")
return err
}
func (l *LoopBack) readLoop() {
defer l.wg.Done()
fmt.Println("Starting LoopBack read loop")
for {
select {
case <-l.ctx.Done():
fmt.Println("LoopBack read loop stopped")
return
default:
// Set read timeout to allow periodic context checking
l.bindPort.SetReadDeadline(time.Now().Add(1 * time.Second))
buffer, nRead := l.readMessageFromUDPPort()
if nRead > 0 && buffer != nil {
// Update metrics
l.metrics.addRecvd(int64(nRead))
// Process message
if l.f != nil {
if err := l.f(buffer[:nRead]); err != nil {
fmt.Printf("Error processing message: %v\n", err)
}
}
}
}
}
}
func (l *LoopBack) metricsLoop() {
defer l.wg.Done()
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-l.ctx.Done():
return
case <-ticker.C:
l.metrics.updateRates()
}
}
}
func (l *LoopBack) readMessageFromUDPPort() ([]byte, int) {
buffer := make([]byte, 1500) // Standard MTU size
nRead, senderAddr, err := l.bindPort.ReadFromUDP(buffer)
if err != nil {
// Check if it's a timeout (which is expected)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return nil, 0
}
fmt.Printf("Error while reading message from bind port: %v\n", err)
return nil, 0
}
// Auto-discover remote address from first received packet
l.mu.Lock()
if l.remote == nil {
l.remote = &net.UDPAddr{IP: senderAddr.IP, Port: senderAddr.Port}
fmt.Printf("Auto-discovered remote address: %s\n", l.remote.String())
}
l.mu.Unlock()
// Validate sender (optional security check)
if senderAddr != nil && l.remote != nil && senderAddr.Port != l.remote.Port {
fmt.Printf("Warning: expected port %d but got %d\n", l.remote.Port, senderAddr.Port)
}
return buffer, nRead
}
// SetRemoteAddress manually sets the remote address (alternative to auto-discovery)
func (l *LoopBack) SetRemoteAddress(address string) error {
addr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
return fmt.Errorf("failed to resolve remote address: %v", err)
}
l.mu.Lock()
defer l.mu.Unlock()
l.remote = addr
fmt.Printf("Remote address set to: %s\n", addr.String())
return nil
}
// GetMetrics returns current metrics
func (l *LoopBack) GetMetrics() (sent, recvd int64, sentRate, recvdRate float32) {
return l.metrics.GetStats()
}
// GetLocalAddress returns the local binding address
func (l *LoopBack) GetLocalAddress() string {
l.mu.RLock()
defer l.mu.RUnlock()
if l.bindPort == nil {
return ""
}
return l.bindPort.LocalAddr().String()
}
// GetRemoteAddress returns the current remote address
func (l *LoopBack) GetRemoteAddress() string {
l.mu.RLock()
defer l.mu.RUnlock()
if l.remote == nil {
return ""
}
return l.remote.String()
}
// IsConnected returns true if the LoopBack is actively running
func (l *LoopBack) IsConnected() bool {
l.mu.RLock()
defer l.mu.RUnlock()
return l.ctx != nil && l.ctx.Err() == nil
}
func (l *LoopBack) SetOnPayloadCallback(f OnMessage) {
l.f = f
}
+70
View File
@@ -1 +1,71 @@
package loopback
import (
"errors"
"fmt"
"net"
"os"
"os/exec"
)
type Option = func(*LoopBack) error
func WithRandomBindPort(loopback *LoopBack) error {
var err error
if loopback.bindPort, err = net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}); err != nil {
return err
}
return nil
}
func WithBindPort(port int) Option {
return func(loopback *LoopBack) error {
var err error
if loopback.bindPort, err = net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: port}); err != nil {
return err
}
return nil
}
}
func WithLoopBackPort(port int) Option {
return func(loopback *LoopBack) error {
loopback.remote = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: port}
return nil
}
}
func WithCallback(f OnMessage) Option {
return func(loopback *LoopBack) error {
loopback.SetOnPayloadCallback(f)
return nil
}
}
func WithMAVProxy(path string, deviceStr string) Option {
return func(loopback *LoopBack) error {
if loopback.bindPort == nil {
return errors.New("bindPortConn not initialized, call WithBindPort or WithRandomBindPort first")
}
port := loopback.bindPort.LocalAddr().(*net.UDPAddr).Port
// MAVProxy uses --master for the connection string, and --out for output connections
// Format depends on the device: could be a serial port or network address
args := []string{
"--master", deviceStr,
"--out", fmt.Sprintf("udpout:127.0.0.1:%d", port),
"--daemon",
}
cmd := exec.Command(path, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
loopback.cmd = cmd
return nil
}
}
+397
View File
@@ -1 +1,398 @@
package rtsp
import (
"context"
"errors"
"fmt"
"net/url"
"sync"
"time"
"github.com/bluenviron/gortsplib/v4"
"github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/pion/rtp"
)
type HostState int
const (
HostStateDisconnected HostState = iota
HostStateConnecting
HostStateConnected
HostStateRecording
HostStateError
)
type NewHost struct {
// Connection details
serverAddr string
streamPath string
rtspURL string
// RTSP client
client *gortsplib.Client
description *description.Session
// State management
state HostState
stateMux sync.RWMutex
ctx context.Context
cancel context.CancelFunc
// Configuration
config *HostConfig
// Error handling
lastError error
errorMux sync.RWMutex
// Reconnection
reconnectChan chan struct{}
// Metrics
packetsWritten uint64
bytesWritten uint64
lastWriteTime time.Time
}
type HostConfig struct {
ServerAddr string
ServerPort int
StreamPath string
ReadTimeout time.Duration
WriteTimeout time.Duration
DialTimeout time.Duration
ReconnectAttempts int
ReconnectDelay time.Duration
UserAgent string
}
func LocalHostConfig() *HostConfig {
return &HostConfig{
ServerAddr: "localhost",
ServerPort: 8554,
StreamPath: "stream",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
DialTimeout: 10 * time.Second,
ReconnectAttempts: 3,
ReconnectDelay: 2 * time.Second,
UserAgent: "GoRTSP-Host/1.0",
}
}
func SimpleSkylineSonataHostConfig() *HostConfig {
return nil
}
func NewNewHost(config *HostConfig, des *description.Session, options ...HostOption) (*NewHost, error) {
if config == nil {
config = LocalHostConfig()
}
if des == nil {
des = &description.Session{}
}
host := &NewHost{
serverAddr: config.ServerAddr,
streamPath: config.StreamPath,
rtspURL: fmt.Sprintf("rtsp://%s:%d/%s", config.ServerAddr, config.ServerPort, config.StreamPath),
description: des,
config: config,
state: HostStateDisconnected,
reconnectChan: make(chan struct{}, 1),
}
// Apply options
for _, option := range options {
if err := option(host); err != nil {
return nil, fmt.Errorf("failed to apply option: %w", err)
}
}
// Validate des after options
if len(host.description.Medias) == 0 {
return nil, errors.New("options do not set media options")
}
fmt.Printf("Host configured for RTSP URL: %s\n", host.rtspURL)
return host, nil
}
func (h *NewHost) Connect(ctx context.Context) {
h.ctx, h.cancel = context.WithCancel(ctx)
h.setState(HostStateConnecting)
fmt.Printf("Connecting to RTSP server: %s with persistent connection management\n", h.rtspURL)
// This blocks and manages connection for the lifetime of the context
h.connectionManager()
}
func (h *NewHost) connectionManager() {
currentDelay := h.config.ReconnectDelay
for {
// Check if we should stop
select {
case <-h.ctx.Done():
fmt.Printf("Connection manager stopping due to context cancellation\n")
h.setState(HostStateDisconnected)
return
default:
}
// Attempt connection
fmt.Printf("Attempting to connect to RTSP server...\n")
if err := h.attemptConnection(); err != nil {
h.setState(HostStateError)
h.setError(err)
fmt.Printf("Connection failed: %v\n", err)
// Check if we should retry
if h.config.ReconnectAttempts == 0 {
// No retries configured, exit
fmt.Printf("No retries configured, stopping connection attempts\n")
return
}
// For ReconnectAttempts == -1 (infinite) or > 0 (limited), continue retrying
// Note: We don't track attempt count for infinite retries
// TODO: ADD COUNTER
fmt.Printf("Retrying connection in %v...\n", currentDelay)
select {
case <-h.ctx.Done():
fmt.Printf("Connection manager stopping during retry delay\n")
h.setState(HostStateDisconnected)
return
case <-time.After(currentDelay):
// Continue to next attempt
}
// Exponential backoff (cap at 30 seconds)
currentDelay = time.Duration(float64(currentDelay) * 1.5)
if currentDelay > 30*time.Second {
currentDelay = 30 * time.Second
}
continue
}
// Connection successful
h.setState(HostStateRecording)
fmt.Printf("Successfully connected and recording to RTSP server\n")
currentDelay = h.config.ReconnectDelay // Reset delay
h.monitorConnection()
select {
case <-h.ctx.Done():
fmt.Printf("Connection manager stopping - context cancelled\n")
h.setState(HostStateDisconnected)
return
default:
// Connection was lost, loop back to retry
fmt.Printf("Connection lost, attempting to reconnect...\n")
h.setState(HostStateConnecting)
}
}
}
func (h *NewHost) monitorConnection() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-h.ctx.Done():
return
case <-ticker.C:
// Check connection health
if h.client == nil {
fmt.Printf("Client is nil, connection lost\n")
return
}
if time.Since(h.lastWriteTime) > 60*time.Second {
fmt.Printf("No write activity for 60s, connection might be stale\n")
// Don't return here, just log - let write errors trigger reconnection
}
// Additional health checks could go here
// For now, write errors are to detect connection issues
}
}
}
func (h *NewHost) attemptConnection() error {
// Create new client for each attempt
h.client = &gortsplib.Client{
ReadTimeout: h.config.ReadTimeout,
WriteTimeout: h.config.WriteTimeout,
UserAgent: h.config.UserAgent,
}
// Parse RTSP URL
parsedURL, err := url.Parse(h.rtspURL)
if err != nil {
return fmt.Errorf("invalid RTSP URL: %w", err)
}
// Start recording (ANNOUNCE, SETUP, and RECORD)
if err := h.client.StartRecording(parsedURL.String(), h.description); err != nil {
return fmt.Errorf("failed to start recording: %w", err)
}
h.setState(HostStateRecording)
h.lastWriteTime = time.Now()
return nil
}
func (h *NewHost) connectionMonitor() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-h.ctx.Done():
return
case <-ticker.C:
// Check connection health
if h.getState() == HostStateRecording {
if time.Since(h.lastWriteTime) > 30*time.Second {
fmt.Printf("No write activity for 30s, connection might be stale\n")
}
}
case <-h.reconnectChan:
// Reconnection requested
if h.getState() == HostStateError {
fmt.Printf("Attempting reconnection...\n")
if err := h.attemptConnection(); err != nil {
fmt.Printf("Reconnection failed: %v\n", err)
} else {
fmt.Printf("Reconnection successful\n")
}
}
}
}
}
func (h *NewHost) WriteRTP(packet *rtp.Packet) error {
if packet == nil {
fmt.Println("Nil packet, not sending to server")
return nil
}
state := h.getState()
if state != HostStateRecording {
return fmt.Errorf("cannot write: host state is %v, expected %v", state, HostStateRecording)
}
if h.client == nil {
return errors.New("client not initialized")
}
// Validate and fix payload type if needed
if len(h.description.Medias) == 0 || len(h.description.Medias[0].Formats) == 0 {
return errors.New("no media formats available")
}
expectedPayloadType := h.description.Medias[0].Formats[0].PayloadType()
if packet.PayloadType != expectedPayloadType {
fmt.Printf("Fixing payload type: expected %d, got %d\n", expectedPayloadType, packet.PayloadType)
packet.PayloadType = expectedPayloadType
}
// Write packet to server
if err := h.client.WritePacketRTP(h.description.Medias[0], packet); err != nil {
h.setState(HostStateError)
h.setError(err)
// Trigger reconnection attempt
select {
case h.reconnectChan <- struct{}{}:
default:
// Channel full, skip
}
return fmt.Errorf("failed to write RTP packet: %w", err)
}
// Update metrics
h.packetsWritten++
h.bytesWritten += uint64(len(packet.Payload))
h.lastWriteTime = time.Now()
return nil
}
func (h *NewHost) Write(_ []byte) error {
return nil
}
func (h *NewHost) Close() error {
fmt.Printf("Closing RTSP host connection\n")
// Cancel context
if h.cancel != nil {
h.cancel()
}
// Close client
if h.client != nil {
h.client.Close()
}
h.setState(HostStateDisconnected)
fmt.Printf("RTSP host connection closed\n")
return nil
}
// State management methods
func (h *NewHost) getState() HostState {
h.stateMux.RLock()
defer h.stateMux.RUnlock()
return h.state
}
func (h *NewHost) setState(state HostState) {
h.stateMux.Lock()
defer h.stateMux.Unlock()
h.state = state
}
func (h *NewHost) setError(err error) {
h.errorMux.Lock()
defer h.errorMux.Unlock()
h.lastError = err
}
func (h *NewHost) GetLastError() error {
h.errorMux.RLock()
defer h.errorMux.RUnlock()
return h.lastError
}
// Status methods
func (h *NewHost) IsConnected() bool {
state := h.getState()
return state == HostStateConnected || state == HostStateRecording
}
func (h *NewHost) IsRecording() bool {
return h.getState() == HostStateRecording
}
func (h *NewHost) GetStats() (uint64, uint64, time.Time) {
return h.packetsWritten, h.bytesWritten, h.lastWriteTime
}
func (h *NewHost) GetRTSPURL() string {
return h.rtspURL
}
func (h *NewHost) AppendRTSPMediaDescription(media *description.Media) {
h.description.Medias = append(h.description.Medias, media)
}
+582
View File
@@ -1 +1,583 @@
package rtsp
import (
"context"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/bluenviron/gortsplib/v4"
"github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/pion/rtp"
)
type ClientSession struct {
ID string
Session *gortsplib.ServerSession
RemoteAddr string
IsLocal bool
ConnTime time.Time
LastActive time.Time
}
type StreamInfo struct {
Stream *gortsplib.ServerStream
Publisher *gortsplib.ServerSession
PublisherLastActive time.Time
Description *description.Session
Clients map[string]*ClientSession
CreatedAt time.Time
mutex sync.RWMutex
}
func (si *StreamInfo) AddClient(client *ClientSession) {
si.mutex.Lock()
defer si.mutex.Unlock()
si.Clients[client.ID] = client
}
func (si *StreamInfo) RemoveClient(clientID string) {
si.mutex.Lock()
defer si.mutex.Unlock()
delete(si.Clients, clientID)
}
func (si *StreamInfo) GetClientCount() int {
si.mutex.RLock()
defer si.mutex.RUnlock()
return len(si.Clients)
}
func (si *StreamInfo) GetClients() []*ClientSession {
si.mutex.RLock()
defer si.mutex.RUnlock()
clients := make([]*ClientSession, 0, len(si.Clients))
for _, client := range si.Clients {
clients = append(clients, client)
}
return clients
}
type NewServer struct {
server *gortsplib.Server
config *Config
// Stream management
streams map[string]*StreamInfo
mutex sync.RWMutex
// Context management
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
// Metrics
// TODO: SHIFT TO METRICS STRUCT AND EXPOSE A SIMPLE API ENDPOINT
totalConnections int64
}
type Config struct {
Port int
MaxClients int
MaxStreams int
ReadTimeout time.Duration
WriteTimeout time.Duration
PublisherSessionTimeout time.Duration
ClientSessionTimeout time.Duration
AllowLocalOnly bool
UDPRTPAddress string
UDPRTCPAddress string
MulticastIPRange string
MulticastRTPPort int
MulticastRTCPPort int
WriteQueueSize int
}
func DefaultConfig() *Config {
return &Config{
Port: 8554,
MaxClients: 100,
MaxStreams: 10,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
PublisherSessionTimeout: 60 * time.Second,
ClientSessionTimeout: 60 * time.Second,
AllowLocalOnly: false,
UDPRTPAddress: ":8000",
UDPRTCPAddress: ":8001",
MulticastIPRange: "224.1.0.0/16",
MulticastRTPPort: 8002,
MulticastRTCPPort: 8003,
WriteQueueSize: 4096,
}
}
func NewNewServer(config *Config) *NewServer {
if config == nil {
config = DefaultConfig()
}
server := &NewServer{
config: config,
streams: make(map[string]*StreamInfo),
}
server.server = &gortsplib.Server{
Handler: server,
RTSPAddress: fmt.Sprintf("0.0.0.0:%d", config.Port),
ReadTimeout: config.ReadTimeout,
WriteTimeout: config.WriteTimeout,
WriteQueueSize: config.WriteQueueSize,
UDPRTPAddress: config.UDPRTPAddress,
UDPRTCPAddress: config.UDPRTCPAddress,
MulticastIPRange: config.MulticastIPRange, // NOTE: NOT NEEDED
MulticastRTPPort: config.MulticastRTPPort, // NOTE: NOT NEEDED
MulticastRTCPPort: config.MulticastRTCPPort, // NOTE: NOT NEEDED
}
return server
}
func (s *NewServer) Start(ctx context.Context) {
s.setSession(ctx)
fmt.Printf("Starting RTSP server on port %d\n", s.config.Port)
if err := s.server.Start(); err != nil {
fmt.Println("failed to start RTSP server:", err.Error())
}
// Start background tasks
s.wg.Add(2)
go s.cleanupRoutine()
go s.metricsRoutine()
fmt.Println("RTSP server started successfully")
}
func (s *NewServer) setSession(ctx context.Context) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.cancel != nil {
s.cancel()
}
s.ctx, s.cancel = context.WithCancel(ctx)
}
func (s *NewServer) Close() error {
fmt.Println("Stopping RTSP server...")
// Close all streams
s.mutex.Lock()
s.cancel()
for path, streamInfo := range s.streams {
if streamInfo.Stream != nil {
streamInfo.Stream.Close()
}
if streamInfo.Publisher != nil {
streamInfo.Publisher.Close()
}
fmt.Printf("Closed stream: %s\n", path)
}
s.streams = make(map[string]*StreamInfo)
s.mutex.Unlock()
// Stop server
s.server.Close()
// Wait for background routines
s.wg.Wait()
s.totalConnections = 0
fmt.Println("RTSP server stopped")
return nil
}
func (s *NewServer) cleanupRoutine() {
defer s.wg.Done()
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
s.cleanupInactiveSessions()
}
}
}
func (s *NewServer) metricsRoutine() {
defer s.wg.Done()
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
s.logMetrics()
}
}
}
func (s *NewServer) cleanupInactiveSessions() {
s.mutex.Lock()
defer s.mutex.Unlock()
now := time.Now()
for path, streamInfo := range s.streams {
streamInfo.mutex.Lock()
publisherInactive := now.Sub(streamInfo.PublisherLastActive) > s.config.PublisherSessionTimeout
if publisherInactive {
// Publisher inactive - close stream
streamInfo.mutex.Unlock()
fmt.Printf("Publisher inactive for stream %s, closing stream\n", path)
if streamInfo.Stream != nil {
streamInfo.Stream.Close()
}
if streamInfo.Publisher != nil {
streamInfo.Publisher.Close()
}
delete(s.streams, path)
continue
}
toRemove := make([]string, 0)
for clientID, client := range streamInfo.Clients {
if now.Sub(client.LastActive) > s.config.ClientSessionTimeout {
toRemove = append(toRemove, clientID)
}
}
for _, clientID := range toRemove {
delete(streamInfo.Clients, clientID)
fmt.Printf("Removed inactive client %s from stream %s\n", clientID, path)
}
streamInfo.mutex.Unlock()
}
}
func (s *NewServer) logMetrics() {
s.mutex.RLock()
defer s.mutex.RUnlock()
totalClients := 0
for path, streamInfo := range s.streams {
clientCount := streamInfo.GetClientCount()
totalClients += clientCount
fmt.Printf("Stream %s: %d clients\n", path, clientCount)
}
fmt.Printf("Total streams: %d, Total clients: %d\n", len(s.streams), totalClients)
}
func (s *NewServer) isLocalhost(remoteAddr string) bool {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return false
}
ip := net.ParseIP(host)
if ip == nil {
return strings.ToLower(host) == "localhost"
}
return ip.IsLoopback()
}
func (s *NewServer) validateConnection(remoteAddr string) error {
if s.config.AllowLocalOnly && !s.isLocalhost(remoteAddr) {
return fmt.Errorf("only localhost connections allowed")
}
// Check max clients across all streams
s.mutex.RLock()
totalClients := 0
for _, streamInfo := range s.streams {
totalClients += streamInfo.GetClientCount()
}
s.mutex.RUnlock()
if totalClients >= s.config.MaxClients {
return fmt.Errorf("maximum client limit reached")
}
return nil
}
// OnConnOpen is called when a publisher/client completes the TCP handshake
func (s *NewServer) OnConnOpen(ctx *gortsplib.ServerHandlerOnConnOpenCtx) {
if err := s.validateConnection(ctx.Conn.NetConn().RemoteAddr().String()); err != nil {
fmt.Printf("Connection rejected: %v\n", err)
ctx.Conn.Close()
return
}
// TODO: MAYBE REMOVE totalConnections
s.totalConnections++
fmt.Printf("Connection opened from %s (total: %d)\n", ctx.Conn.NetConn().RemoteAddr(), s.totalConnections)
}
// OnConnClose is called when a publisher/client disconnects TCP
func (s *NewServer) OnConnClose(ctx *gortsplib.ServerHandlerOnConnCloseCtx) {
s.totalConnections--
fmt.Printf("Connection closed from %s: %v\n", ctx.Conn.NetConn().RemoteAddr(), ctx.Error)
}
// OnSessionOpen is called after OnConnOpen and indicates RTSP session start.
func (s *NewServer) OnSessionOpen(ctx *gortsplib.ServerHandlerOnSessionOpenCtx) {
clientID := fmt.Sprintf("%s-%d", ctx.Conn.NetConn().RemoteAddr(), time.Now().UnixNano())
fmt.Printf("Session opened: %s from %s\n", clientID, ctx.Conn.NetConn().RemoteAddr())
// Store client info in session
ctx.Session.SetUserData(map[string]interface{}{
"clientID": clientID,
"remoteAddr": ctx.Conn.NetConn().RemoteAddr().String(),
"isLocal": s.isLocalhost(ctx.Conn.NetConn().RemoteAddr().String()),
"connTime": time.Now(),
})
}
// OnSessionClose is called after OnConnClose and indicates RTSP session close.
func (s *NewServer) OnSessionClose(ctx *gortsplib.ServerHandlerOnSessionCloseCtx) {
userData := ctx.Session.UserData()
if userData == nil {
return
}
userMap, ok := userData.(map[string]interface{})
if !ok {
return
}
clientID, _ := userMap["clientID"].(string)
fmt.Printf("Session closed: %s\n", clientID)
s.mutex.Lock()
defer s.mutex.Unlock()
// Remove client from all streams
for path, streamInfo := range s.streams {
streamInfo.RemoveClient(clientID)
// If this was the publisher, close the stream
if streamInfo.Publisher == ctx.Session {
if streamInfo.Stream != nil {
streamInfo.Stream.Close()
}
delete(s.streams, path)
fmt.Printf("Publisher disconnected, stream %s closed\n", path)
}
}
}
func (s *NewServer) OnDescribe(ctx *gortsplib.ServerHandlerOnDescribeCtx) (*base.Response, *gortsplib.ServerStream, error) {
path := ctx.Path
fmt.Printf("Describe request for path: %s from %s\n", path, ctx.Conn.NetConn().RemoteAddr())
s.mutex.RLock()
streamInfo, exists := s.streams[path]
s.mutex.RUnlock()
if !exists || streamInfo.Stream == nil {
fmt.Printf("Stream not found: %s\n", path)
return &base.Response{
StatusCode: base.StatusNotFound,
}, nil, nil
}
return &base.Response{
StatusCode: base.StatusOK,
}, streamInfo.Stream, nil
}
func (s *NewServer) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (*base.Response, error) {
path := ctx.Path
fmt.Printf("Announce request for path: %s from %s\n", path, ctx.Conn.NetConn().RemoteAddr())
s.mutex.Lock()
defer s.mutex.Unlock()
// Check max streams limit
if len(s.streams) >= s.config.MaxStreams {
fmt.Println("Maximum stream limit reached")
return &base.Response{
StatusCode: base.StatusServiceUnavailable,
}, nil
}
// Close existing stream if it exists
if existingStream, exists := s.streams[path]; exists {
if existingStream.Stream != nil {
existingStream.Stream.Close()
}
if existingStream.Publisher != nil {
existingStream.Publisher.Close()
}
fmt.Printf("Replaced existing stream: %s\n", path)
}
// Create new stream
stream := gortsplib.NewServerStream(s.server, ctx.Description)
streamInfo := &StreamInfo{
Stream: stream,
Publisher: ctx.Session,
PublisherLastActive: time.Now(),
Description: ctx.Description,
Clients: make(map[string]*ClientSession),
CreatedAt: time.Now(),
}
s.streams[path] = streamInfo
fmt.Printf("Stream created: %s\n", path)
return &base.Response{
StatusCode: base.StatusOK,
}, nil
}
func (s *NewServer) OnSetup(ctx *gortsplib.ServerHandlerOnSetupCtx) (*base.Response, *gortsplib.ServerStream, error) {
path := ctx.Path
fmt.Printf("Setup request for path: %s from %s\n", path, ctx.Conn.NetConn().RemoteAddr())
s.mutex.RLock()
streamInfo, exists := s.streams[path]
s.mutex.RUnlock()
if !exists || streamInfo.Stream == nil {
fmt.Printf("Stream not found for setup: %s\n", path)
return &base.Response{
StatusCode: base.StatusNotFound,
}, nil, nil
}
isPublisher := streamInfo.Publisher == ctx.Session
if !isPublisher {
// Add client to stream
userData := ctx.Session.UserData()
if userData != nil {
if userMap, ok := userData.(map[string]interface{}); ok {
clientID, _ := userMap["clientID"].(string)
remoteAddr, _ := userMap["remoteAddr"].(string)
isLocal, _ := userMap["isLocal"].(bool)
connTime, _ := userMap["connTime"].(time.Time)
client := &ClientSession{
ID: clientID,
Session: ctx.Session,
RemoteAddr: remoteAddr,
IsLocal: isLocal,
ConnTime: connTime,
LastActive: time.Now(),
}
streamInfo.AddClient(client)
fmt.Printf("Client %s added to stream %s\n", clientID, path)
}
}
}
return &base.Response{
StatusCode: base.StatusOK,
}, streamInfo.Stream, nil
}
func (s *NewServer) OnPlay(ctx *gortsplib.ServerHandlerOnPlayCtx) (*base.Response, error) {
path := ctx.Path
fmt.Printf("Play request for path: %s from %s\n", path, ctx.Conn.NetConn().RemoteAddr())
// Update client's last active time
s.mutex.RLock()
if streamInfo, exists := s.streams[path]; exists {
userData := ctx.Session.UserData()
if userData != nil {
if userMap, ok := userData.(map[string]interface{}); ok {
clientID, _ := userMap["clientID"].(string)
streamInfo.mutex.Lock()
if client, exists := streamInfo.Clients[clientID]; exists {
client.LastActive = time.Now()
}
streamInfo.mutex.Unlock()
}
}
}
s.mutex.RUnlock()
return &base.Response{
StatusCode: base.StatusOK,
}, nil
}
func (s *NewServer) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*base.Response, error) {
path := ctx.Path
fmt.Printf("Record request for path: %s from %s\n", path, ctx.Conn.NetConn().RemoteAddr())
s.mutex.RLock()
streamInfo, exists := s.streams[path]
s.mutex.RUnlock()
if !exists || streamInfo.Stream == nil {
fmt.Printf("Stream not found for record: %s\n", path)
return &base.Response{
StatusCode: base.StatusNotFound,
}, nil
}
// Set up packet handling
ctx.Session.OnPacketRTPAny(func(media *description.Media, format format.Format, pkt *rtp.Packet) {
if err := streamInfo.Stream.WritePacketRTP(media, pkt); err != nil {
fmt.Printf("Error writing RTP packet to stream %s: %v\n", path, err)
}
// TODO: CONSIDER NOT ADDING PUBLISHER TO CLIENTS LIST AS THEIR LIFELINE NEEDS TO BE SEPARATE FROM OTHER CLIENTS
streamInfo.mutex.Lock()
now := time.Now()
streamInfo.PublisherLastActive = now
for _, client := range streamInfo.Clients {
client.LastActive = now
}
streamInfo.mutex.Unlock()
})
return &base.Response{
StatusCode: base.StatusOK,
}, nil
}
// GetStreamInfo returns information about a specific stream
func (s *NewServer) GetStreamInfo(path string) (*StreamInfo, bool) {
s.mutex.RLock()
defer s.mutex.RUnlock()
streamInfo, exists := s.streams[path]
return streamInfo, exists
}
// GetAllStreams returns information about all active streams
func (s *NewServer) GetAllStreams() map[string]*StreamInfo {
s.mutex.RLock()
defer s.mutex.RUnlock()
result := make(map[string]*StreamInfo)
for path, streamInfo := range s.streams {
result[path] = streamInfo
}
return result
}
+1 -1
View File
@@ -1,4 +1,4 @@
package interfaces
package mediasink
import (
"context"
+202
View File
@@ -1 +1,203 @@
package socket
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/coder/websocket"
"github.com/pion/rtp"
)
type HostConfig struct {
WriteTimeout time.Duration
ReadTimout time.Duration
ConnectRetry bool
}
type OnServerMessage = func(msg []byte) error
type Host struct {
addr string
port uint16
server *websocket.Conn
config HostConfig
f OnServerMessage
mux sync.RWMutex
ctx context.Context
cancel context.CancelFunc
}
func NewHost(addr string, port uint16, config HostConfig, f OnServerMessage) *Host {
return &Host{
addr: addr,
port: port,
config: config,
f: f,
}
}
func (h *Host) Connect(ctx context.Context) {
h.setSession(ctx)
defer h.close()
loop:
for {
select {
case <-h.ctx.Done():
return
default:
url := fmt.Sprintf("ws://%s:%d/ws", h.addr, h.port)
conn, _, err := websocket.Dial(h.ctx, url, nil)
if err != nil {
fmt.Printf("error while dailing to the socket server; err: %s\n", err.Error())
}
if conn == nil {
fmt.Println("error while dialing to the socket server; conn is nil")
return
}
h.setServer(conn)
h.serverHandler()
h.setServer(nil)
if !h.config.ConnectRetry {
break loop
}
fmt.Println("connection lost, retrying in 5 seconds...")
time.Sleep(5 * time.Second)
}
}
}
func (h *Host) setSession(ctx context.Context) {
h.mux.Lock()
defer h.mux.Unlock()
if h.cancel != nil {
h.cancel()
}
h.ctx, h.cancel = context.WithCancel(ctx)
}
func (h *Host) setServer(conn *websocket.Conn) {
h.mux.Lock()
defer h.mux.Unlock()
h.server = conn
}
func (h *Host) WriteRTP(packet *rtp.Packet) error {
h.mux.RLock()
defer h.mux.RUnlock()
msg, err := packet.Marshal()
if err != nil {
return fmt.Errorf("error while marshalling the packet; err: %s", err.Error())
}
if h.server == nil {
return fmt.Errorf("yet to connect to server. skipping message")
}
return h.send(h.server, websocket.MessageBinary, msg)
}
func (h *Host) Write(msg []byte) error {
return h.send(h.server, websocket.MessageBinary, msg)
}
// func (h *Host) Write(msgType websocket.MessageType, msg []byte) error {
// return h.send(h.server, msgType, msg)
// }
func (h *Host) send(conn *websocket.Conn, msgType websocket.MessageType, msg []byte) error {
if err := conn.Write(h.ctx, msgType, msg); err != nil {
return err
}
return nil
}
func (h *Host) read(conn *websocket.Conn) ([]byte, error) {
_, msg, err := conn.Read(h.ctx)
if err != nil {
return nil, err
}
return msg, nil
}
func (h *Host) readFromServer() ([]byte, error) {
h.mux.RLock()
defer h.mux.RUnlock()
return h.read(h.server)
}
func (h *Host) sendToSource(msg []byte) error {
h.mux.RLock()
defer h.mux.RUnlock()
if h.f == nil {
return errors.New("server sent a message but no handler was passed to send back to source")
}
return h.f(msg)
}
func (h *Host) serverHandler() {
fmt.Println("starting server routine...")
for {
select {
case <-h.ctx.Done():
return
default:
msg, err := h.readFromServer()
if err != nil {
if errors.Is(err, context.Canceled) {
fmt.Println("error while reading from server; err:", err.Error())
continue
}
fmt.Println("error while reading from server; err:", err.Error())
return
}
if err := h.sendToSource(msg); err != nil {
fmt.Println("error while sending msg back to source; err:", err.Error())
continue
}
}
}
}
func (h *Host) close() {
h.mux.Lock()
defer h.mux.Unlock()
if h.cancel != nil {
h.cancel()
}
if h.server != nil {
_ = h.server.Close(websocket.StatusNormalClosure, "host closing connection")
h.server = nil
}
}
func (h *Host) SetOnPayloadCallback(f OnServerMessage) {
h.f = f
}
func (h *Host) Close() error {
h.close()
return nil
}
+479
View File
@@ -1 +1,480 @@
package socket
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/coder/websocket"
"github.com/google/uuid"
)
const (
ServerDown string = "SERVER_OFFLINE"
ServerUp string = "SERVER_ONLINE"
)
func isLoopBack(remoteAddr string) bool {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return false
}
ip := net.ParseIP(host)
if ip == nil {
return strings.ToLower(host) == "localhost"
}
return ip.IsLoopback()
}
type Config struct {
Addr string
Port uint16
ReadTimout time.Duration
WriteTimout time.Duration
TotalConnections uint8
KeepHosting bool
}
type metrics struct {
Uptime time.Duration `json:"uptime"`
ActiveConnections uint8 `json:"active_connections"`
FailedConnections uint8 `json:"failed_connections"`
TotalDataSent int64 `json:"total_data_sent"`
TotalDataRecvd int64 `json:"total_data_recvd"`
mux sync.RWMutex
}
func (m *metrics) active() uint8 {
m.mux.RLock()
defer m.mux.RUnlock()
return m.ActiveConnections
}
func (m *metrics) failed() uint8 {
m.mux.RLock()
defer m.mux.RUnlock()
return m.FailedConnections
}
func (m *metrics) increaseActiveConnections() {
m.mux.Lock()
defer m.mux.Unlock()
m.ActiveConnections++
}
func (m *metrics) decreaseActiveConnections() {
m.mux.Lock()
defer m.mux.Unlock()
m.ActiveConnections--
}
func (m *metrics) increaseFailedConnections() {
m.mux.Lock()
defer m.mux.Unlock()
m.FailedConnections++
}
func (m *metrics) addDataSent(len int64) {
m.mux.Lock()
defer m.mux.Unlock()
m.TotalDataSent = m.TotalDataSent + len
}
func (m *metrics) addDataRecvd(len int64) {
m.mux.Lock()
defer m.mux.Unlock()
m.TotalDataRecvd = m.TotalDataRecvd + len
}
func (m *metrics) Marshal() ([]byte, error) {
m.mux.RLock()
defer m.mux.RUnlock()
return json.Marshal(m)
}
type health struct {
State string `json:"state"`
RecentErrors []string `json:"recent_errors"` // TODO: IMPLEMENT THIS LATER
mux sync.RWMutex
}
func (h *health) Marshal() ([]byte, error) {
h.mux.RLock()
defer h.mux.RUnlock()
return json.Marshal(h)
}
type Server struct {
httpServer *http.Server
config Config
host *websocket.Conn
clients map[string]*websocket.Conn
once sync.Once
ctx context.Context
cancel context.CancelFunc
mux sync.RWMutex
metrics *metrics
health *health
}
func NewServer(config Config) *Server {
router := http.NewServeMux()
server := &Server{
httpServer: &http.Server{
Addr: fmt.Sprintf("%s:%d", config.Addr, config.Port),
ReadHeaderTimeout: config.ReadTimout,
WriteTimeout: config.WriteTimout,
Handler: router,
},
config: config,
host: nil,
clients: make(map[string]*websocket.Conn),
metrics: &metrics{},
health: &health{},
}
router.HandleFunc("/ws", server.wsHandler)
router.HandleFunc("/status", server.statusHandler)
router.HandleFunc("/metrics", server.metricsHandler)
return server
}
// Start starts the server with the given configuration and listens for clients.
// This is a blocking call and must be called in a separate goroutine.
func (s *Server) Start(ctx context.Context) {
s.setSession(ctx)
defer s.close()
loop:
for {
select {
case <-s.ctx.Done():
return
default:
s.setStatus(ServerUp)
if err := s.httpServer.ListenAndServe(); err != nil {
s.setStatus(ServerDown)
fmt.Printf("error while serving socket; err: %s\n", err.Error())
}
s.setStatus(ServerDown)
if !s.config.KeepHosting {
break loop
}
fmt.Println("failed to host server, retrying in 5 seconds...")
time.Sleep(5 * time.Second)
}
}
}
func (s *Server) setStatus(status string) {
s.mux.Lock()
defer s.mux.Unlock()
s.health.State = status
}
func (s *Server) setSession(ctx context.Context) {
s.mux.Lock()
defer s.mux.Unlock()
if s.cancel != nil {
s.cancel()
}
s.ctx, s.cancel = context.WithCancel(ctx)
}
func (s *Server) wsHandler(w http.ResponseWriter, req *http.Request) {
fmt.Println("got ws request from:", req.RemoteAddr)
if s.getHost() == nil && isLoopBack(req.RemoteAddr) {
defer s.setHost(nil)
conn, err := s.upgradeRequest(true, w, req)
if err != nil {
fmt.Println("error while handling http request; err:", err.Error())
return
}
s.setHost(conn)
s.hostHandler()
return
}
conn, err := s.upgradeRequest(false, w, req)
if err != nil {
fmt.Println("error while handling http request; err:", err.Error())
return
}
s.metrics.increaseActiveConnections()
defer s.metrics.decreaseActiveConnections()
fmt.Println("socket found a client with IP:", req.RemoteAddr)
id := uuid.NewString()
s.addClient(id, conn)
s.clientHandler(id)
}
func (s *Server) getHost() *websocket.Conn {
s.mux.RLock()
defer s.mux.RUnlock()
return s.host
}
func (s *Server) setHost(conn *websocket.Conn) {
s.mux.Lock()
defer s.mux.Unlock()
fmt.Println("host set")
s.host = conn
}
func (s *Server) upgradeRequest(asHost bool, w http.ResponseWriter, req *http.Request) (*websocket.Conn, error) {
if !asHost {
if s.metrics.active()+1 > s.config.TotalConnections {
s.metrics.increaseFailedConnections()
fmt.Printf("current number of clients: %d; max allowed: %d\n", s.metrics.active(), s.config.TotalConnections)
return nil, errors.New("max clients reached")
}
s.metrics.increaseActiveConnections()
}
conn, err := websocket.Accept(w, req, nil)
if err != nil {
if !asHost {
s.metrics.decreaseActiveConnections()
s.metrics.increaseFailedConnections()
}
return nil, fmt.Errorf("error while upgrading http request to websocket; err: %s", err.Error())
}
return conn, nil
}
func (s *Server) statusHandler(w http.ResponseWriter, _ *http.Request) {
s.mux.RLock()
msg, err := s.health.Marshal()
if err != nil {
http.Error(w, "Failed to marshal health status", http.StatusInternalServerError)
return
}
s.mux.RUnlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if _, err := w.Write(msg); err != nil {
fmt.Println("error while sending status response")
return
}
}
func (s *Server) metricsHandler(w http.ResponseWriter, _ *http.Request) {
s.mux.RLock()
msg, err := s.metrics.Marshal()
if err != nil {
http.Error(w, "Failed to marshal metrics", http.StatusInternalServerError)
return
}
s.mux.RUnlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if _, err := w.Write(msg); err != nil {
fmt.Println("error while sending status response")
return
}
}
func (s *Server) hostHandler() {
fmt.Println("starting host routine...")
for {
select {
case <-s.ctx.Done():
return
default:
msgType, msg, err := s.readHost()
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
fmt.Println("error while reading from host; err:", err.Error())
continue
}
fmt.Println("error while reading from host; err:", err.Error())
return
}
if err := s.sendAllClients(msgType, msg); err != nil {
fmt.Println("error while sending msg from host to clients; err:", err.Error())
continue
}
}
}
}
func (s *Server) clientHandler(id string) {
defer s.removeClient(id)
for {
select {
case <-s.ctx.Done():
return
default:
msgType, msg, err := s.readClient(id)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
fmt.Println("error while reading from client; err:", err.Error())
continue
}
return
}
if err := s.sendHost(msgType, msg); err != nil {
fmt.Println("error while sending message from client to host; err:", err.Error())
continue
}
}
}
}
func (s *Server) removeClient(id string) {
s.mux.Lock()
defer s.mux.Unlock()
delete(s.clients, id)
}
func (s *Server) addClient(id string, conn *websocket.Conn) {
s.mux.Lock()
defer s.mux.Unlock()
s.clients[id] = conn
}
func (s *Server) readHost() (websocket.MessageType, []byte, error) {
s.mux.RLock()
host := s.host
s.mux.RUnlock()
if host == nil {
return websocket.MessageBinary, nil, errors.New("host is nil")
}
return s.read(host)
}
func (s *Server) sendHost(msgType websocket.MessageType, msg []byte) error {
s.mux.RLock()
host := s.host
s.mux.RUnlock()
if host == nil {
return errors.New("host not available. this is not normal. socket in invalid State")
}
return s.send(msgType, host, msg)
}
func (s *Server) readClient(id string) (websocket.MessageType, []byte, error) {
s.mux.RLock()
client, exists := s.clients[id]
s.mux.RUnlock()
if !exists {
return 0, nil, errors.New("read called on unknown client")
}
return s.read(client)
}
func (s *Server) read(conn *websocket.Conn) (websocket.MessageType, []byte, error) {
msgType, msg, err := conn.Read(s.ctx)
if err != nil {
return 0, nil, err
}
s.metrics.addDataRecvd(int64(len(msg)))
return msgType, msg, nil
}
func (s *Server) sendAllClients(msgType websocket.MessageType, msg []byte) error {
s.mux.RLock()
clientList := make([]*websocket.Conn, 0, len(s.clients))
for _, client := range s.clients {
clientList = append(clientList, client)
}
s.mux.RUnlock()
if len(clientList) <= 0 {
return errors.New("no clients to send")
}
for _, client := range clientList {
if err := s.send(msgType, client, msg); err != nil {
return err
}
}
return nil
}
func (s *Server) send(msgType websocket.MessageType, conn *websocket.Conn, msg []byte) error {
if err := conn.Write(s.ctx, msgType, msg); err != nil {
return err
}
s.metrics.addDataSent(int64(len(msg)))
return nil
}
func (s *Server) close() {
s.mux.Lock()
defer s.mux.Unlock()
s.cancel()
s.host = nil
s.clients = make(map[string]*websocket.Conn)
s.metrics = &metrics{}
s.health = &health{}
}
func (s *Server) Close() error {
var err error
s.once.Do(func() {
s.close()
s.mux.Lock()
defer s.mux.Unlock()
s.config = Config{}
err = s.httpServer.Close()
})
return err
}
+30 -30
View File
@@ -17,33 +17,33 @@ import (
)
// Test helper functions
func createTestServer(port uint16) *Server {
config := config{
addr: "localhost",
port: port, // Let OS choose port
readTimout: 5 * time.Second,
writeTimout: 5 * time.Second,
totalConnections: 5,
keepHosting: false,
func createTestServer(Port uint16) *Server {
config := Config{
Addr: "localhost",
Port: Port, // Let OS choose Port
ReadTimout: 5 * time.Second,
WriteTimout: 5 * time.Second,
TotalConnections: 5,
KeepHosting: false,
}
return NewServer(config)
}
func createTestHost(addr string, port uint16, onMessage OnServerMessage) *Host {
func createTestHost(Addr string, Port uint16, onMessage OnServerMessage) *Host {
config := HostConfig{
writeTimeout: 5 * time.Second,
readTimout: 5 * time.Second,
connectRetry: false,
WriteTimeout: 5 * time.Second,
ReadTimout: 5 * time.Second,
ConnectRetry: false,
}
return &Host{
addr: addr,
port: port,
addr: Addr,
port: Port,
config: config,
f: onMessage,
}
}
// Test port allocation - each test gets a dedicated port
// Test Port allocation - each test gets a dedicated Port
const (
testPortBase = 9000
testPortStatus = testPortBase + 1 // 9001
@@ -164,7 +164,7 @@ func TestHost_Connect_NoServer(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
// Should fail to connect and return quickly since connectRetry is false
// Should fail to connect and return quickly since ConnectRetry is false
host.Connect(ctx)
// Verify no messages received
@@ -185,7 +185,7 @@ func TestHost_WriteWithoutConnection(t *testing.T) {
Payload: []byte("test payload"),
}
err := host.Write(packet)
err := host.WriteRTP(packet)
assert.Error(t, err)
assert.Contains(t, err.Error(), "yet to connect to server")
}
@@ -232,7 +232,7 @@ func TestServerHost_Integration(t *testing.T) {
}
host := createTestHost("localhost", testPortIntegration, onMessage)
host.config.connectRetry = false
host.config.ConnectRetry = false
// Connect host
hostCtx, hostCancel := context.WithTimeout(context.Background(), 5*time.Second)
@@ -260,8 +260,8 @@ func TestServerHost_MultipleClients(t *testing.T) {
// Create and start server
server := createTestServer(testPortMultiClient)
server.config.port = testPortMultiClient
server.config.totalConnections = 3
server.config.Port = testPortMultiClient
server.config.TotalConnections = 3
server.httpServer.Addr = fmt.Sprintf("localhost:%d", testPortMultiClient)
serverCtx, serverCancel := context.WithCancel(context.Background())
@@ -272,7 +272,7 @@ func TestServerHost_MultipleClients(t *testing.T) {
// Connect host first
host := createTestHost("localhost", testPortMultiClient, func(msg []byte) error { return nil })
host.config.connectRetry = false
host.config.ConnectRetry = false
hostCtx, hostCancel := context.WithTimeout(context.Background(), 3*time.Second)
defer hostCancel()
@@ -319,7 +319,7 @@ func TestServer_ConnectionLimit(t *testing.T) {
// Connect host first
host := createTestHost("localhost", testPortConnLimit, func(msg []byte) error { return nil })
host.config.connectRetry = false
host.config.ConnectRetry = false
go host.Connect(context.Background())
time.Sleep(300 * time.Millisecond)
@@ -342,7 +342,7 @@ func TestServer_ConnectionLimit(t *testing.T) {
}
// Should respect connection limit
assert.LessOrEqual(t, successfulConnections, int(server.config.totalConnections))
assert.LessOrEqual(t, successfulConnections, int(server.config.TotalConnections))
// Check that failed connections were recorded
failed := server.metrics.failed()
@@ -352,13 +352,13 @@ func TestServer_ConnectionLimit(t *testing.T) {
}
func BenchmarkServer_WebSocketUpgrade(b *testing.B) {
config := config{
addr: "localhost",
port: testPortBenchmark, // Let OS choose port
readTimout: 100 * time.Millisecond,
writeTimout: 100 * time.Millisecond,
totalConnections: 255,
keepHosting: false,
config := Config{
Addr: "localhost",
Port: testPortBenchmark, // Let OS choose Port
ReadTimout: 100 * time.Millisecond,
WriteTimout: 100 * time.Millisecond,
TotalConnections: 255,
KeepHosting: false,
}
server := NewServer(config)
+234 -13
View File
@@ -7,6 +7,7 @@ import (
"io"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
@@ -22,6 +23,12 @@ type (
ServerStatus string
)
var (
// Remove the start codes (0 0 0 1) for RTP - only keep the NAL unit
spsData = []byte{103, 66, 192, 31, 166, 128, 216, 61, 230, 225, 0, 0, 3, 0, 1, 0, 0, 3, 0, 50, 224, 32, 0, 19, 18, 208, 0, 9, 137, 110, 41, 32, 7, 140, 25, 80}
ppsData = []byte{104, 206, 62, 128}
)
const (
SDPNotInitialised SDPStatus = "not-initialised"
SDPRequested SDPStatus = "requested"
@@ -51,6 +58,9 @@ func isLoopBack(remoteAddr string) bool {
type ClientSession struct {
peerConnection *webrtc.PeerConnection
videoTrack *webrtc.TrackLocalStaticRTP
primarySSRC webrtc.SSRC
rtxSSRC webrtc.SSRC
RTPPayloadType webrtc.PayloadType
sdpStatus SDPStatus
ttl time.Duration
connectionState webrtc.PeerConnectionState
@@ -230,7 +240,6 @@ func NewHost(config Config) *Host {
Addr: fmt.Sprintf("%s:%d", config.Addr, config.Port),
ReadHeaderTimeout: config.ReadTimout,
WriteTimeout: config.WriteTimout,
Handler: router,
},
config: config,
clients: make(map[string]*ClientSession),
@@ -245,6 +254,8 @@ func NewHost(config Config) *Host {
router.HandleFunc("/api/webrtc-sink/metrics", server.metricsHandler)
router.HandleFunc("/api/webrtc-sink/health", server.healthHandler)
server.httpServer.Handler = server.enableCORS(router)
return server
}
@@ -265,7 +276,7 @@ func (h *Host) setSession(ctx context.Context) {
h.ctx, h.cancel = context.WithCancel(ctx)
}
func (h *Host) Start(ctx context.Context) {
func (h *Host) Connect(ctx context.Context) {
h.setSession(ctx)
defer h.close()
@@ -302,8 +313,25 @@ loop:
}
}
func (h *Host) enableCORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Accept, Authorization, Content-Type, X-CSRF-Token")
// Handle preflight requests
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
func (h *Host) requestHandler(w http.ResponseWriter, req *http.Request) {
// Extract ID from path
fmt.Printf("got request")
id := req.PathValue("ID")
if id == "" {
http.Error(w, "Missing ID parameter", http.StatusBadRequest)
@@ -362,7 +390,7 @@ func (h *Host) handleOfferRequest(w http.ResponseWriter, req *http.Request, id s
return
}
// Set local description
// Set local description FIRST
if err := session.peerConnection.SetLocalDescription(offer); err != nil {
errMsg := fmt.Sprintf("Failed to set local description: %v", err)
h.health.addError(errMsg)
@@ -370,18 +398,34 @@ func (h *Host) handleOfferRequest(w http.ResponseWriter, req *http.Request, id s
return
}
// THEN wait for ICE gathering
fmt.Printf("Waiting for ICE gathering to complete for session: %s\n", id)
<-webrtc.GatheringCompletePromise(session.peerConnection)
fmt.Printf("ICE gathering complete for session: %s\n", id)
// Get the complete local description with ICE candidates
finalOffer := session.peerConnection.LocalDescription()
localSDP := session.peerConnection.LocalDescription()
primarySSRC, rtxSSRC := extractSSRCFromSDP(localSDP.SDP)
session.primarySSRC = webrtc.SSRC(primarySSRC)
session.rtxSSRC = webrtc.SSRC(rtxSSRC)
// Update session status
session.mu.Lock()
session.sdpStatus = SDPOfferSent
session.mu.Unlock()
// Send offer response
// Send offer response with complete ICE candidates
offerResponse := Offer{
ID: id,
Offer: offer,
Offer: *finalOffer, // Use the final description with ICE candidates
Timestamp: time.Now(),
}
fmt.Printf("offer: %s\n", finalOffer.SDP)
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(offerResponse); err != nil {
errMsg := fmt.Sprintf("Failed to encode response: %v", err)
@@ -390,7 +434,7 @@ func (h *Host) handleOfferRequest(w http.ResponseWriter, req *http.Request, id s
return
}
fmt.Printf("Sent offer for session: %h\n", id)
fmt.Printf("Sent offer for session: %s\n", id)
}
func (h *Host) answerHandler(w http.ResponseWriter, req *http.Request) {
@@ -433,6 +477,8 @@ func (h *Host) handleAnswerResponse(w http.ResponseWriter, req *http.Request, id
return
}
fmt.Printf("answer: %s\n", answer.Answer.SDP)
// Set remote description (answer)
if err := session.peerConnection.SetRemoteDescription(answer.Answer); err != nil {
errMsg := fmt.Sprintf("Failed to set remote description: %v", err)
@@ -509,9 +555,13 @@ func (h *Host) getOrCreateSession(id string, ttl time.Duration) (*ClientSession,
// Create video track
videoTrack, err := webrtc.NewTrackLocalStaticRTP(
webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264},
webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeH264,
ClockRate: 90000,
SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42001f", // Match PT 102
},
"video",
fmt.Sprintf("%h-video", id),
fmt.Sprintf("%s-video", id),
)
if err != nil {
pc.Close()
@@ -557,7 +607,7 @@ func (h *Host) getOrCreateSession(id string, ttl time.Duration) (*ClientSession,
session.connectionState = state
session.mu.Unlock()
fmt.Printf("Session %h connection state: %h\n", id, state.String())
fmt.Printf("Session %s connection state: %s\n", id, state.String())
switch state {
case webrtc.PeerConnectionStateConnected:
@@ -565,6 +615,13 @@ func (h *Host) getOrCreateSession(id string, ttl time.Duration) (*ClientSession,
session.sdpStatus = SDPConnected
session.mu.Unlock()
h.metrics.increaseActiveConnections()
fmt.Printf("🚀 Connection established! Sending SPS/PPS for session %s\n", id)
if err := h.sendParameterSets(id); err != nil {
fmt.Printf("❌ Failed to send parameter sets: %v\n", err)
}
// Start periodic SPS/PPS sending
h.startParameterSetTimer(id)
case webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateClosed:
session.mu.Lock()
session.sdpStatus = SDPFailed
@@ -586,7 +643,28 @@ func (h *Host) getOrCreateSession(id string, ttl time.Duration) (*ClientSession,
func (h *Host) forwardRTPPackets(session *ClientSession) {
for packet := range session.rtpChan {
// Forward to WebRTC track
fmt.Printf("Forwarding packet - PT: %d, Seq: %d, TS: %d, SSRC: %d, Size: %d, Marker: %t\n",
packet.Header.PayloadType,
packet.Header.SequenceNumber,
packet.Header.Timestamp,
packet.Header.SSRC,
len(packet.Payload),
packet.Header.Marker)
// Validate H.264 payload
if len(packet.Payload) > 0 {
nalType := packet.Payload[0] & 0x1F
nalHeader := packet.Payload[0]
fmt.Printf("H.264 NAL - Type: %d, Header: 0x%02x, First 8 bytes: %x\n",
nalType, nalHeader, packet.Payload[:min(8, len(packet.Payload))])
// Check for valid NAL unit types
if nalType == 0 || nalType > 23 {
fmt.Printf("WARNING: Invalid NAL unit type %d\n", nalType)
}
} else {
fmt.Printf("WARNING: Empty payload\n")
}
if err := session.videoTrack.WriteRTP(packet); err != nil {
continue
}
@@ -603,15 +681,19 @@ func (h *Host) ProcessRTPPacket(sessionID string, rtpData *rtp.Packet) error {
h.mux.RUnlock()
if !exists {
return fmt.Errorf("session %h not found", sessionID)
fmt.Printf("Session %s not found\n", sessionID)
return fmt.Errorf("session %s not found", sessionID)
}
rtpData.Header.SSRC = uint32(session.primarySSRC)
select {
case session.rtpChan <- rtpData:
session.updateActivity()
return nil
default:
return fmt.Errorf("RTP channel full for session %h", sessionID)
fmt.Printf("RTP channel full for session %s\n", sessionID)
return fmt.Errorf("RTP channel full for session %s", sessionID)
}
}
@@ -643,6 +725,145 @@ func (h *Host) WriteRTP(packet *rtp.Packet) error {
return nil
}
func extractSSRCFromSDP(sdp string) (primarySSRC uint32, rtxSSRC uint32) {
lines := strings.Split(sdp, "\n")
ssrcs := make([]uint32, 0)
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "a=ssrc:") && !strings.Contains(line, "rtx") {
// Extract SSRC number
parts := strings.Split(line, " ")
if len(parts) >= 1 {
ssrcStr := strings.Split(parts[0], ":")[1]
if ssrc, err := strconv.ParseUint(ssrcStr, 10, 32); err == nil {
ssrcs = append(ssrcs, uint32(ssrc))
}
}
}
}
// Remove duplicates and sort
uniqueSSRCs := make(map[uint32]bool)
for _, ssrc := range ssrcs {
uniqueSSRCs[ssrc] = true
}
ssrcList := make([]uint32, 0, len(uniqueSSRCs))
for ssrc := range uniqueSSRCs {
ssrcList = append(ssrcList, ssrc)
}
if len(ssrcList) >= 1 {
primarySSRC = ssrcList[0] // 3420211287
}
if len(ssrcList) >= 2 {
rtxSSRC = ssrcList[1] // 230256171
}
return primarySSRC, rtxSSRC
}
func (h *Host) sendParameterSets(sessionID string) error {
h.mux.RLock()
session, exists := h.clients[sessionID]
h.mux.RUnlock()
if !exists {
return fmt.Errorf("session %s not found", sessionID)
}
// Create SPS packet
spsPacket := &rtp.Packet{
Header: rtp.Header{
Version: 2,
PayloadType: 102,
Timestamp: uint32(time.Now().Unix() * 90000), // Use current timestamp
SSRC: uint32(session.primarySSRC), // Use the extracted SSRC
Marker: false, // SPS should not have marker bit
},
Payload: spsData,
}
// Create PPS packet
ppsPacket := &rtp.Packet{
Header: rtp.Header{
Version: 2,
PayloadType: 102,
Timestamp: uint32(time.Now().Unix() * 90000), // Same timestamp as SPS
SSRC: uint32(session.primarySSRC),
Marker: false, // PPS should not have marker bit
},
Payload: ppsData,
}
// Send SPS first
select {
case session.rtpChan <- spsPacket:
session.updateActivity()
session.mu.Lock()
session.mu.Unlock()
default:
return fmt.Errorf("failed to send SPS: channel full")
}
// Send PPS second
select {
case session.rtpChan <- ppsPacket:
session.updateActivity()
session.mu.Lock()
session.mu.Unlock()
default:
return fmt.Errorf("failed to send PPS: channel full")
}
return nil
}
func (h *Host) startParameterSetTimer(sessionID string) {
go func() {
ticker := time.NewTicker(3 * time.Second) // Send every 3 seconds
defer ticker.Stop()
fmt.Printf("⏰ Started periodic SPS/PPS timer for session %s\n", sessionID)
for {
select {
case <-h.ctx.Done():
fmt.Printf("⏰ Stopping SPS/PPS timer for session %s (context done)\n", sessionID)
return
case <-ticker.C:
h.mux.RLock()
session, exists := h.clients[sessionID]
h.mux.RUnlock()
if !exists {
fmt.Printf("⏰ Stopping SPS/PPS timer for session %s (session not found)\n", sessionID)
return
}
session.mu.RLock()
connected := session.connectionState == webrtc.PeerConnectionStateConnected
session.mu.RUnlock()
if connected {
fmt.Printf("⏰ Periodic SPS/PPS send for session %s\n", sessionID)
if err := h.sendParameterSets(sessionID); err != nil {
fmt.Printf("❌ Failed periodic SPS/PPS send: %v\n", err)
}
} else {
fmt.Printf("⏰ Stopping SPS/PPS timer for session %s (not connected)\n", sessionID)
return
}
}
}
}()
}
func (h *Host) Write(_ []byte) error {
return nil
}
func (h *Host) cleanupExpiredClients() {
defer h.wg.Done()
+1 -1
View File
@@ -1 +1 @@
package wasm
package main