diff --git a/.gitignore b/.gitignore index 53a5781..70e3b07 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,7 @@ tables/ data/ cover/ *.pprof +*.go.yaml +*.test +*.jpg +__pycache__/ diff --git a/internal/app/app.go b/internal/app/app.go index c46bef0..056d4d5 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -35,7 +35,7 @@ func Run(bc *conf.Bootstrap) { go setupZLM(ctx, bc.ConfigDir) // 如果需要执行表迁移,递增此版本号和表更新说明 - versionapi.DBVersion = "0.0.17" + versionapi.DBVersion = "0.0.18" versionapi.DBRemark = "onvif device support" handler, cleanUp, err := wireApp(bc, log) diff --git a/internal/core/config/model.go b/internal/core/config/model.go index 86e847c..0f2de7e 100755 --- a/internal/core/config/model.go +++ b/internal/core/config/model.go @@ -9,7 +9,7 @@ type Ext struct { } // Scan implements orm.Scaner. -func (i *Ext) Scan(input interface{}) error { +func (i *Ext) Scan(input any) error { return orm.JSONUnmarshal(input, i) } diff --git a/internal/core/ipc/model.go b/internal/core/ipc/model.go index b46b36a..ddc6f2c 100755 --- a/internal/core/ipc/model.go +++ b/internal/core/ipc/model.go @@ -38,13 +38,21 @@ type DeviceExt struct { Firmware string `json:"firmware"` // 固件版本 Name string `json:"name"` // 设备名 GBVersion string `json:"gb_version"` // GB版本 + Zones []Zone `json:"zones"` // 区域 } // Scan implements orm.Scaner. -func (i *DeviceExt) Scan(input interface{}) error { +func (i *DeviceExt) Scan(input any) error { return orm.JSONUnmarshal(input, i) } func (i DeviceExt) Value() (driver.Value, error) { return json.Marshal(i) } + +type Zone struct { + Name string `json:"name"` // 区域名称 + Coordinates []float32 `json:"coordinates"` // 坐标 + Color string `json:"color"` // 颜色,支持 hex 颜色值,如 #FF0000 + Labels []string `json:"labels"` // 标签 +} diff --git a/internal/core/ipc/port/protocol.go b/internal/core/ipc/port/protocol.go index 5a708a1..b79710d 100644 --- a/internal/core/ipc/port/protocol.go +++ b/internal/core/ipc/port/protocol.go @@ -6,15 +6,10 @@ import ( // Device 设备接口 // 注意: 适配器实现时,参数类型为 *ipc.Device,满足此接口 -type Device interface { - // 这里不定义任何方法,让所有类型都能满足 - // 适配器实现时直接使用 *ipc.Device 类型 -} +type Device any // Channel 通道接口 -type Channel interface { - // 同上 -} +type Channel any // Protocol 协议抽象接口(端口) // diff --git a/internal/core/sms/driver_zlm.go b/internal/core/sms/driver_zlm.go index 6a5f5e2..caeb679 100644 --- a/internal/core/sms/driver_zlm.go +++ b/internal/core/sms/driver_zlm.go @@ -83,32 +83,32 @@ func (d *ZLMDriver) Setup(ctx context.Context, ms *MediaServer, webhookURL strin // 构造配置请求 req := zlm.SetServerConfigRequest{ - RtcExternIP: zlm.NewString(ms.IP), - GeneralMediaServerID: zlm.NewString(ms.ID), - HookEnable: zlm.NewString("1"), - HookOnFlowReport: zlm.NewString(""), - HookOnPlay: zlm.NewString(fmt.Sprintf("%s/on_play", webhookURL)), + RtcExternIP: new(ms.IP), + GeneralMediaServerID: new(ms.ID), + HookEnable: new("1"), + HookOnFlowReport: new(""), + HookOnPlay: new(fmt.Sprintf("%s/on_play", webhookURL)), - ProtocolEnableTs: zlm.NewString("0"), - ProtocolEnableFmp4: zlm.NewString("0"), - ProtocolEnableHls: zlm.NewString("0"), - ProtocolEnableHlsFmp4: zlm.NewString("1"), + ProtocolEnableTs: new("0"), + ProtocolEnableFmp4: new("0"), + ProtocolEnableHls: new("0"), + ProtocolEnableHlsFmp4: new("1"), - HookOnPublish: zlm.NewString(fmt.Sprintf("%s/on_publish", webhookURL)), - HookOnStreamNoneReader: zlm.NewString(fmt.Sprintf("%s/on_stream_none_reader", webhookURL)), - GeneralStreamNoneReaderDelayMS: zlm.NewString("30000"), - HookOnStreamNotFound: zlm.NewString(fmt.Sprintf("%s/on_stream_not_found", webhookURL)), - HookOnRecordTs: zlm.NewString(""), - HookOnRtspAuth: zlm.NewString(""), - HookOnRtspRealm: zlm.NewString(""), - HookOnShellLogin: zlm.NewString(""), - HookOnStreamChanged: zlm.NewString(fmt.Sprintf("%s/on_stream_changed", webhookURL)), - HookOnServerKeepalive: zlm.NewString(fmt.Sprintf("%s/on_server_keepalive", webhookURL)), - HookTimeoutSec: zlm.NewString("10"), - HookAliveInterval: zlm.NewString(fmt.Sprint(ms.HookAliveInterval)), - ProtocolContinuePushMs: zlm.NewString("3000"), + HookOnPublish: new(fmt.Sprintf("%s/on_publish", webhookURL)), + HookOnStreamNoneReader: new(fmt.Sprintf("%s/on_stream_none_reader", webhookURL)), + GeneralStreamNoneReaderDelayMS: new("30000"), + HookOnStreamNotFound: new(fmt.Sprintf("%s/on_stream_not_found", webhookURL)), + HookOnRecordTs: new(""), + HookOnRtspAuth: new(""), + HookOnRtspRealm: new(""), + HookOnShellLogin: new(""), + HookOnStreamChanged: new(fmt.Sprintf("%s/on_stream_changed", webhookURL)), + HookOnServerKeepalive: new(fmt.Sprintf("%s/on_server_keepalive", webhookURL)), + HookTimeoutSec: new("10"), + HookAliveInterval: new(fmt.Sprint(ms.HookAliveInterval)), + ProtocolContinuePushMs: new("3000"), RtpProxyPortRange: &ms.RTPPortRange, - FfmpegLog: zlm.NewString("./fflogs/ffmpeg.log"), + FfmpegLog: new("./fflogs/ffmpeg.log"), } resp, err := engine.SetServerConfig(&req) @@ -147,12 +147,12 @@ func (d *ZLMDriver) AddStreamProxy(ctx context.Context, ms *MediaServer, req *Ad RTPType: req.RTPType, RetryCount: 3, TimeoutSec: PullTimeoutMs / 1000, - EnableHLSFMP4: zlm.NewBool(true), - EnableAudio: zlm.NewBool(true), - EnableRTSP: zlm.NewBool(true), - EnableRTMP: zlm.NewBool(true), - AddMuteAudio: zlm.NewBool(true), - AutoClose: zlm.NewBool(true), + EnableHLSFMP4: new(true), + EnableAudio: new(true), + EnableRTSP: new(true), + EnableRTMP: new(true), + AddMuteAudio: new(true), + AutoClose: new(true), }) } diff --git a/internal/core/sms/model.go b/internal/core/sms/model.go index ea63f12..6b9de8a 100755 --- a/internal/core/sms/model.go +++ b/internal/core/sms/model.go @@ -24,7 +24,7 @@ type MediaServerPorts struct { } // Scan implements orm.Scaner. -func (i *MediaServerPorts) Scan(input interface{}) error { +func (i *MediaServerPorts) Scan(input any) error { return orm.JSONUnmarshal(input, i) } diff --git a/internal/core/sms/service.go b/internal/core/sms/service.go index 949e958..c6bf313 100644 --- a/internal/core/sms/service.go +++ b/internal/core/sms/service.go @@ -1,3 +1,3 @@ package sms -type NodeServicer interface{} +type NodeServicer any diff --git a/internal/web/api/ipc.go b/internal/web/api/ipc.go index d587fb5..d946203 100755 --- a/internal/web/api/ipc.go +++ b/internal/web/api/ipc.go @@ -103,6 +103,8 @@ func registerGB28181(g gin.IRouter, api IPCAPI, handler ...gin.HandlerFunc) { group.POST("/:id/play", web.WrapH(api.play)) // 播放(所有协议) group.POST("/:id/snapshot", web.WrapH(api.refreshSnapshot)) // 图像抓拍(所有协议) group.GET("/:id/snapshot", api.getSnapshot) // 获取图像(所有协议) + group.POST("/:id/zones", web.WrapH(api.addZone)) // 添加区域(所有协议) + group.GET("/:id/zones", web.WrapH(api.getZones)) // 获取区域(所有协议) } } @@ -285,6 +287,7 @@ func (a IPCAPI) play(c *gin.Context, _ *struct{}) (*playOutput, error) { for range 2 { time.Sleep(3 * time.Second) rtsp := fmt.Sprintf("rtsp://%s:%d/%s", "127.0.0.1", svr.Ports.RTSP, stream) + "?" + session + body, err := a.uc.SMSAPI.smsCore.GetSnapshot(svr, sms.GetSnapRequest{ GetSnapRequest: zlm.GetSnapRequest{ URL: rtsp, @@ -362,6 +365,20 @@ func (a IPCAPI) refreshSnapshot(c *gin.Context, in *refreshSnapshotInput) (any, return gin.H{"link": fmt.Sprintf("%s/channels/%s/snapshot?token=%s", prefix, channelID, token)}, nil } +func (a IPCAPI) addZone(c *gin.Context, in *ipc.AddZoneInput) (gin.H, error) { + channelID := c.Param("id") + if len(in.Labels) == 0 { + in.Labels = []string{"person", "car", "cat", "dog"} + } + zones, err := a.ipc.AddZone(c.Request.Context(), in, channelID) + return gin.H{"items": zones}, err +} + +func (a IPCAPI) getZones(c *gin.Context, _ *struct{}) (any, error) { + channelID := c.Param("id") + return a.ipc.GetZones(c.Request.Context(), channelID) +} + func (a IPCAPI) getSnapshot(c *gin.Context) { channelID := c.Param("id") body, err := readCover(a.uc.Conf.ConfigDir, channelID) diff --git a/internal/web/api/user.go b/internal/web/api/user.go index 269618a..1a88483 100644 --- a/internal/web/api/user.go +++ b/internal/web/api/user.go @@ -1,6 +1,15 @@ package api import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "sync" "time" "github.com/gin-gonic/gin" @@ -10,25 +19,85 @@ import ( ) type UserAPI struct { - conf *conf.Bootstrap + conf *conf.Bootstrap + secret *Secret +} + +type Secret struct { + privateKey *rsa.PrivateKey + publicKey *rsa.PublicKey + expiredAt time.Time + m sync.RWMutex +} + +// TODO: 有概率存在过期导致登录解密识别 +func (s *Secret) GetOrCreatePublicKey() (*rsa.PublicKey, error) { + s.m.RLock() + if s.publicKey != nil && time.Now().Before(s.expiredAt) { + s.m.RUnlock() + return s.publicKey, nil + } + s.m.RUnlock() + + s.m.Lock() + defer s.m.Unlock() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + s.privateKey = privateKey + s.publicKey = &privateKey.PublicKey + s.expiredAt = time.Now().Add(1 * time.Hour) + return s.publicKey, nil +} + +func (s *Secret) MarshalPKIXPublicKey(key *rsa.PublicKey) []byte { + publicKeyBytes, _ := x509.MarshalPKIXPublicKey(key) + return pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyBytes, + }) +} + +func (s *Secret) Decrypt(ciphertext string) ([]byte, error) { + s.m.RLock() + pri := s.privateKey + s.m.RUnlock() + if pri == nil { + return nil, fmt.Errorf("请刷新页面后重试") + } + data, err := base64.StdEncoding.DecodeString(ciphertext) + if err != nil { + return nil, err + } + plaintext, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, pri, data, nil) + if err != nil { + return nil, err + } + return plaintext, nil } func NewUserAPI(conf *conf.Bootstrap) UserAPI { return UserAPI{ - conf: conf, + conf: conf, + secret: &Secret{}, } } func RegisterUser(r gin.IRouter, api UserAPI, mid ...gin.HandlerFunc) { - group := r.Group("/user") - group.POST("/login", web.WrapH(api.login)) - group.PUT("/user", web.WrapHs(api.updateCredentials, mid...)...) + r.POST("/login", web.WrapH(api.login)) + r.GET("/login/key", web.WrapH(api.getPublicKey)) + + group := r.Group("/users", mid...) + group.PUT("", web.WrapHs(api.updateCredentials, mid...)...) } // 登录请求结构体 type loginInput struct { - Username string `json:"username" binding:"required"` - Password string `json:"password" binding:"required"` + // Username string `json:"username" binding:"required"` + // Password string `json:"password" binding:"required"` + Data string `json:"data" binding:"required"` } // 登录响应结构体 @@ -39,16 +108,28 @@ type loginOutput struct { // 登录接口 func (api UserAPI) login(_ *gin.Context, in *loginInput) (*loginOutput, error) { + body, err := api.secret.Decrypt(in.Data) + if err != nil { + return nil, reason.ErrServer.SetMsg(err.Error()) + } + var credentials struct { + Username string `json:"username"` + Password string `json:"password"` + } + if err := json.Unmarshal(body, &credentials); err != nil { + return nil, reason.ErrServer.SetMsg(err.Error()) + } + // 验证用户名和密码 if api.conf.Server.Username == "" && api.conf.Server.Password == "" { api.conf.Server.Username = "admin" api.conf.Server.Password = "admin" } - if in.Username != api.conf.Server.Username || in.Password != api.conf.Server.Password { + if credentials.Username != api.conf.Server.Username || credentials.Password != api.conf.Server.Password { return nil, reason.ErrNameOrPasswd } - data := web.NewClaimsData().SetUsername(in.Username) + data := web.NewClaimsData().SetUsername(credentials.Username) token, err := web.NewToken(data, api.conf.Server.HTTP.JwtSecret, web.WithExpiresAt(time.Now().Add(3*24*time.Hour))) if err != nil { @@ -57,7 +138,7 @@ func (api UserAPI) login(_ *gin.Context, in *loginInput) (*loginOutput, error) { return &loginOutput{ Token: token, - User: in.Username, + User: credentials.Username, }, nil } @@ -80,3 +161,12 @@ func (api UserAPI) updateCredentials(_ *gin.Context, in *updateCredentialsInput) return gin.H{"msg": "凭据更新成功"}, nil } + +func (api UserAPI) getPublicKey(_ *gin.Context, _ *struct{}) (gin.H, error) { + publicKey, err := api.secret.GetOrCreatePublicKey() + if err != nil { + return nil, reason.ErrServer.SetMsg(err.Error()) + } + result := api.secret.MarshalPKIXPublicKey(publicKey) + return gin.H{"key": base64.StdEncoding.EncodeToString(result)}, nil +} diff --git a/pkg/gbs/files.go b/pkg/gbs/files.go index bf815aa..aa3d9a1 100644 --- a/pkg/gbs/files.go +++ b/pkg/gbs/files.go @@ -50,7 +50,7 @@ type apiRecordItem struct { id string } -func (ri *apiRecordItem) Start() (string, interface{}) { +func (ri *apiRecordItem) Start() (string, any) { if config.Record.Recordmax <= 0 { return m.StatusSysERR, errors.New("config record max time invalid.") } @@ -86,7 +86,7 @@ func (ri *apiRecordItem) Start() (string, interface{}) { return m.StatusSucc, ri.id } -func (ri *apiRecordItem) Stop() (string, interface{}) { +func (ri *apiRecordItem) Stop() (string, any) { err := zlmStopRecord(ri.params) if err != nil { return m.StatusSysERR, "" diff --git a/pkg/gbs/notify.go b/pkg/gbs/notify.go index 28ffcac..8cb2aff 100644 --- a/pkg/gbs/notify.go +++ b/pkg/gbs/notify.go @@ -22,8 +22,8 @@ const ( // Notify 消息通知结构 type Notify struct { - Method string `json:"method"` - Data interface{} `json:"data"` + Method string `json:"method"` + Data any `json:"data"` } func notify(data *Notify) { @@ -45,7 +45,7 @@ func notify(data *Notify) { func notifyDevicesAcitve(id, status string) *Notify { return &Notify{ Method: NotifyMethodDevicesActive, - Data: map[string]interface{}{ + Data: map[string]any{ "deviceid": id, "status": status, "time": time.Now().Unix(), @@ -64,7 +64,7 @@ func notifyDevicesRegister(u Devices) *Notify { func notifyChannelsActive(d Channels) *Notify { return &Notify{ Method: NotifyMethodChannelsActive, - Data: map[string]interface{}{ + Data: map[string]any{ "channelid": d.ChannelID, "status": d.Status, "time": time.Now().Unix(), @@ -73,7 +73,7 @@ func notifyChannelsActive(d Channels) *Notify { } func notifyRecordStop(url string, req url.Values) *Notify { - d := map[string]interface{}{ + d := map[string]any{ "url": fmt.Sprintf("%s/%s", config.Media.HTTP, url), } for k, v := range req { diff --git a/pkg/gbs/sip/auth.go b/pkg/gbs/sip/auth.go index 789398f..329242c 100644 --- a/pkg/gbs/sip/auth.go +++ b/pkg/gbs/sip/auth.go @@ -50,7 +50,7 @@ func AuthFromValue(value string) *Authorization { case "response": auth.response = match[2] case "qop": - for _, v := range strings.Split(match[2], ",") { + for v := range strings.SplitSeq(match[2], ",") { v = strings.Trim(v, " ") if v == "auth" || v == "auth-int" { auth.qop = "auth" diff --git a/pkg/gbs/sip/header.go b/pkg/gbs/sip/header.go index f33a839..70c5746 100644 --- a/pkg/gbs/sip/header.go +++ b/pkg/gbs/sip/header.go @@ -232,7 +232,7 @@ type Params interface { Get(key string) (MaybeString, bool) Add(key string, val MaybeString) Params Clone() Params - Equals(params interface{}) bool + Equals(params any) bool ToString(sep uint8) string String() string Length() int @@ -297,7 +297,7 @@ type Header interface { // Clone returns copy of header struct. Clone() Header String() string - Equals(other interface{}) bool + Equals(other any) bool } // headers is a struct with methods to work with SIP headers. @@ -629,7 +629,7 @@ func (params *headerParams) Length() int { // Check if two maps of parameters are equal in the sense of having the same keys with the same values. // This does not rely on any ordering of the keys of the map in memory. -func (params *headerParams) Equals(other interface{}) bool { +func (params *headerParams) Equals(other any) bool { q, ok := other.(*headerParams) if !ok { return false @@ -679,7 +679,7 @@ func (contentLength *ContentLength) Name() string { return "Content-Length" } func (contentLength *ContentLength) Clone() Header { return contentLength } // Equals Equals -func (contentLength *ContentLength) Equals(other interface{}) bool { +func (contentLength *ContentLength) Equals(other any) bool { if h, ok := other.(ContentLength); ok { if contentLength == nil { return false @@ -737,7 +737,7 @@ func (via ViaHeader) Clone() Header { } // Equals Equals -func (via ViaHeader) Equals(other interface{}) bool { +func (via ViaHeader) Equals(other any) bool { if h, ok := other.(ViaHeader); ok { if len(via) != len(h) { return false @@ -827,7 +827,7 @@ func (hop *ViaHop) Clone() *ViaHop { } // Equals Equals -func (hop *ViaHop) Equals(other interface{}) bool { +func (hop *ViaHop) Equals(other any) bool { if h, ok := other.(*ViaHop); ok { if hop == h { return true @@ -874,7 +874,7 @@ func (callId *CallID) Clone() Header { } // Equals Equals -func (callId *CallID) Equals(other interface{}) bool { +func (callId *CallID) Equals(other any) bool { if h, ok := other.(CallID); ok { if callId == nil { return false @@ -925,7 +925,7 @@ func (cseq *CSeq) Clone() Header { } // Equals Equals -func (cseq *CSeq) Equals(other interface{}) bool { +func (cseq *CSeq) Equals(other any) bool { if h, ok := other.(*CSeq); ok { if cseq == h { return true @@ -991,7 +991,7 @@ func (to *ToHeader) Clone() Header { } // Equals Equals -func (to *ToHeader) Equals(other interface{}) bool { +func (to *ToHeader) Equals(other any) bool { if h, ok := other.(*ToHeader); ok { if to == h { return true @@ -1087,7 +1087,7 @@ func (from *FromHeader) Clone() Header { } // Equals Equals -func (from *FromHeader) Equals(other interface{}) bool { +func (from *FromHeader) Equals(other any) bool { if h, ok := other.(*FromHeader); ok { if from == h { return true @@ -1142,7 +1142,7 @@ func (ct *ContentType) Name() string { return "Content-Type" } func (ct *ContentType) Clone() Header { return ct } // Equals Equals -func (ct *ContentType) Equals(other interface{}) bool { +func (ct *ContentType) Equals(other any) bool { if h, ok := other.(ContentType); ok { if ct == nil { return false @@ -1215,7 +1215,7 @@ func (contact *ContactHeader) Clone() Header { } // Equals Equals -func (contact *ContactHeader) Equals(other interface{}) bool { +func (contact *ContactHeader) Equals(other any) bool { if h, ok := other.(*ContactHeader); ok { if contact == h { return true @@ -1272,7 +1272,7 @@ func (maxForwards *MaxForwards) Name() string { return "Max-Forwards" } func (maxForwards *MaxForwards) Clone() Header { return maxForwards } // Equals Equals -func (maxForwards *MaxForwards) Equals(other interface{}) bool { +func (maxForwards *MaxForwards) Equals(other any) bool { if h, ok := other.(MaxForwards); ok { if maxForwards == nil { return false @@ -1310,7 +1310,7 @@ func (expires *Expires) Name() string { return "Expires" } func (expires *Expires) Clone() Header { return expires } // Equals Equals -func (expires *Expires) Equals(other interface{}) bool { +func (expires *Expires) Equals(other any) bool { if h, ok := other.(Expires); ok { if expires == nil { return false @@ -1348,7 +1348,7 @@ func (ua *UserAgentHeader) Name() string { return "User-Agent" } func (ua *UserAgentHeader) Clone() Header { return ua } // Equals equals -func (ua *UserAgentHeader) Equals(other interface{}) bool { +func (ua *UserAgentHeader) Equals(other any) bool { if h, ok := other.(UserAgentHeader); ok { if ua == nil { return false @@ -1403,7 +1403,7 @@ func (allow AllowHeader) Clone() Header { } // Equals equals -func (allow AllowHeader) Equals(other interface{}) bool { +func (allow AllowHeader) Equals(other any) bool { if h, ok := other.(AllowHeader); ok { if len(allow) != len(h) { return false @@ -1435,7 +1435,7 @@ func (ct *Accept) Name() string { return "Accept" } func (ct *Accept) Clone() Header { return ct } // Equals Equals -func (ct *Accept) Equals(other interface{}) bool { +func (ct *Accept) Equals(other any) bool { if h, ok := other.(Accept); ok { if ct == nil { return false @@ -1496,7 +1496,7 @@ func (route *RouteHeader) Clone() Header { } // Equals Equals -func (route *RouteHeader) Equals(other interface{}) bool { +func (route *RouteHeader) Equals(other any) bool { if h, ok := other.(*RouteHeader); ok { if route == h { return true @@ -1556,7 +1556,7 @@ func (route *RecordRouteHeader) Clone() Header { } // Equals Equals -func (route *RecordRouteHeader) Equals(other interface{}) bool { +func (route *RecordRouteHeader) Equals(other any) bool { if h, ok := other.(*RecordRouteHeader); ok { if route == h { return true @@ -1605,7 +1605,7 @@ func (support *SupportedHeader) Clone() Header { } // Equals Equals -func (support *SupportedHeader) Equals(other interface{}) bool { +func (support *SupportedHeader) Equals(other any) bool { if h, ok := other.(*SupportedHeader); ok { if support == h { return true @@ -1664,7 +1664,7 @@ func (header *GenericHeader) Clone() Header { } // Equals Equals -func (header *GenericHeader) Equals(other interface{}) bool { +func (header *GenericHeader) Equals(other any) bool { if h, ok := other.(*GenericHeader); ok { if header == h { return true @@ -1691,7 +1691,7 @@ func (ct XGBVer) Name() string { return "X-GB-Ver" } func (ct XGBVer) Clone() Header { return &ct } // Equals Equals -func (ct *XGBVer) Equals(other interface{}) bool { +func (ct *XGBVer) Equals(other any) bool { h, ok := other.(XGBVer) if !ok { return false diff --git a/pkg/gbs/sip/message.go b/pkg/gbs/sip/message.go index e6dfa8f..b010bc5 100644 --- a/pkg/gbs/sip/message.go +++ b/pkg/gbs/sip/message.go @@ -303,7 +303,7 @@ func (uri *URI) Clone() *URI { // Equals Determine if the SIP URI is equal to the specified URI according to the rules laid down in RFC 3261 s. 19.1.4. // TODO: The Equals method is not currently RFC-compliant; fix this! -func (uri *URI) Equals(val interface{}) bool { +func (uri *URI) Equals(val any) bool { otherPtr, ok := val.(*URI) if !ok { return false diff --git a/pkg/gbs/sip/models.go b/pkg/gbs/sip/models.go index 07da670..f1bc856 100644 --- a/pkg/gbs/sip/models.go +++ b/pkg/gbs/sip/models.go @@ -38,7 +38,7 @@ func (port *Port) String() string { } // Equals Equals -func (port *Port) Equals(other interface{}) bool { +func (port *Port) Equals(other any) bool { if p, ok := other.(*Port); ok { return Uint16PtrEq((*uint16)(port), (*uint16)(p)) } @@ -49,7 +49,7 @@ func (port *Port) Equals(other interface{}) bool { // MaybeString wrapper type MaybeString interface { String() string - Equals(other interface{}) bool + Equals(other any) bool } // String string @@ -62,7 +62,7 @@ func (str String) String() string { } // Equals Equals -func (str String) Equals(other interface{}) bool { +func (str String) Equals(other any) bool { if v, ok := other.(String); ok { return str.Str == v.Str } @@ -109,17 +109,17 @@ var ( // GetDeviceInfoXML 获取设备详情指令 func GetDeviceInfoXML(id string) []byte { - return []byte(fmt.Sprintf(DeviceInfoXML, RandInt(100000, 999999), id)) + return fmt.Appendf(nil, DeviceInfoXML, RandInt(100000, 999999), id) } // GetCatalogXML 获取NVR下设备列表指令 func GetCatalogXML(id string) []byte { - return []byte(fmt.Sprintf(CatalogXML, RandInt(100000, 999999), id)) + return fmt.Appendf(nil, CatalogXML, RandInt(100000, 999999), id) } // GetRecordInfoXML 获取录像文件列表指令 func GetRecordInfoXML(id string, sceqNo int, start, end int64) []byte { - return []byte(fmt.Sprintf(RecordInfoXML, sceqNo, id, time.Unix(start, 0).Format("2006-01-02T15:04:05"), time.Unix(end, 0).Format("2006-01-02T15:04:05"))) + return fmt.Appendf(nil, RecordInfoXML, sceqNo, id, time.Unix(start, 0).Format("2006-01-02T15:04:05"), time.Unix(end, 0).Format("2006-01-02T15:04:05")) } // RFC3261BranchMagicCookie RFC3261BranchMagicCookie diff --git a/pkg/gbs/sip/parser.go b/pkg/gbs/sip/parser.go index ab6cb1e..2f3f698 100644 --- a/pkg/gbs/sip/parser.go +++ b/pkg/gbs/sip/parser.go @@ -346,8 +346,8 @@ func parseAccept(headerName string, headerText string) (headers []Header, err er func parseAllow(headerName string, headerText string) (headers []Header, err error) { allow := make(AllowHeader, 0) - methods := strings.Split(headerText, ",") - for _, method := range methods { + methods := strings.SplitSeq(headerText, ",") + for method := range methods { allow = append(allow, string(strings.TrimSpace(method))) } headers = []Header{allow} @@ -358,8 +358,8 @@ func parseAllow(headerName string, headerText string) (headers []Header, err err func parseSupported(headerName string, headerText string) (headers []Header, err error) { var supported SupportedHeader supported.Options = make([]string, 0) - extensions := strings.Split(headerText, ",") - for _, ext := range extensions { + extensions := strings.SplitSeq(headerText, ",") + for ext := range extensions { supported.Options = append(supported.Options, strings.TrimSpace(ext)) } headers = []Header{&supported} @@ -411,6 +411,11 @@ func ParseAddressValues(addresses string) ( var params Params displayName, uri, params, err = ParseAddressValue(addresses[prevIdx:idx]) if err != nil { + // sip:1678@80.79.5.134;expires=3600 + arr := strings.Split(addresses, "@") + if len(arr) == 2 && len(arr[0]) < 20 { + err = nil + } return } prevIdx = idx + 1 @@ -782,19 +787,19 @@ func ParseURI(uriStr string) (uri *URI, err error) { return } - colonIdx := strings.Index(uriStr, ":") - if colonIdx == -1 { + before, _, ok := strings.Cut(uriStr, ":") + if !ok { err = fmt.Errorf("no ':' in URI %s", uriStr) return } - switch strings.ToLower(uriStr[:colonIdx]) { + switch strings.ToLower(before) { case "sip", "sips": var sipURI URI sipURI, err = ParseSipURI(uriStr) uri = &sipURI default: - err = fmt.Errorf("unsupported URI schema %s", uriStr[:colonIdx]) + err = fmt.Errorf("unsupported URI schema %s", before) } return @@ -908,8 +913,8 @@ func ParseSipURI(uriStr string) (uri URI, err error) { // The port may or may not be present, so we represent it with a *uint16, // and return 'nil' if no port was present. func ParseHostPort(rawText string) (host string, port *Port, err error) { - colonIdx := strings.Index(rawText, ":") - if colonIdx == -1 { + before, after, ok := strings.Cut(rawText, ":") + if !ok { host = rawText return } @@ -917,8 +922,8 @@ func ParseHostPort(rawText string) (host string, port *Port, err error) { // Surely there must be a better way..! var portRaw64 uint64 var portRaw16 uint16 - host = rawText[:colonIdx] - portRaw64, err = strconv.ParseUint(rawText[colonIdx+1:], 10, 16) + host = before + portRaw64, err = strconv.ParseUint(after, 10, 16) portRaw16 = uint16(portRaw64) port = (*Port)(&portRaw16) @@ -1079,15 +1084,15 @@ parseLoop: func ParseHeader(headerText string) (headers []Header, err error) { headers = make([]Header, 0) - colonIdx := strings.Index(headerText, ":") - if colonIdx == -1 { + before, after, ok := strings.Cut(headerText, ":") + if !ok { err = fmt.Errorf("field name with no value in header: %s", headerText) return } - fieldName := strings.TrimSpace(headerText[:colonIdx]) + fieldName := strings.TrimSpace(before) lowerFieldName := strings.ToLower(fieldName) - fieldText := strings.TrimSpace(headerText[colonIdx+1:]) + fieldText := strings.TrimSpace(after) if headerParser, ok := defaultHeaderParsers[lowerFieldName]; ok { // We have a registered parser for this header type - use it. headers, err = headerParser(lowerFieldName, fieldText) diff --git a/pkg/gbs/sip/response.go b/pkg/gbs/sip/response.go index 3496ecb..6b77616 100644 --- a/pkg/gbs/sip/response.go +++ b/pkg/gbs/sip/response.go @@ -39,12 +39,11 @@ func NewResponseFromRequest( if _, ok := to.Params.Get("tag"); !ok { to.Params.Add("tag", String{Str: RandString(32)}) } + res.AppendHeader(to) } CopyHeaders("CSeq", req, res) CopyHeaders("Call-ID", req, res) - res.AppendHeader(to) - if statusCode == 100 { CopyHeaders("Timestamp", req, res) } diff --git a/pkg/gbs/sip/sip_test.go b/pkg/gbs/sip/sip_test.go new file mode 100644 index 0000000..fd30acd --- /dev/null +++ b/pkg/gbs/sip/sip_test.go @@ -0,0 +1,14 @@ +package sip + +import ( + "testing" +) + +func TestParseHeader(t *testing.T) { + const s = `Contact: sip:1678@80.79.5.134;expires=3600` + h, err := ParseHeader(s) + if err != nil { + t.Fatalf("ParseHeader failed: %v", err) + } + t.Log(h) +} diff --git a/pkg/gbs/sip/utils.go b/pkg/gbs/sip/utils.go index 675ae76..051f9f4 100644 --- a/pkg/gbs/sip/utils.go +++ b/pkg/gbs/sip/utils.go @@ -8,7 +8,7 @@ import ( "errors" "fmt" "io" - "io/ioutil" + "log/slog" "math/rand" "net" @@ -24,7 +24,7 @@ import ( // Error Error type Error struct { err error - params []interface{} + params []any } func (err *Error) Error() string { @@ -39,12 +39,12 @@ func (err *Error) Error() string { } // NewError NewError -func NewError(err error, params ...interface{}) error { +func NewError(err error, params ...any) error { return &Error{err, params} } // JSONEncode JSONEncode -func JSONEncode(data interface{}) []byte { +func JSONEncode(data any) []byte { d, err := json.Marshal(data) if err != nil { slog.Error("JSONEncode error:", "err", err) @@ -53,7 +53,7 @@ func JSONEncode(data interface{}) []byte { } // JSONDecode JSONDecode -func JSONDecode(data []byte, obj interface{}) error { +func JSONDecode(data []byte, obj any) error { return json.Unmarshal(data, obj) } @@ -131,7 +131,7 @@ func PostRequest(url string, bodyType string, body io.Reader) ([]byte, error) { } defer resp.Body.Close() - respbody, err := ioutil.ReadAll(resp.Body) + respbody, err := io.ReadAll(resp.Body) if err != nil { return nil, err } @@ -139,7 +139,7 @@ func PostRequest(url string, bodyType string, body io.Reader) ([]byte, error) { } // PostJSONRequest PostJSONRequest -func PostJSONRequest(url string, data interface{}) ([]byte, error) { +func PostJSONRequest(url string, data any) ([]byte, error) { bytesData, err := json.Marshal(data) if err != nil { return nil, err @@ -164,7 +164,7 @@ func GetRequest(url string) ([]byte, error) { } // XMLDecode 解码 xml -func XMLDecode(data []byte, v interface{}) error { +func XMLDecode(data []byte, v any) error { decoder := xml.NewDecoder(bytes.NewReader(data)) decoder.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) { if utf8.Valid(data) { @@ -181,7 +181,7 @@ func XMLDecode(data []byte, v interface{}) error { return xmlDecode([]byte(value), v) } -func xmlDecode(data []byte, v interface{}) error { +func xmlDecode(data []byte, v any) error { decoder := xml.NewDecoder(bytes.NewReader(data)) decoder.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) { if utf8.Valid(data) { diff --git a/pkg/gbs/zlm.go b/pkg/gbs/zlm.go index 4e2885b..2768b78 100644 --- a/pkg/gbs/zlm.go +++ b/pkg/gbs/zlm.go @@ -107,7 +107,7 @@ func zlmStartRecord(values url.Values) error { if err != nil { return err } - tmp := map[string]interface{}{} + tmp := map[string]any{} err = sip.JSONDecode(body, &tmp) if err != nil { return err @@ -124,7 +124,7 @@ func zlmStopRecord(values url.Values) error { if err != nil { return err } - tmp := map[string]interface{}{} + tmp := map[string]any{} err = sip.JSONDecode(body, &tmp) if err != nil { return err