diff --git a/api/openapi.yaml b/api/openapi.yaml index 44ecfba6..66b73db3 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -763,6 +763,8 @@ components: type: string query: type: string + user: + type: string bytesReceived: type: integer format: uint64 @@ -838,6 +840,8 @@ components: type: string query: type: string + user: + type: string transport: type: string nullable: true @@ -911,6 +915,8 @@ components: type: string query: type: string + user: + type: string packetsSent: type: integer format: uint64 @@ -1185,6 +1191,8 @@ components: type: string query: type: string + user: + type: string bytesReceived: type: integer format: uint64 diff --git a/internal/api/api.go b/internal/api/api.go index e44c16a7..a1b2dcbd 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -66,7 +66,7 @@ func recordingsOfPath( } type apiAuthManager interface { - Authenticate(req *auth.Request) *auth.Error + Authenticate(req *auth.Request) (string, *auth.Error) RefreshJWTJWKS() } @@ -250,7 +250,7 @@ func (a *API) middlewareAuth(ctx *gin.Context) { IP: net.ParseIP(ctx.ClientIP()), } - err := a.AuthManager.Authenticate(req) + _, err := a.AuthManager.Authenticate(req) if err != nil { if err.AskCredentials { ctx.Header("WWW-Authenticate", `Basic realm="mediamtx"`) diff --git a/internal/api/api_config_global_test.go b/internal/api/api_config_global_test.go index a6c5cbe1..e23fe54a 100644 --- a/internal/api/api_config_global_test.go +++ b/internal/api/api_config_global_test.go @@ -23,12 +23,12 @@ func TestConfigGlobalGet(t *testing.T) { WriteTimeout: conf.Duration(10 * time.Second), Conf: cnf, AuthManager: &test.AuthManager{ - AuthenticateImpl: func(req *auth.Request) *auth.Error { + AuthenticateImpl: func(req *auth.Request) (string, *auth.Error) { require.Equal(t, conf.AuthActionAPI, req.Action) require.Equal(t, "myuser", req.Credentials.User) require.Equal(t, "mypass", req.Credentials.Pass) checked = true - return nil + return req.Credentials.User, nil }, }, Parent: &testParent{}, diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 454e53af..70074618 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -162,8 +162,8 @@ func TestAuthJWKSRefresh(t *testing.T) { ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), AuthManager: &test.AuthManager{ - AuthenticateImpl: func(_ *auth.Request) *auth.Error { - return nil + AuthenticateImpl: func(_ *auth.Request) (string, *auth.Error) { + return "", nil }, RefreshJWTJWKSImpl: func() { ok = true @@ -197,11 +197,11 @@ func TestAuthError(t *testing.T) { WriteTimeout: conf.Duration(10 * time.Second), Conf: cnf, AuthManager: &test.AuthManager{ - AuthenticateImpl: func(req *auth.Request) *auth.Error { + AuthenticateImpl: func(req *auth.Request) (string, *auth.Error) { if req.Credentials.User == "" { - return &auth.Error{AskCredentials: true, Wrapped: fmt.Errorf("auth error")} + return "", &auth.Error{AskCredentials: true, Wrapped: fmt.Errorf("auth error")} } - return &auth.Error{Wrapped: fmt.Errorf("auth error")} + return "", &auth.Error{Wrapped: fmt.Errorf("auth error")} }, }, Parent: &testParent{ diff --git a/internal/auth/manager.go b/internal/auth/manager.go index 71a67ab5..5dfb4da2 100644 --- a/internal/auth/manager.go +++ b/internal/auth/manager.go @@ -92,41 +92,43 @@ func (m *Manager) ReloadInternalUsers(u []conf.AuthInternalUser) { } // Authenticate authenticates a request. -func (m *Manager) Authenticate(req *Request) *Error { +// It returns the user name. +func (m *Manager) Authenticate(req *Request) (string, *Error) { + var user string var err error switch m.Method { case conf.AuthMethodInternal: - err = m.authenticateInternal(req) + user, err = m.authenticateInternal(req) case conf.AuthMethodHTTP: - err = m.authenticateHTTP(req) + user, err = m.authenticateHTTP(req) default: - err = m.authenticateJWT(req) + user, err = m.authenticateJWT(req) } if err != nil { - return &Error{ + return "", &Error{ Wrapped: err, AskCredentials: (req.Credentials.User == "" && req.Credentials.Pass == "" && req.Credentials.Token == ""), } } - return nil + return user, nil } -func (m *Manager) authenticateInternal(req *Request) error { +func (m *Manager) authenticateInternal(req *Request) (string, error) { m.mutex.RLock() defer m.mutex.RUnlock() for _, u := range m.InternalUsers { if ok := m.authenticateWithUser(req, &u); ok { - return nil + return req.Credentials.User, nil } } - return fmt.Errorf("authentication failed") + return "", fmt.Errorf("authentication failed") } func (m *Manager) authenticateWithUser( @@ -156,9 +158,9 @@ func (m *Manager) authenticateWithUser( return true } -func (m *Manager) authenticateHTTP(req *Request) error { +func (m *Manager) authenticateHTTP(req *Request) (string, error) { if matchesPermission(m.HTTPExclude, req) { - return nil + return "", nil } enc, _ := json.Marshal(struct { @@ -185,7 +187,7 @@ func (m *Manager) authenticateHTTP(req *Request) error { u, err := url.Parse(m.HTTPAddress) if err != nil { - return err + return "", err } tr := &http.Transport{ @@ -200,29 +202,29 @@ func (m *Manager) authenticateHTTP(req *Request) error { res, err := httpClient.Post(m.HTTPAddress, "application/json", bytes.NewReader(enc)) if err != nil { - return fmt.Errorf("HTTP request failed: %w", err) + return "", fmt.Errorf("HTTP request failed: %w", err) } defer res.Body.Close() if res.StatusCode < 200 || res.StatusCode > 299 { if resBody, err2 := io.ReadAll(res.Body); err2 == nil && len(resBody) != 0 { - return fmt.Errorf("server replied with code %d: %s", res.StatusCode, string(resBody)) + return "", fmt.Errorf("server replied with code %d: %s", res.StatusCode, string(resBody)) } - return fmt.Errorf("server replied with code %d", res.StatusCode) + return "", fmt.Errorf("server replied with code %d", res.StatusCode) } - return nil + return req.Credentials.User, nil } -func (m *Manager) authenticateJWT(req *Request) error { +func (m *Manager) authenticateJWT(req *Request) (string, error) { if matchesPermission(m.JWTExclude, req) { - return nil + return "", nil } keyfunc, err := m.pullJWTJWKS() if err != nil { - return err + return "", err } var encodedJWT string @@ -239,17 +241,17 @@ func (m *Manager) authenticateJWT(req *Request) error { var v url.Values v, err = url.ParseQuery(req.Query) if err != nil { - return err + return "", err } if len(v["jwt"]) != 1 || len(v["jwt"][0]) == 0 { - return fmt.Errorf("JWT not provided") + return "", fmt.Errorf("JWT not provided") } encodedJWT = v["jwt"][0] default: - return fmt.Errorf("JWT not provided") + return "", fmt.Errorf("JWT not provided") } var opts []jwt.ParserOption @@ -264,14 +266,14 @@ func (m *Manager) authenticateJWT(req *Request) error { cc.permissionsKey = m.JWTClaimKey _, err = jwt.ParseWithClaims(encodedJWT, &cc, keyfunc, opts...) if err != nil { - return err + return "", err } if !matchesPermission(cc.permissions, req) { - return fmt.Errorf("user doesn't have permission to perform action") + return "", fmt.Errorf("user doesn't have permission to perform action") } - return nil + return cc.Subject, nil } func (m *Manager) pullJWTJWKS() (jwt.Keyfunc, error) { diff --git a/internal/auth/manager_test.go b/internal/auth/manager_test.go index 651d31d9..0816b885 100644 --- a/internal/auth/manager_test.go +++ b/internal/auth/manager_test.go @@ -195,7 +195,7 @@ func TestAuthInternal(t *testing.T) { } // first request with empty credentials - err := m.Authenticate(&Request{ + _, err := m.Authenticate(&Request{ Action: req.Action, Path: req.Path, Credentials: &Credentials{}, @@ -207,9 +207,10 @@ func TestAuthInternal(t *testing.T) { }, err) // second request - err = m.Authenticate(req) + user, err := m.Authenticate(req) if outcome == "ok" { require.Nil(t, err) + require.Equal(t, "testuser", user) } else { require.EqualError(t, err.Wrapped, "authentication failed") require.False(t, err.AskCredentials) @@ -238,19 +239,23 @@ func TestAuthInternalCustomVerifyFunc(t *testing.T) { } req1 := &Request{ - Action: conf.AuthActionPublish, - Path: "mypath", - Credentials: &Credentials{}, - IP: net.ParseIP("127.1.1.1"), + Action: conf.AuthActionPublish, + Path: "mypath", + Credentials: &Credentials{ + User: "myuser", + }, + IP: net.ParseIP("127.1.1.1"), CustomVerifyFunc: func(expectedUser, expectedPass string) bool { require.Equal(t, "myuser", expectedUser) require.Equal(t, "mypass", expectedPass) return (ca == "ok") }, } - err := m.Authenticate(req1) + + user, err := m.Authenticate(req1) if ca == "ok" { require.Nil(t, err) + require.Equal(t, "myuser", user) } else { require.EqualError(t, err.Wrapped, "authentication failed") } @@ -339,7 +344,7 @@ func TestAuthHTTP(t *testing.T) { } // first request with empty credentials - err2 := m.Authenticate(&Request{ + _, err2 := m.Authenticate(&Request{ Action: req.Action, Path: req.Path, Credentials: &Credentials{}, @@ -351,9 +356,10 @@ func TestAuthHTTP(t *testing.T) { }, err2) // second request - err2 = m.Authenticate(req) + user, err2 := m.Authenticate(req) if outcome == "ok" { require.Nil(t, err2) + require.Equal(t, "testpublisher", user) } else { require.EqualError(t, err2.Wrapped, "server replied with code 400") require.False(t, err2.AskCredentials) @@ -405,7 +411,7 @@ func TestAuthHTTPFingerprint(t *testing.T) { HTTPFingerprint: "33949e05fffb5ff3e8aa16f8213a6251b4d9363804ba53233c4da9a46d6f2739", } - err2 := m.Authenticate(&Request{ + user, err2 := m.Authenticate(&Request{ Action: conf.AuthActionPublish, Path: "teststream", Protocol: ProtocolRTSP, @@ -416,6 +422,7 @@ func TestAuthHTTPFingerprint(t *testing.T) { IP: net.ParseIP("127.0.0.1"), }) require.Nil(t, err2) + require.Equal(t, "testuser", user) } func TestAuthHTTPExclude(t *testing.T) { @@ -427,7 +434,7 @@ func TestAuthHTTPExclude(t *testing.T) { }}, } - err := m.Authenticate(&Request{ + user, err := m.Authenticate(&Request{ Action: conf.AuthActionPublish, Path: "teststream", Query: "param=value", @@ -439,6 +446,7 @@ func TestAuthHTTPExclude(t *testing.T) { IP: net.ParseIP("127.0.0.1"), }) require.Nil(t, err) + require.Equal(t, "", user) } func TestAuthJWT(t *testing.T) { @@ -568,7 +576,7 @@ func TestAuthJWT(t *testing.T) { } // first request with empty credentials - err2 := m.Authenticate(&Request{ + _, err2 := m.Authenticate(&Request{ Action: req.Action, Path: req.Path, Credentials: &Credentials{}, @@ -580,8 +588,9 @@ func TestAuthJWT(t *testing.T) { }, err2) // second request - err2 = m.Authenticate(req) + user, err2 := m.Authenticate(req) require.Nil(t, err2) + require.Equal(t, "somebody", user) }) } } @@ -596,7 +605,7 @@ func TestAuthJWTExclude(t *testing.T) { }}, } - err := m.Authenticate(&Request{ + user, err := m.Authenticate(&Request{ Action: conf.AuthActionPublish, Path: "teststream", Query: "param=value", @@ -604,6 +613,7 @@ func TestAuthJWTExclude(t *testing.T) { IP: net.ParseIP("127.0.0.1"), }) require.Nil(t, err) + require.Equal(t, "", user) } func TestAuthJWTIssuer(t *testing.T) { @@ -640,32 +650,6 @@ func TestAuthJWTIssuer(t *testing.T) { go httpServ.Serve(ln) defer httpServ.Shutdown(context.Background()) - signToken := func(issuer string) string { - type customClaims struct { - jwt.RegisteredClaims - MediaMTXPermissions []conf.AuthInternalUserPermission `json:"my_permission_key"` - } - - claims := customClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - Issuer: issuer, - }, - MediaMTXPermissions: []conf.AuthInternalUserPermission{{ - Action: conf.AuthActionPublish, - Path: "mypath", - }}, - } - - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - token.Header[jwkset.HeaderKID] = "test-key-id" - ss, err2 := token.SignedString(key) - require.NoError(t, err2) - return ss - } - for _, ca := range []struct { name string jwtIssuer string @@ -686,6 +670,32 @@ func TestAuthJWTIssuer(t *testing.T) { }, } { t.Run(ca.name, func(t *testing.T) { + signToken := func(issuer string) string { + type customClaims struct { + jwt.RegisteredClaims + MediaMTXPermissions []conf.AuthInternalUserPermission `json:"my_permission_key"` + } + + claims := customClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + Issuer: issuer, + }, + MediaMTXPermissions: []conf.AuthInternalUserPermission{{ + Action: conf.AuthActionPublish, + Path: "mypath", + }}, + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header[jwkset.HeaderKID] = "test-key-id" + var ss string + ss, err = token.SignedString(key) + require.NoError(t, err) + return ss + } ss := signToken(ca.tokenIss) m := Manager{ @@ -695,7 +705,7 @@ func TestAuthJWTIssuer(t *testing.T) { JWTIssuer: ca.jwtIssuer, } - err2 := m.Authenticate(&Request{ + _, err := m.Authenticate(&Request{ Action: conf.AuthActionPublish, Path: "mypath", Protocol: ProtocolRTSP, @@ -706,9 +716,9 @@ func TestAuthJWTIssuer(t *testing.T) { }) if ca.expectErr { - require.NotNil(t, err2) + require.NotNil(t, err) } else { - require.Nil(t, err2) + require.Nil(t, err) } }) } @@ -748,32 +758,6 @@ func TestAuthJWTAudience(t *testing.T) { go httpServ.Serve(ln) defer httpServ.Shutdown(context.Background()) - signToken := func(audience jwt.ClaimStrings) string { - type customClaims struct { - jwt.RegisteredClaims - MediaMTXPermissions []conf.AuthInternalUserPermission `json:"my_permission_key"` - } - - claims := customClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - Audience: audience, - }, - MediaMTXPermissions: []conf.AuthInternalUserPermission{{ - Action: conf.AuthActionPublish, - Path: "mypath", - }}, - } - - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - token.Header[jwkset.HeaderKID] = "test-key-id" - ss, err2 := token.SignedString(key) - require.NoError(t, err2) - return ss - } - for _, ca := range []struct { name string jwtAudience string @@ -800,6 +784,32 @@ func TestAuthJWTAudience(t *testing.T) { }, } { t.Run(ca.name, func(t *testing.T) { + signToken := func(audience jwt.ClaimStrings) string { + type customClaims struct { + jwt.RegisteredClaims + MediaMTXPermissions []conf.AuthInternalUserPermission `json:"my_permission_key"` + } + + claims := customClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + Audience: audience, + }, + MediaMTXPermissions: []conf.AuthInternalUserPermission{{ + Action: conf.AuthActionPublish, + Path: "mypath", + }}, + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header[jwkset.HeaderKID] = "test-key-id" + var ss string + ss, err = token.SignedString(key) + require.NoError(t, err) + return ss + } ss := signToken(ca.tokenAud) m := Manager{ @@ -809,7 +819,7 @@ func TestAuthJWTAudience(t *testing.T) { JWTAudience: ca.jwtAudience, } - err2 := m.Authenticate(&Request{ + _, err := m.Authenticate(&Request{ Action: conf.AuthActionPublish, Path: "mypath", Protocol: ProtocolRTSP, @@ -820,9 +830,9 @@ func TestAuthJWTAudience(t *testing.T) { }) if ca.expectErr { - require.NotNil(t, err2) + require.NotNil(t, err) } else { - require.Nil(t, err2) + require.Nil(t, err) } }) } @@ -900,7 +910,7 @@ func TestAuthJWTRefresh(t *testing.T) { ss, err = token.SignedString(key) require.NoError(t, err) - err2 := m.Authenticate(&Request{ + user, err2 := m.Authenticate(&Request{ Action: conf.AuthActionPublish, Path: "mypath", Query: "param=value", @@ -911,6 +921,7 @@ func TestAuthJWTRefresh(t *testing.T) { IP: net.ParseIP("127.0.0.1"), }) require.Nil(t, err2) + require.Equal(t, "somebody", user) m.RefreshJWTJWKS() } @@ -987,7 +998,7 @@ func TestAuthJWTFingerprint(t *testing.T) { JWTClaimKey: "my_permission_key", } - err2 := m.Authenticate(&Request{ + user, err2 := m.Authenticate(&Request{ Action: conf.AuthActionPublish, Path: "mypath", Protocol: ProtocolRTSP, @@ -997,4 +1008,5 @@ func TestAuthJWTFingerprint(t *testing.T) { IP: net.ParseIP("127.0.0.1"), }) require.Nil(t, err2) + require.Equal(t, "somebody", user) } diff --git a/internal/core/api_test.go b/internal/core/api_test.go index 6932cdf4..72acf564 100644 --- a/internal/core/api_test.go +++ b/internal/core/api_test.go @@ -645,6 +645,7 @@ func TestAPIProtocolListGet(t *testing.T) { "id": out1.(map[string]any)["items"].([]any)[0].(map[string]any)["id"], "path": "mypath", "query": "key=val", + "user": "", "remoteAddr": out1.(map[string]any)["items"].([]any)[0].(map[string]any)["remoteAddr"], "state": "publish", "transport": "UDP", @@ -691,6 +692,7 @@ func TestAPIProtocolListGet(t *testing.T) { "id": out1.(map[string]any)["items"].([]any)[0].(map[string]any)["id"], "path": "mypath", "query": "key=val", + "user": "", "remoteAddr": out1.(map[string]any)["items"].([]any)[0].(map[string]any)["remoteAddr"], "state": "publish", "transport": "UDP", @@ -720,6 +722,7 @@ func TestAPIProtocolListGet(t *testing.T) { "id": out1.(map[string]any)["items"].([]any)[0].(map[string]any)["id"], "path": "mypath", "query": "key=val", + "user": "", "remoteAddr": out1.(map[string]any)["items"].([]any)[0].(map[string]any)["remoteAddr"], "state": "publish", }, @@ -738,6 +741,7 @@ func TestAPIProtocolListGet(t *testing.T) { "id": out1.(map[string]any)["items"].([]any)[0].(map[string]any)["id"], "path": "mypath", "query": "key=val", + "user": "", "remoteAddr": out1.(map[string]any)["items"].([]any)[0].(map[string]any)["remoteAddr"], "state": "publish", }, @@ -775,6 +779,7 @@ func TestAPIProtocolListGet(t *testing.T) { "remoteAddr": out1.(map[string]any)["items"].([]any)[0].(map[string]any)["remoteAddr"], "remoteCandidate": out1.(map[string]any)["items"].([]any)[0].(map[string]any)["remoteCandidate"], "state": "read", + "user": "", "rtcpPacketsReceived": float64(0), "rtcpPacketsSent": float64(2), "rtpPacketsJitter": float64(0), @@ -848,6 +853,7 @@ func TestAPIProtocolListGet(t *testing.T) { "query": "key=val", "remoteAddr": out1.(map[string]any)["items"].([]any)[0].(map[string]any)["remoteAddr"], "state": "publish", + "user": "", "usPacketsSendPeriod": float64(10.967254638671875), "usSndDuration": float64(0), }, diff --git a/internal/core/path.go b/internal/core/path.go index 5de44e5c..e5a77afc 100644 --- a/internal/core/path.go +++ b/internal/core/path.go @@ -33,7 +33,7 @@ type pathParent interface { setPathNotReady(*path) closePathIfIdle(*path) removePath(*path) - AddReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) + AddReader(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) } type pathOnDemandState int @@ -572,10 +572,7 @@ func (pa *path) doAddPublisher(req defs.PathAddPublisherReq) { pa.consumeOnHoldRequests() - req.Res <- defs.PathAddPublisherRes{ - Path: pa, - SubStream: subStream, - } + req.Res <- defs.PathAddPublisherRes{SubStream: subStream} } func (pa *path) doAddReader(req defs.PathAddReaderReq) { @@ -933,10 +930,7 @@ func (pa *path) executeRemovePublisher() { func (pa *path) addReaderPost(req defs.PathAddReaderReq) { if _, ok := pa.readers[req.Author]; ok { - req.Res <- defs.PathAddReaderRes{ - Path: pa, - Stream: pa.stream, - } + req.Res <- defs.PathAddReaderRes{Stream: pa.stream} return } @@ -961,10 +955,7 @@ func (pa *path) addReaderPost(req defs.PathAddReaderReq) { } } - req.Res <- defs.PathAddReaderRes{ - Path: pa, - Stream: pa.stream, - } + req.Res <- defs.PathAddReaderRes{Stream: pa.stream} } // reloadConf is called by pathManager. @@ -1022,13 +1013,13 @@ func (pa *path) describe(req defs.PathDescribeReq) defs.PathDescribeRes { } // addPublisher is called by a publisher through pathManager. -func (pa *path) addPublisher(req defs.PathAddPublisherReq) (defs.Path, *stream.SubStream, error) { +func (pa *path) addPublisher(req defs.PathAddPublisherReq) (*defs.PathAddPublisherRes, error) { select { case pa.chAddPublisher <- req: res := <-req.Res - return res.Path, res.SubStream, res.Err + return &res, res.Err case <-pa.ctx.Done(): - return nil, nil, fmt.Errorf("terminated") + return nil, fmt.Errorf("terminated") } } @@ -1043,13 +1034,13 @@ func (pa *path) RemovePublisher(req defs.PathRemovePublisherReq) { } // addReader is called by a reader through pathManager. -func (pa *path) addReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { +func (pa *path) addReader(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) { select { case pa.chAddReader <- req: res := <-req.Res - return res.Path, res.Stream, res.Err + return &res, res.Err case <-pa.ctx.Done(): - return nil, nil, fmt.Errorf("terminated") + return nil, fmt.Errorf("terminated") } } diff --git a/internal/core/path_manager.go b/internal/core/path_manager.go index d9800957..459c006c 100644 --- a/internal/core/path_manager.go +++ b/internal/core/path_manager.go @@ -15,7 +15,6 @@ import ( "github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/metrics" "github.com/bluenviron/mediamtx/internal/servers/hls" - "github.com/bluenviron/mediamtx/internal/stream" ) func pathConfCanBeUpdated(oldPathConf *conf.Path, newPathConf *conf.Path) bool { @@ -62,7 +61,7 @@ type pathSetHLSServerReq struct { } type pathManagerAuthManager interface { - Authenticate(req *auth.Request) *auth.Error + Authenticate(req *auth.Request) (string, *auth.Error) } type pathManagerParent interface { @@ -326,13 +325,16 @@ func (pm *pathManager) doFindPathConf(req defs.PathFindPathConfReq) { return } - err2 := pm.authManager.Authenticate(req.AccessRequest.ToAuthRequest()) + user, err2 := pm.authManager.Authenticate(req.AccessRequest.ToAuthRequest()) if err2 != nil { req.Res <- defs.PathFindPathConfRes{Err: err2} return } - req.Res <- defs.PathFindPathConfRes{Conf: pathConf} + req.Res <- defs.PathFindPathConfRes{ + Conf: pathConf, + User: user, + } } func (pm *pathManager) doDescribe(req defs.PathDescribeReq) { @@ -342,7 +344,7 @@ func (pm *pathManager) doDescribe(req defs.PathDescribeReq) { return } - err2 := pm.authManager.Authenticate(req.AccessRequest.ToAuthRequest()) + _, err2 := pm.authManager.Authenticate(req.AccessRequest.ToAuthRequest()) if err2 != nil { req.Res <- defs.PathDescribeRes{Err: err2} return @@ -367,10 +369,13 @@ func (pm *pathManager) doAddReader(req defs.PathAddReaderReq) { return } + var user string + if !req.AccessRequest.SkipAuth { - err2 := pm.authManager.Authenticate(req.AccessRequest.ToAuthRequest()) - if err2 != nil { - req.Res <- defs.PathAddReaderRes{Err: err2} + var authErr *auth.Error + user, authErr = pm.authManager.Authenticate(req.AccessRequest.ToAuthRequest()) + if authErr != nil { + req.Res <- defs.PathAddReaderRes{Err: authErr} return } } @@ -384,7 +389,10 @@ func (pm *pathManager) doAddReader(req defs.PathAddReaderReq) { atomic.AddInt64(pa.pendingRequests, 1) - req.Res <- defs.PathAddReaderRes{Path: pa} + req.Res <- defs.PathAddReaderRes{ + Path: pa, + User: user, + } } func (pm *pathManager) doAddPublisher(req defs.PathAddPublisherReq) { @@ -399,10 +407,13 @@ func (pm *pathManager) doAddPublisher(req defs.PathAddPublisherReq) { return } + var user string + if !req.AccessRequest.SkipAuth { - err2 := pm.authManager.Authenticate(req.AccessRequest.ToAuthRequest()) - if err2 != nil { - req.Res <- defs.PathAddPublisherRes{Err: err2} + var authErr *auth.Error + user, authErr = pm.authManager.Authenticate(req.AccessRequest.ToAuthRequest()) + if authErr != nil { + req.Res <- defs.PathAddPublisherRes{Err: authErr} return } } @@ -416,7 +427,10 @@ func (pm *pathManager) doAddPublisher(req defs.PathAddPublisherReq) { atomic.AddInt64(pa.pendingRequests, 1) - req.Res <- defs.PathAddPublisherRes{Path: pa} + req.Res <- defs.PathAddPublisherRes{ + Path: pa, + User: user, + } } func (pm *pathManager) doAPIPathsList(req pathAPIPathsListReq) { @@ -507,12 +521,12 @@ func (pm *pathManager) closePathIfIdle(pa *path) { } // FindPathConf is called by a reader or publisher. -func (pm *pathManager) FindPathConf(req defs.PathFindPathConfReq) (*conf.Path, error) { +func (pm *pathManager) FindPathConf(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) { req.Res = make(chan defs.PathFindPathConfRes) select { case pm.chFindPathConf <- req: res := <-req.Res - return res.Conf, res.Err + return &res, res.Err case <-pm.ctx.Done(): return nil, fmt.Errorf("terminated") @@ -543,36 +557,52 @@ func (pm *pathManager) Describe(req defs.PathDescribeReq) defs.PathDescribeRes { } // AddPublisher is called by a publisher. -func (pm *pathManager) AddPublisher(req defs.PathAddPublisherReq) (defs.Path, *stream.SubStream, error) { +func (pm *pathManager) AddPublisher(req defs.PathAddPublisherReq) (*defs.PathAddPublisherRes, error) { req.Res = make(chan defs.PathAddPublisherRes) select { case pm.chAddPublisher <- req: - res := <-req.Res - if res.Err != nil { - return nil, nil, res.Err + res1 := <-req.Res + if res1.Err != nil { + return nil, res1.Err } - return res.Path.(*path).addPublisher(req) + res2, err := res1.Path.(*path).addPublisher(req) + if err != nil { + return nil, err + } + + res2.Path = res1.Path + res2.User = res1.User + + return res2, nil case <-pm.ctx.Done(): - return nil, nil, fmt.Errorf("terminated") + return nil, fmt.Errorf("terminated") } } // AddReader is called by a reader. -func (pm *pathManager) AddReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { +func (pm *pathManager) AddReader(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) { req.Res = make(chan defs.PathAddReaderRes) select { case pm.chAddReader <- req: - res := <-req.Res - if res.Err != nil { - return nil, nil, res.Err + res1 := <-req.Res + if res1.Err != nil { + return nil, res1.Err } - return res.Path.(*path).addReader(req) + res2, err := res1.Path.(*path).addReader(req) + if err != nil { + return nil, err + } + + res2.Path = res1.Path + res2.User = res1.User + + return res2, nil case <-pm.ctx.Done(): - return nil, nil, fmt.Errorf("terminated") + return nil, fmt.Errorf("terminated") } } diff --git a/internal/core/path_manager_test.go b/internal/core/path_manager_test.go index 9cb11b6c..591d1b0d 100644 --- a/internal/core/path_manager_test.go +++ b/internal/core/path_manager_test.go @@ -65,7 +65,7 @@ func TestPathManagerDynamicPathAutoDeletion(t *testing.T) { }) require.EqualError(t, res.Err, "no stream is available on path 'mypath'") } else { - _, _, err := pm.AddReader(defs.PathAddReaderReq{ + _, err := pm.AddReader(defs.PathAddReaderReq{ Author: &dummyReader{}, AccessRequest: defs.PathAccessRequest{ Name: "mypath", @@ -112,7 +112,7 @@ func TestPathManagerDynamicPathDescribeAndPublish(t *testing.T) { } }() - _, _, err := pm.AddPublisher(defs.PathAddPublisherReq{ + _, err := pm.AddPublisher(defs.PathAddPublisherReq{ Author: &dummyPublisher{}, Desc: &description.Session{}, AccessRequest: defs.PathAccessRequest{ diff --git a/internal/defs/api.go b/internal/defs/api.go index 627d7ed4..a0f153d2 100644 --- a/internal/defs/api.go +++ b/internal/defs/api.go @@ -143,6 +143,7 @@ type APIRTMPConn struct { State APIRTMPConnState `json:"state"` Path string `json:"path"` Query string `json:"query"` + User string `json:"user"` BytesReceived uint64 `json:"bytesReceived"` BytesSent uint64 `json:"bytesSent"` } @@ -190,6 +191,7 @@ type APIRTSPSession struct { State APIRTSPSessionState `json:"state"` Path string `json:"path"` Query string `json:"query"` + User string `json:"user"` Transport *string `json:"transport"` Profile *string `json:"profile"` Conns []uuid.UUID `json:"conns"` @@ -230,6 +232,7 @@ type APISRTConn struct { State APISRTConnState `json:"state"` Path string `json:"path"` Query string `json:"query"` + User string `json:"user"` // The metric names/comments are pulled from GoSRT @@ -380,6 +383,7 @@ type APIWebRTCSession struct { State APIWebRTCSessionState `json:"state"` Path string `json:"path"` Query string `json:"query"` + User string `json:"user"` BytesReceived uint64 `json:"bytesReceived"` BytesSent uint64 `json:"bytesSent"` RTPPacketsReceived uint64 `json:"rtpPacketsReceived"` diff --git a/internal/defs/path.go b/internal/defs/path.go index 51202128..605a8ee0 100644 --- a/internal/defs/path.go +++ b/internal/defs/path.go @@ -32,6 +32,7 @@ type Path interface { // PathFindPathConfRes contains the response of FindPathConf(). type PathFindPathConfRes struct { Conf *conf.Path + User string Err error } @@ -58,6 +59,7 @@ type PathDescribeReq struct { // PathAddPublisherRes contains the response of AddPublisher(). type PathAddPublisherRes struct { Path Path + User string SubStream *stream.SubStream Err error } @@ -82,6 +84,7 @@ type PathRemovePublisherReq struct { // PathAddReaderRes contains the response of AddReader(). type PathAddReaderRes struct { Path Path + User string Stream *stream.Stream Err error } diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index e6ed0875..77d9929e 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -61,7 +61,7 @@ func metricFloat(key string, tags string, value float64) string { } type metricsAuthManager interface { - Authenticate(req *auth.Request) *auth.Error + Authenticate(req *auth.Request) (string, *auth.Error) } type metricsParent interface { @@ -156,7 +156,7 @@ func (m *Metrics) middlewareAuth(ctx *gin.Context) { IP: net.ParseIP(ctx.ClientIP()), } - err := m.AuthManager.Authenticate(req) + _, err := m.AuthManager.Authenticate(req) if err != nil { if err.AskCredentials { ctx.Header("WWW-Authenticate", `Basic realm="mediamtx"`) diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go index e50c6a6b..e1560c48 100644 --- a/internal/metrics/metrics_test.go +++ b/internal/metrics/metrics_test.go @@ -238,12 +238,12 @@ func TestMetrics(t *testing.T) { ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), AuthManager: &test.AuthManager{ - AuthenticateImpl: func(req *auth.Request) *auth.Error { + AuthenticateImpl: func(req *auth.Request) (string, *auth.Error) { require.Equal(t, conf.AuthActionMetrics, req.Action) require.Equal(t, "myuser", req.Credentials.User) require.Equal(t, "mypass", req.Credentials.Pass) checked = true - return nil + return req.Credentials.User, nil }, }, Parent: test.NilLogger, @@ -374,11 +374,11 @@ func TestAuthError(t *testing.T) { ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), AuthManager: &test.AuthManager{ - AuthenticateImpl: func(req *auth.Request) *auth.Error { + AuthenticateImpl: func(req *auth.Request) (string, *auth.Error) { if req.Credentials.User == "" { - return &auth.Error{AskCredentials: true, Wrapped: fmt.Errorf("auth error")} + return "", &auth.Error{AskCredentials: true, Wrapped: fmt.Errorf("auth error")} } - return &auth.Error{Wrapped: fmt.Errorf("auth error")} + return "", &auth.Error{Wrapped: fmt.Errorf("auth error")} }, }, Parent: test.Logger(func(l logger.Level, s string, i ...any) { diff --git a/internal/playback/on_list_test.go b/internal/playback/on_list_test.go index ca90e623..057c981b 100644 --- a/internal/playback/on_list_test.go +++ b/internal/playback/on_list_test.go @@ -70,12 +70,12 @@ func TestOnList(t *testing.T) { }, }, AuthManager: &test.AuthManager{ - AuthenticateImpl: func(req *auth.Request) *auth.Error { + AuthenticateImpl: func(req *auth.Request) (string, *auth.Error) { require.Equal(t, conf.AuthActionPlayback, req.Action) require.Equal(t, "myuser", req.Credentials.User) require.Equal(t, "mypass", req.Credentials.Pass) checked = true - return nil + return req.Credentials.User, nil }, }, Parent: test.NilLogger, diff --git a/internal/playback/server.go b/internal/playback/server.go index bdc7180c..3c606a45 100644 --- a/internal/playback/server.go +++ b/internal/playback/server.go @@ -16,7 +16,7 @@ import ( ) type serverAuthManager interface { - Authenticate(req *auth.Request) *auth.Error + Authenticate(req *auth.Request) (string, *auth.Error) } // Server is the playback server. @@ -124,7 +124,7 @@ func (s *Server) doAuth(ctx *gin.Context, pathName string) bool { IP: net.ParseIP(ctx.ClientIP()), } - err := s.AuthManager.Authenticate(req) + _, err := s.AuthManager.Authenticate(req) if err != nil { if err.AskCredentials { ctx.Header("WWW-Authenticate", `Basic realm="mediamtx"`) diff --git a/internal/playback/server_test.go b/internal/playback/server_test.go index 4d11649c..7d719835 100644 --- a/internal/playback/server_test.go +++ b/internal/playback/server_test.go @@ -60,11 +60,11 @@ func TestAuthError(t *testing.T) { ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), AuthManager: &test.AuthManager{ - AuthenticateImpl: func(req *auth.Request) *auth.Error { + AuthenticateImpl: func(req *auth.Request) (string, *auth.Error) { if req.Credentials.User == "" { - return &auth.Error{AskCredentials: true, Wrapped: fmt.Errorf("auth error")} + return "", &auth.Error{AskCredentials: true, Wrapped: fmt.Errorf("auth error")} } - return &auth.Error{Wrapped: fmt.Errorf("auth error")} + return "", &auth.Error{Wrapped: fmt.Errorf("auth error")} }, }, Parent: test.Logger(func(l logger.Level, s string, i ...any) { diff --git a/internal/pprof/pprof.go b/internal/pprof/pprof.go index 12b03650..ef5354b8 100644 --- a/internal/pprof/pprof.go +++ b/internal/pprof/pprof.go @@ -17,7 +17,7 @@ import ( ) type pprofAuthManager interface { - Authenticate(req *auth.Request) *auth.Error + Authenticate(req *auth.Request) (string, *auth.Error) } type pprofParent interface { @@ -103,7 +103,7 @@ func (pp *PPROF) middlewareAuth(ctx *gin.Context) { IP: net.ParseIP(ctx.ClientIP()), } - err := pp.AuthManager.Authenticate(req) + _, err := pp.AuthManager.Authenticate(req) if err != nil { if err.AskCredentials { ctx.Header("WWW-Authenticate", `Basic realm="mediamtx"`) diff --git a/internal/pprof/pprof_test.go b/internal/pprof/pprof_test.go index c58658f0..9d9d0ba3 100644 --- a/internal/pprof/pprof_test.go +++ b/internal/pprof/pprof_test.go @@ -60,12 +60,12 @@ func TestPprof(t *testing.T) { ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), AuthManager: &test.AuthManager{ - AuthenticateImpl: func(req *auth.Request) *auth.Error { + AuthenticateImpl: func(req *auth.Request) (string, *auth.Error) { require.Equal(t, conf.AuthActionPprof, req.Action) require.Equal(t, "myuser", req.Credentials.User) require.Equal(t, "mypass", req.Credentials.Pass) checked = true - return nil + return req.Credentials.User, nil }, }, Parent: test.NilLogger, @@ -103,11 +103,11 @@ func TestAuthError(t *testing.T) { ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), AuthManager: &test.AuthManager{ - AuthenticateImpl: func(req *auth.Request) *auth.Error { + AuthenticateImpl: func(req *auth.Request) (string, *auth.Error) { if req.Credentials.User == "" { - return &auth.Error{AskCredentials: true, Wrapped: fmt.Errorf("auth error")} + return "", &auth.Error{AskCredentials: true, Wrapped: fmt.Errorf("auth error")} } - return &auth.Error{Wrapped: fmt.Errorf("auth error")} + return "", &auth.Error{Wrapped: fmt.Errorf("auth error")} }, }, Parent: test.Logger(func(l logger.Level, s string, i ...any) { diff --git a/internal/servers/hls/http_server.go b/internal/servers/hls/http_server.go index fb48b80e..2282eef2 100644 --- a/internal/servers/hls/http_server.go +++ b/internal/servers/hls/http_server.go @@ -145,7 +145,7 @@ func (s *httpServer) onRequest(ctx *gin.Context) { return } - pathConf, err := s.pathManager.FindPathConf(defs.PathFindPathConfReq{ + res, err := s.pathManager.FindPathConf(defs.PathFindPathConfReq{ AccessRequest: defs.PathAccessRequest{ Name: dir, Query: ctx.Request.URL.RawQuery, @@ -196,7 +196,7 @@ func (s *httpServer) onRequest(ctx *gin.Context) { path: dir, remoteAddr: httpp.RemoteAddr(ctx), query: ctx.Request.URL.RawQuery, - sourceOnDemand: pathConf.SourceOnDemand, + sourceOnDemand: res.Conf.SourceOnDemand, }) if err != nil { ctx.Writer.WriteHeader(http.StatusNotFound) diff --git a/internal/servers/hls/muxer.go b/internal/servers/hls/muxer.go index eda01250..98159c1e 100644 --- a/internal/servers/hls/muxer.go +++ b/internal/servers/hls/muxer.go @@ -122,7 +122,7 @@ func (m *muxer) run() { } func (m *muxer) runInner() error { - path, stream, err := m.pathManager.AddReader(defs.PathAddReaderReq{ + res, err := m.pathManager.AddReader(defs.PathAddReaderReq{ Author: m, AccessRequest: defs.PathAccessRequest{ Name: m.pathName, @@ -134,7 +134,7 @@ func (m *muxer) runInner() error { return err } - m.path = path + m.path = res.Path defer m.path.RemoveReader(defs.PathRemoveReaderReq{Author: m}) @@ -149,7 +149,7 @@ func (m *muxer) runInner() error { segmentMaxSize: m.segmentMaxSize, directory: m.directory, pathName: m.pathName, - stream: stream, + stream: res.Stream, bytesSent: m.bytesSent, parent: m, } @@ -206,7 +206,7 @@ func (m *muxer) runInner() error { segmentMaxSize: m.segmentMaxSize, directory: m.directory, pathName: m.pathName, - stream: stream, + stream: res.Stream, bytesSent: m.bytesSent, parent: m, } diff --git a/internal/servers/hls/server.go b/internal/servers/hls/server.go index fcbce238..57596bec 100644 --- a/internal/servers/hls/server.go +++ b/internal/servers/hls/server.go @@ -12,7 +12,6 @@ import ( "github.com/bluenviron/mediamtx/internal/conf" "github.com/bluenviron/mediamtx/internal/defs" "github.com/bluenviron/mediamtx/internal/logger" - "github.com/bluenviron/mediamtx/internal/stream" ) // ErrMuxerNotFound is returned when a muxer is not found. @@ -60,8 +59,8 @@ type serverMetrics interface { type serverPathManager interface { SetHLSServer(*Server) []defs.Path - FindPathConf(req defs.PathFindPathConfReq) (*conf.Path, error) - AddReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) + FindPathConf(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) + AddReader(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) } type serverParent interface { diff --git a/internal/servers/hls/server_test.go b/internal/servers/hls/server_test.go index 7258b790..fca4e23f 100644 --- a/internal/servers/hls/server_test.go +++ b/internal/servers/hls/server_test.go @@ -26,8 +26,8 @@ import ( type dummyPathManager struct { setHLSServerImpl func() []defs.Path - findPathConfImpl func(req defs.PathFindPathConfReq) (*conf.Path, error) - addReaderImpl func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) + findPathConfImpl func(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) + addReaderImpl func(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) } func (pm *dummyPathManager) SetHLSServer(*Server) []defs.Path { @@ -37,11 +37,11 @@ func (pm *dummyPathManager) SetHLSServer(*Server) []defs.Path { return nil } -func (pm *dummyPathManager) FindPathConf(req defs.PathFindPathConfReq) (*conf.Path, error) { +func (pm *dummyPathManager) FindPathConf(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) { return pm.findPathConfImpl(req) } -func (pm *dummyPathManager) AddReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { +func (pm *dummyPathManager) AddReader(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) { return pm.addReaderImpl(req) } @@ -110,13 +110,13 @@ func TestServerNotFound(t *testing.T) { } { t.Run(ca, func(t *testing.T) { pm := &dummyPathManager{ - findPathConfImpl: func(req defs.PathFindPathConfReq) (*conf.Path, error) { + findPathConfImpl: func(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) { require.Equal(t, "nonexisting", req.AccessRequest.Name) - return &conf.Path{}, nil + return &defs.PathFindPathConfRes{Conf: &conf.Path{}, User: req.AccessRequest.Credentials.User}, nil }, - addReaderImpl: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { + addReaderImpl: func(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) { require.Equal(t, "nonexisting", req.AccessRequest.Name) - return nil, nil, fmt.Errorf("not found") + return nil, fmt.Errorf("not found") }, } @@ -202,21 +202,21 @@ func TestServerRead(t *testing.T) { require.NoError(t, err) pm := &dummyPathManager{ - findPathConfImpl: func(req defs.PathFindPathConfReq) (*conf.Path, error) { + findPathConfImpl: func(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) { require.Equal(t, "teststream", req.AccessRequest.Name) require.Equal(t, "param=value", req.AccessRequest.Query) require.Equal(t, "myuser", req.AccessRequest.Credentials.User) require.Equal(t, "mypass", req.AccessRequest.Credentials.Pass) - return &conf.Path{}, nil + return &defs.PathFindPathConfRes{Conf: &conf.Path{}, User: req.AccessRequest.Credentials.User}, nil }, - addReaderImpl: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { + addReaderImpl: func(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) { require.Equal(t, "teststream", req.AccessRequest.Name) if ca == "always remux off" { require.Equal(t, "param=value", req.AccessRequest.Query) } else { require.Equal(t, "", req.AccessRequest.Query) } - return &dummyPath{}, strm, nil + return &defs.PathAddReaderRes{Path: &dummyPath{}, Stream: strm}, nil }, } @@ -431,8 +431,8 @@ func TestServerDirectory(t *testing.T) { require.NoError(t, err) pm := &dummyPathManager{ - addReaderImpl: func(_ defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { - return &dummyPath{}, strm, nil + addReaderImpl: func(_ defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) { + return &defs.PathAddReaderRes{Path: &dummyPath{}, Stream: strm}, nil }, } @@ -491,9 +491,9 @@ func TestServerDynamicAlwaysRemux(t *testing.T) { setHLSServerImpl: func() []defs.Path { return []defs.Path{&dummyPath{}} }, - addReaderImpl: func(_ defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { + addReaderImpl: func(_ defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) { close(done) - return &dummyPath{}, strm, nil + return &defs.PathAddReaderRes{Path: &dummyPath{}, Stream: strm}, nil }, } @@ -537,7 +537,7 @@ func TestAuthError(t *testing.T) { ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), PathManager: &dummyPathManager{ - findPathConfImpl: func(req defs.PathFindPathConfReq) (*conf.Path, error) { + findPathConfImpl: func(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) { if req.AccessRequest.Credentials.User == "" && req.AccessRequest.Credentials.Pass == "" { return nil, &auth.Error{AskCredentials: true, Wrapped: fmt.Errorf("auth error")} } diff --git a/internal/servers/rtmp/conn.go b/internal/servers/rtmp/conn.go index 18f94d65..bc41a737 100644 --- a/internal/servers/rtmp/conn.go +++ b/internal/servers/rtmp/conn.go @@ -47,6 +47,7 @@ type conn struct { state defs.APIRTMPConnState pathName string query string + user string } func (c *conn) initialize() { @@ -151,7 +152,7 @@ func (c *conn) runRead() error { pathName := strings.TrimLeft(c.rconn.URL.Path, "/") query := c.rconn.URL.Query() - path, strm, err := c.pathManager.AddReader(defs.PathAddReaderReq{ + res, err := c.pathManager.AddReader(defs.PathAddReaderReq{ Author: c, AccessRequest: defs.PathAccessRequest{ Name: pathName, @@ -175,29 +176,30 @@ func (c *conn) runRead() error { return err } - defer path.RemoveReader(defs.PathRemoveReaderReq{Author: c}) + defer res.Path.RemoveReader(defs.PathRemoveReaderReq{Author: c}) c.mutex.Lock() c.state = defs.APIRTMPConnStateRead c.pathName = pathName c.query = c.rconn.URL.RawQuery + c.user = res.User c.mutex.Unlock() r := &stream.Reader{Parent: c} - err = rtmp.FromStream(strm.Desc, r, c.rconn, c.nconn, time.Duration(c.writeTimeout)) + err = rtmp.FromStream(res.Stream.Desc, r, c.rconn, c.nconn, time.Duration(c.writeTimeout)) if err != nil { return err } c.Log(logger.Info, "is reading from path '%s', %s", - path.Name(), defs.FormatsInfo(r.Formats())) + res.Path.Name(), defs.FormatsInfo(r.Formats())) onUnreadHook := hooks.OnRead(hooks.OnReadParams{ Logger: c, ExternalCmdPool: c.externalCmdPool, - Conf: path.SafeConf(), - ExternalCmdEnv: path.ExternalCmdEnv(), + Conf: res.Path.SafeConf(), + ExternalCmdEnv: res.Path.ExternalCmdEnv(), Reader: *c.APIReaderDescribe(), Query: c.rconn.URL.RawQuery, }) @@ -205,8 +207,8 @@ func (c *conn) runRead() error { c.nconn.SetReadDeadline(time.Time{}) - strm.AddReader(r) - defer strm.RemoveReader(r) + res.Stream.AddReader(r) + defer res.Stream.RemoveReader(r) select { case <-c.ctx.Done(): @@ -236,8 +238,7 @@ func (c *conn) runPublish() error { return err } - var path defs.Path - path, subStream, err = c.pathManager.AddPublisher(defs.PathAddPublisherReq{ + res, err := c.pathManager.AddPublisher(defs.PathAddPublisherReq{ Author: c, Desc: &description.Session{Medias: medias}, UseRTPPackets: false, @@ -265,12 +266,15 @@ func (c *conn) runPublish() error { return err } - defer path.RemovePublisher(defs.PathRemovePublisherReq{Author: c}) + defer res.Path.RemovePublisher(defs.PathRemovePublisherReq{Author: c}) + + subStream = res.SubStream c.mutex.Lock() c.state = defs.APIRTMPConnStatePublish c.pathName = pathName c.query = c.rconn.URL.RawQuery + c.user = res.User c.mutex.Unlock() c.nconn.SetWriteDeadline(time.Time{}) @@ -329,6 +333,7 @@ func (c *conn) apiItem() *defs.APIRTMPConn { State: c.state, Path: c.pathName, Query: c.query, + User: c.user, BytesReceived: bytesReceived, BytesSent: bytesSent, } diff --git a/internal/servers/rtmp/server.go b/internal/servers/rtmp/server.go index 8d45f65c..a5f62113 100644 --- a/internal/servers/rtmp/server.go +++ b/internal/servers/rtmp/server.go @@ -20,7 +20,6 @@ import ( "github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/packetdumper" "github.com/bluenviron/mediamtx/internal/restrictnetwork" - "github.com/bluenviron/mediamtx/internal/stream" ) // ErrConnNotFound is returned when a connection is not found. @@ -64,8 +63,8 @@ type serverMetrics interface { } type serverPathManager interface { - AddPublisher(req defs.PathAddPublisherReq) (defs.Path, *stream.SubStream, error) - AddReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) + AddPublisher(req defs.PathAddPublisherReq) (*defs.PathAddPublisherRes, error) + AddReader(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) } type serverParent interface { diff --git a/internal/servers/rtmp/server_test.go b/internal/servers/rtmp/server_test.go index b249ea63..9d0fa847 100644 --- a/internal/servers/rtmp/server_test.go +++ b/internal/servers/rtmp/server_test.go @@ -69,7 +69,7 @@ func TestServerPublish(t *testing.T) { n := 0 pathManager := &test.PathManager{ - AddPublisherImpl: func(req defs.PathAddPublisherReq) (defs.Path, *stream.SubStream, error) { + AddPublisherImpl: func(req defs.PathAddPublisherReq) (*defs.PathAddPublisherRes, error) { require.Equal(t, "teststream", req.AccessRequest.Name) require.Equal(t, "user=myuser&pass=mypass¶m=value", req.AccessRequest.Query) require.Equal(t, "myuser", req.AccessRequest.Credentials.User) @@ -118,7 +118,11 @@ func TestServerPublish(t *testing.T) { strm.AddReader(reader) - return &dummyPath{}, subStream, nil + return &defs.PathAddPublisherRes{ + Path: &dummyPath{}, + User: req.AccessRequest.Credentials.User, + SubStream: subStream, + }, nil }, } @@ -186,6 +190,24 @@ func TestServerPublish(t *testing.T) { require.NoError(t, err) <-dataReceived + + list, err := s.APIConnsList() + require.NoError(t, err) + require.Equal(t, &defs.APIRTMPConnList{ + Items: []defs.APIRTMPConn{ + { + ID: list.Items[0].ID, + Created: list.Items[0].Created, + RemoteAddr: list.Items[0].RemoteAddr, + State: "publish", + Path: "teststream", + Query: "user=myuser&pass=mypass¶m=value", + User: "myuser", + BytesReceived: list.Items[0].BytesReceived, + BytesSent: list.Items[0].BytesSent, + }, + }, + }, list) }) } } @@ -228,12 +250,12 @@ func TestServerRead(t *testing.T) { require.NoError(t, err) pathManager := &test.PathManager{ - AddReaderImpl: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { + AddReaderImpl: func(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) { require.Equal(t, "teststream", req.AccessRequest.Name) require.Equal(t, "user=myuser&pass=mypass¶m=value", req.AccessRequest.Query) require.Equal(t, "myuser", req.AccessRequest.Credentials.User) require.Equal(t, "mypass", req.AccessRequest.Credentials.Pass) - return &dummyPath{}, strm, nil + return &defs.PathAddReaderRes{Path: &dummyPath{}, User: req.AccessRequest.Credentials.User, Stream: strm}, nil }, } @@ -326,6 +348,24 @@ func TestServerRead(t *testing.T) { err = r.Read() require.NoError(t, err) + + list, err := s.APIConnsList() + require.NoError(t, err) + require.Equal(t, &defs.APIRTMPConnList{ + Items: []defs.APIRTMPConn{ + { + ID: list.Items[0].ID, + Created: list.Items[0].Created, + RemoteAddr: list.Items[0].RemoteAddr, + State: "read", + Path: "teststream", + Query: "user=myuser&pass=mypass¶m=value", + User: "myuser", + BytesReceived: list.Items[0].BytesReceived, + BytesSent: list.Items[0].BytesSent, + }, + }, + }, list) }) } } diff --git a/internal/servers/rtsp/server.go b/internal/servers/rtsp/server.go index f2124d46..58b89c6a 100644 --- a/internal/servers/rtsp/server.go +++ b/internal/servers/rtsp/server.go @@ -25,7 +25,6 @@ import ( "github.com/bluenviron/mediamtx/internal/externalcmd" "github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/packetdumper" - "github.com/bluenviron/mediamtx/internal/stream" ) // ErrConnNotFound is returned when a connection is not found. @@ -78,10 +77,10 @@ type serverMetrics interface { } type serverPathManager interface { - FindPathConf(req defs.PathFindPathConfReq) (*conf.Path, error) + FindPathConf(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) Describe(req defs.PathDescribeReq) defs.PathDescribeRes - AddPublisher(_ defs.PathAddPublisherReq) (defs.Path, *stream.SubStream, error) - AddReader(_ defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) + AddPublisher(_ defs.PathAddPublisherReq) (*defs.PathAddPublisherRes, error) + AddReader(_ defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) } type serverParent interface { diff --git a/internal/servers/rtsp/server_test.go b/internal/servers/rtsp/server_test.go index 846190d6..8e6b230e 100644 --- a/internal/servers/rtsp/server_test.go +++ b/internal/servers/rtsp/server_test.go @@ -23,6 +23,12 @@ import ( "github.com/stretchr/testify/require" ) +func ptrOf[T any](v T) *T { + p := new(T) + *p = v + return p +} + type dummyPath struct{} func (p *dummyPath) Name() string { @@ -56,7 +62,7 @@ func TestServerPublish(t *testing.T) { n := 0 pathManager := &test.PathManager{ - FindPathConfImpl: func(req defs.PathFindPathConfReq) (*conf.Path, error) { + FindPathConfImpl: func(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) { require.Equal(t, "teststream", req.AccessRequest.Name) require.Equal(t, "param=value", req.AccessRequest.Query) @@ -79,9 +85,9 @@ func TestServerPublish(t *testing.T) { require.True(t, ok) } - return &conf.Path{}, nil + return &defs.PathFindPathConfRes{Conf: &conf.Path{}, User: req.AccessRequest.Credentials.User}, nil }, - AddPublisherImpl: func(req defs.PathAddPublisherReq) (defs.Path, *stream.SubStream, error) { + AddPublisherImpl: func(req defs.PathAddPublisherReq) (*defs.PathAddPublisherRes, error) { require.Equal(t, "teststream", req.AccessRequest.Name) require.Equal(t, "param=value", req.AccessRequest.Query) require.True(t, req.AccessRequest.SkipAuth) @@ -119,7 +125,7 @@ func TestServerPublish(t *testing.T) { strm.AddReader(reader) - return &dummyPath{}, subStream, nil + return &defs.PathAddPublisherRes{Path: &dummyPath{}, SubStream: subStream}, nil }, } @@ -171,6 +177,28 @@ func TestServerPublish(t *testing.T) { require.NoError(t, err) <-dataReceived + + list, err := s.APISessionsList() + require.NoError(t, err) + require.Equal(t, &defs.APIRTSPSessionList{ + Items: []defs.APIRTSPSession{ + { + ID: list.Items[0].ID, + Created: list.Items[0].Created, + RemoteAddr: list.Items[0].RemoteAddr, + State: "publish", + Path: "teststream", + Query: "param=value", + User: "myuser", + BytesReceived: list.Items[0].BytesReceived, + BytesSent: list.Items[0].BytesSent, + Conns: list.Items[0].Conns, + RTPPacketsReceived: list.Items[0].RTPPacketsReceived, + Transport: ptrOf("TCP"), + Profile: ptrOf("AVP"), + }, + }, + }, list) }) } } @@ -228,7 +256,7 @@ func TestServerRead(t *testing.T) { Err: nil, } }, - AddReaderImpl: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { + AddReaderImpl: func(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) { require.Equal(t, "teststream", req.AccessRequest.Name) require.Equal(t, "param=value", req.AccessRequest.Query) @@ -241,7 +269,7 @@ func TestServerRead(t *testing.T) { require.True(t, ok) } - return &dummyPath{}, strm, nil + return &defs.PathAddReaderRes{Path: &dummyPath{}, User: req.AccessRequest.Credentials.User, Stream: strm}, nil }, } @@ -321,6 +349,29 @@ func TestServerRead(t *testing.T) { }) <-recv + + list, err := s.APISessionsList() + require.NoError(t, err) + require.Equal(t, &defs.APIRTSPSessionList{ + Items: []defs.APIRTSPSession{ + { + ID: list.Items[0].ID, + Created: list.Items[0].Created, + RemoteAddr: list.Items[0].RemoteAddr, + State: "read", + Path: "teststream", + Query: "param=value", + User: "myuser", + BytesReceived: list.Items[0].BytesReceived, + BytesSent: list.Items[0].BytesSent, + Conns: list.Items[0].Conns, + RTPPacketsReceived: list.Items[0].RTPPacketsReceived, + RTPPacketsSent: list.Items[0].RTPPacketsSent, + Transport: ptrOf("TCP"), + Profile: ptrOf("AVP"), + }, + }, + }, list) }) } } diff --git a/internal/servers/rtsp/session.go b/internal/servers/rtsp/session.go index 95f28943..57947e8e 100644 --- a/internal/servers/rtsp/session.go +++ b/internal/servers/rtsp/session.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "slices" + "sync" "time" "github.com/bluenviron/gortsplib/v5" @@ -61,6 +62,8 @@ type session struct { packetsLost *counterdumper.Dumper decodeErrors *errordumper.Dumper discardedFrames *counterdumper.Dumper + mutex sync.RWMutex + user string } func (s *session) initialize() { @@ -169,7 +172,7 @@ func (s *session) onAnnounce(c *conn, ctx *gortsplib.ServerHandlerOnAnnounceCtx) } } - pathConf, err := s.pathManager.FindPathConf(defs.PathFindPathConfReq{ + res, err := s.pathManager.FindPathConf(defs.PathFindPathConfReq{ AccessRequest: defs.PathAccessRequest{ Name: ctx.Path, Query: ctx.Query, @@ -192,7 +195,11 @@ func (s *session) onAnnounce(c *conn, ctx *gortsplib.ServerHandlerOnAnnounceCtx) }, err } - s.pathConf = pathConf + s.pathConf = res.Conf + + s.mutex.Lock() + s.user = res.User + s.mutex.Unlock() return &base.Response{ StatusCode: base.StatusOK, @@ -239,7 +246,7 @@ func (s *session) onSetup(c *conn, ctx *gortsplib.ServerHandlerOnSetupCtx, switch s.rsession.State() { case gortsplib.ServerSessionStateInitial: // play - path, stream, err := s.pathManager.AddReader(defs.PathAddReaderReq{ + res, err := s.pathManager.AddReader(defs.PathAddReaderReq{ Author: s, AccessRequest: defs.PathAccessRequest{ Name: ctx.Path, @@ -270,8 +277,12 @@ func (s *session) onSetup(c *conn, ctx *gortsplib.ServerHandlerOnSetupCtx, }, nil, err } - s.path = path - s.stream = stream + s.path = res.Path + s.stream = res.Stream + + s.mutex.Lock() + s.user = res.User + s.mutex.Unlock() return &base.Response{ StatusCode: base.StatusOK, @@ -317,7 +328,7 @@ func (s *session) onPlay(_ *gortsplib.ServerHandlerOnPlayCtx) (*base.Response, e // onRecord is called by rtspServer. func (s *session) onRecord(_ *gortsplib.ServerHandlerOnRecordCtx) (*base.Response, error) { - path, subStream, err := s.pathManager.AddPublisher(defs.PathAddPublisherReq{ + res, err := s.pathManager.AddPublisher(defs.PathAddPublisherReq{ Author: s, Desc: s.rsession.AnnouncedDescription(), UseRTPPackets: true, @@ -339,12 +350,12 @@ func (s *session) onRecord(_ *gortsplib.ServerHandlerOnRecordCtx) (*base.Respons rtsp.ToStream( s.rsession, s.rsession.AnnouncedDescription().Medias, - path.SafeConf(), + res.Path.SafeConf(), &s.subStream, s) - s.path = path - s.subStream = subStream + s.path = res.Path + s.subStream = res.SubStream return &base.Response{ StatusCode: base.StatusOK, @@ -411,6 +422,9 @@ func (s *session) onStreamWriteError(_ *gortsplib.ServerHandlerOnStreamWriteErro func (s *session) apiItem() *defs.APIRTSPSession { stats := s.rsession.Stats() + s.mutex.RLock() + defer s.mutex.RUnlock() + return &defs.APIRTSPSession{ ID: s.uuid, Created: s.created, @@ -436,6 +450,7 @@ func (s *session) apiItem() *defs.APIRTSPSession { return "" }(), Query: s.rsession.Query(), + User: s.user, Transport: func() *string { transport := s.rsession.Transport() if transport == nil { diff --git a/internal/servers/srt/conn.go b/internal/servers/srt/conn.go index b4c74b64..1e5b3661 100644 --- a/internal/servers/srt/conn.go +++ b/internal/servers/srt/conn.go @@ -64,6 +64,7 @@ type conn struct { state defs.APISRTConnState pathName string query string + user string sconn srt.Conn } @@ -131,7 +132,7 @@ func (c *conn) runInner() error { } func (c *conn) runPublish(streamID *streamID) error { - pathConf, err := c.pathManager.FindPathConf(defs.PathFindPathConfReq{ + res, err := c.pathManager.FindPathConf(defs.PathFindPathConfReq{ AccessRequest: defs.PathAccessRequest{ Name: streamID.path, Query: streamID.query, @@ -157,7 +158,11 @@ func (c *conn) runPublish(streamID *streamID) error { return err } - err = srtCheckPassphrase(c.connReq, pathConf.SRTPublishPassphrase) + c.mutex.Lock() + c.user = res.User + c.mutex.Unlock() + + err = srtCheckPassphrase(c.connReq, res.Conf.SRTPublishPassphrase) if err != nil { c.connReq.Reject(srt.REJ_PEER) return err @@ -170,7 +175,7 @@ func (c *conn) runPublish(streamID *streamID) error { readerErr := make(chan error) go func() { - readerErr <- c.runPublishReader(sconn, streamID, pathConf) + readerErr <- c.runPublishReader(sconn, streamID, res.Conf) }() select { @@ -217,8 +222,7 @@ func (c *conn) runPublishReader(sconn srt.Conn, streamID *streamID, pathConf *co return err } - var path defs.Path - path, subStream, err = c.pathManager.AddPublisher(defs.PathAddPublisherReq{ + res, err := c.pathManager.AddPublisher(defs.PathAddPublisherReq{ Author: c, Desc: &description.Session{Medias: medias}, UseRTPPackets: false, @@ -235,7 +239,9 @@ func (c *conn) runPublishReader(sconn srt.Conn, streamID *streamID, pathConf *co return err } - defer path.RemovePublisher(defs.PathRemovePublisherReq{Author: c}) + defer res.Path.RemovePublisher(defs.PathRemovePublisherReq{Author: c}) + + subStream = res.SubStream c.mutex.Lock() c.state = defs.APISRTConnStatePublish @@ -253,7 +259,7 @@ func (c *conn) runPublishReader(sconn srt.Conn, streamID *streamID, pathConf *co } func (c *conn) runRead(streamID *streamID) error { - path, strm, err := c.pathManager.AddReader(defs.PathAddReaderReq{ + res, err := c.pathManager.AddReader(defs.PathAddReaderReq{ Author: c, AccessRequest: defs.PathAccessRequest{ Name: streamID.path, @@ -279,9 +285,9 @@ func (c *conn) runRead(streamID *streamID) error { return err } - defer path.RemoveReader(defs.PathRemoveReaderReq{Author: c}) + defer res.Path.RemoveReader(defs.PathRemoveReaderReq{Author: c}) - err = srtCheckPassphrase(c.connReq, path.SafeConf().SRTReadPassphrase) + err = srtCheckPassphrase(c.connReq, res.Path.SafeConf().SRTReadPassphrase) if err != nil { c.connReq.Reject(srt.REJ_PEER) return err @@ -297,7 +303,7 @@ func (c *conn) runRead(streamID *streamID) error { r := &stream.Reader{Parent: c} - err = mpegts.FromStream(strm.Desc, r, bw, sconn, time.Duration(c.writeTimeout)) + err = mpegts.FromStream(res.Stream.Desc, r, bw, sconn, time.Duration(c.writeTimeout)) if err != nil { return err } @@ -306,17 +312,18 @@ func (c *conn) runRead(streamID *streamID) error { c.state = defs.APISRTConnStateRead c.pathName = streamID.path c.query = streamID.query + c.user = res.User c.sconn = sconn c.mutex.Unlock() c.Log(logger.Info, "is reading from path '%s', %s", - path.Name(), defs.FormatsInfo(r.Formats())) + res.Path.Name(), defs.FormatsInfo(r.Formats())) onUnreadHook := hooks.OnRead(hooks.OnReadParams{ Logger: c, ExternalCmdPool: c.externalCmdPool, - Conf: path.SafeConf(), - ExternalCmdEnv: path.ExternalCmdEnv(), + Conf: res.Path.SafeConf(), + ExternalCmdEnv: res.Path.ExternalCmdEnv(), Reader: *c.APIReaderDescribe(), Query: streamID.query, }) @@ -325,8 +332,8 @@ func (c *conn) runRead(streamID *streamID) error { // disable read deadline sconn.SetReadDeadline(time.Time{}) - strm.AddReader(r) - defer strm.RemoveReader(r) + res.Stream.AddReader(r) + defer res.Stream.RemoveReader(r) select { case <-c.ctx.Done(): @@ -364,6 +371,7 @@ func (c *conn) apiItem() *defs.APISRTConn { State: c.state, Path: c.pathName, Query: c.query, + User: c.user, } if c.sconn != nil { diff --git a/internal/servers/srt/server.go b/internal/servers/srt/server.go index c2e9bc52..f69a0ad0 100644 --- a/internal/servers/srt/server.go +++ b/internal/servers/srt/server.go @@ -17,7 +17,6 @@ import ( "github.com/bluenviron/mediamtx/internal/defs" "github.com/bluenviron/mediamtx/internal/externalcmd" "github.com/bluenviron/mediamtx/internal/logger" - "github.com/bluenviron/mediamtx/internal/stream" ) // ErrConnNotFound is returned when a connection is not found. @@ -64,9 +63,9 @@ type serverMetrics interface { } type serverPathManager interface { - FindPathConf(req defs.PathFindPathConfReq) (*conf.Path, error) - AddPublisher(req defs.PathAddPublisherReq) (defs.Path, *stream.SubStream, error) - AddReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) + FindPathConf(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) + AddPublisher(req defs.PathAddPublisherReq) (*defs.PathAddPublisherRes, error) + AddReader(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) } type serverParent interface { diff --git a/internal/servers/srt/server_test.go b/internal/servers/srt/server_test.go index 1bae58cf..11389dfa 100644 --- a/internal/servers/srt/server_test.go +++ b/internal/servers/srt/server_test.go @@ -53,14 +53,14 @@ func TestServerPublish(t *testing.T) { n := 0 pathManager := &test.PathManager{ - FindPathConfImpl: func(req defs.PathFindPathConfReq) (*conf.Path, error) { + FindPathConfImpl: func(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) { require.Equal(t, "teststream", req.AccessRequest.Name) require.Equal(t, "param=value", req.AccessRequest.Query) require.Equal(t, "myuser", req.AccessRequest.Credentials.User) require.Equal(t, "mypass", req.AccessRequest.Credentials.Pass) - return &conf.Path{}, nil + return &defs.PathFindPathConfRes{Conf: &conf.Path{}, User: req.AccessRequest.Credentials.User}, nil }, - AddPublisherImpl: func(req defs.PathAddPublisherReq) (defs.Path, *stream.SubStream, error) { + AddPublisherImpl: func(req defs.PathAddPublisherReq) (*defs.PathAddPublisherRes, error) { require.Equal(t, "teststream", req.AccessRequest.Name) require.Equal(t, "param=value", req.AccessRequest.Query) require.True(t, req.AccessRequest.SkipAuth) @@ -113,7 +113,7 @@ func TestServerPublish(t *testing.T) { strm.AddReader(reader) - return &dummyPath{}, subStream, nil + return &defs.PathAddPublisherRes{Path: &dummyPath{}, SubStream: subStream}, nil }, } @@ -174,6 +174,75 @@ func TestServerPublish(t *testing.T) { <-dataReceived + list, err := s.APIConnsList() + require.NoError(t, err) + require.Equal(t, &defs.APISRTConnList{ //nolint:dupl + Items: []defs.APISRTConn{ + { + ID: list.Items[0].ID, + Created: list.Items[0].Created, + RemoteAddr: list.Items[0].RemoteAddr, + State: "publish", + Path: "teststream", + Query: "param=value", + User: "myuser", + PacketsSent: list.Items[0].PacketsSent, + PacketsReceived: list.Items[0].PacketsReceived, + PacketsSentUnique: list.Items[0].PacketsSentUnique, + PacketsReceivedUnique: list.Items[0].PacketsReceivedUnique, + PacketsSendLoss: list.Items[0].PacketsSendLoss, + PacketsReceivedLoss: list.Items[0].PacketsReceivedLoss, + PacketsRetrans: list.Items[0].PacketsRetrans, + PacketsReceivedRetrans: list.Items[0].PacketsReceivedRetrans, + PacketsSentACK: list.Items[0].PacketsSentACK, + PacketsReceivedACK: list.Items[0].PacketsReceivedACK, + PacketsSentNAK: list.Items[0].PacketsSentNAK, + PacketsReceivedNAK: list.Items[0].PacketsReceivedNAK, + PacketsSentKM: list.Items[0].PacketsSentKM, + PacketsReceivedKM: list.Items[0].PacketsReceivedKM, + UsSndDuration: list.Items[0].UsSndDuration, + PacketsReceivedBelated: list.Items[0].PacketsReceivedBelated, + PacketsSendDrop: list.Items[0].PacketsSendDrop, + PacketsReceivedDrop: list.Items[0].PacketsReceivedDrop, + PacketsReceivedUndecrypt: list.Items[0].PacketsReceivedUndecrypt, + BytesReceived: list.Items[0].BytesReceived, + BytesSent: list.Items[0].BytesSent, + BytesSentUnique: list.Items[0].BytesSentUnique, + BytesReceivedUnique: list.Items[0].BytesReceivedUnique, + BytesReceivedLoss: list.Items[0].BytesReceivedLoss, + BytesRetrans: list.Items[0].BytesRetrans, + BytesReceivedRetrans: list.Items[0].BytesReceivedRetrans, + BytesReceivedBelated: list.Items[0].BytesReceivedBelated, + BytesSendDrop: list.Items[0].BytesSendDrop, + BytesReceivedDrop: list.Items[0].BytesReceivedDrop, + BytesReceivedUndecrypt: list.Items[0].BytesReceivedUndecrypt, + UsPacketsSendPeriod: list.Items[0].UsPacketsSendPeriod, + PacketsFlowWindow: list.Items[0].PacketsFlowWindow, + PacketsFlightSize: list.Items[0].PacketsFlightSize, + MsRTT: list.Items[0].MsRTT, + MbpsSendRate: list.Items[0].MbpsSendRate, + MbpsReceiveRate: list.Items[0].MbpsReceiveRate, + MbpsLinkCapacity: list.Items[0].MbpsLinkCapacity, + BytesAvailSendBuf: list.Items[0].BytesAvailSendBuf, + BytesAvailReceiveBuf: list.Items[0].BytesAvailReceiveBuf, + MbpsMaxBW: list.Items[0].MbpsMaxBW, + ByteMSS: list.Items[0].ByteMSS, + PacketsSendBuf: list.Items[0].PacketsSendBuf, + BytesSendBuf: list.Items[0].BytesSendBuf, + MsSendBuf: list.Items[0].MsSendBuf, + MsSendTsbPdDelay: list.Items[0].MsSendTsbPdDelay, + PacketsReceiveBuf: list.Items[0].PacketsReceiveBuf, + BytesReceiveBuf: list.Items[0].BytesReceiveBuf, + MsReceiveBuf: list.Items[0].MsReceiveBuf, + MsReceiveTsbPdDelay: list.Items[0].MsReceiveTsbPdDelay, + PacketsReorderTolerance: list.Items[0].PacketsReorderTolerance, + PacketsReceivedAvgBelatedTime: list.Items[0].PacketsReceivedAvgBelatedTime, + PacketsSendLossRate: list.Items[0].PacketsSendLossRate, + PacketsReceivedLossRate: list.Items[0].PacketsReceivedLossRate, + }, + }, + }, list) + // the second PES is written after writer is closed publisher.Close() <-dataReceived2 @@ -203,12 +272,12 @@ func TestServerRead(t *testing.T) { require.NoError(t, err) pathManager := &test.PathManager{ - AddReaderImpl: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { + AddReaderImpl: func(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) { require.Equal(t, "teststream", req.AccessRequest.Name) require.Equal(t, "param=value", req.AccessRequest.Query) require.Equal(t, "myuser", req.AccessRequest.Credentials.User) require.Equal(t, "mypass", req.AccessRequest.Credentials.Pass) - return &dummyPath{}, strm, nil + return &defs.PathAddReaderRes{Path: &dummyPath{}, User: req.AccessRequest.Credentials.User, Stream: strm}, nil }, } @@ -288,4 +357,73 @@ func TestServerRead(t *testing.T) { break } } + + list, err := s.APIConnsList() + require.NoError(t, err) + require.Equal(t, &defs.APISRTConnList{ //nolint:dupl + Items: []defs.APISRTConn{ + { + ID: list.Items[0].ID, + Created: list.Items[0].Created, + RemoteAddr: list.Items[0].RemoteAddr, + State: "read", + Path: "teststream", + Query: "param=value", + User: "myuser", + PacketsSent: list.Items[0].PacketsSent, + PacketsReceived: list.Items[0].PacketsReceived, + PacketsSentUnique: list.Items[0].PacketsSentUnique, + PacketsReceivedUnique: list.Items[0].PacketsReceivedUnique, + PacketsSendLoss: list.Items[0].PacketsSendLoss, + PacketsReceivedLoss: list.Items[0].PacketsReceivedLoss, + PacketsRetrans: list.Items[0].PacketsRetrans, + PacketsReceivedRetrans: list.Items[0].PacketsReceivedRetrans, + PacketsSentACK: list.Items[0].PacketsSentACK, + PacketsReceivedACK: list.Items[0].PacketsReceivedACK, + PacketsSentNAK: list.Items[0].PacketsSentNAK, + PacketsReceivedNAK: list.Items[0].PacketsReceivedNAK, + PacketsSentKM: list.Items[0].PacketsSentKM, + PacketsReceivedKM: list.Items[0].PacketsReceivedKM, + UsSndDuration: list.Items[0].UsSndDuration, + PacketsReceivedBelated: list.Items[0].PacketsReceivedBelated, + PacketsSendDrop: list.Items[0].PacketsSendDrop, + PacketsReceivedDrop: list.Items[0].PacketsReceivedDrop, + PacketsReceivedUndecrypt: list.Items[0].PacketsReceivedUndecrypt, + BytesReceived: list.Items[0].BytesReceived, + BytesSent: list.Items[0].BytesSent, + BytesSentUnique: list.Items[0].BytesSentUnique, + BytesReceivedUnique: list.Items[0].BytesReceivedUnique, + BytesReceivedLoss: list.Items[0].BytesReceivedLoss, + BytesRetrans: list.Items[0].BytesRetrans, + BytesReceivedRetrans: list.Items[0].BytesReceivedRetrans, + BytesReceivedBelated: list.Items[0].BytesReceivedBelated, + BytesSendDrop: list.Items[0].BytesSendDrop, + BytesReceivedDrop: list.Items[0].BytesReceivedDrop, + BytesReceivedUndecrypt: list.Items[0].BytesReceivedUndecrypt, + UsPacketsSendPeriod: list.Items[0].UsPacketsSendPeriod, + PacketsFlowWindow: list.Items[0].PacketsFlowWindow, + PacketsFlightSize: list.Items[0].PacketsFlightSize, + MsRTT: list.Items[0].MsRTT, + MbpsSendRate: list.Items[0].MbpsSendRate, + MbpsReceiveRate: list.Items[0].MbpsReceiveRate, + MbpsLinkCapacity: list.Items[0].MbpsLinkCapacity, + BytesAvailSendBuf: list.Items[0].BytesAvailSendBuf, + BytesAvailReceiveBuf: list.Items[0].BytesAvailReceiveBuf, + MbpsMaxBW: list.Items[0].MbpsMaxBW, + ByteMSS: list.Items[0].ByteMSS, + PacketsSendBuf: list.Items[0].PacketsSendBuf, + BytesSendBuf: list.Items[0].BytesSendBuf, + MsSendBuf: list.Items[0].MsSendBuf, + MsSendTsbPdDelay: list.Items[0].MsSendTsbPdDelay, + PacketsReceiveBuf: list.Items[0].PacketsReceiveBuf, + BytesReceiveBuf: list.Items[0].BytesReceiveBuf, + MsReceiveBuf: list.Items[0].MsReceiveBuf, + MsReceiveTsbPdDelay: list.Items[0].MsReceiveTsbPdDelay, + PacketsReorderTolerance: list.Items[0].PacketsReorderTolerance, + PacketsReceivedAvgBelatedTime: list.Items[0].PacketsReceivedAvgBelatedTime, + PacketsSendLossRate: list.Items[0].PacketsSendLossRate, + PacketsReceivedLossRate: list.Items[0].PacketsReceivedLossRate, + }, + }, + }, list) } diff --git a/internal/servers/webrtc/server.go b/internal/servers/webrtc/server.go index 979d9260..9758bdd2 100644 --- a/internal/servers/webrtc/server.go +++ b/internal/servers/webrtc/server.go @@ -29,7 +29,6 @@ import ( "github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/protocols/webrtc" "github.com/bluenviron/mediamtx/internal/restrictnetwork" - "github.com/bluenviron/mediamtx/internal/stream" ) const ( @@ -175,9 +174,9 @@ type serverMetrics interface { } type serverPathManager interface { - FindPathConf(req defs.PathFindPathConfReq) (*conf.Path, error) - AddPublisher(req defs.PathAddPublisherReq) (defs.Path, *stream.SubStream, error) - AddReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) + FindPathConf(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) + AddPublisher(req defs.PathAddPublisherReq) (*defs.PathAddPublisherRes, error) + AddReader(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) } type serverParent interface { diff --git a/internal/servers/webrtc/server_test.go b/internal/servers/webrtc/server_test.go index 97ae77cd..a8e49da3 100644 --- a/internal/servers/webrtc/server_test.go +++ b/internal/servers/webrtc/server_test.go @@ -62,8 +62,8 @@ func (p *dummyPath) RemoveReader(_ defs.PathRemoveReaderReq) { func initializeTestServer(t *testing.T) *Server { pm := &test.PathManager{ - FindPathConfImpl: func(_ defs.PathFindPathConfReq) (*conf.Path, error) { - return &conf.Path{}, nil + FindPathConfImpl: func(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) { + return &defs.PathFindPathConfRes{Conf: &conf.Path{}, User: req.AccessRequest.Credentials.User}, nil }, } @@ -144,8 +144,8 @@ func TestPreflightRequest(t *testing.T) { func TestServerOptionsICEServer(t *testing.T) { pathManager := &test.PathManager{ - FindPathConfImpl: func(_ defs.PathFindPathConfReq) (*conf.Path, error) { - return &conf.Path{}, nil + FindPathConfImpl: func(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) { + return &defs.PathFindPathConfRes{Conf: &conf.Path{}, User: req.AccessRequest.Credentials.User}, nil }, } @@ -207,14 +207,14 @@ func TestServerPublish(t *testing.T) { dataReceived := make(chan struct{}) pathManager := &test.PathManager{ - FindPathConfImpl: func(req defs.PathFindPathConfReq) (*conf.Path, error) { + FindPathConfImpl: func(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) { require.Equal(t, "teststream", req.AccessRequest.Name) require.Equal(t, "param=value", req.AccessRequest.Query) require.Equal(t, "myuser", req.AccessRequest.Credentials.User) require.Equal(t, "mypass", req.AccessRequest.Credentials.Pass) - return &conf.Path{}, nil + return &defs.PathFindPathConfRes{Conf: &conf.Path{}, User: req.AccessRequest.Credentials.User}, nil }, - AddPublisherImpl: func(req defs.PathAddPublisherReq) (defs.Path, *stream.SubStream, error) { + AddPublisherImpl: func(req defs.PathAddPublisherReq) (*defs.PathAddPublisherRes, error) { require.Equal(t, "teststream", req.AccessRequest.Name) require.Equal(t, "param=value", req.AccessRequest.Query) require.True(t, req.AccessRequest.SkipAuth) @@ -255,7 +255,7 @@ func TestServerPublish(t *testing.T) { strm.AddReader(reader) - return &dummyPath{}, subStream, nil + return &defs.PathAddPublisherRes{Path: &dummyPath{}, SubStream: subStream}, nil }, } @@ -321,6 +321,31 @@ func TestServerPublish(t *testing.T) { require.NoError(t, err) <-dataReceived + + list, err := s.APISessionsList() + require.NoError(t, err) + require.Equal(t, &defs.APIWebRTCSessionList{ + Items: []defs.APIWebRTCSession{ + { + ID: list.Items[0].ID, + Created: list.Items[0].Created, + RemoteAddr: list.Items[0].RemoteAddr, + State: "publish", + Path: "teststream", + Query: "param=value", + User: "myuser", + BytesReceived: list.Items[0].BytesReceived, + BytesSent: list.Items[0].BytesSent, + RTPPacketsReceived: list.Items[0].RTPPacketsReceived, + RTPPacketsSent: list.Items[0].RTPPacketsSent, + RTCPPacketsReceived: list.Items[0].RTCPPacketsReceived, + RTCPPacketsSent: list.Items[0].RTCPPacketsSent, + PeerConnectionEstablished: true, + LocalCandidate: list.Items[0].LocalCandidate, + RemoteCandidate: list.Items[0].RemoteCandidate, + }, + }, + }, list) } func TestServerRead(t *testing.T) { @@ -491,19 +516,19 @@ func TestServerRead(t *testing.T) { require.NoError(t, err) pathManager := &test.PathManager{ - FindPathConfImpl: func(req defs.PathFindPathConfReq) (*conf.Path, error) { + FindPathConfImpl: func(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) { require.Equal(t, "teststream", req.AccessRequest.Name) require.Equal(t, "param=value", req.AccessRequest.Query) require.Equal(t, "myuser", req.AccessRequest.Credentials.User) require.Equal(t, "mypass", req.AccessRequest.Credentials.Pass) - return &conf.Path{}, nil + return &defs.PathFindPathConfRes{Conf: &conf.Path{}, User: req.AccessRequest.Credentials.User}, nil }, - AddReaderImpl: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { + AddReaderImpl: func(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) { require.Equal(t, "teststream", req.AccessRequest.Name) require.Equal(t, "param=value", req.AccessRequest.Query) require.Equal(t, "myuser", req.AccessRequest.Credentials.User) require.Equal(t, "mypass", req.AccessRequest.Credentials.Pass) - return &dummyPath{}, strm, nil + return &defs.PathAddReaderRes{Path: &dummyPath{}, User: req.AccessRequest.Credentials.User, Stream: strm}, nil }, } @@ -581,17 +606,42 @@ func TestServerRead(t *testing.T) { <-writerDone <-done + + list, err := s.APISessionsList() + require.NoError(t, err) + require.Equal(t, &defs.APIWebRTCSessionList{ + Items: []defs.APIWebRTCSession{ + { + ID: list.Items[0].ID, + Created: list.Items[0].Created, + RemoteAddr: list.Items[0].RemoteAddr, + State: "read", + Path: "teststream", + Query: "param=value", + User: "myuser", + BytesReceived: list.Items[0].BytesReceived, + BytesSent: list.Items[0].BytesSent, + RTPPacketsReceived: list.Items[0].RTPPacketsReceived, + RTPPacketsSent: list.Items[0].RTPPacketsSent, + RTCPPacketsReceived: list.Items[0].RTCPPacketsReceived, + RTCPPacketsSent: list.Items[0].RTCPPacketsSent, + PeerConnectionEstablished: true, + LocalCandidate: list.Items[0].LocalCandidate, + RemoteCandidate: list.Items[0].RemoteCandidate, + }, + }, + }, list) }) } } func TestServerReadNotFound(t *testing.T) { pm := &test.PathManager{ - FindPathConfImpl: func(_ defs.PathFindPathConfReq) (*conf.Path, error) { - return &conf.Path{}, nil + FindPathConfImpl: func(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) { + return &defs.PathFindPathConfRes{Conf: &conf.Path{}, User: req.AccessRequest.Credentials.User}, nil }, - AddReaderImpl: func(_ defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { - return nil, nil, defs.PathNoStreamAvailableError{} + AddReaderImpl: func(_ defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) { + return nil, defs.PathNoStreamAvailableError{} }, } @@ -750,7 +800,7 @@ func TestAuthError(t *testing.T) { ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), PathManager: &test.PathManager{ - FindPathConfImpl: func(req defs.PathFindPathConfReq) (*conf.Path, error) { + FindPathConfImpl: func(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) { if req.AccessRequest.Credentials.User == "" && req.AccessRequest.Credentials.Pass == "" { return nil, &auth.Error{AskCredentials: true, Wrapped: fmt.Errorf("auth error")} } diff --git a/internal/servers/webrtc/session.go b/internal/servers/webrtc/session.go index db9b5141..2743043a 100644 --- a/internal/servers/webrtc/session.go +++ b/internal/servers/webrtc/session.go @@ -64,6 +64,7 @@ type session struct { secret uuid.UUID mutex sync.RWMutex pc *webrtc.PeerConnection + user string chNew chan webRTCNewSessionReq chAddCandidates chan webRTCAddSessionCandidatesReq @@ -138,7 +139,7 @@ func (s *session) runInner2() (int, error) { func (s *session) runPublish() (int, error) { ip, _, _ := net.SplitHostPort(s.req.remoteAddr) - pathConf, err := s.pathManager.FindPathConf(defs.PathFindPathConfReq{ + res1, err := s.pathManager.FindPathConf(defs.PathFindPathConfReq{ AccessRequest: defs.PathAccessRequest{ Name: s.req.pathName, Query: s.req.httpRequest.URL.RawQuery, @@ -153,6 +154,10 @@ func (s *session) runPublish() (int, error) { return http.StatusBadRequest, err } + s.mutex.Lock() + s.user = res1.User + s.mutex.Unlock() + iceServers, err := s.parent.generateICEServers(false) if err != nil { return http.StatusInternalServerError, err @@ -233,18 +238,17 @@ func (s *session) runPublish() (int, error) { var subStream *stream.SubStream - medias, err := webrtc.ToStream(pc, pathConf, &subStream, s) + medias, err := webrtc.ToStream(pc, res1.Conf, &subStream, s) if err != nil { return 0, err } - var path defs.Path - path, subStream, err = s.pathManager.AddPublisher(defs.PathAddPublisherReq{ + res2, err := s.pathManager.AddPublisher(defs.PathAddPublisherReq{ Author: s, Desc: &description.Session{Medias: medias}, UseRTPPackets: true, - ReplaceNTP: !pathConf.UseAbsoluteTimestamp, - ConfToCompare: pathConf, + ReplaceNTP: !res1.Conf.UseAbsoluteTimestamp, + ConfToCompare: res1.Conf, AccessRequest: defs.PathAccessRequest{ Name: s.req.pathName, Query: s.req.httpRequest.URL.RawQuery, @@ -256,7 +260,9 @@ func (s *session) runPublish() (int, error) { return 0, err } - defer path.RemovePublisher(defs.PathRemovePublisherReq{Author: s}) + defer res2.Path.RemovePublisher(defs.PathRemovePublisherReq{Author: s}) + + subStream = res2.SubStream pc.StartReading() @@ -281,7 +287,7 @@ func (s *session) runRead() (int, error) { IP: net.ParseIP(ip), } - path, strm, err := s.pathManager.AddReader(defs.PathAddReaderReq{ + res, err := s.pathManager.AddReader(defs.PathAddReaderReq{ Author: s, AccessRequest: req, }) @@ -294,7 +300,11 @@ func (s *session) runRead() (int, error) { return http.StatusBadRequest, err } - defer path.RemoveReader(defs.PathRemoveReaderReq{Author: s}) + defer res.Path.RemoveReader(defs.PathRemoveReaderReq{Author: s}) + + s.mutex.Lock() + s.user = res.User + s.mutex.Unlock() iceServers, err := s.parent.generateICEServers(false) if err != nil { @@ -316,7 +326,7 @@ func (s *session) runRead() (int, error) { r := &stream.Reader{Parent: s} - err = webrtc.FromStream(strm.Desc, r, pc) + err = webrtc.FromStream(res.Stream.Desc, r, pc) if err != nil { return http.StatusBadRequest, err } @@ -362,20 +372,20 @@ func (s *session) runRead() (int, error) { s.mutex.Unlock() s.Log(logger.Info, "is reading from path '%s', %s", - path.Name(), defs.FormatsInfo(r.Formats())) + res.Path.Name(), defs.FormatsInfo(r.Formats())) onUnreadHook := hooks.OnRead(hooks.OnReadParams{ Logger: s, ExternalCmdPool: s.externalCmdPool, - Conf: path.SafeConf(), - ExternalCmdEnv: path.ExternalCmdEnv(), + Conf: res.Path.SafeConf(), + ExternalCmdEnv: res.Path.ExternalCmdEnv(), Reader: *s.APIReaderDescribe(), Query: s.req.httpRequest.URL.RawQuery, }) defer onUnreadHook() - strm.AddReader(r) - defer strm.RemoveReader(r) + res.Stream.AddReader(r) + defer res.Stream.RemoveReader(r) select { case <-pc.Failed(): @@ -500,6 +510,7 @@ func (s *session) apiItem() *defs.APIWebRTCSession { }(), Path: s.req.pathName, Query: s.req.httpRequest.URL.RawQuery, + User: s.user, BytesReceived: bytesReceived, BytesSent: bytesSent, RTPPacketsReceived: rtpPacketsReceived, diff --git a/internal/staticsources/handler.go b/internal/staticsources/handler.go index 64600f2b..a8ee24ab 100644 --- a/internal/staticsources/handler.go +++ b/internal/staticsources/handler.go @@ -19,7 +19,6 @@ import ( ssrtsp "github.com/bluenviron/mediamtx/internal/staticsources/rtsp" sssrt "github.com/bluenviron/mediamtx/internal/staticsources/srt" sswebrtc "github.com/bluenviron/mediamtx/internal/staticsources/webrtc" - "github.com/bluenviron/mediamtx/internal/stream" ) const ( @@ -51,7 +50,7 @@ type staticSource interface { } type handlerPathManager interface { - AddReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) + AddReader(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) } type handlerParent interface { @@ -334,6 +333,6 @@ func (s *Handler) SetNotReady(req defs.PathSourceStaticSetNotReadyReq) { } // AddReader is called by a staticSource. -func (s *Handler) AddReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { +func (s *Handler) AddReader(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) { return s.PathManager.AddReader(req) } diff --git a/internal/staticsources/rpicamera/source.go b/internal/staticsources/rpicamera/source.go index c288eb60..70bb2468 100644 --- a/internal/staticsources/rpicamera/source.go +++ b/internal/staticsources/rpicamera/source.go @@ -104,7 +104,7 @@ type parent interface { logger.Writer SetReady(req defs.PathSourceStaticSetReadyReq) defs.PathSourceStaticSetReadyRes SetNotReady(req defs.PathSourceStaticSetNotReadyReq) - AddReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) + AddReader(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) } // Source is a Raspberry Pi Camera static source. @@ -333,7 +333,7 @@ func (s *Source) waitForPrimary( params defs.StaticSourceRunParams, ) (defs.Path, *stream.Stream, error) { for { - path, primaryStream, err := s.Parent.AddReader(defs.PathAddReaderReq{ + res, err := s.Parent.AddReader(defs.PathAddReaderReq{ Author: r, AccessRequest: defs.PathAccessRequest{ Name: params.Conf.RPICameraPrimaryName, @@ -354,7 +354,7 @@ func (s *Source) waitForPrimary( return nil, nil, err } - return path, primaryStream, nil + return res.Path, res.Stream, nil } } diff --git a/internal/test/auth_manager.go b/internal/test/auth_manager.go index 9f404898..e2ffe722 100644 --- a/internal/test/auth_manager.go +++ b/internal/test/auth_manager.go @@ -5,12 +5,12 @@ import "github.com/bluenviron/mediamtx/internal/auth" // AuthManager is a dummy auth manager. type AuthManager struct { - AuthenticateImpl func(req *auth.Request) *auth.Error + AuthenticateImpl func(req *auth.Request) (string, *auth.Error) RefreshJWTJWKSImpl func() } // Authenticate replicates auth.Manager.Authenticate. -func (m *AuthManager) Authenticate(req *auth.Request) *auth.Error { +func (m *AuthManager) Authenticate(req *auth.Request) (string, *auth.Error) { return m.AuthenticateImpl(req) } @@ -21,7 +21,7 @@ func (m *AuthManager) RefreshJWTJWKS() { // NilAuthManager is an auth manager that accepts everything. var NilAuthManager = &AuthManager{ - AuthenticateImpl: func(_ *auth.Request) *auth.Error { - return nil + AuthenticateImpl: func(_ *auth.Request) (string, *auth.Error) { + return "", nil }, } diff --git a/internal/test/path_manager.go b/internal/test/path_manager.go index 8deb8522..b4326910 100644 --- a/internal/test/path_manager.go +++ b/internal/test/path_manager.go @@ -1,21 +1,19 @@ package test import ( - "github.com/bluenviron/mediamtx/internal/conf" "github.com/bluenviron/mediamtx/internal/defs" - "github.com/bluenviron/mediamtx/internal/stream" ) // PathManager is a dummy path manager. type PathManager struct { - FindPathConfImpl func(req defs.PathFindPathConfReq) (*conf.Path, error) + FindPathConfImpl func(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) DescribeImpl func(req defs.PathDescribeReq) defs.PathDescribeRes - AddPublisherImpl func(req defs.PathAddPublisherReq) (defs.Path, *stream.SubStream, error) - AddReaderImpl func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) + AddPublisherImpl func(req defs.PathAddPublisherReq) (*defs.PathAddPublisherRes, error) + AddReaderImpl func(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) } // FindPathConf implements PathManager. -func (pm *PathManager) FindPathConf(req defs.PathFindPathConfReq) (*conf.Path, error) { +func (pm *PathManager) FindPathConf(req defs.PathFindPathConfReq) (*defs.PathFindPathConfRes, error) { return pm.FindPathConfImpl(req) } @@ -25,11 +23,11 @@ func (pm *PathManager) Describe(req defs.PathDescribeReq) defs.PathDescribeRes { } // AddPublisher implements PathManager. -func (pm *PathManager) AddPublisher(req defs.PathAddPublisherReq) (defs.Path, *stream.SubStream, error) { +func (pm *PathManager) AddPublisher(req defs.PathAddPublisherReq) (*defs.PathAddPublisherRes, error) { return pm.AddPublisherImpl(req) } // AddReader implements PathManager. -func (pm *PathManager) AddReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { +func (pm *PathManager) AddReader(req defs.PathAddReaderReq) (*defs.PathAddReaderRes, error) { return pm.AddReaderImpl(req) }