diff --git a/.golangci.yml b/.golangci.yml index a3dfcbca..feb3d164 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -78,6 +78,12 @@ linters: - stringsbuilder - testingcontext + exclusions: + rules: + - linters: + - lll + source: "^\\s*// https?://" + formatters: enable: - gofmt diff --git a/go.mod b/go.mod index b67e28e5..45ab5974 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,7 @@ require ( github.com/pion/webrtc/v4 v4.2.11 github.com/stretchr/testify v1.11.1 golang.org/x/crypto v0.49.0 + golang.org/x/net v0.52.0 golang.org/x/sys v0.42.0 golang.org/x/term v0.41.0 ) @@ -96,7 +97,6 @@ require ( github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect golang.org/x/arch v0.22.0 // indirect - golang.org/x/net v0.52.0 // indirect golang.org/x/text v0.35.0 // indirect golang.org/x/time v0.12.0 // indirect google.golang.org/protobuf v1.36.10 // indirect diff --git a/internal/auth/manager.go b/internal/auth/manager.go index 5dfb4da2..6ac25642 100644 --- a/internal/auth/manager.go +++ b/internal/auth/manager.go @@ -185,13 +185,8 @@ func (m *Manager) authenticateHTTP(req *Request) (string, error) { Query: req.Query, }) - u, err := url.Parse(m.HTTPAddress) - if err != nil { - return "", err - } - tr := &http.Transport{ - TLSClientConfig: tls.MakeConfig(u.Hostname(), m.HTTPFingerprint), + TLSClientConfig: tls.MakeConfig(m.HTTPFingerprint), } defer tr.CloseIdleConnections() @@ -283,13 +278,8 @@ func (m *Manager) pullJWTJWKS() (jwt.Keyfunc, error) { defer m.mutex.Unlock() if now.Sub(m.jwksLastRefresh) >= jwksRefreshPeriod { - u, err := url.Parse(m.JWTJWKS) - if err != nil { - return nil, err - } - tr := &http.Transport{ - TLSClientConfig: tls.MakeConfig(u.Hostname(), m.JWTJWKSFingerprint), + TLSClientConfig: tls.MakeConfig(m.JWTJWKSFingerprint), } defer tr.CloseIdleConnections() diff --git a/internal/core/core.go b/internal/core/core.go index 58234642..ef272c79 100644 --- a/internal/core/core.go +++ b/internal/core/core.go @@ -470,7 +470,7 @@ func (p *Core) createResources(initial bool) error { MulticastIPRange: p.conf.MulticastIPRange, MulticastRTPPort: p.conf.MulticastRTPPort, MulticastRTCPPort: p.conf.MulticastRTCPPort, - IsTLS: false, + Encryption: false, ServerCert: "", ServerKey: "", RTSPAddress: p.conf.RTSPAddress, @@ -513,7 +513,7 @@ func (p *Core) createResources(initial bool) error { MulticastIPRange: p.conf.MulticastIPRange, MulticastRTPPort: p.conf.MulticastSRTPPort, MulticastRTCPPort: p.conf.MulticastSRTCPPort, - IsTLS: true, + Encryption: true, ServerCert: p.conf.RTSPServerCert, ServerKey: p.conf.RTSPServerKey, RTSPAddress: p.conf.RTSPAddress, @@ -542,7 +542,7 @@ func (p *Core) createResources(initial bool) error { DumpPackets: p.conf.DumpPackets, ReadTimeout: p.conf.ReadTimeout, WriteTimeout: p.conf.WriteTimeout, - IsTLS: false, + Encryption: false, ServerCert: "", ServerKey: "", RTSPAddress: p.conf.RTSPAddress, @@ -569,7 +569,7 @@ func (p *Core) createResources(initial bool) error { Address: p.conf.RTMPSAddress, ReadTimeout: p.conf.ReadTimeout, WriteTimeout: p.conf.WriteTimeout, - IsTLS: true, + Encryption: true, ServerCert: p.conf.RTMPServerCert, ServerKey: p.conf.RTMPServerKey, DumpPackets: p.conf.DumpPackets, diff --git a/internal/packetdumper/conn.go b/internal/packetdumper/conn.go index 68a77cfd..36e294a1 100644 --- a/internal/packetdumper/conn.go +++ b/internal/packetdumper/conn.go @@ -2,7 +2,9 @@ package packetdumper import ( + "encoding/binary" "fmt" + "io" "net" "os" "sync" @@ -14,39 +16,79 @@ import ( "github.com/google/uuid" ) -var _ net.Conn = (*Conn)(nil) +var _ net.Conn = (*conn)(nil) type direction int const ( - dirRead direction = iota - dirWrite + dirInbound direction = iota + dirOutbound dirHandshake + dirSecret ) +func writeDecryptionSecretsBlock(f io.Writer, data []byte) { + const ( + dsbSecretTypeTLS = 0x544c534b + blockType = 0x0000000A + fixedHeaderBytes = 4 + 4 + 4 + 4 // type + totalLen + secretsType + secretsLen + trailerBytes = 4 // repeated totalLen + overheadBytes = fixedHeaderBytes + trailerBytes + ) + + secretsLen := len(data) + paddedLen := (secretsLen + 3) &^ 3 + padBytes := paddedLen - secretsLen + + totalLen := uint32(overheadBytes + paddedLen) + + buf := make([]byte, totalLen) + pos := 0 + + binary.LittleEndian.PutUint32(buf[pos:], blockType) + pos += 4 + binary.LittleEndian.PutUint32(buf[pos:], totalLen) + pos += 4 + binary.LittleEndian.PutUint32(buf[pos:], dsbSecretTypeTLS) + pos += 4 + binary.LittleEndian.PutUint32(buf[pos:], uint32(secretsLen)) // unpadded length, per spec + pos += 4 + pos += copy(buf[pos:], data) + pos += padBytes // zero padding already present (make zeroes) + binary.LittleEndian.PutUint32(buf[pos:], totalLen) // trailing repeat + + f.Write(buf) //nolint:errcheck +} + type dumpEntry struct { ntp time.Time data []byte direction direction } -// Conn is a wrapper around net.Conn that dumps packets to disk. -type Conn struct { +// conn is a wrapper around net.Conn that dumps packets to disk. +type conn struct { Prefix string Conn net.Conn ServerSide bool - f *os.File - pw *pcapgo.NgWriter - once sync.Once + expectingSecrets int + f *os.File + pw *pcapgo.NgWriter + once sync.Once + local *net.TCPAddr + remote *net.TCPAddr + nextLocalSequence uint32 + nextRemoteSequence uint32 + delayed []dumpEntry queue chan dumpEntry terminated chan struct{} done chan struct{} } -// Initialize initializes Conn. -func (c *Conn) Initialize() error { +// Initialize initializes conn. +func (c *conn) Initialize() error { var err error c.f, err = os.Create(fmt.Sprintf("%s_%d_%s.pcapng", c.Prefix, time.Now().UnixNano(), uuid.New().String())) if err != nil { @@ -59,6 +101,17 @@ func (c *Conn) Initialize() error { return err } + c.local = c.Conn.LocalAddr().(*net.TCPAddr) + c.remote = c.Conn.RemoteAddr().(*net.TCPAddr) + + if c.ServerSide { + c.nextLocalSequence = uint32(2000) + c.nextRemoteSequence = uint32(1000) + } else { + c.nextLocalSequence = uint32(1000) + c.nextRemoteSequence = uint32(2000) + } + c.queue = make(chan dumpEntry, 64) c.terminated = make(chan struct{}) c.done = make(chan struct{}) @@ -71,7 +124,7 @@ func (c *Conn) Initialize() error { } // Close implements net.Conn. -func (c *Conn) Close() error { +func (c *conn) Close() error { c.once.Do(func() { close(c.terminated) }) @@ -79,29 +132,23 @@ func (c *Conn) Close() error { return c.Conn.Close() } -func (c *Conn) run() { +func (c *conn) run() { defer close(c.done) defer c.f.Close() - - local := c.Conn.LocalAddr().(*net.TCPAddr) - remote := c.Conn.RemoteAddr().(*net.TCPAddr) - - nextLocalSequence := uint32(1000) - nextRemoteSequence := uint32(2000) + defer c.pw.Flush() //nolint:errcheck for { select { case e := <-c.queue: - c.processEntry(e, local, remote, &nextLocalSequence, &nextRemoteSequence) + c.processEntry(e) case <-c.terminated: // Drain anything already in the queue before exiting. for { select { case e := <-c.queue: - c.processEntry(e, local, remote, &nextLocalSequence, &nextRemoteSequence) + c.processEntry(e) default: - c.pw.Flush() //nolint:errcheck return } } @@ -109,18 +156,19 @@ func (c *Conn) run() { } } -func (c *Conn) processEntry( - e dumpEntry, - local, remote *net.TCPAddr, - nextLocalSequence, nextRemoteSequence *uint32, -) { +func (c *conn) processEntry(e dumpEntry) { + if c.expectingSecrets > 0 && e.direction != dirSecret { + c.delayed = append(c.delayed, e) + return + } + switch e.direction { case dirHandshake: - clientAddr, serverAddr := local, remote // client side: local initiates - clientSeq, serverSeq := nextLocalSequence, nextRemoteSequence + clientAddr, serverAddr := c.local, c.remote // client side: local initiates + clientSeq, serverSeq := &c.nextLocalSequence, &c.nextRemoteSequence if c.ServerSide { - clientAddr, serverAddr = remote, local // server side: remote initiated - clientSeq, serverSeq = nextRemoteSequence, nextLocalSequence + clientAddr, serverAddr = c.remote, c.local // server side: remote initiates + clientSeq, serverSeq = &c.nextRemoteSequence, &c.nextLocalSequence } // SYN (client -> server) @@ -137,31 +185,43 @@ func (c *Conn) processEntry( c.writePacket(e.ntp, clientAddr, serverAddr, layers.TCP{ACK: true, Window: 65535, Seq: *clientSeq, Ack: *serverSeq}, nil) - case dirRead: - tcpFlags := layers.TCP{ - PSH: true, - ACK: true, - Window: 14600, - Seq: *nextRemoteSequence, - Ack: *nextLocalSequence, - } - c.writePacket(e.ntp, remote, local, tcpFlags, e.data) - *nextRemoteSequence += uint32(len(e.data)) + case dirSecret: + c.pw.Flush() //nolint:errcheck + writeDecryptionSecretsBlock(c.f, e.data) - case dirWrite: + c.expectingSecrets-- + if c.expectingSecrets == 0 { + for _, e2 := range c.delayed { + c.processEntry(e2) + } + c.delayed = nil + } + + case dirInbound: tcpFlags := layers.TCP{ PSH: true, ACK: true, Window: 14600, - Seq: *nextLocalSequence, - Ack: *nextRemoteSequence, + Seq: c.nextRemoteSequence, + Ack: c.nextLocalSequence, } - c.writePacket(e.ntp, local, remote, tcpFlags, e.data) - *nextLocalSequence += uint32(len(e.data)) + c.writePacket(e.ntp, c.remote, c.local, tcpFlags, e.data) + c.nextRemoteSequence += uint32(len(e.data)) + + case dirOutbound: + tcpFlags := layers.TCP{ + PSH: true, + ACK: true, + Window: 14600, + Seq: c.nextLocalSequence, + Ack: c.nextRemoteSequence, + } + c.writePacket(e.ntp, c.local, c.remote, tcpFlags, e.data) + c.nextLocalSequence += uint32(len(e.data)) } } -func (c *Conn) writePacket( +func (c *conn) writePacket( ntp time.Time, src, dst *net.TCPAddr, tcpFlags layers.TCP, @@ -207,35 +267,35 @@ func (c *Conn) writePacket( }, raw) } -func (c *Conn) enqueue(e dumpEntry) { +func (c *conn) enqueue(e dumpEntry) { select { case c.queue <- e: case <-c.terminated: } } -func (c *Conn) Read(p []byte) (n int, err error) { +func (c *conn) Read(p []byte) (n int, err error) { n, err = c.Conn.Read(p) if n != 0 { c.enqueue(dumpEntry{ ntp: time.Now(), data: append([]byte(nil), p[:n]...), - direction: dirRead, + direction: dirInbound, }) } return n, err } -func (c *Conn) Write(p []byte) (n int, err error) { +func (c *conn) Write(p []byte) (n int, err error) { n, err = c.Conn.Write(p) if err == nil { c.enqueue(dumpEntry{ ntp: time.Now(), data: append([]byte(nil), p...), - direction: dirWrite, + direction: dirOutbound, }) } @@ -243,16 +303,16 @@ func (c *Conn) Write(p []byte) (n int, err error) { } // LocalAddr implements net.Conn. -func (c *Conn) LocalAddr() net.Addr { return c.Conn.LocalAddr() } +func (c *conn) LocalAddr() net.Addr { return c.Conn.LocalAddr() } // RemoteAddr implements net.Conn. -func (c *Conn) RemoteAddr() net.Addr { return c.Conn.RemoteAddr() } +func (c *conn) RemoteAddr() net.Addr { return c.Conn.RemoteAddr() } // SetDeadline implements net.Conn. -func (c *Conn) SetDeadline(t time.Time) error { return c.Conn.SetDeadline(t) } +func (c *conn) SetDeadline(t time.Time) error { return c.Conn.SetDeadline(t) } // SetReadDeadline implements net.Conn. -func (c *Conn) SetReadDeadline(t time.Time) error { return c.Conn.SetReadDeadline(t) } +func (c *conn) SetReadDeadline(t time.Time) error { return c.Conn.SetReadDeadline(t) } // SetWriteDeadline implements net.Conn. -func (c *Conn) SetWriteDeadline(t time.Time) error { return c.Conn.SetWriteDeadline(t) } +func (c *conn) SetWriteDeadline(t time.Time) error { return c.Conn.SetWriteDeadline(t) } diff --git a/internal/packetdumper/conn_test.go b/internal/packetdumper/conn_test.go index 89227ecb..ffd93014 100644 --- a/internal/packetdumper/conn_test.go +++ b/internal/packetdumper/conn_test.go @@ -57,7 +57,7 @@ func TestConnInitialize_CreatesFile(t *testing.T) { defer server.Close() prefix := filepath.Join(t.TempDir(), "capture") - c := &Conn{Prefix: prefix, Conn: client} + c := &conn{Prefix: prefix, Conn: client} require.NoError(t, c.Initialize()) defer cleanupPcapng(t, prefix) @@ -69,7 +69,7 @@ func TestConnWrite(t *testing.T) { defer server.Close() prefix := filepath.Join(t.TempDir(), "capture") - c := &Conn{Prefix: prefix, Conn: client} + c := &conn{Prefix: prefix, Conn: client} require.NoError(t, c.Initialize()) defer cleanupPcapng(t, prefix) @@ -90,7 +90,7 @@ func TestConnRead(t *testing.T) { defer server.Close() prefix := filepath.Join(t.TempDir(), "capture") - c := &Conn{Prefix: prefix, Conn: client} + c := &conn{Prefix: prefix, Conn: client} require.NoError(t, c.Initialize()) defer cleanupPcapng(t, prefix) @@ -110,7 +110,7 @@ func TestConnServerSide(t *testing.T) { defer client.Close() prefix := filepath.Join(t.TempDir(), "capture") - c := &Conn{Prefix: prefix, Conn: server, ServerSide: true} + c := &conn{Prefix: prefix, Conn: server, ServerSide: true} require.NoError(t, c.Initialize()) defer cleanupPcapng(t, prefix) @@ -131,7 +131,7 @@ func TestConnMultipleWriteRead(t *testing.T) { defer server.Close() prefix := filepath.Join(t.TempDir(), "capture") - c := &Conn{Prefix: prefix, Conn: client} + c := &conn{Prefix: prefix, Conn: client} require.NoError(t, c.Initialize()) defer cleanupPcapng(t, prefix) @@ -164,7 +164,7 @@ func TestConnCloseIdempotent(t *testing.T) { defer server.Close() prefix := filepath.Join(t.TempDir(), "capture") - c := &Conn{Prefix: prefix, Conn: client} + c := &conn{Prefix: prefix, Conn: client} require.NoError(t, c.Initialize()) defer cleanupPcapng(t, prefix) @@ -178,7 +178,7 @@ func TestConnDelegatesAddrMethods(t *testing.T) { defer server.Close() prefix := filepath.Join(t.TempDir(), "capture") - c := &Conn{Prefix: prefix, Conn: client} + c := &conn{Prefix: prefix, Conn: client} require.NoError(t, c.Initialize()) defer cleanupPcapng(t, prefix) diff --git a/internal/packetdumper/dial_context.go b/internal/packetdumper/dial_context.go index ec39d84c..d258ed05 100644 --- a/internal/packetdumper/dial_context.go +++ b/internal/packetdumper/dial_context.go @@ -7,26 +7,25 @@ import ( // DialContext is a wrapper around net.Dialer.DialContext that dumps packets to disk. type DialContext struct { - Prefix string - DialContext func(ctx context.Context, network, address string) (net.Conn, error) + Prefix string } // Do mimics net.Dialer.DialContext. func (d *DialContext) Do(ctx context.Context, network, address string) (net.Conn, error) { - conn, err := d.DialContext(ctx, network, address) + netConn, err := (&net.Dialer{}).DialContext(ctx, network, address) if err != nil { return nil, err } - c := &Conn{ + pdConn := &conn{ Prefix: d.Prefix, - Conn: conn, + Conn: netConn, } - err = c.Initialize() + err = pdConn.Initialize() if err != nil { - conn.Close() + netConn.Close() return nil, err } - return c, nil + return pdConn, nil } diff --git a/internal/packetdumper/dial_tls_context.go b/internal/packetdumper/dial_tls_context.go new file mode 100644 index 00000000..dfa383fd --- /dev/null +++ b/internal/packetdumper/dial_tls_context.go @@ -0,0 +1,43 @@ +package packetdumper + +import ( + "context" + "crypto/tls" + "net" +) + +// DialTLSContext provides the DialTLSContext function. +type DialTLSContext struct { + DialContext func(ctx context.Context, network, addr string) (net.Conn, error) + TLSConfig *tls.Config +} + +// Do provides DialTLSContext. +func (t *DialTLSContext) Do(ctx context.Context, network, addr string) (net.Conn, error) { + netConn, err := t.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + // clone TLS config and fill ServerName if empty. + // this is the same behavior of http.Client. + // https://cs.opensource.google/go/go/+/master:src/net/http/transport.go;l=1754;drc=a4b534f5e42fe58d58c0ff0562d76680cedb0466 + tlsConfig := t.TLSConfig + + if tlsConfig == nil { + tlsConfig = &tls.Config{} + } else { + tlsConfig = tlsConfig.Clone() + } + + if tlsConfig.ServerName == "" { + host, _, _ := net.SplitHostPort(addr) + tlsConfig.ServerName = host + } + + pdConn := netConn.(*conn) + pdConn.expectingSecrets = 4 + tlsConfig.KeyLogWriter = &connKeyLogWriter{c: pdConn} + + return tls.Client(netConn, tlsConfig), nil +} diff --git a/internal/packetdumper/listen.go b/internal/packetdumper/listen.go index 2f88d651..feef1786 100644 --- a/internal/packetdumper/listen.go +++ b/internal/packetdumper/listen.go @@ -7,18 +7,17 @@ import ( // Listen is a wrapper around net.Listen that dumps packets to disk. type Listen struct { Prefix string - Listen func(network, address string) (net.Listener, error) } // Do mimics net.Listen. func (l *Listen) Do(network, address string) (net.Listener, error) { - ln, err := l.Listen(network, address) + netListener, err := net.Listen(network, address) if err != nil { return nil, err } - return &Listener{ + return &listener{ Prefix: l.Prefix, - Listener: ln, + Listener: netListener, }, nil } diff --git a/internal/packetdumper/listen_packet.go b/internal/packetdumper/listen_packet.go index 2e9824a0..f90690e7 100644 --- a/internal/packetdumper/listen_packet.go +++ b/internal/packetdumper/listen_packet.go @@ -6,25 +6,25 @@ import ( // ListenPacket is a wrapper around net.ListenPacket that dumps packets to disk. type ListenPacket struct { - Prefix string - ListenPacket func(network, address string) (net.PacketConn, error) + Prefix string } -// Do mimics net.ListenPacket +// Do mimics net.ListenPacket. func (l *ListenPacket) Do(network, address string) (net.PacketConn, error) { - pc, err := l.ListenPacket(network, address) + netPacketConn, err := net.ListenPacket(network, address) if err != nil { return nil, err } - d := &PacketConn{ + pdPacketConn := &packetConn{ Prefix: l.Prefix, - PacketConn: pc, + PacketConn: netPacketConn, } - err = d.Initialize() + err = pdPacketConn.Initialize() if err != nil { + netPacketConn.Close() return nil, err } - return d, nil + return pdPacketConn, nil } diff --git a/internal/packetdumper/listener.go b/internal/packetdumper/listener.go index 87bcd9ef..8ef9a576 100644 --- a/internal/packetdumper/listener.go +++ b/internal/packetdumper/listener.go @@ -2,41 +2,41 @@ package packetdumper import "net" -var _ net.Listener = (*Listener)(nil) +var _ net.Listener = (*listener)(nil) -// Listener is a wrapper around net.Listener that dumps packets to disk. -type Listener struct { +// listener is a wrapper around a net.Listener that dumps packets to disk. +type listener struct { Prefix string Listener net.Listener } // Accept implements net.Listener. -func (l *Listener) Accept() (net.Conn, error) { - conn, err := l.Listener.Accept() +func (l *listener) Accept() (net.Conn, error) { + netConn, err := l.Listener.Accept() if err != nil { return nil, err } - cd := &Conn{ + pdConn := &conn{ Prefix: l.Prefix, - Conn: conn, + Conn: netConn, ServerSide: true, } - err = cd.Initialize() + err = pdConn.Initialize() if err != nil { - conn.Close() //nolint:errcheck + netConn.Close() //nolint:errcheck return nil, err } - return cd, nil + return pdConn, nil } // Close implements net.Listener. -func (l *Listener) Close() error { +func (l *listener) Close() error { return l.Listener.Close() } // Addr implements net.Listener. -func (l *Listener) Addr() net.Addr { +func (l *listener) Addr() net.Addr { return l.Listener.Addr() } diff --git a/internal/packetdumper/packet_conn.go b/internal/packetdumper/packet_conn.go index b6805fcd..1ed25b00 100644 --- a/internal/packetdumper/packet_conn.go +++ b/internal/packetdumper/packet_conn.go @@ -14,7 +14,7 @@ import ( "github.com/google/uuid" ) -var _ net.PacketConn = (*PacketConn)(nil) +var _ net.PacketConn = (*packetConn)(nil) type extendedPacketConn interface { net.PacketConn @@ -28,8 +28,8 @@ type packetDumpEntry struct { src, dst *net.UDPAddr } -// PacketConn is a wrapper around net.PacketConn that dumps packets to disk. -type PacketConn struct { +// packetConn is a wrapper around net.PacketConn that dumps packets to disk. +type packetConn struct { Prefix string PacketConn net.PacketConn @@ -42,8 +42,8 @@ type PacketConn struct { done chan struct{} } -// Initialize initializes PacketConn. -func (c *PacketConn) Initialize() error { +// Initialize initializes packetConn. +func (c *packetConn) Initialize() error { var err error c.f, err = os.Create(fmt.Sprintf("%s_%d_%s.pcapng", c.Prefix, time.Now().UnixNano(), uuid.New().String())) if err != nil { @@ -66,7 +66,7 @@ func (c *PacketConn) Initialize() error { } // Close implements net.PacketConn. -func (c *PacketConn) Close() error { +func (c *packetConn) Close() error { c.once.Do(func() { close(c.terminated) }) @@ -74,7 +74,7 @@ func (c *PacketConn) Close() error { return c.PacketConn.Close() } -func (c *PacketConn) run() { +func (c *packetConn) run() { defer close(c.done) defer c.f.Close() @@ -97,7 +97,7 @@ func (c *PacketConn) run() { } } -func (c *PacketConn) writePacket(ntp time.Time, src, dst *net.UDPAddr, payload []byte) { +func (c *packetConn) writePacket(ntp time.Time, src, dst *net.UDPAddr, payload []byte) { eth := &layers.Ethernet{ SrcMAC: net.HardwareAddr{0, 0, 0, 0, 0, 0}, DstMAC: net.HardwareAddr{0, 0, 0, 0, 0, 0}, @@ -130,7 +130,7 @@ func (c *PacketConn) writePacket(ntp time.Time, src, dst *net.UDPAddr, payload [ }, raw) } -func (c *PacketConn) enqueue(e packetDumpEntry) { +func (c *packetConn) enqueue(e packetDumpEntry) { select { case c.queue <- e: case <-c.terminated: @@ -138,7 +138,7 @@ func (c *PacketConn) enqueue(e packetDumpEntry) { } // ReadFrom implements net.PacketConn. -func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { +func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { n, addr, err = c.PacketConn.ReadFrom(p) if n != 0 { @@ -157,7 +157,7 @@ func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } // WriteTo implements net.PacketConn. -func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { +func (c *packetConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { n, err = c.PacketConn.WriteTo(p, addr) if err == nil { @@ -176,23 +176,23 @@ func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { } // LocalAddr implements net.PacketConn. -func (c *PacketConn) LocalAddr() net.Addr { return c.PacketConn.LocalAddr() } +func (c *packetConn) LocalAddr() net.Addr { return c.PacketConn.LocalAddr() } // SetDeadline implements net.PacketConn. -func (c *PacketConn) SetDeadline(t time.Time) error { return c.PacketConn.SetDeadline(t) } +func (c *packetConn) SetDeadline(t time.Time) error { return c.PacketConn.SetDeadline(t) } // SetReadDeadline implements net.PacketConn. -func (c *PacketConn) SetReadDeadline(t time.Time) error { return c.PacketConn.SetReadDeadline(t) } +func (c *packetConn) SetReadDeadline(t time.Time) error { return c.PacketConn.SetReadDeadline(t) } // SetWriteDeadline implements net.PacketConn. -func (c *PacketConn) SetWriteDeadline(t time.Time) error { return c.PacketConn.SetWriteDeadline(t) } +func (c *packetConn) SetWriteDeadline(t time.Time) error { return c.PacketConn.SetWriteDeadline(t) } // SetReadBuffer implements extendedPacketConn. -func (c *PacketConn) SetReadBuffer(bytes int) error { +func (c *packetConn) SetReadBuffer(bytes int) error { return c.PacketConn.(extendedPacketConn).SetReadBuffer(bytes) } // SyscallConn implements extendedPacketConn. -func (c *PacketConn) SyscallConn() (syscall.RawConn, error) { +func (c *packetConn) SyscallConn() (syscall.RawConn, error) { return c.PacketConn.(extendedPacketConn).SyscallConn() } diff --git a/internal/packetdumper/packet_conn_test.go b/internal/packetdumper/packet_conn_test.go index df43baf5..0f6dbe1b 100644 --- a/internal/packetdumper/packet_conn_test.go +++ b/internal/packetdumper/packet_conn_test.go @@ -51,7 +51,7 @@ func TestPacketConnInitialize_CreatesFile(t *testing.T) { defer server.Close() //nolint:errcheck prefix := filepath.Join(t.TempDir(), "capture") - c := &PacketConn{Prefix: prefix, PacketConn: client} + c := &packetConn{Prefix: prefix, PacketConn: client} require.NoError(t, c.Initialize()) defer cleanupPcapngPacket(t, prefix) @@ -63,7 +63,7 @@ func TestPacketConnWriteTo(t *testing.T) { defer server.Close() //nolint:errcheck prefix := filepath.Join(t.TempDir(), "capture") - c := &PacketConn{Prefix: prefix, PacketConn: client} + c := &packetConn{Prefix: prefix, PacketConn: client} require.NoError(t, c.Initialize()) defer cleanupPcapngPacket(t, prefix) @@ -85,7 +85,7 @@ func TestPacketConnReadFrom(t *testing.T) { defer server.Close() //nolint:errcheck prefix := filepath.Join(t.TempDir(), "capture") - c := &PacketConn{Prefix: prefix, PacketConn: client} + c := &packetConn{Prefix: prefix, PacketConn: client} require.NoError(t, c.Initialize()) defer cleanupPcapngPacket(t, prefix) @@ -107,7 +107,7 @@ func TestPacketConnMultipleWriteRead(t *testing.T) { defer server.Close() //nolint:errcheck prefix := filepath.Join(t.TempDir(), "capture") - c := &PacketConn{Prefix: prefix, PacketConn: client} + c := &packetConn{Prefix: prefix, PacketConn: client} require.NoError(t, c.Initialize()) defer cleanupPcapngPacket(t, prefix) @@ -150,7 +150,7 @@ func TestPacketConnCloseIdempotent(t *testing.T) { defer server.Close() //nolint:errcheck prefix := filepath.Join(t.TempDir(), "capture") - c := &PacketConn{Prefix: prefix, PacketConn: client} + c := &packetConn{Prefix: prefix, PacketConn: client} require.NoError(t, c.Initialize()) defer cleanupPcapngPacket(t, prefix) @@ -164,7 +164,7 @@ func TestPacketConnDelegatesAddrMethods(t *testing.T) { defer server.Close() //nolint:errcheck prefix := filepath.Join(t.TempDir(), "capture") - c := &PacketConn{Prefix: prefix, PacketConn: client} + c := &packetConn{Prefix: prefix, PacketConn: client} require.NoError(t, c.Initialize()) defer cleanupPcapngPacket(t, prefix) @@ -182,7 +182,7 @@ func TestPacketConnReadFromRecordsSource(t *testing.T) { defer server.Close() //nolint:errcheck prefix := filepath.Join(t.TempDir(), "capture") - c := &PacketConn{Prefix: prefix, PacketConn: client} + c := &packetConn{Prefix: prefix, PacketConn: client} require.NoError(t, c.Initialize()) defer cleanupPcapngPacket(t, prefix) diff --git a/internal/packetdumper/tls_listen.go b/internal/packetdumper/tls_listen.go new file mode 100644 index 00000000..05c753a5 --- /dev/null +++ b/internal/packetdumper/tls_listen.go @@ -0,0 +1,24 @@ +package packetdumper + +import ( + "crypto/tls" + "net" +) + +// TLSListen provides a tls.Listen that also dumps TLS master secrets to disk. +type TLSListen struct { + Listen func(network string, address string) (net.Listener, error) +} + +// Do mimics tls.Listen. +func (l *TLSListen) Do(network string, laddr string, tlsConfig *tls.Config) (net.Listener, error) { + netListener, err := l.Listen(network, laddr) + if err != nil { + return nil, err + } + + return &tlsListener{ + Listener: netListener, + TLSConfig: tlsConfig, + }, nil +} diff --git a/internal/packetdumper/tls_listener.go b/internal/packetdumper/tls_listener.go new file mode 100644 index 00000000..ad00815a --- /dev/null +++ b/internal/packetdumper/tls_listener.go @@ -0,0 +1,48 @@ +package packetdumper + +import ( + "crypto/tls" + "net" + "time" +) + +type connKeyLogWriter struct { + c *conn +} + +func (w *connKeyLogWriter) Write(p []byte) (int, error) { + w.c.enqueue(dumpEntry{ + ntp: time.Now(), + data: append([]byte(nil), p...), + direction: dirSecret, + }) + + return len(p), nil +} + +type tlsListener struct { + Listener net.Listener + TLSConfig *tls.Config +} + +func (l *tlsListener) Close() error { + return l.Listener.Close() +} + +func (l *tlsListener) Addr() net.Addr { + return l.Listener.Addr() +} + +func (l *tlsListener) Accept() (net.Conn, error) { + netConn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + + tlsConfig := l.TLSConfig.Clone() + pdConn := netConn.(*conn) + pdConn.expectingSecrets = 4 + tlsConfig.KeyLogWriter = &connKeyLogWriter{c: pdConn} + + return tls.Server(netConn, tlsConfig), nil +} diff --git a/internal/protocols/httpp/server.go b/internal/protocols/httpp/server.go index 85a2d0f4..32e91cb2 100644 --- a/internal/protocols/httpp/server.go +++ b/internal/protocols/httpp/server.go @@ -15,6 +15,7 @@ import ( "github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/packetdumper" "github.com/bluenviron/mediamtx/internal/restrictnetwork" + "golang.org/x/net/http2" ) type nilWriter struct{} @@ -59,6 +60,7 @@ func (s *Server) Initialize() error { } var tlsConfig *tls.Config + if s.Encryption { if s.ServerCert == "" { return fmt.Errorf("server cert is missing") @@ -93,23 +95,6 @@ func (s *Server) Initialize() error { os.Remove(address) } - var err error - s.ln, err = net.Listen(network, address) - if err != nil { - return err - } - - if s.DumpPackets { - s.ln = &packetdumper.Listener{ - Prefix: s.DumpPacketsPrefix, - Listener: s.ln, - } - } - - if network == "unix" { - os.Chmod(address, 0o755) //nolint:errcheck - } - h := s.Handler h = &handlerOrigin{h, s.AllowOrigins} h = &handlerServerHeader{h} @@ -134,11 +119,48 @@ func (s *Server) Initialize() error { } if tlsConfig != nil { - go s.inner.ServeTLS(s.ln, "", "") - } else { - go s.inner.Serve(s.ln) + err := http2.ConfigureServer(s.inner, &http2.Server{}) + if err != nil { + return err + } } + var listen func(network string, address string) (net.Listener, error) + var tlsListen func(network string, laddr string, config *tls.Config) (net.Listener, error) + + if s.DumpPackets { + listen = (&packetdumper.Listen{ + Prefix: s.DumpPacketsPrefix, + }).Do + + tlsListen = (&packetdumper.TLSListen{ + Listen: listen, + }).Do + } else { + listen = net.Listen + tlsListen = tls.Listen + } + + if tlsConfig != nil { + var err error + s.ln, err = tlsListen(network, address, tlsConfig) + if err != nil { + return err + } + } else { + var err error + s.ln, err = listen(network, address) + if err != nil { + return err + } + } + + if network == "unix" { + os.Chmod(address, 0o755) //nolint:errcheck + } + + go s.inner.Serve(s.ln) + return nil } diff --git a/internal/protocols/httpp/server_test.go b/internal/protocols/httpp/server_test.go index 6359e1f3..3c71583e 100644 --- a/internal/protocols/httpp/server_test.go +++ b/internal/protocols/httpp/server_test.go @@ -26,8 +26,9 @@ func TestUnixSocket(t *testing.T) { err := s.Initialize() require.NoError(t, err) - _, err = os.Stat("http.sock") + info, err := os.Stat("http.sock") require.NoError(t, err) + require.Equal(t, os.FileMode(0o755), info.Mode().Perm()) conn, err := net.Dial("unix", "http.sock") require.NoError(t, err) diff --git a/internal/protocols/tls/make_config.go b/internal/protocols/tls/make_config.go index 7dfa16b8..47458354 100644 --- a/internal/protocols/tls/make_config.go +++ b/internal/protocols/tls/make_config.go @@ -9,15 +9,11 @@ import ( "strings" ) -// MakeConfig returns a tls.Config with: -// - server name indicator (SNI) support -// - fingerprint support -func MakeConfig(serverName string, fingerprint string) *tls.Config { - conf := &tls.Config{ - ServerName: serverName, - } - +// MakeConfig returns a tls.Config with fingerprint support. +func MakeConfig(fingerprint string) *tls.Config { if fingerprint != "" { + conf := &tls.Config{} + fingerprintLower := strings.ToLower(fingerprint) conf.InsecureSkipVerify = true conf.VerifyConnection = func(cs tls.ConnectionState) error { @@ -32,7 +28,9 @@ func MakeConfig(serverName string, fingerprint string) *tls.Config { return nil } + + return conf } - return conf + return nil } diff --git a/internal/protocols/tls/make_config_test.go b/internal/protocols/tls/make_config_test.go index bd45ebf8..facc5752 100644 --- a/internal/protocols/tls/make_config_test.go +++ b/internal/protocols/tls/make_config_test.go @@ -60,44 +60,6 @@ y++U32uuSFiXDcSLarfIsE992MEJLSAynbF1Rsgsr3gXbGiuToJRyxbIeVy7gwzD -----END RSA PRIVATE KEY----- `) -func TestMakeConfigSNI(t *testing.T) { - l, err := net.Listen("tcp", "localhost:8556") - require.NoError(t, err) - defer l.Close() - - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - - go func() { - defer close(serverDone) - - nconn, err2 := l.Accept() - require.NoError(t, err2) - defer nconn.Close() - - cert, err2 := tls.X509KeyPair(testTLSCertPub, testTLSCertKey) - require.NoError(t, err2) - - tnconn := tls.Server(nconn, &tls.Config{ - Certificates: []tls.Certificate{cert}, - InsecureSkipVerify: true, - VerifyConnection: func(cs tls.ConnectionState) error { - require.Equal(t, "myhost", cs.ServerName) - return nil - }, - }) - - err2 = tnconn.Handshake() - require.EqualError(t, err2, "remote error: tls: bad certificate") - }() - - conf := MakeConfig("myhost", "") - - _, err = tls.Dial("tcp", "localhost:8556", conf) - require.EqualError(t, err, "tls: failed to verify certificate: x509: "+ - "certificate is not valid for any names, but wanted to match myhost") -} - func TestMakeConfigFingerprint(t *testing.T) { l, err := net.Listen("tcp", "localhost:8556") require.NoError(t, err) @@ -119,17 +81,13 @@ func TestMakeConfigFingerprint(t *testing.T) { tnconn := tls.Server(nconn, &tls.Config{ Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true, - VerifyConnection: func(cs tls.ConnectionState) error { - require.Equal(t, "myhost", cs.ServerName) - return nil - }, }) err2 = tnconn.Handshake() require.NoError(t, err2) }() - conf := MakeConfig("myhost", "33949e05fffb5ff3e8aa16f8213a6251b4d9363804ba53233c4da9a46d6f2739") + conf := MakeConfig("33949e05fffb5ff3e8aa16f8213a6251b4d9363804ba53233c4da9a46d6f2739") conn, err := tls.Dial("tcp", "localhost:8556", conf) require.NoError(t, err) diff --git a/internal/protocols/unix/listener.go b/internal/protocols/unix/listener.go index 6963b901..02db5a20 100644 --- a/internal/protocols/unix/listener.go +++ b/internal/protocols/unix/listener.go @@ -11,7 +11,8 @@ import ( // Listener is a listener on a Unix socket. type Listener struct { - Path string + Path string + Listen func(network string, address string) (net.Listener, error) l net.Listener c net.Conn @@ -25,11 +26,14 @@ func (l *Listener) Initialize() error { if l.Path == "" { return fmt.Errorf("invalid unix path") } + if l.Listen == nil { + l.Listen = net.Listen + } os.Remove(l.Path) var err error - l.l, err = net.Listen("unix", l.Path) + l.l, err = l.Listen("unix", l.Path) if err != nil { return err } diff --git a/internal/servers/hls/http_server.go b/internal/servers/hls/http_server.go index 863e9cea..fbb0f914 100644 --- a/internal/servers/hls/http_server.go +++ b/internal/servers/hls/http_server.go @@ -58,11 +58,18 @@ func (s *httpServer) initialize() error { router.Use(s.onRequest) + var proto string + if s.encryption { + proto = "hlss" + } else { + proto = "hls" + } + s.inner = &httpp.Server{ Address: s.address, AllowOrigins: s.allowOrigins, DumpPackets: s.dumpPackets, - DumpPacketsPrefix: "hls_server_conn", + DumpPacketsPrefix: proto + "_server_conn", ReadTimeout: time.Duration(s.readTimeout), WriteTimeout: time.Duration(s.writeTimeout), Encryption: s.encryption, diff --git a/internal/servers/rtmp/conn.go b/internal/servers/rtmp/conn.go index 6fcd4e2a..e0b368f5 100644 --- a/internal/servers/rtmp/conn.go +++ b/internal/servers/rtmp/conn.go @@ -25,7 +25,7 @@ import ( type conn struct { parentCtx context.Context - isTLS bool + encryption bool rtspAddress string readTimeout conf.Duration writeTimeout conf.Duration @@ -297,7 +297,7 @@ func (c *conn) runPublish() error { func (c *conn) APIReaderDescribe() *defs.APIPathReader { return &defs.APIPathReader{ Type: func() defs.APIPathReaderType { - if c.isTLS { + if c.encryption { return defs.APIPathReaderTypeRTMPSConn } return defs.APIPathReaderTypeRTMPConn @@ -310,7 +310,7 @@ func (c *conn) APIReaderDescribe() *defs.APIPathReader { func (c *conn) APISourceDescribe() *defs.APIPathSource { return &defs.APIPathSource{ Type: func() defs.APIPathSourceType { - if c.isTLS { + if c.encryption { return defs.APIPathSourceTypeRTMPSConn } return defs.APIPathSourceTypeRTMPConn diff --git a/internal/servers/rtmp/server.go b/internal/servers/rtmp/server.go index 59e150a5..826437a5 100644 --- a/internal/servers/rtmp/server.go +++ b/internal/servers/rtmp/server.go @@ -77,7 +77,7 @@ type Server struct { DumpPackets bool ReadTimeout conf.Duration WriteTimeout conf.Duration - IsTLS bool + Encryption bool ServerCert string ServerKey string RTSPAddress string @@ -107,32 +107,51 @@ type Server struct { // Initialize initializes the server. func (s *Server) Initialize() error { - var err error - s.ln, err = net.Listen(restrictnetwork.Restrict("tcp", s.Address)) - if err != nil { - return err - } + var listen func(network string, address string) (net.Listener, error) + var tlsListen func(network string, laddr string, config *tls.Config) (net.Listener, error) if s.DumpPackets { - s.ln = &packetdumper.Listener{ - Prefix: "rtmp_server_conn", - Listener: s.ln, + var proto string + if s.Encryption { + proto = "rtmps" + } else { + proto = "rtmp" } + + listen = (&packetdumper.Listen{ + Prefix: proto + "_server_conn", + }).Do + + tlsListen = (&packetdumper.TLSListen{ + Listen: listen, + }).Do + } else { + listen = net.Listen + tlsListen = tls.Listen } - if s.IsTLS { + if s.Encryption { s.loader = &certloader.CertLoader{ CertPath: s.ServerCert, KeyPath: s.ServerKey, Parent: s.Parent, } - err = s.loader.Initialize() + err := s.loader.Initialize() if err != nil { - s.ln.Close() return err } - s.ln = tls.NewListener(s.ln, &tls.Config{GetCertificate: s.loader.GetCertificate()}) + net, addr := restrictnetwork.Restrict("tcp", s.Address) + s.ln, err = tlsListen(net, addr, &tls.Config{GetCertificate: s.loader.GetCertificate()}) + if err != nil { + return err + } + } else { + var err error + s.ln, err = listen(restrictnetwork.Restrict("tcp", s.Address)) + if err != nil { + return err + } } s.ctx, s.ctxCancel = context.WithCancel(context.Background()) @@ -146,7 +165,7 @@ func (s *Server) Initialize() error { s.chAPIConnsKick = make(chan serverAPIConnsKickReq) str := "listener opened on " + s.Address - if s.IsTLS { + if s.Encryption { str += " (TCP/RTMPS)" } else { str += " (TCP/RTMP)" @@ -164,7 +183,7 @@ func (s *Server) Initialize() error { go s.run() if !interfaceIsEmpty(s.Metrics) { - if s.IsTLS { + if s.Encryption { s.Metrics.SetRTMPSServer(s) } else { s.Metrics.SetRTMPServer(s) @@ -177,7 +196,7 @@ func (s *Server) Initialize() error { // Log implements logger.Writer. func (s *Server) Log(level logger.Level, format string, args ...any) { label := func() string { - if s.IsTLS { + if s.Encryption { return "RTMPS" } return "RTMP" @@ -190,7 +209,7 @@ func (s *Server) Close() { s.Log(logger.Info, "listener is closing") if !interfaceIsEmpty((s.Metrics)) { - if s.IsTLS { + if s.Encryption { s.Metrics.SetRTMPSServer(nil) } else { s.Metrics.SetRTMPServer(nil) @@ -218,7 +237,7 @@ outer: case nconn := <-s.chNewConn: c := &conn{ parentCtx: s.ctx, - isTLS: s.IsTLS, + encryption: s.Encryption, rtspAddress: s.RTSPAddress, readTimeout: s.ReadTimeout, writeTimeout: s.WriteTimeout, diff --git a/internal/servers/rtmp/server_test.go b/internal/servers/rtmp/server_test.go index d3f6a0fa..9b14257f 100644 --- a/internal/servers/rtmp/server_test.go +++ b/internal/servers/rtmp/server_test.go @@ -130,7 +130,7 @@ func TestServerPublish(t *testing.T) { Address: "127.0.0.1:1939", ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), - IsTLS: encrypt == "tls", + Encryption: encrypt == "tls", ServerCert: serverCertFpath, ServerKey: serverKeyFpath, RTSPAddress: "", @@ -266,7 +266,7 @@ func TestServerRead(t *testing.T) { Address: "127.0.0.1:1939", ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), - IsTLS: encrypt == "tls", + Encryption: encrypt == "tls", ServerCert: serverCertFpath, ServerKey: serverKeyFpath, RTSPAddress: "", diff --git a/internal/servers/rtsp/conn.go b/internal/servers/rtsp/conn.go index b5bf14b4..e4c48520 100644 --- a/internal/servers/rtsp/conn.go +++ b/internal/servers/rtsp/conn.go @@ -54,7 +54,7 @@ type connParent interface { } type conn struct { - isTLS bool + encryption bool rtspAddress string authMethods []rtspauth.VerifyMethod readTimeout conf.Duration @@ -87,7 +87,7 @@ func (c *conn) initialize() { RTSPAddress: c.rtspAddress, Desc: defs.APIPathReader{ Type: func() defs.APIPathReaderType { - if c.isTLS { + if c.encryption { return defs.APIPathReaderTypeRTSPSConn } return defs.APIPathReaderTypeRTSPConn @@ -192,7 +192,7 @@ func (c *conn) onDescribe(ctx *gortsplib.ServerHandlerOnDescribeCtx, } var strm *gortsplib.ServerStream - if !c.isTLS { + if !c.encryption { strm = res.Stream.RTSPStream(c.rserver) } else { strm = res.Stream.RTSPSStream(c.rserver) diff --git a/internal/servers/rtsp/server.go b/internal/servers/rtsp/server.go index 58b89c6a..8e026e0a 100644 --- a/internal/servers/rtsp/server.go +++ b/internal/servers/rtsp/server.go @@ -6,7 +6,6 @@ import ( "crypto/tls" "errors" "fmt" - "net" "reflect" "sort" "strings" @@ -102,7 +101,7 @@ type Server struct { MulticastIPRange string MulticastRTPPort int MulticastRTCPPort int - IsTLS bool + Encryption bool ServerCert string ServerKey string RTSPAddress string @@ -153,7 +152,7 @@ func (s *Server) Initialize() error { s.srv.MulticastRTCPPort = s.MulticastRTCPPort } - if s.IsTLS { + if s.Encryption { s.loader = &certloader.CertLoader{ CertPath: s.ServerCert, KeyPath: s.ServerKey, @@ -168,14 +167,23 @@ func (s *Server) Initialize() error { } if s.DumpPackets { + var proto string + if s.Encryption { + proto = "rtsps" + } else { + proto = "rtsp" + } + s.srv.Listen = (&packetdumper.Listen{ - Prefix: "rtsp_server_conn", - Listen: net.Listen, + Prefix: proto + "_server_conn", }).Do s.srv.ListenPacket = (&packetdumper.ListenPacket{ - Prefix: "rtsp_server_packetconn", - ListenPacket: net.ListenPacket, + Prefix: proto + "_server_packet_conn", + }).Do + + s.srv.TLSListen = (&packetdumper.TLSListen{ + Listen: s.srv.Listen, }).Do } @@ -190,7 +198,7 @@ func (s *Server) Initialize() error { go s.run() if !interfaceIsEmpty(s.Metrics) { - if s.IsTLS { + if s.Encryption { s.Metrics.SetRTSPSServer(s) } else { s.Metrics.SetRTSPServer(s) @@ -203,7 +211,7 @@ func (s *Server) Initialize() error { // Log implements logger.Writer. func (s *Server) Log(level logger.Level, format string, args ...any) { label := func() string { - if s.IsTLS { + if s.Encryption { return "RTSPS" } return "RTSP" @@ -216,7 +224,7 @@ func (s *Server) Close() { s.Log(logger.Info, "listener is closing") if !interfaceIsEmpty(s.Metrics) { - if s.IsTLS { + if s.Encryption { s.Metrics.SetRTSPSServer(nil) } else { s.Metrics.SetRTSPServer(nil) @@ -257,7 +265,7 @@ outer: // OnConnOpen implements gortsplib.ServerHandlerOnConnOpen. func (s *Server) OnConnOpen(ctx *gortsplib.ServerHandlerOnConnOpenCtx) { c := &conn{ - isTLS: s.IsTLS, + encryption: s.Encryption, rtspAddress: s.RTSPAddress, authMethods: s.AuthMethods, readTimeout: s.ReadTimeout, @@ -302,7 +310,7 @@ func (s *Server) OnResponse(sc *gortsplib.ServerConn, res *base.Response) { // OnSessionOpen implements gortsplib.ServerHandlerOnSessionOpen. func (s *Server) OnSessionOpen(ctx *gortsplib.ServerHandlerOnSessionOpenCtx) { se := &session{ - isTLS: s.IsTLS, + encryption: s.Encryption, transports: s.Transports, rsession: ctx.Session, rconn: ctx.Conn, diff --git a/internal/servers/rtsp/session.go b/internal/servers/rtsp/session.go index 8212e9ce..e3bd488b 100644 --- a/internal/servers/rtsp/session.go +++ b/internal/servers/rtsp/session.go @@ -58,7 +58,7 @@ type sessionParent interface { } type session struct { - isTLS bool + encryption bool transports conf.RTSPTransports rsession *gortsplib.ServerSession rconn *gortsplib.ServerConn @@ -229,7 +229,7 @@ func (s *session) onAnnounce(c *conn, ctx *gortsplib.ServerHandlerOnAnnounceCtx) } func (s *session) rtspStream() *gortsplib.ServerStream { - if !s.isTLS { + if !s.encryption { return s.stream.RTSPStream(s.rserver) } return s.stream.RTSPSStream(s.rserver) @@ -439,7 +439,7 @@ func (s *session) onPause(_ *gortsplib.ServerHandlerOnPauseCtx) (*base.Response, func (s *session) APIReaderDescribe() *defs.APIPathReader { return &defs.APIPathReader{ Type: func() defs.APIPathReaderType { - if s.isTLS { + if s.encryption { return defs.APIPathReaderTypeRTSPSSession } return defs.APIPathReaderTypeRTSPSession @@ -452,7 +452,7 @@ func (s *session) APIReaderDescribe() *defs.APIPathReader { func (s *session) APISourceDescribe() *defs.APIPathSource { return &defs.APIPathSource{ Type: func() defs.APIPathSourceType { - if s.isTLS { + if s.encryption { return defs.APIPathSourceTypeRTSPSSession } return defs.APIPathSourceTypeRTSPSession diff --git a/internal/servers/webrtc/http_server.go b/internal/servers/webrtc/http_server.go index 0adc4692..dc037713 100644 --- a/internal/servers/webrtc/http_server.go +++ b/internal/servers/webrtc/http_server.go @@ -96,11 +96,18 @@ func (s *httpServer) initialize() error { router.Use(s.onRequest) + var proto string + if s.encryption { + proto = "webrtcs" + } else { + proto = "webrtc" + } + s.inner = &httpp.Server{ Address: s.address, AllowOrigins: s.allowOrigins, DumpPackets: s.dumpPackets, - DumpPacketsPrefix: "webrtc_server_conn", + DumpPacketsPrefix: proto + "_server_conn", ReadTimeout: time.Duration(s.readTimeout), WriteTimeout: time.Duration(s.writeTimeout), Encryption: s.encryption, diff --git a/internal/staticsources/hls/source.go b/internal/staticsources/hls/source.go index 01439ef0..663b801b 100644 --- a/internal/staticsources/hls/source.go +++ b/internal/staticsources/hls/source.go @@ -2,10 +2,9 @@ package hls import ( - "net" "net/http" "net/http/cookiejar" - "net/url" + "strings" "time" "github.com/bluenviron/gohlslib/v2" @@ -17,7 +16,7 @@ import ( "github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/packetdumper" "github.com/bluenviron/mediamtx/internal/protocols/hls" - "github.com/bluenviron/mediamtx/internal/protocols/tls" + ptls "github.com/bluenviron/mediamtx/internal/protocols/tls" "github.com/bluenviron/mediamtx/internal/stream" ) @@ -62,25 +61,30 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { decodeErrors.Start() defer decodeErrors.Stop() - u, err := url.Parse(params.ResolvedSource) - if err != nil { - return err - } + tr := &http.Transport{} + defer tr.CloseIdleConnections() - dialContext := (&net.Dialer{}).DialContext + tlsConfig := ptls.MakeConfig(params.Conf.SourceFingerprint) if s.DumpPackets { - dialContext = (&packetdumper.DialContext{ - Prefix: "hls_source_conn", - DialContext: dialContext, - }).Do - } + var proto string + if strings.HasPrefix(params.ResolvedSource, "https") { + proto = "hlss" + } else { + proto = "hls" + } - tr := &http.Transport{ - DialContext: dialContext, - TLSClientConfig: tls.MakeConfig(u.Hostname(), params.Conf.SourceFingerprint), + tr.DialContext = (&packetdumper.DialContext{ + Prefix: proto + "_source_conn", + }).Do + + tr.DialTLSContext = (&packetdumper.DialTLSContext{ + DialContext: tr.DialContext, + TLSConfig: tlsConfig, + }).Do + } else { + tr.TLSClientConfig = tlsConfig } - defer tr.CloseIdleConnections() jar, _ := cookiejar.New(nil) @@ -128,7 +132,7 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { }, } - err = c.Start() + err := c.Start() if err != nil { return err } diff --git a/internal/staticsources/mpegts/source.go b/internal/staticsources/mpegts/source.go index 62f62e40..a0fef765 100644 --- a/internal/staticsources/mpegts/source.go +++ b/internal/staticsources/mpegts/source.go @@ -56,6 +56,13 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { l := &unix.Listener{ Path: params.Path, } + + if s.DumpPackets { + l.Listen = (&packetdumper.Listen{ + Prefix: "mpegts_source_unix_conn", + }).Do + } + err = l.Initialize() if err != nil { return err @@ -68,36 +75,20 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { udpReadBufferSize = *params.Conf.MPEGTSUDPReadBufferSize } - listenPacket := net.ListenPacket - - if s.DumpPackets { - listenPacket = func(network, address string) (net.PacketConn, error) { - pc, err2 := net.ListenPacket(network, address) - if err2 != nil { - return nil, err2 - } - - d := &packetdumper.PacketConn{ - Prefix: "mpegts_source_packetconn", - PacketConn: pc, - } - err2 = d.Initialize() - if err2 != nil { - return nil, err2 - } - - return d, nil - } - } - params := udp.URLToParams(u) l := &udp.Listener{ Address: params.Address, Source: params.Source, IntfName: params.IntfName, UDPReadBufferSize: int(udpReadBufferSize), - ListenPacket: listenPacket, } + + if s.DumpPackets { + l.ListenPacket = (&packetdumper.ListenPacket{ + Prefix: "mpegts_source_packet_conn", + }).Do + } + err = l.Initialize() if err != nil { return err diff --git a/internal/staticsources/rtmp/source.go b/internal/staticsources/rtmp/source.go index 6c44bfdd..5e4060d9 100644 --- a/internal/staticsources/rtmp/source.go +++ b/internal/staticsources/rtmp/source.go @@ -16,7 +16,7 @@ import ( "github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/packetdumper" "github.com/bluenviron/mediamtx/internal/protocols/rtmp" - "github.com/bluenviron/mediamtx/internal/protocols/tls" + ptls "github.com/bluenviron/mediamtx/internal/protocols/tls" "github.com/bluenviron/mediamtx/internal/stream" ) @@ -58,23 +58,28 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { } } - dialContext := (&net.Dialer{}).DialContext - - if s.DumpPackets { - dialContext = (&packetdumper.DialContext{ - Prefix: "rtmp_source_conn", - DialContext: dialContext, - }).Do - } - connectCtx, connectCtxCancel := context.WithTimeout(params.Context, time.Duration(s.ReadTimeout)) conn := &gortmplib.Client{ - URL: u, - TLSConfig: tls.MakeConfig(u.Hostname(), params.Conf.SourceFingerprint), - Publish: false, - DialContext: dialContext, + URL: u, + Publish: false, } + + tlsConfig := ptls.MakeConfig(params.Conf.SourceFingerprint) + + if s.DumpPackets { + conn.DialContext = (&packetdumper.DialContext{ + Prefix: u.Scheme + "_source_conn", + }).Do + + conn.DialTLSContext = (&packetdumper.DialTLSContext{ + DialContext: conn.DialContext, + TLSConfig: tlsConfig, + }).Do + } else { + conn.TLSConfig = tlsConfig + } + err = conn.Initialize(connectCtx) connectCtxCancel() if err != nil { diff --git a/internal/staticsources/rtp/source.go b/internal/staticsources/rtp/source.go index 5b759570..95bf9bca 100644 --- a/internal/staticsources/rtp/source.go +++ b/internal/staticsources/rtp/source.go @@ -66,10 +66,11 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { var nc net.Conn switch u.Scheme { - case "unix+rtp": + case "unix+rtp": // deprecated params := unix.URLToParams(u) l := &unix.Listener{ - Path: params.Path, + Path: params.Path, + Listen: net.Listen, } err = l.Initialize() if err != nil { @@ -83,36 +84,20 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { udpReadBufferSize = *params.Conf.RTPUDPReadBufferSize } - listenPacket := net.ListenPacket - - if s.DumpPackets { - listenPacket = func(network, address string) (net.PacketConn, error) { - pc, err2 := net.ListenPacket(network, address) - if err2 != nil { - return nil, err2 - } - - d := &packetdumper.PacketConn{ - Prefix: "rtp_source_packetconn", - PacketConn: pc, - } - err2 = d.Initialize() - if err2 != nil { - return nil, err2 - } - - return d, nil - } - } - params := udp.URLToParams(u) l := &udp.Listener{ Address: params.Address, Source: params.Source, IntfName: params.IntfName, UDPReadBufferSize: int(udpReadBufferSize), - ListenPacket: listenPacket, } + + if s.DumpPackets { + l.ListenPacket = (&packetdumper.ListenPacket{ + Prefix: "rtp_source_packet_conn", + }).Do + } + err = l.Initialize() if err != nil { return err diff --git a/internal/staticsources/rtsp/source.go b/internal/staticsources/rtsp/source.go index e9147203..27f72add 100644 --- a/internal/staticsources/rtsp/source.go +++ b/internal/staticsources/rtsp/source.go @@ -3,7 +3,6 @@ package rtsp import ( "fmt" - "net" "net/url" "time" @@ -19,7 +18,7 @@ import ( "github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/packetdumper" "github.com/bluenviron/mediamtx/internal/protocols/rtsp" - "github.com/bluenviron/mediamtx/internal/protocols/tls" + ptls "github.com/bluenviron/mediamtx/internal/protocols/tls" "github.com/bluenviron/mediamtx/internal/stream" ) @@ -122,6 +121,11 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { decodeErrors.Start() defer decodeErrors.Stop() + u0, err := url.Parse(params.ResolvedSource) + if err != nil { + return err + } + c := &gortsplib.Client{ Protocol: params.Conf.RTSPTransport.Protocol, ReadTimeout: time.Duration(s.ReadTimeout), @@ -150,11 +154,6 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { }, } - u0, err := url.Parse(params.ResolvedSource) - if err != nil { - return err - } - switch u0.Scheme { case "rtsp+http", "rtsps+http": c.Tunnel = gortsplib.TunnelHTTP @@ -167,7 +166,6 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { u0.Scheme = "rtsp" default: u0.Scheme = "rtsps" - c.TLSConfig = tls.MakeConfig(u0.Hostname(), params.Conf.SourceFingerprint) } u, err := base.ParseURL(u0.String()) @@ -182,16 +180,23 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { s.UDPReadBufferSize = *params.Conf.RTSPUDPReadBufferSize } + tlsConfig := ptls.MakeConfig(params.Conf.SourceFingerprint) + if s.DumpPackets { c.DialContext = (&packetdumper.DialContext{ - Prefix: "rtsp_source_conn", - DialContext: (&net.Dialer{}).DialContext, + Prefix: u.Scheme + "_source_conn", }).Do c.ListenPacket = (&packetdumper.ListenPacket{ - Prefix: "rtsp_source_packetconn", - ListenPacket: net.ListenPacket, + Prefix: u.Scheme + "_source_packet_conn", }).Do + + c.DialTLSContext = (&packetdumper.DialTLSContext{ + DialContext: c.DialContext, + TLSConfig: tlsConfig, + }).Do + } else { + c.TLSConfig = tlsConfig } err = c.Start() diff --git a/internal/staticsources/webrtc/source.go b/internal/staticsources/webrtc/source.go index f6f56e07..855ebb01 100644 --- a/internal/staticsources/webrtc/source.go +++ b/internal/staticsources/webrtc/source.go @@ -3,7 +3,6 @@ package webrtc import ( "fmt" - "net" "net/http" "net/url" "strings" @@ -15,7 +14,7 @@ import ( "github.com/bluenviron/mediamtx/internal/defs" "github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/packetdumper" - "github.com/bluenviron/mediamtx/internal/protocols/tls" + ptls "github.com/bluenviron/mediamtx/internal/protocols/tls" "github.com/bluenviron/mediamtx/internal/protocols/webrtc" "github.com/bluenviron/mediamtx/internal/protocols/whip" "github.com/bluenviron/mediamtx/internal/stream" @@ -49,22 +48,25 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { return err } - u.Scheme = strings.ReplaceAll(u.Scheme, "whep", "http") + tr := &http.Transport{} + defer tr.CloseIdleConnections() - dialContext := (&net.Dialer{}).DialContext + tlsConfig := ptls.MakeConfig(params.Conf.SourceFingerprint) if s.DumpPackets { - dialContext = (&packetdumper.DialContext{ - Prefix: "webrtc_source_conn", - DialContext: dialContext, + tr.DialContext = (&packetdumper.DialContext{ + Prefix: u.Scheme + "_source_conn", }).Do + + tr.DialTLSContext = (&packetdumper.DialTLSContext{ + DialContext: tr.DialContext, + TLSConfig: tlsConfig, + }).Do + } else { + tr.TLSClientConfig = tlsConfig } - tr := &http.Transport{ - DialContext: dialContext, - TLSClientConfig: tls.MakeConfig(u.Hostname(), params.Conf.SourceFingerprint), - } - defer tr.CloseIdleConnections() + u.Scheme = strings.ReplaceAll(u.Scheme, "whep", "http") client := whip.Client{ URL: u,