dump unencrypted TLS sessions (#5624)

when dumpPackets is true, embed TLS master keys into the dump, in a
format which is natively compatible with Wireshark.
This commit is contained in:
Alessandro Ros
2026-04-04 14:46:43 +02:00
committed by GitHub
parent f52a63858c
commit d4c6f95291
34 changed files with 544 additions and 359 deletions
+6
View File
@@ -78,6 +78,12 @@ linters:
- stringsbuilder - stringsbuilder
- testingcontext - testingcontext
exclusions:
rules:
- linters:
- lll
source: "^\\s*// https?://"
formatters: formatters:
enable: enable:
- gofmt - gofmt
+1 -1
View File
@@ -39,6 +39,7 @@ require (
github.com/pion/webrtc/v4 v4.2.11 github.com/pion/webrtc/v4 v4.2.11
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
golang.org/x/crypto v0.49.0 golang.org/x/crypto v0.49.0
golang.org/x/net v0.52.0
golang.org/x/sys v0.42.0 golang.org/x/sys v0.42.0
golang.org/x/term v0.41.0 golang.org/x/term v0.41.0
) )
@@ -96,7 +97,6 @@ require (
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
golang.org/x/arch v0.22.0 // indirect golang.org/x/arch v0.22.0 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/text v0.35.0 // indirect golang.org/x/text v0.35.0 // indirect
golang.org/x/time v0.12.0 // indirect golang.org/x/time v0.12.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect google.golang.org/protobuf v1.36.10 // indirect
+2 -12
View File
@@ -185,13 +185,8 @@ func (m *Manager) authenticateHTTP(req *Request) (string, error) {
Query: req.Query, Query: req.Query,
}) })
u, err := url.Parse(m.HTTPAddress)
if err != nil {
return "", err
}
tr := &http.Transport{ tr := &http.Transport{
TLSClientConfig: tls.MakeConfig(u.Hostname(), m.HTTPFingerprint), TLSClientConfig: tls.MakeConfig(m.HTTPFingerprint),
} }
defer tr.CloseIdleConnections() defer tr.CloseIdleConnections()
@@ -283,13 +278,8 @@ func (m *Manager) pullJWTJWKS() (jwt.Keyfunc, error) {
defer m.mutex.Unlock() defer m.mutex.Unlock()
if now.Sub(m.jwksLastRefresh) >= jwksRefreshPeriod { if now.Sub(m.jwksLastRefresh) >= jwksRefreshPeriod {
u, err := url.Parse(m.JWTJWKS)
if err != nil {
return nil, err
}
tr := &http.Transport{ tr := &http.Transport{
TLSClientConfig: tls.MakeConfig(u.Hostname(), m.JWTJWKSFingerprint), TLSClientConfig: tls.MakeConfig(m.JWTJWKSFingerprint),
} }
defer tr.CloseIdleConnections() defer tr.CloseIdleConnections()
+4 -4
View File
@@ -470,7 +470,7 @@ func (p *Core) createResources(initial bool) error {
MulticastIPRange: p.conf.MulticastIPRange, MulticastIPRange: p.conf.MulticastIPRange,
MulticastRTPPort: p.conf.MulticastRTPPort, MulticastRTPPort: p.conf.MulticastRTPPort,
MulticastRTCPPort: p.conf.MulticastRTCPPort, MulticastRTCPPort: p.conf.MulticastRTCPPort,
IsTLS: false, Encryption: false,
ServerCert: "", ServerCert: "",
ServerKey: "", ServerKey: "",
RTSPAddress: p.conf.RTSPAddress, RTSPAddress: p.conf.RTSPAddress,
@@ -513,7 +513,7 @@ func (p *Core) createResources(initial bool) error {
MulticastIPRange: p.conf.MulticastIPRange, MulticastIPRange: p.conf.MulticastIPRange,
MulticastRTPPort: p.conf.MulticastSRTPPort, MulticastRTPPort: p.conf.MulticastSRTPPort,
MulticastRTCPPort: p.conf.MulticastSRTCPPort, MulticastRTCPPort: p.conf.MulticastSRTCPPort,
IsTLS: true, Encryption: true,
ServerCert: p.conf.RTSPServerCert, ServerCert: p.conf.RTSPServerCert,
ServerKey: p.conf.RTSPServerKey, ServerKey: p.conf.RTSPServerKey,
RTSPAddress: p.conf.RTSPAddress, RTSPAddress: p.conf.RTSPAddress,
@@ -542,7 +542,7 @@ func (p *Core) createResources(initial bool) error {
DumpPackets: p.conf.DumpPackets, DumpPackets: p.conf.DumpPackets,
ReadTimeout: p.conf.ReadTimeout, ReadTimeout: p.conf.ReadTimeout,
WriteTimeout: p.conf.WriteTimeout, WriteTimeout: p.conf.WriteTimeout,
IsTLS: false, Encryption: false,
ServerCert: "", ServerCert: "",
ServerKey: "", ServerKey: "",
RTSPAddress: p.conf.RTSPAddress, RTSPAddress: p.conf.RTSPAddress,
@@ -569,7 +569,7 @@ func (p *Core) createResources(initial bool) error {
Address: p.conf.RTMPSAddress, Address: p.conf.RTMPSAddress,
ReadTimeout: p.conf.ReadTimeout, ReadTimeout: p.conf.ReadTimeout,
WriteTimeout: p.conf.WriteTimeout, WriteTimeout: p.conf.WriteTimeout,
IsTLS: true, Encryption: true,
ServerCert: p.conf.RTMPServerCert, ServerCert: p.conf.RTMPServerCert,
ServerKey: p.conf.RTMPServerKey, ServerKey: p.conf.RTMPServerKey,
DumpPackets: p.conf.DumpPackets, DumpPackets: p.conf.DumpPackets,
+116 -56
View File
@@ -2,7 +2,9 @@
package packetdumper package packetdumper
import ( import (
"encoding/binary"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"sync" "sync"
@@ -14,39 +16,79 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
var _ net.Conn = (*Conn)(nil) var _ net.Conn = (*conn)(nil)
type direction int type direction int
const ( const (
dirRead direction = iota dirInbound direction = iota
dirWrite dirOutbound
dirHandshake dirHandshake
dirSecret
) )
func writeDecryptionSecretsBlock(f io.Writer, data []byte) {
const (
dsbSecretTypeTLS = 0x544c534b
blockType = 0x0000000A
fixedHeaderBytes = 4 + 4 + 4 + 4 // type + totalLen + secretsType + secretsLen
trailerBytes = 4 // repeated totalLen
overheadBytes = fixedHeaderBytes + trailerBytes
)
secretsLen := len(data)
paddedLen := (secretsLen + 3) &^ 3
padBytes := paddedLen - secretsLen
totalLen := uint32(overheadBytes + paddedLen)
buf := make([]byte, totalLen)
pos := 0
binary.LittleEndian.PutUint32(buf[pos:], blockType)
pos += 4
binary.LittleEndian.PutUint32(buf[pos:], totalLen)
pos += 4
binary.LittleEndian.PutUint32(buf[pos:], dsbSecretTypeTLS)
pos += 4
binary.LittleEndian.PutUint32(buf[pos:], uint32(secretsLen)) // unpadded length, per spec
pos += 4
pos += copy(buf[pos:], data)
pos += padBytes // zero padding already present (make zeroes)
binary.LittleEndian.PutUint32(buf[pos:], totalLen) // trailing repeat
f.Write(buf) //nolint:errcheck
}
type dumpEntry struct { type dumpEntry struct {
ntp time.Time ntp time.Time
data []byte data []byte
direction direction direction direction
} }
// Conn is a wrapper around net.Conn that dumps packets to disk. // conn is a wrapper around net.Conn that dumps packets to disk.
type Conn struct { type conn struct {
Prefix string Prefix string
Conn net.Conn Conn net.Conn
ServerSide bool ServerSide bool
f *os.File expectingSecrets int
pw *pcapgo.NgWriter f *os.File
once sync.Once pw *pcapgo.NgWriter
once sync.Once
local *net.TCPAddr
remote *net.TCPAddr
nextLocalSequence uint32
nextRemoteSequence uint32
delayed []dumpEntry
queue chan dumpEntry queue chan dumpEntry
terminated chan struct{} terminated chan struct{}
done chan struct{} done chan struct{}
} }
// Initialize initializes Conn. // Initialize initializes conn.
func (c *Conn) Initialize() error { func (c *conn) Initialize() error {
var err error var err error
c.f, err = os.Create(fmt.Sprintf("%s_%d_%s.pcapng", c.Prefix, time.Now().UnixNano(), uuid.New().String())) c.f, err = os.Create(fmt.Sprintf("%s_%d_%s.pcapng", c.Prefix, time.Now().UnixNano(), uuid.New().String()))
if err != nil { if err != nil {
@@ -59,6 +101,17 @@ func (c *Conn) Initialize() error {
return err return err
} }
c.local = c.Conn.LocalAddr().(*net.TCPAddr)
c.remote = c.Conn.RemoteAddr().(*net.TCPAddr)
if c.ServerSide {
c.nextLocalSequence = uint32(2000)
c.nextRemoteSequence = uint32(1000)
} else {
c.nextLocalSequence = uint32(1000)
c.nextRemoteSequence = uint32(2000)
}
c.queue = make(chan dumpEntry, 64) c.queue = make(chan dumpEntry, 64)
c.terminated = make(chan struct{}) c.terminated = make(chan struct{})
c.done = make(chan struct{}) c.done = make(chan struct{})
@@ -71,7 +124,7 @@ func (c *Conn) Initialize() error {
} }
// Close implements net.Conn. // Close implements net.Conn.
func (c *Conn) Close() error { func (c *conn) Close() error {
c.once.Do(func() { c.once.Do(func() {
close(c.terminated) close(c.terminated)
}) })
@@ -79,29 +132,23 @@ func (c *Conn) Close() error {
return c.Conn.Close() return c.Conn.Close()
} }
func (c *Conn) run() { func (c *conn) run() {
defer close(c.done) defer close(c.done)
defer c.f.Close() defer c.f.Close()
defer c.pw.Flush() //nolint:errcheck
local := c.Conn.LocalAddr().(*net.TCPAddr)
remote := c.Conn.RemoteAddr().(*net.TCPAddr)
nextLocalSequence := uint32(1000)
nextRemoteSequence := uint32(2000)
for { for {
select { select {
case e := <-c.queue: case e := <-c.queue:
c.processEntry(e, local, remote, &nextLocalSequence, &nextRemoteSequence) c.processEntry(e)
case <-c.terminated: case <-c.terminated:
// Drain anything already in the queue before exiting. // Drain anything already in the queue before exiting.
for { for {
select { select {
case e := <-c.queue: case e := <-c.queue:
c.processEntry(e, local, remote, &nextLocalSequence, &nextRemoteSequence) c.processEntry(e)
default: default:
c.pw.Flush() //nolint:errcheck
return return
} }
} }
@@ -109,18 +156,19 @@ func (c *Conn) run() {
} }
} }
func (c *Conn) processEntry( func (c *conn) processEntry(e dumpEntry) {
e dumpEntry, if c.expectingSecrets > 0 && e.direction != dirSecret {
local, remote *net.TCPAddr, c.delayed = append(c.delayed, e)
nextLocalSequence, nextRemoteSequence *uint32, return
) { }
switch e.direction { switch e.direction {
case dirHandshake: case dirHandshake:
clientAddr, serverAddr := local, remote // client side: local initiates clientAddr, serverAddr := c.local, c.remote // client side: local initiates
clientSeq, serverSeq := nextLocalSequence, nextRemoteSequence clientSeq, serverSeq := &c.nextLocalSequence, &c.nextRemoteSequence
if c.ServerSide { if c.ServerSide {
clientAddr, serverAddr = remote, local // server side: remote initiated clientAddr, serverAddr = c.remote, c.local // server side: remote initiates
clientSeq, serverSeq = nextRemoteSequence, nextLocalSequence clientSeq, serverSeq = &c.nextRemoteSequence, &c.nextLocalSequence
} }
// SYN (client -> server) // SYN (client -> server)
@@ -137,31 +185,43 @@ func (c *Conn) processEntry(
c.writePacket(e.ntp, clientAddr, serverAddr, c.writePacket(e.ntp, clientAddr, serverAddr,
layers.TCP{ACK: true, Window: 65535, Seq: *clientSeq, Ack: *serverSeq}, nil) layers.TCP{ACK: true, Window: 65535, Seq: *clientSeq, Ack: *serverSeq}, nil)
case dirRead: case dirSecret:
tcpFlags := layers.TCP{ c.pw.Flush() //nolint:errcheck
PSH: true, writeDecryptionSecretsBlock(c.f, e.data)
ACK: true,
Window: 14600,
Seq: *nextRemoteSequence,
Ack: *nextLocalSequence,
}
c.writePacket(e.ntp, remote, local, tcpFlags, e.data)
*nextRemoteSequence += uint32(len(e.data))
case dirWrite: c.expectingSecrets--
if c.expectingSecrets == 0 {
for _, e2 := range c.delayed {
c.processEntry(e2)
}
c.delayed = nil
}
case dirInbound:
tcpFlags := layers.TCP{ tcpFlags := layers.TCP{
PSH: true, PSH: true,
ACK: true, ACK: true,
Window: 14600, Window: 14600,
Seq: *nextLocalSequence, Seq: c.nextRemoteSequence,
Ack: *nextRemoteSequence, Ack: c.nextLocalSequence,
} }
c.writePacket(e.ntp, local, remote, tcpFlags, e.data) c.writePacket(e.ntp, c.remote, c.local, tcpFlags, e.data)
*nextLocalSequence += uint32(len(e.data)) c.nextRemoteSequence += uint32(len(e.data))
case dirOutbound:
tcpFlags := layers.TCP{
PSH: true,
ACK: true,
Window: 14600,
Seq: c.nextLocalSequence,
Ack: c.nextRemoteSequence,
}
c.writePacket(e.ntp, c.local, c.remote, tcpFlags, e.data)
c.nextLocalSequence += uint32(len(e.data))
} }
} }
func (c *Conn) writePacket( func (c *conn) writePacket(
ntp time.Time, ntp time.Time,
src, dst *net.TCPAddr, src, dst *net.TCPAddr,
tcpFlags layers.TCP, tcpFlags layers.TCP,
@@ -207,35 +267,35 @@ func (c *Conn) writePacket(
}, raw) }, raw)
} }
func (c *Conn) enqueue(e dumpEntry) { func (c *conn) enqueue(e dumpEntry) {
select { select {
case c.queue <- e: case c.queue <- e:
case <-c.terminated: case <-c.terminated:
} }
} }
func (c *Conn) Read(p []byte) (n int, err error) { func (c *conn) Read(p []byte) (n int, err error) {
n, err = c.Conn.Read(p) n, err = c.Conn.Read(p)
if n != 0 { if n != 0 {
c.enqueue(dumpEntry{ c.enqueue(dumpEntry{
ntp: time.Now(), ntp: time.Now(),
data: append([]byte(nil), p[:n]...), data: append([]byte(nil), p[:n]...),
direction: dirRead, direction: dirInbound,
}) })
} }
return n, err return n, err
} }
func (c *Conn) Write(p []byte) (n int, err error) { func (c *conn) Write(p []byte) (n int, err error) {
n, err = c.Conn.Write(p) n, err = c.Conn.Write(p)
if err == nil { if err == nil {
c.enqueue(dumpEntry{ c.enqueue(dumpEntry{
ntp: time.Now(), ntp: time.Now(),
data: append([]byte(nil), p...), data: append([]byte(nil), p...),
direction: dirWrite, direction: dirOutbound,
}) })
} }
@@ -243,16 +303,16 @@ func (c *Conn) Write(p []byte) (n int, err error) {
} }
// LocalAddr implements net.Conn. // LocalAddr implements net.Conn.
func (c *Conn) LocalAddr() net.Addr { return c.Conn.LocalAddr() } func (c *conn) LocalAddr() net.Addr { return c.Conn.LocalAddr() }
// RemoteAddr implements net.Conn. // RemoteAddr implements net.Conn.
func (c *Conn) RemoteAddr() net.Addr { return c.Conn.RemoteAddr() } func (c *conn) RemoteAddr() net.Addr { return c.Conn.RemoteAddr() }
// SetDeadline implements net.Conn. // SetDeadline implements net.Conn.
func (c *Conn) SetDeadline(t time.Time) error { return c.Conn.SetDeadline(t) } func (c *conn) SetDeadline(t time.Time) error { return c.Conn.SetDeadline(t) }
// SetReadDeadline implements net.Conn. // SetReadDeadline implements net.Conn.
func (c *Conn) SetReadDeadline(t time.Time) error { return c.Conn.SetReadDeadline(t) } func (c *conn) SetReadDeadline(t time.Time) error { return c.Conn.SetReadDeadline(t) }
// SetWriteDeadline implements net.Conn. // SetWriteDeadline implements net.Conn.
func (c *Conn) SetWriteDeadline(t time.Time) error { return c.Conn.SetWriteDeadline(t) } func (c *conn) SetWriteDeadline(t time.Time) error { return c.Conn.SetWriteDeadline(t) }
+7 -7
View File
@@ -57,7 +57,7 @@ func TestConnInitialize_CreatesFile(t *testing.T) {
defer server.Close() defer server.Close()
prefix := filepath.Join(t.TempDir(), "capture") prefix := filepath.Join(t.TempDir(), "capture")
c := &Conn{Prefix: prefix, Conn: client} c := &conn{Prefix: prefix, Conn: client}
require.NoError(t, c.Initialize()) require.NoError(t, c.Initialize())
defer cleanupPcapng(t, prefix) defer cleanupPcapng(t, prefix)
@@ -69,7 +69,7 @@ func TestConnWrite(t *testing.T) {
defer server.Close() defer server.Close()
prefix := filepath.Join(t.TempDir(), "capture") prefix := filepath.Join(t.TempDir(), "capture")
c := &Conn{Prefix: prefix, Conn: client} c := &conn{Prefix: prefix, Conn: client}
require.NoError(t, c.Initialize()) require.NoError(t, c.Initialize())
defer cleanupPcapng(t, prefix) defer cleanupPcapng(t, prefix)
@@ -90,7 +90,7 @@ func TestConnRead(t *testing.T) {
defer server.Close() defer server.Close()
prefix := filepath.Join(t.TempDir(), "capture") prefix := filepath.Join(t.TempDir(), "capture")
c := &Conn{Prefix: prefix, Conn: client} c := &conn{Prefix: prefix, Conn: client}
require.NoError(t, c.Initialize()) require.NoError(t, c.Initialize())
defer cleanupPcapng(t, prefix) defer cleanupPcapng(t, prefix)
@@ -110,7 +110,7 @@ func TestConnServerSide(t *testing.T) {
defer client.Close() defer client.Close()
prefix := filepath.Join(t.TempDir(), "capture") prefix := filepath.Join(t.TempDir(), "capture")
c := &Conn{Prefix: prefix, Conn: server, ServerSide: true} c := &conn{Prefix: prefix, Conn: server, ServerSide: true}
require.NoError(t, c.Initialize()) require.NoError(t, c.Initialize())
defer cleanupPcapng(t, prefix) defer cleanupPcapng(t, prefix)
@@ -131,7 +131,7 @@ func TestConnMultipleWriteRead(t *testing.T) {
defer server.Close() defer server.Close()
prefix := filepath.Join(t.TempDir(), "capture") prefix := filepath.Join(t.TempDir(), "capture")
c := &Conn{Prefix: prefix, Conn: client} c := &conn{Prefix: prefix, Conn: client}
require.NoError(t, c.Initialize()) require.NoError(t, c.Initialize())
defer cleanupPcapng(t, prefix) defer cleanupPcapng(t, prefix)
@@ -164,7 +164,7 @@ func TestConnCloseIdempotent(t *testing.T) {
defer server.Close() defer server.Close()
prefix := filepath.Join(t.TempDir(), "capture") prefix := filepath.Join(t.TempDir(), "capture")
c := &Conn{Prefix: prefix, Conn: client} c := &conn{Prefix: prefix, Conn: client}
require.NoError(t, c.Initialize()) require.NoError(t, c.Initialize())
defer cleanupPcapng(t, prefix) defer cleanupPcapng(t, prefix)
@@ -178,7 +178,7 @@ func TestConnDelegatesAddrMethods(t *testing.T) {
defer server.Close() defer server.Close()
prefix := filepath.Join(t.TempDir(), "capture") prefix := filepath.Join(t.TempDir(), "capture")
c := &Conn{Prefix: prefix, Conn: client} c := &conn{Prefix: prefix, Conn: client}
require.NoError(t, c.Initialize()) require.NoError(t, c.Initialize())
defer cleanupPcapng(t, prefix) defer cleanupPcapng(t, prefix)
+7 -8
View File
@@ -7,26 +7,25 @@ import (
// DialContext is a wrapper around net.Dialer.DialContext that dumps packets to disk. // DialContext is a wrapper around net.Dialer.DialContext that dumps packets to disk.
type DialContext struct { type DialContext struct {
Prefix string Prefix string
DialContext func(ctx context.Context, network, address string) (net.Conn, error)
} }
// Do mimics net.Dialer.DialContext. // Do mimics net.Dialer.DialContext.
func (d *DialContext) Do(ctx context.Context, network, address string) (net.Conn, error) { func (d *DialContext) Do(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.DialContext(ctx, network, address) netConn, err := (&net.Dialer{}).DialContext(ctx, network, address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c := &Conn{ pdConn := &conn{
Prefix: d.Prefix, Prefix: d.Prefix,
Conn: conn, Conn: netConn,
} }
err = c.Initialize() err = pdConn.Initialize()
if err != nil { if err != nil {
conn.Close() netConn.Close()
return nil, err return nil, err
} }
return c, nil return pdConn, nil
} }
+43
View File
@@ -0,0 +1,43 @@
package packetdumper
import (
"context"
"crypto/tls"
"net"
)
// DialTLSContext provides the DialTLSContext function.
type DialTLSContext struct {
DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
TLSConfig *tls.Config
}
// Do provides DialTLSContext.
func (t *DialTLSContext) Do(ctx context.Context, network, addr string) (net.Conn, error) {
netConn, err := t.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
// clone TLS config and fill ServerName if empty.
// this is the same behavior of http.Client.
// https://cs.opensource.google/go/go/+/master:src/net/http/transport.go;l=1754;drc=a4b534f5e42fe58d58c0ff0562d76680cedb0466
tlsConfig := t.TLSConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{}
} else {
tlsConfig = tlsConfig.Clone()
}
if tlsConfig.ServerName == "" {
host, _, _ := net.SplitHostPort(addr)
tlsConfig.ServerName = host
}
pdConn := netConn.(*conn)
pdConn.expectingSecrets = 4
tlsConfig.KeyLogWriter = &connKeyLogWriter{c: pdConn}
return tls.Client(netConn, tlsConfig), nil
}
+3 -4
View File
@@ -7,18 +7,17 @@ import (
// Listen is a wrapper around net.Listen that dumps packets to disk. // Listen is a wrapper around net.Listen that dumps packets to disk.
type Listen struct { type Listen struct {
Prefix string Prefix string
Listen func(network, address string) (net.Listener, error)
} }
// Do mimics net.Listen. // Do mimics net.Listen.
func (l *Listen) Do(network, address string) (net.Listener, error) { func (l *Listen) Do(network, address string) (net.Listener, error) {
ln, err := l.Listen(network, address) netListener, err := net.Listen(network, address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Listener{ return &listener{
Prefix: l.Prefix, Prefix: l.Prefix,
Listener: ln, Listener: netListener,
}, nil }, nil
} }
+8 -8
View File
@@ -6,25 +6,25 @@ import (
// ListenPacket is a wrapper around net.ListenPacket that dumps packets to disk. // ListenPacket is a wrapper around net.ListenPacket that dumps packets to disk.
type ListenPacket struct { type ListenPacket struct {
Prefix string Prefix string
ListenPacket func(network, address string) (net.PacketConn, error)
} }
// Do mimics net.ListenPacket // Do mimics net.ListenPacket.
func (l *ListenPacket) Do(network, address string) (net.PacketConn, error) { func (l *ListenPacket) Do(network, address string) (net.PacketConn, error) {
pc, err := l.ListenPacket(network, address) netPacketConn, err := net.ListenPacket(network, address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
d := &PacketConn{ pdPacketConn := &packetConn{
Prefix: l.Prefix, Prefix: l.Prefix,
PacketConn: pc, PacketConn: netPacketConn,
} }
err = d.Initialize() err = pdPacketConn.Initialize()
if err != nil { if err != nil {
netPacketConn.Close()
return nil, err return nil, err
} }
return d, nil return pdPacketConn, nil
} }
+12 -12
View File
@@ -2,41 +2,41 @@ package packetdumper
import "net" import "net"
var _ net.Listener = (*Listener)(nil) var _ net.Listener = (*listener)(nil)
// Listener is a wrapper around net.Listener that dumps packets to disk. // listener is a wrapper around a net.Listener that dumps packets to disk.
type Listener struct { type listener struct {
Prefix string Prefix string
Listener net.Listener Listener net.Listener
} }
// Accept implements net.Listener. // Accept implements net.Listener.
func (l *Listener) Accept() (net.Conn, error) { func (l *listener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept() netConn, err := l.Listener.Accept()
if err != nil { if err != nil {
return nil, err return nil, err
} }
cd := &Conn{ pdConn := &conn{
Prefix: l.Prefix, Prefix: l.Prefix,
Conn: conn, Conn: netConn,
ServerSide: true, ServerSide: true,
} }
err = cd.Initialize() err = pdConn.Initialize()
if err != nil { if err != nil {
conn.Close() //nolint:errcheck netConn.Close() //nolint:errcheck
return nil, err return nil, err
} }
return cd, nil return pdConn, nil
} }
// Close implements net.Listener. // Close implements net.Listener.
func (l *Listener) Close() error { func (l *listener) Close() error {
return l.Listener.Close() return l.Listener.Close()
} }
// Addr implements net.Listener. // Addr implements net.Listener.
func (l *Listener) Addr() net.Addr { func (l *listener) Addr() net.Addr {
return l.Listener.Addr() return l.Listener.Addr()
} }
+17 -17
View File
@@ -14,7 +14,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
var _ net.PacketConn = (*PacketConn)(nil) var _ net.PacketConn = (*packetConn)(nil)
type extendedPacketConn interface { type extendedPacketConn interface {
net.PacketConn net.PacketConn
@@ -28,8 +28,8 @@ type packetDumpEntry struct {
src, dst *net.UDPAddr src, dst *net.UDPAddr
} }
// PacketConn is a wrapper around net.PacketConn that dumps packets to disk. // packetConn is a wrapper around net.PacketConn that dumps packets to disk.
type PacketConn struct { type packetConn struct {
Prefix string Prefix string
PacketConn net.PacketConn PacketConn net.PacketConn
@@ -42,8 +42,8 @@ type PacketConn struct {
done chan struct{} done chan struct{}
} }
// Initialize initializes PacketConn. // Initialize initializes packetConn.
func (c *PacketConn) Initialize() error { func (c *packetConn) Initialize() error {
var err error var err error
c.f, err = os.Create(fmt.Sprintf("%s_%d_%s.pcapng", c.Prefix, time.Now().UnixNano(), uuid.New().String())) c.f, err = os.Create(fmt.Sprintf("%s_%d_%s.pcapng", c.Prefix, time.Now().UnixNano(), uuid.New().String()))
if err != nil { if err != nil {
@@ -66,7 +66,7 @@ func (c *PacketConn) Initialize() error {
} }
// Close implements net.PacketConn. // Close implements net.PacketConn.
func (c *PacketConn) Close() error { func (c *packetConn) Close() error {
c.once.Do(func() { c.once.Do(func() {
close(c.terminated) close(c.terminated)
}) })
@@ -74,7 +74,7 @@ func (c *PacketConn) Close() error {
return c.PacketConn.Close() return c.PacketConn.Close()
} }
func (c *PacketConn) run() { func (c *packetConn) run() {
defer close(c.done) defer close(c.done)
defer c.f.Close() defer c.f.Close()
@@ -97,7 +97,7 @@ func (c *PacketConn) run() {
} }
} }
func (c *PacketConn) writePacket(ntp time.Time, src, dst *net.UDPAddr, payload []byte) { func (c *packetConn) writePacket(ntp time.Time, src, dst *net.UDPAddr, payload []byte) {
eth := &layers.Ethernet{ eth := &layers.Ethernet{
SrcMAC: net.HardwareAddr{0, 0, 0, 0, 0, 0}, SrcMAC: net.HardwareAddr{0, 0, 0, 0, 0, 0},
DstMAC: net.HardwareAddr{0, 0, 0, 0, 0, 0}, DstMAC: net.HardwareAddr{0, 0, 0, 0, 0, 0},
@@ -130,7 +130,7 @@ func (c *PacketConn) writePacket(ntp time.Time, src, dst *net.UDPAddr, payload [
}, raw) }, raw)
} }
func (c *PacketConn) enqueue(e packetDumpEntry) { func (c *packetConn) enqueue(e packetDumpEntry) {
select { select {
case c.queue <- e: case c.queue <- e:
case <-c.terminated: case <-c.terminated:
@@ -138,7 +138,7 @@ func (c *PacketConn) enqueue(e packetDumpEntry) {
} }
// ReadFrom implements net.PacketConn. // ReadFrom implements net.PacketConn.
func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.PacketConn.ReadFrom(p) n, addr, err = c.PacketConn.ReadFrom(p)
if n != 0 { if n != 0 {
@@ -157,7 +157,7 @@ func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
} }
// WriteTo implements net.PacketConn. // WriteTo implements net.PacketConn.
func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *packetConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
n, err = c.PacketConn.WriteTo(p, addr) n, err = c.PacketConn.WriteTo(p, addr)
if err == nil { if err == nil {
@@ -176,23 +176,23 @@ func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
} }
// LocalAddr implements net.PacketConn. // LocalAddr implements net.PacketConn.
func (c *PacketConn) LocalAddr() net.Addr { return c.PacketConn.LocalAddr() } func (c *packetConn) LocalAddr() net.Addr { return c.PacketConn.LocalAddr() }
// SetDeadline implements net.PacketConn. // SetDeadline implements net.PacketConn.
func (c *PacketConn) SetDeadline(t time.Time) error { return c.PacketConn.SetDeadline(t) } func (c *packetConn) SetDeadline(t time.Time) error { return c.PacketConn.SetDeadline(t) }
// SetReadDeadline implements net.PacketConn. // SetReadDeadline implements net.PacketConn.
func (c *PacketConn) SetReadDeadline(t time.Time) error { return c.PacketConn.SetReadDeadline(t) } func (c *packetConn) SetReadDeadline(t time.Time) error { return c.PacketConn.SetReadDeadline(t) }
// SetWriteDeadline implements net.PacketConn. // SetWriteDeadline implements net.PacketConn.
func (c *PacketConn) SetWriteDeadline(t time.Time) error { return c.PacketConn.SetWriteDeadline(t) } func (c *packetConn) SetWriteDeadline(t time.Time) error { return c.PacketConn.SetWriteDeadline(t) }
// SetReadBuffer implements extendedPacketConn. // SetReadBuffer implements extendedPacketConn.
func (c *PacketConn) SetReadBuffer(bytes int) error { func (c *packetConn) SetReadBuffer(bytes int) error {
return c.PacketConn.(extendedPacketConn).SetReadBuffer(bytes) return c.PacketConn.(extendedPacketConn).SetReadBuffer(bytes)
} }
// SyscallConn implements extendedPacketConn. // SyscallConn implements extendedPacketConn.
func (c *PacketConn) SyscallConn() (syscall.RawConn, error) { func (c *packetConn) SyscallConn() (syscall.RawConn, error) {
return c.PacketConn.(extendedPacketConn).SyscallConn() return c.PacketConn.(extendedPacketConn).SyscallConn()
} }
+7 -7
View File
@@ -51,7 +51,7 @@ func TestPacketConnInitialize_CreatesFile(t *testing.T) {
defer server.Close() //nolint:errcheck defer server.Close() //nolint:errcheck
prefix := filepath.Join(t.TempDir(), "capture") prefix := filepath.Join(t.TempDir(), "capture")
c := &PacketConn{Prefix: prefix, PacketConn: client} c := &packetConn{Prefix: prefix, PacketConn: client}
require.NoError(t, c.Initialize()) require.NoError(t, c.Initialize())
defer cleanupPcapngPacket(t, prefix) defer cleanupPcapngPacket(t, prefix)
@@ -63,7 +63,7 @@ func TestPacketConnWriteTo(t *testing.T) {
defer server.Close() //nolint:errcheck defer server.Close() //nolint:errcheck
prefix := filepath.Join(t.TempDir(), "capture") prefix := filepath.Join(t.TempDir(), "capture")
c := &PacketConn{Prefix: prefix, PacketConn: client} c := &packetConn{Prefix: prefix, PacketConn: client}
require.NoError(t, c.Initialize()) require.NoError(t, c.Initialize())
defer cleanupPcapngPacket(t, prefix) defer cleanupPcapngPacket(t, prefix)
@@ -85,7 +85,7 @@ func TestPacketConnReadFrom(t *testing.T) {
defer server.Close() //nolint:errcheck defer server.Close() //nolint:errcheck
prefix := filepath.Join(t.TempDir(), "capture") prefix := filepath.Join(t.TempDir(), "capture")
c := &PacketConn{Prefix: prefix, PacketConn: client} c := &packetConn{Prefix: prefix, PacketConn: client}
require.NoError(t, c.Initialize()) require.NoError(t, c.Initialize())
defer cleanupPcapngPacket(t, prefix) defer cleanupPcapngPacket(t, prefix)
@@ -107,7 +107,7 @@ func TestPacketConnMultipleWriteRead(t *testing.T) {
defer server.Close() //nolint:errcheck defer server.Close() //nolint:errcheck
prefix := filepath.Join(t.TempDir(), "capture") prefix := filepath.Join(t.TempDir(), "capture")
c := &PacketConn{Prefix: prefix, PacketConn: client} c := &packetConn{Prefix: prefix, PacketConn: client}
require.NoError(t, c.Initialize()) require.NoError(t, c.Initialize())
defer cleanupPcapngPacket(t, prefix) defer cleanupPcapngPacket(t, prefix)
@@ -150,7 +150,7 @@ func TestPacketConnCloseIdempotent(t *testing.T) {
defer server.Close() //nolint:errcheck defer server.Close() //nolint:errcheck
prefix := filepath.Join(t.TempDir(), "capture") prefix := filepath.Join(t.TempDir(), "capture")
c := &PacketConn{Prefix: prefix, PacketConn: client} c := &packetConn{Prefix: prefix, PacketConn: client}
require.NoError(t, c.Initialize()) require.NoError(t, c.Initialize())
defer cleanupPcapngPacket(t, prefix) defer cleanupPcapngPacket(t, prefix)
@@ -164,7 +164,7 @@ func TestPacketConnDelegatesAddrMethods(t *testing.T) {
defer server.Close() //nolint:errcheck defer server.Close() //nolint:errcheck
prefix := filepath.Join(t.TempDir(), "capture") prefix := filepath.Join(t.TempDir(), "capture")
c := &PacketConn{Prefix: prefix, PacketConn: client} c := &packetConn{Prefix: prefix, PacketConn: client}
require.NoError(t, c.Initialize()) require.NoError(t, c.Initialize())
defer cleanupPcapngPacket(t, prefix) defer cleanupPcapngPacket(t, prefix)
@@ -182,7 +182,7 @@ func TestPacketConnReadFromRecordsSource(t *testing.T) {
defer server.Close() //nolint:errcheck defer server.Close() //nolint:errcheck
prefix := filepath.Join(t.TempDir(), "capture") prefix := filepath.Join(t.TempDir(), "capture")
c := &PacketConn{Prefix: prefix, PacketConn: client} c := &packetConn{Prefix: prefix, PacketConn: client}
require.NoError(t, c.Initialize()) require.NoError(t, c.Initialize())
defer cleanupPcapngPacket(t, prefix) defer cleanupPcapngPacket(t, prefix)
+24
View File
@@ -0,0 +1,24 @@
package packetdumper
import (
"crypto/tls"
"net"
)
// TLSListen provides a tls.Listen that also dumps TLS master secrets to disk.
type TLSListen struct {
Listen func(network string, address string) (net.Listener, error)
}
// Do mimics tls.Listen.
func (l *TLSListen) Do(network string, laddr string, tlsConfig *tls.Config) (net.Listener, error) {
netListener, err := l.Listen(network, laddr)
if err != nil {
return nil, err
}
return &tlsListener{
Listener: netListener,
TLSConfig: tlsConfig,
}, nil
}
+48
View File
@@ -0,0 +1,48 @@
package packetdumper
import (
"crypto/tls"
"net"
"time"
)
type connKeyLogWriter struct {
c *conn
}
func (w *connKeyLogWriter) Write(p []byte) (int, error) {
w.c.enqueue(dumpEntry{
ntp: time.Now(),
data: append([]byte(nil), p...),
direction: dirSecret,
})
return len(p), nil
}
type tlsListener struct {
Listener net.Listener
TLSConfig *tls.Config
}
func (l *tlsListener) Close() error {
return l.Listener.Close()
}
func (l *tlsListener) Addr() net.Addr {
return l.Listener.Addr()
}
func (l *tlsListener) Accept() (net.Conn, error) {
netConn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
tlsConfig := l.TLSConfig.Clone()
pdConn := netConn.(*conn)
pdConn.expectingSecrets = 4
tlsConfig.KeyLogWriter = &connKeyLogWriter{c: pdConn}
return tls.Server(netConn, tlsConfig), nil
}
+42 -20
View File
@@ -15,6 +15,7 @@ import (
"github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/logger"
"github.com/bluenviron/mediamtx/internal/packetdumper" "github.com/bluenviron/mediamtx/internal/packetdumper"
"github.com/bluenviron/mediamtx/internal/restrictnetwork" "github.com/bluenviron/mediamtx/internal/restrictnetwork"
"golang.org/x/net/http2"
) )
type nilWriter struct{} type nilWriter struct{}
@@ -59,6 +60,7 @@ func (s *Server) Initialize() error {
} }
var tlsConfig *tls.Config var tlsConfig *tls.Config
if s.Encryption { if s.Encryption {
if s.ServerCert == "" { if s.ServerCert == "" {
return fmt.Errorf("server cert is missing") return fmt.Errorf("server cert is missing")
@@ -93,23 +95,6 @@ func (s *Server) Initialize() error {
os.Remove(address) os.Remove(address)
} }
var err error
s.ln, err = net.Listen(network, address)
if err != nil {
return err
}
if s.DumpPackets {
s.ln = &packetdumper.Listener{
Prefix: s.DumpPacketsPrefix,
Listener: s.ln,
}
}
if network == "unix" {
os.Chmod(address, 0o755) //nolint:errcheck
}
h := s.Handler h := s.Handler
h = &handlerOrigin{h, s.AllowOrigins} h = &handlerOrigin{h, s.AllowOrigins}
h = &handlerServerHeader{h} h = &handlerServerHeader{h}
@@ -134,11 +119,48 @@ func (s *Server) Initialize() error {
} }
if tlsConfig != nil { if tlsConfig != nil {
go s.inner.ServeTLS(s.ln, "", "") err := http2.ConfigureServer(s.inner, &http2.Server{})
} else { if err != nil {
go s.inner.Serve(s.ln) return err
}
} }
var listen func(network string, address string) (net.Listener, error)
var tlsListen func(network string, laddr string, config *tls.Config) (net.Listener, error)
if s.DumpPackets {
listen = (&packetdumper.Listen{
Prefix: s.DumpPacketsPrefix,
}).Do
tlsListen = (&packetdumper.TLSListen{
Listen: listen,
}).Do
} else {
listen = net.Listen
tlsListen = tls.Listen
}
if tlsConfig != nil {
var err error
s.ln, err = tlsListen(network, address, tlsConfig)
if err != nil {
return err
}
} else {
var err error
s.ln, err = listen(network, address)
if err != nil {
return err
}
}
if network == "unix" {
os.Chmod(address, 0o755) //nolint:errcheck
}
go s.inner.Serve(s.ln)
return nil return nil
} }
+2 -1
View File
@@ -26,8 +26,9 @@ func TestUnixSocket(t *testing.T) {
err := s.Initialize() err := s.Initialize()
require.NoError(t, err) require.NoError(t, err)
_, err = os.Stat("http.sock") info, err := os.Stat("http.sock")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm())
conn, err := net.Dial("unix", "http.sock") conn, err := net.Dial("unix", "http.sock")
require.NoError(t, err) require.NoError(t, err)
+7 -9
View File
@@ -9,15 +9,11 @@ import (
"strings" "strings"
) )
// MakeConfig returns a tls.Config with: // MakeConfig returns a tls.Config with fingerprint support.
// - server name indicator (SNI) support func MakeConfig(fingerprint string) *tls.Config {
// - fingerprint support
func MakeConfig(serverName string, fingerprint string) *tls.Config {
conf := &tls.Config{
ServerName: serverName,
}
if fingerprint != "" { if fingerprint != "" {
conf := &tls.Config{}
fingerprintLower := strings.ToLower(fingerprint) fingerprintLower := strings.ToLower(fingerprint)
conf.InsecureSkipVerify = true conf.InsecureSkipVerify = true
conf.VerifyConnection = func(cs tls.ConnectionState) error { conf.VerifyConnection = func(cs tls.ConnectionState) error {
@@ -32,7 +28,9 @@ func MakeConfig(serverName string, fingerprint string) *tls.Config {
return nil return nil
} }
return conf
} }
return conf return nil
} }
+1 -43
View File
@@ -60,44 +60,6 @@ y++U32uuSFiXDcSLarfIsE992MEJLSAynbF1Rsgsr3gXbGiuToJRyxbIeVy7gwzD
-----END RSA PRIVATE KEY----- -----END RSA PRIVATE KEY-----
`) `)
func TestMakeConfigSNI(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8556")
require.NoError(t, err)
defer l.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
nconn, err2 := l.Accept()
require.NoError(t, err2)
defer nconn.Close()
cert, err2 := tls.X509KeyPair(testTLSCertPub, testTLSCertKey)
require.NoError(t, err2)
tnconn := tls.Server(nconn, &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true,
VerifyConnection: func(cs tls.ConnectionState) error {
require.Equal(t, "myhost", cs.ServerName)
return nil
},
})
err2 = tnconn.Handshake()
require.EqualError(t, err2, "remote error: tls: bad certificate")
}()
conf := MakeConfig("myhost", "")
_, err = tls.Dial("tcp", "localhost:8556", conf)
require.EqualError(t, err, "tls: failed to verify certificate: x509: "+
"certificate is not valid for any names, but wanted to match myhost")
}
func TestMakeConfigFingerprint(t *testing.T) { func TestMakeConfigFingerprint(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8556") l, err := net.Listen("tcp", "localhost:8556")
require.NoError(t, err) require.NoError(t, err)
@@ -119,17 +81,13 @@ func TestMakeConfigFingerprint(t *testing.T) {
tnconn := tls.Server(nconn, &tls.Config{ tnconn := tls.Server(nconn, &tls.Config{
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true, InsecureSkipVerify: true,
VerifyConnection: func(cs tls.ConnectionState) error {
require.Equal(t, "myhost", cs.ServerName)
return nil
},
}) })
err2 = tnconn.Handshake() err2 = tnconn.Handshake()
require.NoError(t, err2) require.NoError(t, err2)
}() }()
conf := MakeConfig("myhost", "33949e05fffb5ff3e8aa16f8213a6251b4d9363804ba53233c4da9a46d6f2739") conf := MakeConfig("33949e05fffb5ff3e8aa16f8213a6251b4d9363804ba53233c4da9a46d6f2739")
conn, err := tls.Dial("tcp", "localhost:8556", conf) conn, err := tls.Dial("tcp", "localhost:8556", conf)
require.NoError(t, err) require.NoError(t, err)
+6 -2
View File
@@ -11,7 +11,8 @@ import (
// Listener is a listener on a Unix socket. // Listener is a listener on a Unix socket.
type Listener struct { type Listener struct {
Path string Path string
Listen func(network string, address string) (net.Listener, error)
l net.Listener l net.Listener
c net.Conn c net.Conn
@@ -25,11 +26,14 @@ func (l *Listener) Initialize() error {
if l.Path == "" { if l.Path == "" {
return fmt.Errorf("invalid unix path") return fmt.Errorf("invalid unix path")
} }
if l.Listen == nil {
l.Listen = net.Listen
}
os.Remove(l.Path) os.Remove(l.Path)
var err error var err error
l.l, err = net.Listen("unix", l.Path) l.l, err = l.Listen("unix", l.Path)
if err != nil { if err != nil {
return err return err
} }
+8 -1
View File
@@ -58,11 +58,18 @@ func (s *httpServer) initialize() error {
router.Use(s.onRequest) router.Use(s.onRequest)
var proto string
if s.encryption {
proto = "hlss"
} else {
proto = "hls"
}
s.inner = &httpp.Server{ s.inner = &httpp.Server{
Address: s.address, Address: s.address,
AllowOrigins: s.allowOrigins, AllowOrigins: s.allowOrigins,
DumpPackets: s.dumpPackets, DumpPackets: s.dumpPackets,
DumpPacketsPrefix: "hls_server_conn", DumpPacketsPrefix: proto + "_server_conn",
ReadTimeout: time.Duration(s.readTimeout), ReadTimeout: time.Duration(s.readTimeout),
WriteTimeout: time.Duration(s.writeTimeout), WriteTimeout: time.Duration(s.writeTimeout),
Encryption: s.encryption, Encryption: s.encryption,
+3 -3
View File
@@ -25,7 +25,7 @@ import (
type conn struct { type conn struct {
parentCtx context.Context parentCtx context.Context
isTLS bool encryption bool
rtspAddress string rtspAddress string
readTimeout conf.Duration readTimeout conf.Duration
writeTimeout conf.Duration writeTimeout conf.Duration
@@ -297,7 +297,7 @@ func (c *conn) runPublish() error {
func (c *conn) APIReaderDescribe() *defs.APIPathReader { func (c *conn) APIReaderDescribe() *defs.APIPathReader {
return &defs.APIPathReader{ return &defs.APIPathReader{
Type: func() defs.APIPathReaderType { Type: func() defs.APIPathReaderType {
if c.isTLS { if c.encryption {
return defs.APIPathReaderTypeRTMPSConn return defs.APIPathReaderTypeRTMPSConn
} }
return defs.APIPathReaderTypeRTMPConn return defs.APIPathReaderTypeRTMPConn
@@ -310,7 +310,7 @@ func (c *conn) APIReaderDescribe() *defs.APIPathReader {
func (c *conn) APISourceDescribe() *defs.APIPathSource { func (c *conn) APISourceDescribe() *defs.APIPathSource {
return &defs.APIPathSource{ return &defs.APIPathSource{
Type: func() defs.APIPathSourceType { Type: func() defs.APIPathSourceType {
if c.isTLS { if c.encryption {
return defs.APIPathSourceTypeRTMPSConn return defs.APIPathSourceTypeRTMPSConn
} }
return defs.APIPathSourceTypeRTMPConn return defs.APIPathSourceTypeRTMPConn
+37 -18
View File
@@ -77,7 +77,7 @@ type Server struct {
DumpPackets bool DumpPackets bool
ReadTimeout conf.Duration ReadTimeout conf.Duration
WriteTimeout conf.Duration WriteTimeout conf.Duration
IsTLS bool Encryption bool
ServerCert string ServerCert string
ServerKey string ServerKey string
RTSPAddress string RTSPAddress string
@@ -107,32 +107,51 @@ type Server struct {
// Initialize initializes the server. // Initialize initializes the server.
func (s *Server) Initialize() error { func (s *Server) Initialize() error {
var err error var listen func(network string, address string) (net.Listener, error)
s.ln, err = net.Listen(restrictnetwork.Restrict("tcp", s.Address)) var tlsListen func(network string, laddr string, config *tls.Config) (net.Listener, error)
if err != nil {
return err
}
if s.DumpPackets { if s.DumpPackets {
s.ln = &packetdumper.Listener{ var proto string
Prefix: "rtmp_server_conn", if s.Encryption {
Listener: s.ln, proto = "rtmps"
} else {
proto = "rtmp"
} }
listen = (&packetdumper.Listen{
Prefix: proto + "_server_conn",
}).Do
tlsListen = (&packetdumper.TLSListen{
Listen: listen,
}).Do
} else {
listen = net.Listen
tlsListen = tls.Listen
} }
if s.IsTLS { if s.Encryption {
s.loader = &certloader.CertLoader{ s.loader = &certloader.CertLoader{
CertPath: s.ServerCert, CertPath: s.ServerCert,
KeyPath: s.ServerKey, KeyPath: s.ServerKey,
Parent: s.Parent, Parent: s.Parent,
} }
err = s.loader.Initialize() err := s.loader.Initialize()
if err != nil { if err != nil {
s.ln.Close()
return err return err
} }
s.ln = tls.NewListener(s.ln, &tls.Config{GetCertificate: s.loader.GetCertificate()}) net, addr := restrictnetwork.Restrict("tcp", s.Address)
s.ln, err = tlsListen(net, addr, &tls.Config{GetCertificate: s.loader.GetCertificate()})
if err != nil {
return err
}
} else {
var err error
s.ln, err = listen(restrictnetwork.Restrict("tcp", s.Address))
if err != nil {
return err
}
} }
s.ctx, s.ctxCancel = context.WithCancel(context.Background()) s.ctx, s.ctxCancel = context.WithCancel(context.Background())
@@ -146,7 +165,7 @@ func (s *Server) Initialize() error {
s.chAPIConnsKick = make(chan serverAPIConnsKickReq) s.chAPIConnsKick = make(chan serverAPIConnsKickReq)
str := "listener opened on " + s.Address str := "listener opened on " + s.Address
if s.IsTLS { if s.Encryption {
str += " (TCP/RTMPS)" str += " (TCP/RTMPS)"
} else { } else {
str += " (TCP/RTMP)" str += " (TCP/RTMP)"
@@ -164,7 +183,7 @@ func (s *Server) Initialize() error {
go s.run() go s.run()
if !interfaceIsEmpty(s.Metrics) { if !interfaceIsEmpty(s.Metrics) {
if s.IsTLS { if s.Encryption {
s.Metrics.SetRTMPSServer(s) s.Metrics.SetRTMPSServer(s)
} else { } else {
s.Metrics.SetRTMPServer(s) s.Metrics.SetRTMPServer(s)
@@ -177,7 +196,7 @@ func (s *Server) Initialize() error {
// Log implements logger.Writer. // Log implements logger.Writer.
func (s *Server) Log(level logger.Level, format string, args ...any) { func (s *Server) Log(level logger.Level, format string, args ...any) {
label := func() string { label := func() string {
if s.IsTLS { if s.Encryption {
return "RTMPS" return "RTMPS"
} }
return "RTMP" return "RTMP"
@@ -190,7 +209,7 @@ func (s *Server) Close() {
s.Log(logger.Info, "listener is closing") s.Log(logger.Info, "listener is closing")
if !interfaceIsEmpty((s.Metrics)) { if !interfaceIsEmpty((s.Metrics)) {
if s.IsTLS { if s.Encryption {
s.Metrics.SetRTMPSServer(nil) s.Metrics.SetRTMPSServer(nil)
} else { } else {
s.Metrics.SetRTMPServer(nil) s.Metrics.SetRTMPServer(nil)
@@ -218,7 +237,7 @@ outer:
case nconn := <-s.chNewConn: case nconn := <-s.chNewConn:
c := &conn{ c := &conn{
parentCtx: s.ctx, parentCtx: s.ctx,
isTLS: s.IsTLS, encryption: s.Encryption,
rtspAddress: s.RTSPAddress, rtspAddress: s.RTSPAddress,
readTimeout: s.ReadTimeout, readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout, writeTimeout: s.WriteTimeout,
+2 -2
View File
@@ -130,7 +130,7 @@ func TestServerPublish(t *testing.T) {
Address: "127.0.0.1:1939", Address: "127.0.0.1:1939",
ReadTimeout: conf.Duration(10 * time.Second), ReadTimeout: conf.Duration(10 * time.Second),
WriteTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second),
IsTLS: encrypt == "tls", Encryption: encrypt == "tls",
ServerCert: serverCertFpath, ServerCert: serverCertFpath,
ServerKey: serverKeyFpath, ServerKey: serverKeyFpath,
RTSPAddress: "", RTSPAddress: "",
@@ -266,7 +266,7 @@ func TestServerRead(t *testing.T) {
Address: "127.0.0.1:1939", Address: "127.0.0.1:1939",
ReadTimeout: conf.Duration(10 * time.Second), ReadTimeout: conf.Duration(10 * time.Second),
WriteTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second),
IsTLS: encrypt == "tls", Encryption: encrypt == "tls",
ServerCert: serverCertFpath, ServerCert: serverCertFpath,
ServerKey: serverKeyFpath, ServerKey: serverKeyFpath,
RTSPAddress: "", RTSPAddress: "",
+3 -3
View File
@@ -54,7 +54,7 @@ type connParent interface {
} }
type conn struct { type conn struct {
isTLS bool encryption bool
rtspAddress string rtspAddress string
authMethods []rtspauth.VerifyMethod authMethods []rtspauth.VerifyMethod
readTimeout conf.Duration readTimeout conf.Duration
@@ -87,7 +87,7 @@ func (c *conn) initialize() {
RTSPAddress: c.rtspAddress, RTSPAddress: c.rtspAddress,
Desc: defs.APIPathReader{ Desc: defs.APIPathReader{
Type: func() defs.APIPathReaderType { Type: func() defs.APIPathReaderType {
if c.isTLS { if c.encryption {
return defs.APIPathReaderTypeRTSPSConn return defs.APIPathReaderTypeRTSPSConn
} }
return defs.APIPathReaderTypeRTSPConn return defs.APIPathReaderTypeRTSPConn
@@ -192,7 +192,7 @@ func (c *conn) onDescribe(ctx *gortsplib.ServerHandlerOnDescribeCtx,
} }
var strm *gortsplib.ServerStream var strm *gortsplib.ServerStream
if !c.isTLS { if !c.encryption {
strm = res.Stream.RTSPStream(c.rserver) strm = res.Stream.RTSPStream(c.rserver)
} else { } else {
strm = res.Stream.RTSPSStream(c.rserver) strm = res.Stream.RTSPSStream(c.rserver)
+20 -12
View File
@@ -6,7 +6,6 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net"
"reflect" "reflect"
"sort" "sort"
"strings" "strings"
@@ -102,7 +101,7 @@ type Server struct {
MulticastIPRange string MulticastIPRange string
MulticastRTPPort int MulticastRTPPort int
MulticastRTCPPort int MulticastRTCPPort int
IsTLS bool Encryption bool
ServerCert string ServerCert string
ServerKey string ServerKey string
RTSPAddress string RTSPAddress string
@@ -153,7 +152,7 @@ func (s *Server) Initialize() error {
s.srv.MulticastRTCPPort = s.MulticastRTCPPort s.srv.MulticastRTCPPort = s.MulticastRTCPPort
} }
if s.IsTLS { if s.Encryption {
s.loader = &certloader.CertLoader{ s.loader = &certloader.CertLoader{
CertPath: s.ServerCert, CertPath: s.ServerCert,
KeyPath: s.ServerKey, KeyPath: s.ServerKey,
@@ -168,14 +167,23 @@ func (s *Server) Initialize() error {
} }
if s.DumpPackets { if s.DumpPackets {
var proto string
if s.Encryption {
proto = "rtsps"
} else {
proto = "rtsp"
}
s.srv.Listen = (&packetdumper.Listen{ s.srv.Listen = (&packetdumper.Listen{
Prefix: "rtsp_server_conn", Prefix: proto + "_server_conn",
Listen: net.Listen,
}).Do }).Do
s.srv.ListenPacket = (&packetdumper.ListenPacket{ s.srv.ListenPacket = (&packetdumper.ListenPacket{
Prefix: "rtsp_server_packetconn", Prefix: proto + "_server_packet_conn",
ListenPacket: net.ListenPacket, }).Do
s.srv.TLSListen = (&packetdumper.TLSListen{
Listen: s.srv.Listen,
}).Do }).Do
} }
@@ -190,7 +198,7 @@ func (s *Server) Initialize() error {
go s.run() go s.run()
if !interfaceIsEmpty(s.Metrics) { if !interfaceIsEmpty(s.Metrics) {
if s.IsTLS { if s.Encryption {
s.Metrics.SetRTSPSServer(s) s.Metrics.SetRTSPSServer(s)
} else { } else {
s.Metrics.SetRTSPServer(s) s.Metrics.SetRTSPServer(s)
@@ -203,7 +211,7 @@ func (s *Server) Initialize() error {
// Log implements logger.Writer. // Log implements logger.Writer.
func (s *Server) Log(level logger.Level, format string, args ...any) { func (s *Server) Log(level logger.Level, format string, args ...any) {
label := func() string { label := func() string {
if s.IsTLS { if s.Encryption {
return "RTSPS" return "RTSPS"
} }
return "RTSP" return "RTSP"
@@ -216,7 +224,7 @@ func (s *Server) Close() {
s.Log(logger.Info, "listener is closing") s.Log(logger.Info, "listener is closing")
if !interfaceIsEmpty(s.Metrics) { if !interfaceIsEmpty(s.Metrics) {
if s.IsTLS { if s.Encryption {
s.Metrics.SetRTSPSServer(nil) s.Metrics.SetRTSPSServer(nil)
} else { } else {
s.Metrics.SetRTSPServer(nil) s.Metrics.SetRTSPServer(nil)
@@ -257,7 +265,7 @@ outer:
// OnConnOpen implements gortsplib.ServerHandlerOnConnOpen. // OnConnOpen implements gortsplib.ServerHandlerOnConnOpen.
func (s *Server) OnConnOpen(ctx *gortsplib.ServerHandlerOnConnOpenCtx) { func (s *Server) OnConnOpen(ctx *gortsplib.ServerHandlerOnConnOpenCtx) {
c := &conn{ c := &conn{
isTLS: s.IsTLS, encryption: s.Encryption,
rtspAddress: s.RTSPAddress, rtspAddress: s.RTSPAddress,
authMethods: s.AuthMethods, authMethods: s.AuthMethods,
readTimeout: s.ReadTimeout, readTimeout: s.ReadTimeout,
@@ -302,7 +310,7 @@ func (s *Server) OnResponse(sc *gortsplib.ServerConn, res *base.Response) {
// OnSessionOpen implements gortsplib.ServerHandlerOnSessionOpen. // OnSessionOpen implements gortsplib.ServerHandlerOnSessionOpen.
func (s *Server) OnSessionOpen(ctx *gortsplib.ServerHandlerOnSessionOpenCtx) { func (s *Server) OnSessionOpen(ctx *gortsplib.ServerHandlerOnSessionOpenCtx) {
se := &session{ se := &session{
isTLS: s.IsTLS, encryption: s.Encryption,
transports: s.Transports, transports: s.Transports,
rsession: ctx.Session, rsession: ctx.Session,
rconn: ctx.Conn, rconn: ctx.Conn,
+4 -4
View File
@@ -58,7 +58,7 @@ type sessionParent interface {
} }
type session struct { type session struct {
isTLS bool encryption bool
transports conf.RTSPTransports transports conf.RTSPTransports
rsession *gortsplib.ServerSession rsession *gortsplib.ServerSession
rconn *gortsplib.ServerConn rconn *gortsplib.ServerConn
@@ -229,7 +229,7 @@ func (s *session) onAnnounce(c *conn, ctx *gortsplib.ServerHandlerOnAnnounceCtx)
} }
func (s *session) rtspStream() *gortsplib.ServerStream { func (s *session) rtspStream() *gortsplib.ServerStream {
if !s.isTLS { if !s.encryption {
return s.stream.RTSPStream(s.rserver) return s.stream.RTSPStream(s.rserver)
} }
return s.stream.RTSPSStream(s.rserver) return s.stream.RTSPSStream(s.rserver)
@@ -439,7 +439,7 @@ func (s *session) onPause(_ *gortsplib.ServerHandlerOnPauseCtx) (*base.Response,
func (s *session) APIReaderDescribe() *defs.APIPathReader { func (s *session) APIReaderDescribe() *defs.APIPathReader {
return &defs.APIPathReader{ return &defs.APIPathReader{
Type: func() defs.APIPathReaderType { Type: func() defs.APIPathReaderType {
if s.isTLS { if s.encryption {
return defs.APIPathReaderTypeRTSPSSession return defs.APIPathReaderTypeRTSPSSession
} }
return defs.APIPathReaderTypeRTSPSession return defs.APIPathReaderTypeRTSPSession
@@ -452,7 +452,7 @@ func (s *session) APIReaderDescribe() *defs.APIPathReader {
func (s *session) APISourceDescribe() *defs.APIPathSource { func (s *session) APISourceDescribe() *defs.APIPathSource {
return &defs.APIPathSource{ return &defs.APIPathSource{
Type: func() defs.APIPathSourceType { Type: func() defs.APIPathSourceType {
if s.isTLS { if s.encryption {
return defs.APIPathSourceTypeRTSPSSession return defs.APIPathSourceTypeRTSPSSession
} }
return defs.APIPathSourceTypeRTSPSession return defs.APIPathSourceTypeRTSPSession
+8 -1
View File
@@ -96,11 +96,18 @@ func (s *httpServer) initialize() error {
router.Use(s.onRequest) router.Use(s.onRequest)
var proto string
if s.encryption {
proto = "webrtcs"
} else {
proto = "webrtc"
}
s.inner = &httpp.Server{ s.inner = &httpp.Server{
Address: s.address, Address: s.address,
AllowOrigins: s.allowOrigins, AllowOrigins: s.allowOrigins,
DumpPackets: s.dumpPackets, DumpPackets: s.dumpPackets,
DumpPacketsPrefix: "webrtc_server_conn", DumpPacketsPrefix: proto + "_server_conn",
ReadTimeout: time.Duration(s.readTimeout), ReadTimeout: time.Duration(s.readTimeout),
WriteTimeout: time.Duration(s.writeTimeout), WriteTimeout: time.Duration(s.writeTimeout),
Encryption: s.encryption, Encryption: s.encryption,
+22 -18
View File
@@ -2,10 +2,9 @@
package hls package hls
import ( import (
"net"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
"net/url" "strings"
"time" "time"
"github.com/bluenviron/gohlslib/v2" "github.com/bluenviron/gohlslib/v2"
@@ -17,7 +16,7 @@ import (
"github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/logger"
"github.com/bluenviron/mediamtx/internal/packetdumper" "github.com/bluenviron/mediamtx/internal/packetdumper"
"github.com/bluenviron/mediamtx/internal/protocols/hls" "github.com/bluenviron/mediamtx/internal/protocols/hls"
"github.com/bluenviron/mediamtx/internal/protocols/tls" ptls "github.com/bluenviron/mediamtx/internal/protocols/tls"
"github.com/bluenviron/mediamtx/internal/stream" "github.com/bluenviron/mediamtx/internal/stream"
) )
@@ -62,25 +61,30 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error {
decodeErrors.Start() decodeErrors.Start()
defer decodeErrors.Stop() defer decodeErrors.Stop()
u, err := url.Parse(params.ResolvedSource) tr := &http.Transport{}
if err != nil { defer tr.CloseIdleConnections()
return err
}
dialContext := (&net.Dialer{}).DialContext tlsConfig := ptls.MakeConfig(params.Conf.SourceFingerprint)
if s.DumpPackets { if s.DumpPackets {
dialContext = (&packetdumper.DialContext{ var proto string
Prefix: "hls_source_conn", if strings.HasPrefix(params.ResolvedSource, "https") {
DialContext: dialContext, proto = "hlss"
}).Do } else {
} proto = "hls"
}
tr := &http.Transport{ tr.DialContext = (&packetdumper.DialContext{
DialContext: dialContext, Prefix: proto + "_source_conn",
TLSClientConfig: tls.MakeConfig(u.Hostname(), params.Conf.SourceFingerprint), }).Do
tr.DialTLSContext = (&packetdumper.DialTLSContext{
DialContext: tr.DialContext,
TLSConfig: tlsConfig,
}).Do
} else {
tr.TLSClientConfig = tlsConfig
} }
defer tr.CloseIdleConnections()
jar, _ := cookiejar.New(nil) jar, _ := cookiejar.New(nil)
@@ -128,7 +132,7 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error {
}, },
} }
err = c.Start() err := c.Start()
if err != nil { if err != nil {
return err return err
} }
+14 -23
View File
@@ -56,6 +56,13 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error {
l := &unix.Listener{ l := &unix.Listener{
Path: params.Path, Path: params.Path,
} }
if s.DumpPackets {
l.Listen = (&packetdumper.Listen{
Prefix: "mpegts_source_unix_conn",
}).Do
}
err = l.Initialize() err = l.Initialize()
if err != nil { if err != nil {
return err return err
@@ -68,36 +75,20 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error {
udpReadBufferSize = *params.Conf.MPEGTSUDPReadBufferSize udpReadBufferSize = *params.Conf.MPEGTSUDPReadBufferSize
} }
listenPacket := net.ListenPacket
if s.DumpPackets {
listenPacket = func(network, address string) (net.PacketConn, error) {
pc, err2 := net.ListenPacket(network, address)
if err2 != nil {
return nil, err2
}
d := &packetdumper.PacketConn{
Prefix: "mpegts_source_packetconn",
PacketConn: pc,
}
err2 = d.Initialize()
if err2 != nil {
return nil, err2
}
return d, nil
}
}
params := udp.URLToParams(u) params := udp.URLToParams(u)
l := &udp.Listener{ l := &udp.Listener{
Address: params.Address, Address: params.Address,
Source: params.Source, Source: params.Source,
IntfName: params.IntfName, IntfName: params.IntfName,
UDPReadBufferSize: int(udpReadBufferSize), UDPReadBufferSize: int(udpReadBufferSize),
ListenPacket: listenPacket,
} }
if s.DumpPackets {
l.ListenPacket = (&packetdumper.ListenPacket{
Prefix: "mpegts_source_packet_conn",
}).Do
}
err = l.Initialize() err = l.Initialize()
if err != nil { if err != nil {
return err return err
+19 -14
View File
@@ -16,7 +16,7 @@ import (
"github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/logger"
"github.com/bluenviron/mediamtx/internal/packetdumper" "github.com/bluenviron/mediamtx/internal/packetdumper"
"github.com/bluenviron/mediamtx/internal/protocols/rtmp" "github.com/bluenviron/mediamtx/internal/protocols/rtmp"
"github.com/bluenviron/mediamtx/internal/protocols/tls" ptls "github.com/bluenviron/mediamtx/internal/protocols/tls"
"github.com/bluenviron/mediamtx/internal/stream" "github.com/bluenviron/mediamtx/internal/stream"
) )
@@ -58,23 +58,28 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error {
} }
} }
dialContext := (&net.Dialer{}).DialContext
if s.DumpPackets {
dialContext = (&packetdumper.DialContext{
Prefix: "rtmp_source_conn",
DialContext: dialContext,
}).Do
}
connectCtx, connectCtxCancel := context.WithTimeout(params.Context, time.Duration(s.ReadTimeout)) connectCtx, connectCtxCancel := context.WithTimeout(params.Context, time.Duration(s.ReadTimeout))
conn := &gortmplib.Client{ conn := &gortmplib.Client{
URL: u, URL: u,
TLSConfig: tls.MakeConfig(u.Hostname(), params.Conf.SourceFingerprint), Publish: false,
Publish: false,
DialContext: dialContext,
} }
tlsConfig := ptls.MakeConfig(params.Conf.SourceFingerprint)
if s.DumpPackets {
conn.DialContext = (&packetdumper.DialContext{
Prefix: u.Scheme + "_source_conn",
}).Do
conn.DialTLSContext = (&packetdumper.DialTLSContext{
DialContext: conn.DialContext,
TLSConfig: tlsConfig,
}).Do
} else {
conn.TLSConfig = tlsConfig
}
err = conn.Initialize(connectCtx) err = conn.Initialize(connectCtx)
connectCtxCancel() connectCtxCancel()
if err != nil { if err != nil {
+10 -25
View File
@@ -66,10 +66,11 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error {
var nc net.Conn var nc net.Conn
switch u.Scheme { switch u.Scheme {
case "unix+rtp": case "unix+rtp": // deprecated
params := unix.URLToParams(u) params := unix.URLToParams(u)
l := &unix.Listener{ l := &unix.Listener{
Path: params.Path, Path: params.Path,
Listen: net.Listen,
} }
err = l.Initialize() err = l.Initialize()
if err != nil { if err != nil {
@@ -83,36 +84,20 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error {
udpReadBufferSize = *params.Conf.RTPUDPReadBufferSize udpReadBufferSize = *params.Conf.RTPUDPReadBufferSize
} }
listenPacket := net.ListenPacket
if s.DumpPackets {
listenPacket = func(network, address string) (net.PacketConn, error) {
pc, err2 := net.ListenPacket(network, address)
if err2 != nil {
return nil, err2
}
d := &packetdumper.PacketConn{
Prefix: "rtp_source_packetconn",
PacketConn: pc,
}
err2 = d.Initialize()
if err2 != nil {
return nil, err2
}
return d, nil
}
}
params := udp.URLToParams(u) params := udp.URLToParams(u)
l := &udp.Listener{ l := &udp.Listener{
Address: params.Address, Address: params.Address,
Source: params.Source, Source: params.Source,
IntfName: params.IntfName, IntfName: params.IntfName,
UDPReadBufferSize: int(udpReadBufferSize), UDPReadBufferSize: int(udpReadBufferSize),
ListenPacket: listenPacket,
} }
if s.DumpPackets {
l.ListenPacket = (&packetdumper.ListenPacket{
Prefix: "rtp_source_packet_conn",
}).Do
}
err = l.Initialize() err = l.Initialize()
if err != nil { if err != nil {
return err return err
+17 -12
View File
@@ -3,7 +3,6 @@ package rtsp
import ( import (
"fmt" "fmt"
"net"
"net/url" "net/url"
"time" "time"
@@ -19,7 +18,7 @@ import (
"github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/logger"
"github.com/bluenviron/mediamtx/internal/packetdumper" "github.com/bluenviron/mediamtx/internal/packetdumper"
"github.com/bluenviron/mediamtx/internal/protocols/rtsp" "github.com/bluenviron/mediamtx/internal/protocols/rtsp"
"github.com/bluenviron/mediamtx/internal/protocols/tls" ptls "github.com/bluenviron/mediamtx/internal/protocols/tls"
"github.com/bluenviron/mediamtx/internal/stream" "github.com/bluenviron/mediamtx/internal/stream"
) )
@@ -122,6 +121,11 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error {
decodeErrors.Start() decodeErrors.Start()
defer decodeErrors.Stop() defer decodeErrors.Stop()
u0, err := url.Parse(params.ResolvedSource)
if err != nil {
return err
}
c := &gortsplib.Client{ c := &gortsplib.Client{
Protocol: params.Conf.RTSPTransport.Protocol, Protocol: params.Conf.RTSPTransport.Protocol,
ReadTimeout: time.Duration(s.ReadTimeout), ReadTimeout: time.Duration(s.ReadTimeout),
@@ -150,11 +154,6 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error {
}, },
} }
u0, err := url.Parse(params.ResolvedSource)
if err != nil {
return err
}
switch u0.Scheme { switch u0.Scheme {
case "rtsp+http", "rtsps+http": case "rtsp+http", "rtsps+http":
c.Tunnel = gortsplib.TunnelHTTP c.Tunnel = gortsplib.TunnelHTTP
@@ -167,7 +166,6 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error {
u0.Scheme = "rtsp" u0.Scheme = "rtsp"
default: default:
u0.Scheme = "rtsps" u0.Scheme = "rtsps"
c.TLSConfig = tls.MakeConfig(u0.Hostname(), params.Conf.SourceFingerprint)
} }
u, err := base.ParseURL(u0.String()) u, err := base.ParseURL(u0.String())
@@ -182,16 +180,23 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error {
s.UDPReadBufferSize = *params.Conf.RTSPUDPReadBufferSize s.UDPReadBufferSize = *params.Conf.RTSPUDPReadBufferSize
} }
tlsConfig := ptls.MakeConfig(params.Conf.SourceFingerprint)
if s.DumpPackets { if s.DumpPackets {
c.DialContext = (&packetdumper.DialContext{ c.DialContext = (&packetdumper.DialContext{
Prefix: "rtsp_source_conn", Prefix: u.Scheme + "_source_conn",
DialContext: (&net.Dialer{}).DialContext,
}).Do }).Do
c.ListenPacket = (&packetdumper.ListenPacket{ c.ListenPacket = (&packetdumper.ListenPacket{
Prefix: "rtsp_source_packetconn", Prefix: u.Scheme + "_source_packet_conn",
ListenPacket: net.ListenPacket,
}).Do }).Do
c.DialTLSContext = (&packetdumper.DialTLSContext{
DialContext: c.DialContext,
TLSConfig: tlsConfig,
}).Do
} else {
c.TLSConfig = tlsConfig
} }
err = c.Start() err = c.Start()
+14 -12
View File
@@ -3,7 +3,6 @@ package webrtc
import ( import (
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@@ -15,7 +14,7 @@ import (
"github.com/bluenviron/mediamtx/internal/defs" "github.com/bluenviron/mediamtx/internal/defs"
"github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/logger"
"github.com/bluenviron/mediamtx/internal/packetdumper" "github.com/bluenviron/mediamtx/internal/packetdumper"
"github.com/bluenviron/mediamtx/internal/protocols/tls" ptls "github.com/bluenviron/mediamtx/internal/protocols/tls"
"github.com/bluenviron/mediamtx/internal/protocols/webrtc" "github.com/bluenviron/mediamtx/internal/protocols/webrtc"
"github.com/bluenviron/mediamtx/internal/protocols/whip" "github.com/bluenviron/mediamtx/internal/protocols/whip"
"github.com/bluenviron/mediamtx/internal/stream" "github.com/bluenviron/mediamtx/internal/stream"
@@ -49,22 +48,25 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error {
return err return err
} }
u.Scheme = strings.ReplaceAll(u.Scheme, "whep", "http") tr := &http.Transport{}
defer tr.CloseIdleConnections()
dialContext := (&net.Dialer{}).DialContext tlsConfig := ptls.MakeConfig(params.Conf.SourceFingerprint)
if s.DumpPackets { if s.DumpPackets {
dialContext = (&packetdumper.DialContext{ tr.DialContext = (&packetdumper.DialContext{
Prefix: "webrtc_source_conn", Prefix: u.Scheme + "_source_conn",
DialContext: dialContext,
}).Do }).Do
tr.DialTLSContext = (&packetdumper.DialTLSContext{
DialContext: tr.DialContext,
TLSConfig: tlsConfig,
}).Do
} else {
tr.TLSClientConfig = tlsConfig
} }
tr := &http.Transport{ u.Scheme = strings.ReplaceAll(u.Scheme, "whep", "http")
DialContext: dialContext,
TLSClientConfig: tls.MakeConfig(u.Hostname(), params.Conf.SourceFingerprint),
}
defer tr.CloseIdleConnections()
client := whip.Client{ client := whip.Client{
URL: u, URL: u,