diff --git a/client.go b/client.go index 0161a259..d22537a3 100644 --- a/client.go +++ b/client.go @@ -243,6 +243,16 @@ func generateAnnounceData( return data, nil } +func announceDataFormatSSRCs(formats map[uint8]*clientAnnounceDataFormat) []uint32 { + ssrcs := make([]uint32, len(formats)) + n := 0 + for _, af := range formats { + ssrcs[n] = af.localSSRC + n++ + } + return ssrcs +} + func prepareForAnnounce( desc *description.Session, announceData map[*description.Media]*clientAnnounceDataMedia, @@ -255,26 +265,19 @@ func prepareForAnnounce( m.Profile = headers.TransportProfileSAVP announceDataMedia := announceData[m] - ssrcs := make([]uint32, len(m.Formats)) - n := 0 - for _, af := range announceDataMedia.formats { - ssrcs[n] = af.localSSRC - n++ - } - // create a temporary Context. // Context is needed to extract ROC, but since client has not started streaming, // ROC is always zero, therefore a temporary Context can be used. srtpCtx := &wrappedSRTPContext{ key: announceDataMedia.srtpOutKey, - ssrcs: ssrcs, + ssrcs: announceDataFormatSSRCs(announceDataMedia.formats), } err := srtpCtx.initialize() if err != nil { return err } - mikeyMsg, err := mikeyGenerate(srtpCtx) + mikeyMsg, err := contextToMikey(srtpCtx) if err != nil { return err } @@ -543,42 +546,43 @@ type Client struct { receiverReportPeriod time.Duration checkTimeoutPeriod time.Duration - ctx context.Context - ctxCancel func() - propsMutex sync.RWMutex - state clientState - nconn net.Conn - conn *conn.Conn - session string - sender *auth.Sender - cseq int - optionsSent bool - useGetParameter bool - lastDescribeURL *base.URL - lastDescribeDesc *description.Session - baseURL *base.URL - announceData map[*description.Media]*clientAnnounceDataMedia // record - setuppedTransport *SessionTransport - backChannelSetupped bool - stdChannelSetupped bool - setuppedMedias map[*description.Media]*clientMedia - tcpCallbackByChannel map[int]readFunc - lastRange *headers.Range - checkTimeoutTimer *time.Timer - checkTimeoutInitial bool - tcpLastFrameTime atomic.Int64 - keepAlivePeriod time.Duration - keepAliveTimer *time.Timer - closeError error - writerMutex sync.RWMutex - writer *asyncprocessor.Processor - reader *clientReader - timeDecoder *rtptime.GlobalDecoder - mustClose bool - tcpFrame *base.InterleavedFrame - tcpBuffer []byte - bytesReceived atomic.Uint64 - bytesSent atomic.Uint64 + ctx context.Context + ctxCancel func() + propsMutex sync.RWMutex + state clientState + nconn net.Conn + conn *conn.Conn + session string + sender *auth.Sender + cseq int + optionsSent bool + useGetParameter bool + lastDescribeURL *base.URL + lastDescribeDesc *description.Session + baseURL *base.URL + announceData map[*description.Media]*clientAnnounceDataMedia // record + setuppedTransport *SessionTransport + axisClientManagedKeys bool + backChannelSetupped bool + stdChannelSetupped bool + setuppedMedias map[*description.Media]*clientMedia + tcpCallbackByChannel map[int]readFunc + lastRange *headers.Range + checkTimeoutTimer *time.Timer + checkTimeoutInitial bool + tcpLastFrameTime atomic.Int64 + keepAlivePeriod time.Duration + keepAliveTimer *time.Timer + closeError error + writerMutex sync.RWMutex + writer *asyncprocessor.Processor + reader *clientReader + timeDecoder *rtptime.GlobalDecoder + mustClose bool + tcpFrame *base.InterleavedFrame + tcpBuffer []byte + bytesReceived atomic.Uint64 + bytesSent atomic.Uint64 // in chOptions chan optionsReq @@ -1008,11 +1012,13 @@ func (c *Client) trySwitchingProtocol() error { prevBaseURL := c.baseURL prevMedias := c.setuppedMedias + prevProfile := c.setuppedTransport.Profile c.reset() c.setuppedTransport = &SessionTransport{ Protocol: ProtocolTCP, + Profile: prevProfile, } // some Hikvision cameras require a describe before a setup @@ -1792,8 +1798,19 @@ func (c *Client) doSetup( } } + var srtpOutMki []byte + + if c.axisClientManagedKeys { + srtpOutMki = make([]byte, mkiLength) + _, err = rand.Read(srtpOutMki) + if err != nil { + return nil, err + } + } + srtpOutCtx = &wrappedSRTPContext{ key: srtpOutKey, + mki: srtpOutMki, ssrcs: ssrcsMapToList(localSSRCs), } err = srtpOutCtx.initialize() @@ -1802,7 +1819,7 @@ func (c *Client) doSetup( } var mikeyMsg *mikey.Message - mikeyMsg, err = mikeyGenerate(srtpOutCtx) + mikeyMsg, err = contextToMikey(srtpOutCtx) if err != nil { return nil, err } @@ -1831,7 +1848,8 @@ func (c *Client) doSetup( if res.StatusCode != base.StatusOK { // switch transport automatically if res.StatusCode == base.StatusUnsupportedTransport && - c.setuppedTransport == nil && c.Protocol == nil { + c.setuppedTransport == nil && + c.Protocol == nil { c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP2{}) c.setuppedTransport = &SessionTransport{ Protocol: ProtocolTCP, @@ -1841,6 +1859,13 @@ func (c *Client) doSetup( return c.doSetup(baseURL, medi, 0, 0) } + if res.StatusCode == base.StatusKeyManagementFailure && + strings.ToLower(res.StatusMessage) == "key management failure" && + !c.axisClientManagedKeys { + c.axisClientManagedKeys = true + return c.doSetup(baseURL, medi, rtpPort, rtcpPort) + } + return nil, liberrors.ErrClientBadStatusCode{Code: res.StatusCode, Message: res.StatusMessage} } @@ -2033,34 +2058,45 @@ func (c *Client) doSetup( } if isSecure(th.Profile) { - var mikeyMsg *mikey.Message - - // extract key-mgmt from (in order of priority): - // - response - // - media SDP attributes - // - session SDP attributes - switch { - case res.Header["KeyMgmt"] != nil: - var keyMgmt headers.KeyMgmt - err = keyMgmt.Unmarshal(res.Header["KeyMgmt"]) + if c.axisClientManagedKeys { + srtpInCtx = &wrappedSRTPContext{ + key: srtpOutCtx.key, + mki: srtpOutCtx.mki, + } + err = srtpInCtx.initialize() if err != nil { return nil, err } - mikeyMsg = keyMgmt.MikeyMessage + } else { + var mikeyMsg *mikey.Message - case medi.KeyMgmtMikey != nil: - mikeyMsg = medi.KeyMgmtMikey + // extract key-mgmt from (in order of priority): + // - response + // - media SDP attributes + // - session SDP attributes + switch { + case res.Header["KeyMgmt"] != nil: + var keyMgmt headers.KeyMgmt + err = keyMgmt.Unmarshal(res.Header["KeyMgmt"]) + if err != nil { + return nil, err + } + mikeyMsg = keyMgmt.MikeyMessage - case c.lastDescribeDesc.KeyMgmtMikey != nil: - mikeyMsg = c.lastDescribeDesc.KeyMgmtMikey + case medi.KeyMgmtMikey != nil: + mikeyMsg = medi.KeyMgmtMikey - default: - return nil, fmt.Errorf("server did not provide key-mgmt data in any supported way") - } + case c.lastDescribeDesc.KeyMgmtMikey != nil: + mikeyMsg = c.lastDescribeDesc.KeyMgmtMikey - srtpInCtx, err = mikeyToContext(mikeyMsg) - if err != nil { - return nil, err + default: + return nil, fmt.Errorf("server did not provide key-mgmt data in any supported way") + } + + srtpInCtx, err = mikeyToContext(mikeyMsg) + if err != nil { + return nil, err + } } } diff --git a/client_play_test.go b/client_play_test.go index 4b87e49a..aea43552 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -417,7 +417,7 @@ func TestClientPlay(t *testing.T) { require.NoError(t, err2) var mikeyMsg *mikey.Message - mikeyMsg, err = mikeyGenerate(srtpOutCtx[i]) + mikeyMsg, err = contextToMikey(srtpOutCtx[i]) require.NoError(t, err) var enc base.HeaderValue @@ -747,7 +747,7 @@ func TestClientPlaySRTPVariants(t *testing.T) { err2 = srtpOutCtx.initialize() require.NoError(t, err2) - mikeyMsg, err2 := mikeyGenerate(srtpOutCtx) + mikeyMsg, err2 := contextToMikey(srtpOutCtx) require.NoError(t, err2) enc, err2 := mikeyMsg.Marshal() @@ -892,7 +892,15 @@ func TestClientPlaySRTPVariants(t *testing.T) { packetRecv := make(chan struct{}) - c.OnPacketRTPAny(func(_ *description.Media, _ format.Format, _ *rtp.Packet) { + c.OnPacketRTPAny(func(_ *description.Media, _ format.Format, pkt *rtp.Packet) { + require.Equal(t, &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 96, + SSRC: pkt.SSRC, + }, + Payload: testRTPPacket.Payload, + }, pkt) close(packetRecv) }) @@ -904,6 +912,253 @@ func TestClientPlaySRTPVariants(t *testing.T) { } } +func TestClientPlayAxisClientManagedKeys(t *testing.T) { + cert, err := tls.X509KeyPair(serverCert, serverKey) + require.NoError(t, err) + + l, err := tls.Listen("tcp", "127.0.0.1:8554", &tls.Config{Certificates: []tls.Certificate{cert}}) + require.NoError(t, err) + defer l.Close() + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + + rtcpRecv := make(chan struct{}) + + go func() { + defer close(serverDone) + + nconn, err2 := l.Accept() + require.NoError(t, err2) + defer nconn.Close() + conn := conn.NewConn(bufio.NewReader(nconn), nconn) + + req, err2 := conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Options, req.Method) + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + string(base.Setup), + string(base.Play), + }, ", ")}, + }, + }) + require.NoError(t, err2) + + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Describe, req.Method) + + sdp := "v=0\n" + + "o=actionmovie 2891092738 2891092738 IN IP4 movie.example.com\n" + + "s=Action Movie\n" + + "t=0 0\n" + + "c=IN IP4 movie.example.com\n" + + "m=video 0 RTP/SAVP 96\n" + + "a=rtpmap:96 H264/90000\n" + + "a=control:trackID=0\n" + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"application/sdp"}, + "Content-Base": base.HeaderValue{"rtsps://127.0.0.1:8554/stream/"}, + }, + Body: []byte(sdp), + }) + require.NoError(t, err2) + + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Setup, req.Method) + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusKeyManagementFailure, + StatusMessage: "Key management failure", + }) + require.NoError(t, err2) + + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Setup, req.Method) + + var inTH headers.Transport + err2 = inTH.Unmarshal(req.Header["Transport"]) + require.NoError(t, err2) + require.Equal(t, (*headers.TransportMode)(nil), inTH.Mode) + + var reqKeyMgmt headers.KeyMgmt + err2 = reqKeyMgmt.Unmarshal(req.Header["KeyMgmt"]) + require.NoError(t, err2) + require.Equal(t, req.URL.String(), reqKeyMgmt.URL) + + reqKEMAC, ok := mikeyGetPayload[*mikey.PayloadKEMAC](reqKeyMgmt.MikeyMessage) + require.True(t, ok) + require.Len(t, reqKEMAC.SubPayloads, 1) + require.Len(t, reqKEMAC.SubPayloads[0].KeyData, srtpKeyLength) + require.Equal(t, mikey.SubPayloadKeyDataKVSPI, reqKEMAC.SubPayloads[0].KV) + require.Len(t, reqKEMAC.SubPayloads[0].SPI, mkiLength) + + serverOutCtx, err2 := mikeyToContext(reqKeyMgmt.MikeyMessage) + require.NoError(t, err2) + + responseKey := make([]byte, srtpKeyLength) + _, err2 = rand.Read(responseKey) + require.NoError(t, err2) + + responseMKI := make([]byte, mkiLength) + _, err2 = rand.Read(responseMKI) + require.NoError(t, err2) + + responseCtx := &wrappedSRTPContext{ + key: responseKey, + mki: responseMKI, + ssrcs: []uint32{845234432}, + } + err2 = responseCtx.initialize() + require.NoError(t, err2) + + responseMikey, err2 := contextToMikey(responseCtx) + require.NoError(t, err2) + + responseKEMAC, ok := mikeyGetPayload[*mikey.PayloadKEMAC](responseMikey) + require.True(t, ok) + require.Len(t, responseKEMAC.SubPayloads, 1) + require.NotEqual(t, reqKEMAC.SubPayloads[0].KeyData, responseKEMAC.SubPayloads[0].KeyData) + require.NotEqual(t, reqKEMAC.SubPayloads[0].SPI, responseKEMAC.SubPayloads[0].SPI) + + th := headers.Transport{ + Profile: headers.TransportProfileSAVP, + } + th.Delivery = ptrOf(headers.TransportDeliveryUnicast) + th.Protocol = headers.TransportProtocolUDP + th.ClientPorts = inTH.ClientPorts + th.ServerPorts = &[2]int{34556, 34557} + + responseKeyMgmt, err2 := headers.KeyMgmt{ + URL: req.URL.String(), + MikeyMessage: responseMikey, + }.Marshal() + require.NoError(t, err2) + + l1, err2 := net.ListenPacket( + "udp", net.JoinHostPort("127.0.0.1", strconv.FormatInt(int64(th.ServerPorts[0]), 10))) + require.NoError(t, err2) + defer l1.Close() + + l2, err2 := net.ListenPacket( + "udp", net.JoinHostPort("127.0.0.1", strconv.FormatInt(int64(th.ServerPorts[1]), 10))) + require.NoError(t, err2) + defer l2.Close() + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": th.Marshal(), + "KeyMgmt": responseKeyMgmt, + }, + }) + require.NoError(t, err2) + + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Play, req.Method) + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err2) + + buf := make([]byte, 2048) + _, _, err2 = l2.ReadFrom(buf) + require.NoError(t, err2) + + encr, err2 := serverOutCtx.encryptRTP(buf, testRTPPacketMarshaled, nil) + require.NoError(t, err2) + + _, err2 = l1.WriteTo(encr, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: inTH.ClientPorts[0], + }) + require.NoError(t, err2) + + buf = make([]byte, 2048) + n, _, err2 := l2.ReadFrom(buf) + require.NoError(t, err2) + + decr, err2 := serverOutCtx.decryptRTCP(buf[:n], buf[:n], nil) + require.NoError(t, err2) + + packets, err2 := rtcp.Unmarshal(decr) + require.NoError(t, err2) + + rr, ok := packets[0].(*rtcp.ReceiverReport) + require.True(t, ok) + + require.Equal(t, &rtcp.ReceiverReport{ + SSRC: rr.SSRC, + Reports: []rtcp.ReceptionReport{{ + SSRC: rr.Reports[0].SSRC, + LastSequenceNumber: uint32(testRTPPacket.SequenceNumber), + Delay: rr.Reports[0].Delay, + }}, + ProfileExtensions: []uint8{}, + }, rr) + + close(rtcpRecv) + + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Teardown, req.Method) + }() + + u, err := base.ParseURL("rtsps://127.0.0.1:8554/stream") + require.NoError(t, err) + + c := Client{ + Scheme: u.Scheme, + Host: u.Host, + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + receiverReportPeriod: 500 * time.Millisecond, + Protocol: ptrOf(ProtocolUDP), + } + + err = c.Start() + require.NoError(t, err) + defer c.Close() + + sd, _, err := c.Describe(u) + require.NoError(t, err) + + err = c.SetupAll(sd.BaseURL, sd.Medias) + require.NoError(t, err) + + packetRecv := make(chan struct{}) + + c.OnPacketRTPAny(func(_ *description.Media, _ format.Format, pkt *rtp.Packet) { + require.Equal(t, &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 96, + SSRC: pkt.SSRC, + }, + Payload: pkt.Payload, + }, pkt) + close(packetRecv) + }) + + _, err = c.Play(nil) + require.NoError(t, err) + + <-packetRecv + + <-rtcpRecv +} + func TestClientPlayPartial(t *testing.T) { listenIP := multicastCapableIP(t) l, err := net.Listen("tcp", listenIP+":8554") @@ -1310,10 +1565,12 @@ func TestClientPlayAnyPort(t *testing.T) { var n int n, _, err2 = l1b.ReadFrom(buf) require.NoError(t, err2) + var packets []rtcp.Packet packets, err2 = rtcp.Unmarshal(buf[:n]) require.NoError(t, err2) require.Equal(t, &testRTCPPacket, packets[0]) + close(serverRecv) } @@ -2571,10 +2828,13 @@ func TestClientPlayRTCPReport(t *testing.T) { buf = make([]byte, 2048) n, _, err2 := l2.ReadFrom(buf) require.NoError(t, err2) + packets, err2 := rtcp.Unmarshal(buf[:n]) require.NoError(t, err2) + rr, ok := packets[0].(*rtcp.ReceiverReport) require.True(t, ok) + require.Equal(t, &rtcp.ReceiverReport{ SSRC: rr.SSRC, Reports: []rtcp.ReceptionReport{ @@ -4061,7 +4321,7 @@ func TestClientPlayDifferentSSRCs(t *testing.T) { require.NoError(t, err2) var mikeyMsg *mikey.Message - mikeyMsg, err = mikeyGenerate(srtpOutCtx) + mikeyMsg, err = contextToMikey(srtpOutCtx) require.NoError(t, err) var enc base.HeaderValue diff --git a/client_record_test.go b/client_record_test.go index a8b85c14..7b03982b 100644 --- a/client_record_test.go +++ b/client_record_test.go @@ -289,7 +289,7 @@ func TestClientRecord(t *testing.T) { require.NoError(t, err2) var mikeyMsg *mikey.Message - mikeyMsg, err2 = mikeyGenerate(srtpOutCtx) + mikeyMsg, err2 = contextToMikey(srtpOutCtx) require.NoError(t, err2) var enc base.HeaderValue diff --git a/client_test.go b/client_test.go index 3430ea3b..7aeb566c 100644 --- a/client_test.go +++ b/client_test.go @@ -684,9 +684,9 @@ func TestClientTunnelHTTP(t *testing.T) { var scheme string if ca == "http" { - scheme = "rtsp" + scheme = schemeRTSP } else { - scheme = "rtsps" + scheme = schemeRTSPS } serverDone := make(chan struct{}) @@ -1033,9 +1033,9 @@ func TestClientTunnelWebSocket(t *testing.T) { t.Run(ca, func(t *testing.T) { var scheme string if ca == "ws" { - scheme = "rtsp" + scheme = schemeRTSP } else { - scheme = "rtsps" + scheme = schemeRTSPS } s := &http.Server{ diff --git a/constants.go b/constants.go index 58eb6889..78698469 100644 --- a/constants.go +++ b/constants.go @@ -16,4 +16,7 @@ const ( // 10 (HMAC SHA1 authentication tag) + 4 (sequence number) srtcpOverhead = 14 + + // Axis requires a 4-byte MKI + mkiLength = 4 ) diff --git a/pkg/base/response.go b/pkg/base/response.go index 5cae7189..a0094e2d 100644 --- a/pkg/base/response.go +++ b/pkg/base/response.go @@ -44,6 +44,7 @@ const ( StatusUnsupportedTransport StatusCode = 461 StatusDestinationUnreachable StatusCode = 462 StatusDestinationProhibited StatusCode = 463 + StatusKeyManagementFailure StatusCode = 463 StatusDataTransportNotReadyYet StatusCode = 464 StatusNotificationReasonUnknown StatusCode = 465 StatusKeyManagementError StatusCode = 466 diff --git a/pkg/mikey/message_test.go b/pkg/mikey/message_test.go index 584a5ce9..1dde80ca 100644 --- a/pkg/mikey/message_test.go +++ b/pkg/mikey/message_test.go @@ -371,6 +371,42 @@ var cases = []struct { }, }, }, + { + "KEMAC key data with SPI", + []byte{ + 0x01, 0x00, 0x01, 0x00, 0x12, 0x34, 0x56, 0x78, + 0x01, 0x00, + 0x03, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, + 0x00, 0x00, 0x00, 0x0a, + 0x00, 0x21, 0x00, 0x03, 0xaa, 0xbb, 0xcc, 0x02, 0x11, 0x22, + 0x00, + }, + Message{ + Header: Header{ + Version: 1, + CSBID: 0x12345678, + CSIDMapInfo: []SRTPIDEntry{ + { + PolicyNo: 0x03, + SSRC: 0x11223344, + ROC: 0x55667788, + }, + }, + }, + Payloads: []Payload{ + &PayloadKEMAC{ + SubPayloads: []*SubPayloadKeyData{ + { + Type: SubPayloadKeyDataTypeTEK, + KV: SubPayloadKeyDataKVSPI, + KeyData: []byte{0xaa, 0xbb, 0xcc}, + SPI: []byte{0x11, 0x22}, + }, + }, + }, + }, + }, + }, } func TestUnmarshal(t *testing.T) { diff --git a/pkg/mikey/sub_payload_key_data.go b/pkg/mikey/sub_payload_key_data.go index 9a6f0517..4bf582cb 100644 --- a/pkg/mikey/sub_payload_key_data.go +++ b/pkg/mikey/sub_payload_key_data.go @@ -2,19 +2,29 @@ package mikey import "fmt" -// SubPayloadKeyDataKeyType is a data key type. -type SubPayloadKeyDataKeyType uint8 +// SubPayloadKeyDataType is a key type. +type SubPayloadKeyDataType uint8 // RFC3830, table 6.13.a const ( - SubPayloadKeyDataKeyTypeTEK SubPayloadKeyDataKeyType = 2 + SubPayloadKeyDataTypeTEK SubPayloadKeyDataType = 2 +) + +// SubPayloadKeyDataKV is a KV (key validity) value. +type SubPayloadKeyDataKV uint8 + +// RFC3830, table 6.13.b +const ( + SubPayloadKeyDataKVNull SubPayloadKeyDataKV = 0 + SubPayloadKeyDataKVSPI SubPayloadKeyDataKV = 1 ) // SubPayloadKeyData is a key data sub-payload. type SubPayloadKeyData struct { - Type SubPayloadKeyDataKeyType - KV uint8 + Type SubPayloadKeyDataType + KV SubPayloadKeyDataKV KeyData []byte + SPI []byte } func (p *SubPayloadKeyData) unmarshal(buf []byte) (int, error) { @@ -23,15 +33,15 @@ func (p *SubPayloadKeyData) unmarshal(buf []byte) (int, error) { } n := 1 - p.Type = SubPayloadKeyDataKeyType(buf[n] >> 4) - p.KV = buf[n] & 0b00001111 + p.Type = SubPayloadKeyDataType(buf[n] >> 4) + p.KV = SubPayloadKeyDataKV(buf[n] & 0b1111) n++ - if p.Type != SubPayloadKeyDataKeyTypeTEK { + if p.Type != SubPayloadKeyDataTypeTEK { return 0, fmt.Errorf("unsupported key type: %v", p.Type) } - if p.KV != 0 { + if p.KV != SubPayloadKeyDataKVNull && p.KV != SubPayloadKeyDataKVSPI { return 0, fmt.Errorf("unsupported KV: %v", p.KV) } @@ -45,15 +55,35 @@ func (p *SubPayloadKeyData) unmarshal(buf []byte) (int, error) { p.KeyData = buf[n : n+keyDataLen] n += keyDataLen + if p.KV == SubPayloadKeyDataKVSPI { + if len(buf[n:]) < 1 { + return 0, fmt.Errorf("buffer too short") + } + + spiLen := int(buf[n]) + n++ + + if len(buf[n:]) < spiLen { + return 0, fmt.Errorf("buffer too short") + } + + p.SPI = buf[n : n+spiLen] + n += spiLen + } + return n, nil } func (p SubPayloadKeyData) marshalSize() int { - return 4 + len(p.KeyData) + n := 4 + len(p.KeyData) + if p.KV == SubPayloadKeyDataKVSPI { + n += 1 + len(p.SPI) + } + return n } func (p SubPayloadKeyData) marshalTo(buf []byte) (int, error) { - buf[1] = byte(p.Type)<<4 | p.KV + buf[1] = byte(p.Type)<<4 | byte(p.KV) keyDataLen := len(p.KeyData) buf[2] = byte(keyDataLen >> 8) @@ -62,5 +92,11 @@ func (p SubPayloadKeyData) marshalTo(buf []byte) (int, error) { n += copy(buf[n:], p.KeyData) + if p.KV == SubPayloadKeyDataKVSPI { + buf[n] = uint8(len(p.SPI)) + n++ + n += copy(buf[n:], p.SPI) + } + return n, nil } diff --git a/server_conn.go b/server_conn.go index a7475e3d..2eec4b67 100644 --- a/server_conn.go +++ b/server_conn.go @@ -72,7 +72,7 @@ func prepareForDescribe( sm := medias[medi] var err error - keyMgmtMikey, err = mikeyGenerate(sm.srtpOutCtx) + keyMgmtMikey, err = contextToMikey(sm.srtpOutCtx) if err != nil { return nil, err } diff --git a/server_play_test.go b/server_play_test.go index 7030c54b..d1c38c07 100644 --- a/server_play_test.go +++ b/server_play_test.go @@ -858,7 +858,7 @@ func TestServerPlay(t *testing.T) { require.NoError(t, err) var mikeyMsg *mikey.Message - mikeyMsg, err = mikeyGenerate(srtpOutCtx) + mikeyMsg, err = contextToMikey(srtpOutCtx) require.NoError(t, err) var enc base.HeaderValue diff --git a/server_record_test.go b/server_record_test.go index fd8d4c19..6af259d6 100644 --- a/server_record_test.go +++ b/server_record_test.go @@ -811,7 +811,7 @@ func TestServerRecord(t *testing.T) { require.NoError(t, err) var mikeyMsg *mikey.Message - mikeyMsg, err = mikeyGenerate(srtpOutCtx[i]) + mikeyMsg, err = contextToMikey(srtpOutCtx[i]) require.NoError(t, err) var enc base.HeaderValue @@ -1944,7 +1944,7 @@ func TestServerRecordDifferentSSRCs(t *testing.T) { require.NoError(t, err) var mikeyMsg *mikey.Message - mikeyMsg, err = mikeyGenerate(srtpOutCtx) + mikeyMsg, err = contextToMikey(srtpOutCtx) require.NoError(t, err) var enc base.HeaderValue diff --git a/server_session.go b/server_session.go index 4242c514..be6a64e0 100644 --- a/server_session.go +++ b/server_session.go @@ -1168,7 +1168,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( if isSecure(inTH.Profile) { var mk *mikey.Message - mk, err = mikeyGenerate(sm.srtpOutCtx) + mk, err = contextToMikey(sm.srtpOutCtx) if err != nil { return &base.Response{ StatusCode: base.StatusInternalServerError, diff --git a/wrapped_srtp_context.go b/wrapped_srtp_context.go index 197844d7..37083deb 100644 --- a/wrapped_srtp_context.go +++ b/wrapped_srtp_context.go @@ -103,6 +103,7 @@ func mikeyToContext(mikeyMsg *mikey.Message) (*wrappedSRTPContext, error) { srtpCtx := &wrappedSRTPContext{ key: kemacPayload.SubPayloads[0].KeyData, + mki: kemacPayload.SubPayloads[0].SPI, ssrcs: ssrcs, startROCs: startROCs, } @@ -114,7 +115,7 @@ func mikeyToContext(mikeyMsg *mikey.Message) (*wrappedSRTPContext, error) { return srtpCtx, nil } -func mikeyGenerate(ctx *wrappedSRTPContext) (*mikey.Message, error) { +func contextToMikey(ctx *wrappedSRTPContext) (*mikey.Message, error) { csbID, err := randUint32() if err != nil { return nil, err @@ -145,21 +146,30 @@ func mikeyGenerate(ctx *wrappedSRTPContext) (*mikey.Message, error) { return nil, err } - keyLen, err := ctx.profile.KeyLen() + profile := srtp.ProtectionProfileAes128CmHmacSha1_80 + + keyLen, err := profile.KeyLen() if err != nil { return nil, err } - authKeyLen, err := ctx.profile.AuthKeyLen() + authKeyLen, err := profile.AuthKeyLen() if err != nil { return nil, err } - authTagLen, err := ctx.profile.AuthTagRTPLen() + authTagLen, err := profile.AuthTagRTPLen() if err != nil { return nil, err } + var kv mikey.SubPayloadKeyDataKV + if len(ctx.mki) != 0 { + kv = mikey.SubPayloadKeyDataKVSPI + } else { + kv = mikey.SubPayloadKeyDataKVNull + } + msg.Payloads = []mikey.Payload{ &mikey.PayloadT{ TSType: 0, @@ -207,8 +217,10 @@ func mikeyGenerate(ctx *wrappedSRTPContext) (*mikey.Message, error) { &mikey.PayloadKEMAC{ SubPayloads: []*mikey.SubPayloadKeyData{ { - Type: mikey.SubPayloadKeyDataKeyTypeTEK, + Type: mikey.SubPayloadKeyDataTypeTEK, + KV: kv, KeyData: ctx.key, + SPI: ctx.mki, }, }, }, @@ -223,19 +235,22 @@ func mikeyGenerate(ctx *wrappedSRTPContext) (*mikey.Message, error) { // - mutex around Encrypt*, ROC* type wrappedSRTPContext struct { key []byte + mki []byte ssrcs []uint32 startROCs []uint32 - profile srtp.ProtectionProfile - w *srtp.Context - mutex sync.RWMutex + w *srtp.Context + mutex sync.RWMutex } func (ctx *wrappedSRTPContext) initialize() error { - ctx.profile = srtp.ProtectionProfileAes128CmHmacSha1_80 + opts := make([]srtp.ContextOption, 0, 1) + if len(ctx.mki) != 0 { + opts = append(opts, srtp.MasterKeyIndicator(ctx.mki)) + } var err error - ctx.w, err = srtp.CreateContext(ctx.key[:16], ctx.key[16:], ctx.profile) + ctx.w, err = srtp.CreateContext(ctx.key[:16], ctx.key[16:], srtp.ProtectionProfileAes128CmHmacSha1_80, opts...) if err != nil { return err }