From 6c29e525e5d185e808396986efa931ddb4b38683 Mon Sep 17 00:00:00 2001 From: langhuihui <178529795@qq.com> Date: Sat, 12 Aug 2023 19:22:03 +0800 Subject: [PATCH] feat: api can return json fomart now feat: pull on subscribe event use InvitePublish to instead *Stream fix: pull remote stream publish cause bugs fix: Concurrency MarshalJSON Tracks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit desc: - API 现在可以返回{"data":xx,"code":0,"msg":"ok"}格式 - 按需拉流的事件的类型从*Stream修改为InvitePublish - 远程拉流频繁重连后导致多路远程流同时写入同一个track - 在并发序列化Tracks时导致错误,通过加lock解决 --- config/remote.go | 2 +- config/types.go | 2 +- events.go | 12 +++++ http.go | 123 ++++++++++++++++++------------------------- io.go | 37 ++++++------- lang/zh.yaml | 8 ++- main.go | 12 ++--- plugin.go | 133 +++++++++++------------------------------------ publisher.go | 20 ------- puller.go | 85 ++++++++++++++++++++++++++++++ pusher.go | 68 ++++++++++++++++++++++++ stream.go | 26 +++++++-- subscriber.go | 19 +------ track/h264.go | 2 +- track/h265.go | 2 +- util/socket.go | 121 +++++++++++++++++++++++++++++++++++------- 16 files changed, 405 insertions(+), 267 deletions(-) create mode 100644 puller.go create mode 100644 pusher.go diff --git a/config/remote.go b/config/remote.go index 15ee064..04b9966 100644 --- a/config/remote.go +++ b/config/remote.go @@ -67,7 +67,7 @@ func (cfg *Engine) Remote(ctx context.Context) (wasConnected bool, err error) { NextProtos: []string{"monibuca"}, } - conn, err := quic.DialAddr(cfg.Server, tlsConf, &quic.Config{ + conn, err := quic.DialAddr(ctx, cfg.Server, tlsConf, &quic.Config{ KeepAlivePeriod: time.Second * 10, EnableDatagrams: true, }) diff --git a/config/types.go b/config/types.go index b594d55..e18f716 100755 --- a/config/types.go +++ b/config/types.go @@ -155,7 +155,7 @@ var Global *Engine func (cfg *Engine) InitDefaultHttp() { Global = cfg - cfg.HTTP.mux = http.DefaultServeMux + cfg.HTTP.mux = http.NewServeMux() cfg.HTTP.ListenAddrTLS = ":8443" cfg.HTTP.ListenAddr = ":8080" } diff --git a/events.go b/events.go index 66c81da..3eb92ef 100644 --- a/events.go +++ b/events.go @@ -78,3 +78,15 @@ type UnsubscribeEvent struct { type AddTrackEvent struct { Event[common.Track] } + +// InvitePublishEvent 邀请推流事件(按需拉流) +type InvitePublish struct { + Event[string] +} + +func TryInvitePublish(streamPath string) { + s := Streams.Get(streamPath) + if s == nil || s.Publisher == nil { + EventBus <- InvitePublish{Event: CreateEvent(streamPath)} + } +} diff --git a/http.go b/http.go index 72ceb81..70177ec 100644 --- a/http.go +++ b/http.go @@ -2,13 +2,11 @@ package engine import ( "encoding/json" - "fmt" "io" "net/http" "os" "strconv" "strings" - "time" "go.uber.org/zap" "gopkg.in/yaml.v3" @@ -26,11 +24,6 @@ type GlobalConfig struct { config.Engine } -func ShouldYaml(r *http.Request) bool { - format := r.URL.Query().Get("format") - return r.URL.Query().Get("yaml") != "" || format == "yaml" -} - func (conf *GlobalConfig) ServeHTTP(rw http.ResponseWriter, r *http.Request) { if r.URL.Path == "/favicon.ico" { http.ServeFile(rw, r, "favicon.ico") @@ -43,56 +36,39 @@ func (conf *GlobalConfig) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } func (conf *GlobalConfig) API_summary(rw http.ResponseWriter, r *http.Request) { - y := ShouldYaml(r) - if y { - util.ReturnYaml(util.FetchValue(&summary), time.Second, rw, r) - } else { - util.ReturnJson(util.FetchValue(&summary), time.Second, rw, r) - } + util.ReturnValue(&summary, rw, r) } func (conf *GlobalConfig) API_plugins(rw http.ResponseWriter, r *http.Request) { - if ShouldYaml(r) { - if err := yaml.NewEncoder(rw).Encode(Plugins); err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - } - } else if err := json.NewEncoder(rw).Encode(Plugins); err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - } + util.ReturnValue(Plugins, rw, r) } func (conf *GlobalConfig) API_stream(rw http.ResponseWriter, r *http.Request) { if streamPath := r.URL.Query().Get("streamPath"); streamPath != "" { if s := Streams.Get(streamPath); s != nil { - if ShouldYaml(r) { - util.ReturnYaml(util.FetchValue(s), time.Second, rw, r) - } else { - util.ReturnJson(util.FetchValue(s), time.Second, rw, r) - } + util.ReturnValue(s, rw, r) } else { - http.Error(rw, NO_SUCH_STREAM, http.StatusNotFound) + util.ReturnError(util.APIErrorNoStream, NO_SUCH_STREAM, rw, r) } } else { - http.Error(rw, "no streamPath", http.StatusBadRequest) + util.ReturnError(util.APIErrorNoStream, "no streamPath", rw, r) } } func (conf *GlobalConfig) API_sysInfo(rw http.ResponseWriter, r *http.Request) { - if err := json.NewEncoder(rw).Encode(&SysInfo); err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - } + util.ReturnValue(&SysInfo, rw, r) } func (conf *GlobalConfig) API_closeStream(w http.ResponseWriter, r *http.Request) { if streamPath := r.URL.Query().Get("streamPath"); streamPath != "" { if s := Streams.Get(streamPath); s != nil { s.Close() - w.Write([]byte("ok")) + util.ReturnOK(w, r) } else { - http.Error(w, NO_SUCH_STREAM, http.StatusNotFound) + util.ReturnError(util.APIErrorNoStream, NO_SUCH_STREAM, w, r) } } else { - http.Error(w, "no streamPath", http.StatusBadRequest) + util.ReturnError(util.APIErrorNoStream, "no streamPath", w, r) } } @@ -104,27 +80,30 @@ func (conf *GlobalConfig) API_getConfig(w http.ResponseWriter, r *http.Request) if c, ok := Plugins[configName]; ok { p = c } else { - http.Error(w, NO_SUCH_CONIFG, http.StatusNotFound) + util.ReturnError(util.APIErrorNoConfig, NO_SUCH_CONIFG, w, r) return } } else { p = Engine } - if ShouldYaml(r) { + var data any + if q.Get("yaml") != "" { mm, err := yaml.Marshal(p.RawConfig) if err != nil { mm = []byte("") } - json.NewEncoder(w).Encode(struct { + data = struct { File string Modified string Merged string }{ p.Yaml, p.modifiedYaml, string(mm), - }) - } else if err := json.NewEncoder(w).Encode(p.RawConfig); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + } + + } else { + data = p.RawConfig } + util.ReturnValue(data, w, r) } // API_modifyConfig 修改并保存配置 @@ -136,28 +115,28 @@ func (conf *GlobalConfig) API_modifyConfig(w http.ResponseWriter, r *http.Reques if c, ok := Plugins[configName]; ok { p = c } else { - http.Error(w, NO_SUCH_CONIFG, http.StatusNotFound) + util.ReturnError(util.APIErrorNoConfig, NO_SUCH_CONIFG, w, r) return } } else { p = Engine } - if ShouldYaml(r) { + if q.Get("yaml") != "" { err = yaml.NewDecoder(r.Body).Decode(&p.Modified) } else { err = json.NewDecoder(r.Body).Decode(&p.Modified) } if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + util.ReturnError(util.APIErrorDecode, err.Error(), w, r) } else if err = p.Save(); err == nil { p.RawConfig.Assign(p.Modified) out, err := yaml.Marshal(p.Modified) if err == nil { p.modifiedYaml = string(out) } - w.Write([]byte("ok")) + util.ReturnOK(w, r) } else { - w.Write([]byte(err.Error())) + util.ReturnError(util.APIErrorSave, err.Error(), w, r) } } @@ -169,34 +148,34 @@ func (conf *GlobalConfig) API_updateConfig(w http.ResponseWriter, r *http.Reques if c, ok := Plugins[configName]; ok { p = c } else { - http.Error(w, NO_SUCH_CONIFG, http.StatusNotFound) + util.ReturnError(util.APIErrorNoConfig, NO_SUCH_CONIFG, w, r) return } } else { p = Engine } p.Update(p.Modified) - w.Write([]byte("ok")) + util.ReturnOK(w, r) } func (conf *GlobalConfig) API_list_pull(w http.ResponseWriter, r *http.Request) { - util.ReturnJson(func() (result []any) { + util.ReturnFetchValue(func() (result []any) { Pullers.Range(func(key, value any) bool { result = append(result, key) return true }) return - }, time.Second, w, r) + }, w, r) } func (conf *GlobalConfig) API_list_push(w http.ResponseWriter, r *http.Request) { - util.ReturnJson(func() (result []any) { + util.ReturnFetchValue(func() (result []any) { Pushers.Range(func(key, value any) bool { result = append(result, value) return true }) return - }, time.Second, w, r) + }, w, r) } func (conf *GlobalConfig) API_stop_push(w http.ResponseWriter, r *http.Request) { @@ -204,9 +183,9 @@ func (conf *GlobalConfig) API_stop_push(w http.ResponseWriter, r *http.Request) pusher, ok := Pushers.Load(q.Get("url")) if ok { pusher.(IPusher).Stop() - fmt.Fprintln(w, "ok") + util.ReturnOK(w, r) } else { - http.Error(w, "no such pusher", http.StatusNotFound) + util.ReturnError(util.APIErrorNoPusher, "no such pusher", w, r) } } @@ -216,16 +195,16 @@ func (conf *GlobalConfig) API_stop_subscribe(w http.ResponseWriter, r *http.Requ id := q.Get("id") s := Streams.Get(streamPath) if s == nil { - http.Error(w, NO_SUCH_STREAM, http.StatusNotFound) + util.ReturnError(util.APIErrorNoStream, NO_SUCH_STREAM, w, r) return } suber := s.Subscribers.Find(id) if suber == nil { - http.Error(w, "no such subscriber", http.StatusNotFound) + util.ReturnError(util.APIErrorNoSubscriber, "no such subscriber", w, r) return } suber.Stop(zap.String("reason", "stop by api")) - fmt.Fprintln(w, "ok") + util.ReturnOK(w, r) } func (conf *GlobalConfig) API_replay_rtpdump(w http.ResponseWriter, r *http.Request) { @@ -264,29 +243,29 @@ func (conf *GlobalConfig) API_replay_rtpdump(w http.ResponseWriter, r *http.Requ ss := strings.Split(dumpFile, ",") if len(ss) > 1 { if err := Engine.Publish(streamPath, &pub); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + util.ReturnError(util.APIErrorPublish, err.Error(), w, r) } else { for _, s := range ss { f, err := os.Open(s) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + util.ReturnError(util.APIErrorOpen, err.Error(), w, r) return } go pub.Feed(f) } - w.Write([]byte("ok")) + util.ReturnOK(w, r) } } else { f, err := os.Open(dumpFile) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + util.ReturnError(util.APIErrorOpen, err.Error(), w, r) return } if err := Engine.Publish(streamPath, &pub); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + util.ReturnError(util.APIErrorPublish, err.Error(), w, r) } else { pub.SetIO(f) - w.Write([]byte("ok")) + util.ReturnOK(w, r) go pub.Feed(f) } } @@ -304,12 +283,12 @@ func (conf *GlobalConfig) API_replay_ts(w http.ResponseWriter, r *http.Request) } f, err := os.Open(dumpFile) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + util.ReturnError(util.APIErrorOpen, err.Error(), w, r) return } var pub TSPublisher if err := Engine.Publish(streamPath, &pub); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + util.ReturnError(util.APIErrorPublish, err.Error(), w, r) } else { tsReader := NewTSReader(&pub) pub.SetIO(f) @@ -317,7 +296,7 @@ func (conf *GlobalConfig) API_replay_ts(w http.ResponseWriter, r *http.Request) tsReader.Feed(f) tsReader.Close() }() - w.Write([]byte("ok")) + util.ReturnOK(w, r) } } @@ -334,14 +313,14 @@ func (conf *GlobalConfig) API_replay_mp4(w http.ResponseWriter, r *http.Request) var pub MP4Publisher f, err := os.Open(dumpFile) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + util.ReturnError(util.APIErrorOpen, err.Error(), w, r) return } if err := Engine.Publish(streamPath, &pub); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + util.ReturnError(util.APIErrorPublish, err.Error(), w, r) } else { pub.SetIO(f) - w.Write([]byte("ok")) + util.ReturnOK(w, r) go pub.ReadMP4Data(f) } } @@ -351,7 +330,7 @@ func (conf *GlobalConfig) API_insertSEI(w http.ResponseWriter, r *http.Request) streamPath := q.Get("streamPath") s := Streams.Get(streamPath) if s == nil { - http.Error(w, NO_SUCH_STREAM, http.StatusNotFound) + util.ReturnError(util.APIErrorNoStream, NO_SUCH_STREAM, w, r) return } t := q.Get("type") @@ -360,18 +339,18 @@ func (conf *GlobalConfig) API_insertSEI(w http.ResponseWriter, r *http.Request) if t == "" { tb = 5 } else { - http.Error(w, err.Error(), http.StatusBadRequest) + util.ReturnError(util.APIErrorQueryParse, "type must a number", w, r) return } } sei, err := io.ReadAll(r.Body) if err == nil { if s.Tracks.AddSEI(byte(tb), sei) { - w.Write([]byte("ok")) + util.ReturnOK(w, r) } else { - http.Error(w, "no sei track", http.StatusBadRequest) + util.ReturnError(util.APIErrorNoSEI, "no sei track", w, r) } } else { - http.Error(w, err.Error(), http.StatusBadRequest) + util.ReturnError(util.APIErrorNoBody, err.Error(), w, r) } } diff --git a/io.go b/io.go index 7a6c24a..5b398b4 100644 --- a/io.go +++ b/io.go @@ -99,6 +99,7 @@ type IIO interface { SetParentCtx(context.Context) SetLogger(*log.Logger) IsShutdown() bool + log.Zap } func (i *IO) close() bool { @@ -166,9 +167,7 @@ func (io *IO) receive(streamPath string, specific IIO) error { if v, ok := specific.(ISubscriber); ok { wt = v.GetSubscriber().Config.WaitTimeout } - if io.Context == nil { - io.Context, io.CancelFunc = context.WithCancel(Engine) - } + io.Context, io.CancelFunc = context.WithCancel(util.Conditoinal[context.Context](io.Context == nil, Engine, io.Context)) s, create := findOrCreateStream(u.Path, wt) if s == nil { return ErrBadStreamName @@ -179,9 +178,15 @@ func (io *IO) receive(streamPath string, specific IIO) error { if io.Type == "" { io.Type = reflect.TypeOf(specific).Elem().Name() } - io.Logger = s.With(zap.String("type", io.Type)) + logFeilds := []zapcore.Field{zap.String("type", io.Type)} if io.ID != "" { - io.Logger = io.Logger.With(zap.String("ID", io.ID)) + logFeilds = append(logFeilds, zap.String("ID", io.ID)) + } + if io.Logger == nil { + io.Logger = s.With(logFeilds...) + } else { + logFeilds = append(logFeilds, zap.String("streamPath", s.Path)) + io.Logger = io.Logger.With(logFeilds...) } if v, ok := specific.(IPublisher); ok { conf := v.GetPublisher().Config @@ -191,14 +196,14 @@ func (io *IO) receive(streamPath string, specific IIO) error { defer s.pubLocker.Unlock() oldPublisher := s.Publisher if oldPublisher != nil && !oldPublisher.IsClosed() { - // 根据配置是否剔出原来的发布者 - if conf.KickExist { - s.Warn("kick", zap.String("type", oldPublisher.GetPublisher().Type)) + zot := zap.String("old type", oldPublisher.GetPublisher().Type) + if oldPublisher == specific { // 断线重连 + s.Info("republish", zot) + } else if conf.KickExist { // 根据配置是否剔出原来的发布者 + s.Warn("kick", zot) oldPublisher.OnEvent(SEKick{}) - } else if oldPublisher == specific { - //断线重连 } else { - s.Warn("duplicate publish", zap.String("type", oldPublisher.GetPublisher().Type)) + s.Warn("duplicate publish", zot) return ErrDuplicatePublish } } @@ -208,11 +213,7 @@ func (io *IO) receive(streamPath string, specific IIO) error { s.PauseTimeout = conf.PauseTimeout defer func() { if err == nil { - if oldPublisher == nil { - specific.OnEvent(specific) - } else { - specific.OnEvent(oldPublisher) - } + specific.OnEvent(util.Conditoinal[IIO](oldPublisher == nil, specific, oldPublisher)) } }() if config.Global.EnableAuth { @@ -242,8 +243,8 @@ func (io *IO) receive(streamPath string, specific IIO) error { conf := specific.(ISubscriber).GetSubscriber().Config io.Type = strings.TrimSuffix(io.Type, "Subscriber") io.Info("subscribe") - if create || s.State != STATE_PUBLISHING { - EventBus <- s // 通知发布者按需拉流 + if create { + EventBus <- InvitePublish{CreateEvent(s.Path)} // 通知发布者按需拉流 } defer func() { if err == nil { diff --git a/lang/zh.yaml b/lang/zh.yaml index efaacf9..017a74d 100644 --- a/lang/zh.yaml +++ b/lang/zh.yaml @@ -19,8 +19,10 @@ state: 状态 initialize: 初始化 "start read": 开始读取 "start pull": 开始从远端拉流 +"stop pull": 停止从远端拉流 "restart pull": 重新拉流 -"pull failed": 拉取失败 +"pull interrupt": 拉流中断 +"pull publish": 拉流发布 "wait publisher": 等待发布者发布 "wait timeout": 等待超时 created: 已创建 @@ -33,6 +35,7 @@ track+1: 轨道+1 playblock: 阻塞式播放 "play neither video nor audio": 播放既没有视频也没有音频 "play before subscribe": 播放之前需要先订阅 +"play stop": 播放停止 "suber -1": 订阅者-1 "suber +1": 订阅者+1 "innersuber +1": 内部订阅者+1 @@ -59,3 +62,6 @@ skipSeq: 跳过序列号 skipTs: 跳过时间戳 "nalu type not supported": nalu类型不支持 "create file": 创建文件 +"duplicate publish": 重复发布 +"republish": 重新发布 +"stream already had a publisher": 流已经有发布者了 \ No newline at end of file diff --git a/main.go b/main.go index c0aca77..99b24b6 100755 --- a/main.go +++ b/main.go @@ -5,7 +5,7 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "io" "net" "net/http" "os" @@ -68,7 +68,7 @@ func Run(ctx context.Context, configFile string) (err error) { if err = util.CreateShutdownScript(); err != nil { log.Error("create shutdown script error:", err) } - if ConfigRaw, err = ioutil.ReadFile(configFile); err != nil { + if ConfigRaw, err = os.ReadFile(configFile); err != nil { log.Warn("read config file error:", err.Error()) } if err = os.MkdirAll(SettingDir, 0766); err != nil { @@ -159,7 +159,7 @@ func Run(ctx context.Context, configFile string) (err error) { if EngineConfig.LogLang == "zh" { fmt.Print("已运行的插件:") } else { - fmt.Print("enabled plugins:") + fmt.Print("enabled plugins:") } for _, plugin := range enabledPlugins { fmt.Print(Colorize(" "+plugin.Name+" ", BlackFg|GreenBg|BoldFm), " ") @@ -168,7 +168,7 @@ func Run(ctx context.Context, configFile string) (err error) { if EngineConfig.LogLang == "zh" { fmt.Print("已禁用的插件:") } else { - fmt.Print("disabled plugins:") + fmt.Print("disabled plugins:") } for _, plugin := range disabledPlugins { fmt.Print(Colorize(" "+plugin.Name+" ", BlackFg|RedBg|CrossedOutFm), " ") @@ -189,7 +189,7 @@ func Run(ctx context.Context, configFile string) (err error) { Arch string `json:"arch"` }{UUID, id, EngineConfig.GetInstanceId(), version, runtime.GOOS, runtime.GOARCH} json.NewEncoder(contentBuf).Encode(&rp) - req.Body = ioutil.NopCloser(contentBuf) + req.Body = io.NopCloser(contentBuf) if EngineConfig.Secret != "" { EngineConfig.OnEvent(ctx) } @@ -218,7 +218,7 @@ func Run(ctx context.Context, configFile string) (err error) { case <-reportTimer.C: contentBuf.Reset() contentBuf.WriteString(fmt.Sprintf(`{"uuid":"`+UUID+`","streams":%d}`, Streams.Len())) - req.Body = ioutil.NopCloser(contentBuf) + req.Body = io.NopCloser(contentBuf) c.Do(req) } } diff --git a/plugin.go b/plugin.go index 6c5e1df..ef69385 100644 --- a/plugin.go +++ b/plugin.go @@ -243,6 +243,17 @@ func (opt *Plugin) Save() error { return nil } +func (opt *Plugin) AssignPubConfig(puber *Publisher) { + if puber.Config == nil { + conf, ok := opt.Config.(config.PublishConfig) + if !ok { + conf = EngineConfig + } + copyConfig := conf.GetPublishConfig() + puber.Config = ©Config + } +} + func (opt *Plugin) Publish(streamPath string, pub IPublisher) error { puber := pub.GetPublisher() if puber == nil { @@ -252,14 +263,7 @@ func (opt *Plugin) Publish(streamPath string, pub IPublisher) error { return errors.New("not publisher") } } - if puber.Config == nil { - conf, ok := opt.Config.(config.PublishConfig) - if !ok { - conf = EngineConfig - } - copyConfig := conf.GetPublishConfig() - puber.Config = ©Config - } + opt.AssignPubConfig(puber) return pub.Publish(streamPath, pub) } @@ -275,17 +279,7 @@ func (opt *Plugin) SubscribeExist(streamPath string, sub ISubscriber) error { } return opt.Subscribe(streamPath, sub) } - -// Subscribe 订阅一个流,如果流不存在则创建一个等待流 -func (opt *Plugin) Subscribe(streamPath string, sub ISubscriber) error { - suber := sub.GetSubscriber() - if suber == nil { - if EngineConfig.LogLang == "zh" { - return errors.New("不是订阅者") - } else { - return errors.New("not subscriber") - } - } +func (opt *Plugin) AssignSubConfig(suber *Subscriber) { if suber.Config == nil { conf, ok := opt.Config.(config.SubscribeConfig) if !ok { @@ -297,6 +291,18 @@ func (opt *Plugin) Subscribe(streamPath string, sub ISubscriber) error { if suber.ID == "" { suber.ID = fmt.Sprintf("%d", uintptr(unsafe.Pointer(suber))) } +} +// Subscribe 订阅一个流,如果流不存在则创建一个等待流 +func (opt *Plugin) Subscribe(streamPath string, sub ISubscriber) error { + suber := sub.GetSubscriber() + if suber == nil { + if EngineConfig.LogLang == "zh" { + return errors.New("不是订阅者") + } else { + return errors.New("not subscriber") + } + } + opt.AssignSubConfig(suber) return sub.Subscribe(streamPath, sub) } @@ -321,56 +327,10 @@ func (opt *Plugin) Pull(streamPath string, url string, puller IPuller, save int) zurl := zap.String("url", url) zpath := zap.String("path", streamPath) opt.Info("pull", zpath, zurl) - defer func() { - if err != nil { - opt.Error("pull failed", zurl, zap.Error(err)) - } - }() puller.init(streamPath, url, pullConf) + opt.AssignPubConfig(puller.GetPublisher()) puller.SetLogger(opt.Logger.With(zpath, zurl)) - badPuller := true - go func() { - Pullers.Store(puller, url) - defer Pullers.Delete(puller) - for opt.Info("start pull", zurl); puller.Reconnect(); opt.Warn("restart pull", zurl) { - if err = puller.Connect(); err != nil { - if err == io.EOF { - puller.GetPublisher().Stream.Close() - opt.Info("pull complete", zurl) - return - } - opt.Error("pull connect", zurl, zap.Error(err)) - if badPuller { - return - } - time.Sleep(time.Second * 5) - } else { - if err = opt.Publish(streamPath, puller); err != nil { - if stream := Streams.Get(streamPath); stream != nil { - if stream.Publisher != puller && stream.Publisher != nil { - io := stream.Publisher.GetPublisher() - opt.Error("puller is not publisher", zap.String("ID", io.ID), zap.String("Type", io.Type), zap.Error(err)) - return - } else { - opt.Warn("pull publish", zurl, zap.Error(err)) - } - } else { - opt.Error("pull publish", zurl, zap.Error(err)) - return - } - } - badPuller = false - if err = puller.Pull(); err != nil && !puller.IsShutdown() { - opt.Error("pull", zurl, zap.Error(err)) - } - } - if puller.IsShutdown() { - opt.Info("stop pull shutdown", zurl) - return - } - } - opt.Warn("stop pull stop reconnect", zurl) - }() + go puller.startPull(puller) } switch save { case 1: @@ -406,43 +366,10 @@ func (opt *Plugin) Push(streamPath string, url string, pusher IPusher, save bool return ErrNoPushConfig } pushConfig := conf.GetPushConfig() - pusher.init(streamPath, url, pushConfig) - badPusher := true - go func() { - Pushers.Store(url, pusher) - defer Pushers.Delete(url) - for opt.Info("start push", zp, zu); pusher.Reconnect(); opt.Warn("restart push", zp, zu) { - if err = opt.Subscribe(streamPath, pusher); err != nil { - opt.Error("push subscribe", zp, zu, zap.Error(err)) - time.Sleep(time.Second * 5) - } else { - stream := pusher.GetSubscriber().Stream - if err = pusher.Connect(); err != nil { - if err == io.EOF { - opt.Info("push complete", zp, zu) - return - } - opt.Error("push connect", zp, zu, zap.Error(err)) - time.Sleep(time.Second * 5) - stream.Receive(pusher) // 通知stream移除订阅者 - if badPusher { - return - } - } else if err = pusher.Push(); err != nil && !stream.IsClosed() { - opt.Error("push", zp, zu, zap.Error(err)) - pusher.Stop() - } - badPusher = false - if stream.IsClosed() { - opt.Info("stop push closed", zp, zu) - return - } - } - } - opt.Warn("stop push stop reconnect", zp, zu) - }() - + pusher.SetLogger(opt.Logger.With(zp, zu)) + opt.AssignSubConfig(pusher.GetSubscriber()) + go pusher.startPush(pusher) if save { pushConfig.AddPush(url, streamPath) if opt.Modified == nil { diff --git a/publisher.go b/publisher.go index 41af08f..1044425 100644 --- a/publisher.go +++ b/publisher.go @@ -135,23 +135,3 @@ func (p *Publisher) WriteAVCCAudio(ts uint32, frame *util.BLL, pool util.BytesPo p.AudioTrack.WriteAVCC(ts, frame) } } - -type IPuller interface { - IPublisher - Connect() error - Pull() error - Reconnect() bool - init(streamPath string, url string, conf *config.Pull) -} - -// 用于远程拉流的发布者 -type Puller struct { - ClientIO[config.Pull] -} - -// 是否需要重连 -func (pub *Puller) Reconnect() (ok bool) { - ok = pub.Config.RePull == -1 || pub.ReConnectCount <= pub.Config.RePull - pub.ReConnectCount++ - return -} diff --git a/puller.go b/puller.go new file mode 100644 index 0000000..1366ad0 --- /dev/null +++ b/puller.go @@ -0,0 +1,85 @@ +package engine + +import ( + "io" + "time" + + "go.uber.org/zap" + "m7s.live/engine/v4/config" +) + +var zshutdown = zap.String("reason", "shutdown") +var znomorereconnect = zap.String("reason", "no more reconnect") + +type IPuller interface { + IPublisher + Connect() error + Disconnect() + Pull() error + Reconnect() bool + init(streamPath string, url string, conf *config.Pull) + startPull(IPuller) +} + +// 用于远程拉流的发布者 +type Puller struct { + ClientIO[config.Pull] +} + +// 是否需要重连 +func (pub *Puller) Reconnect() (ok bool) { + ok = pub.Config.RePull == -1 || pub.ReConnectCount <= pub.Config.RePull + pub.ReConnectCount++ + return +} + +func (pub *Puller) startPull(puller IPuller) { + badPuller := true + var stream *Stream + var err error + Pullers.Store(puller, pub.RemoteURL) + defer func() { + Pullers.Delete(puller) + puller.Disconnect() + if stream != nil { + stream.Close() + } + }() + puber := puller.GetPublisher() + originContext := puber.Context // 保存原始的Context + for puller.Info("start pull"); puller.Reconnect(); puller.Warn("restart pull") { + if err = puller.Connect(); err != nil { + if err == io.EOF { + puller.Info("pull complete") + return + } + puller.Error("pull connect", zap.Error(err)) + if badPuller { + return + } + time.Sleep(time.Second * 5) + } else { + puber.Context = originContext // 每次重连都需要恢复原始的Context + if err = puller.Publish(pub.StreamPath, puller); err != nil { + puller.Error("pull publish", zap.Error(err)) + return + } + s := puber.Stream + if stream != s && stream != nil { // 这段代码说明老流已经中断,创建了新流,需要把track置空,从而避免复用 + puber.AudioTrack = nil + puber.VideoTrack = nil + } + stream = s + badPuller = false + if err = puller.Pull(); err != nil && !puller.IsShutdown() { + puller.Error("pull interrupt", zap.Error(err)) + } + } + if puller.IsShutdown() { + puller.Info("stop pull", zshutdown) + return + } + puller.Disconnect() + } + puller.Warn("stop pull", znomorereconnect) +} diff --git a/pusher.go b/pusher.go new file mode 100644 index 0000000..fd91870 --- /dev/null +++ b/pusher.go @@ -0,0 +1,68 @@ +package engine + +import ( + "io" + "time" + + "go.uber.org/zap" + "m7s.live/engine/v4/config" +) + +type IPusher interface { + ISubscriber + Push() error + Connect() error + Disconnect() + init(string, string, *config.Push) + Reconnect() bool + startPush(IPusher) +} + +type Pusher struct { + ClientIO[config.Push] +} + +// 是否需要重连 +func (pub *Pusher) Reconnect() (result bool) { + result = pub.Config.RePush == -1 || pub.ReConnectCount <= pub.Config.RePush + pub.ReConnectCount++ + return +} + +func (pub *Pusher) startPush(pusher IPusher) { + badPusher := true + var err error + Pushers.Store(pub.RemoteURL, pusher) + defer Pushers.Delete(pub.RemoteURL) + defer pusher.Disconnect() + for pusher.Info("start push"); pusher.Reconnect(); pusher.Warn("restart push") { + if err = pusher.Subscribe(pub.StreamPath, pusher); err != nil { + pusher.Error("push subscribe", zap.Error(err)) + time.Sleep(time.Second * 5) + } else { + stream := pusher.GetSubscriber().Stream + if err = pusher.Connect(); err != nil { + if err == io.EOF { + pusher.Info("push complete") + return + } + pusher.Error("push connect", zap.Error(err)) + time.Sleep(time.Second * 5) + stream.Receive(pusher) // 通知stream移除订阅者 + if badPusher { + return + } + } else if err = pusher.Push(); err != nil && !stream.IsClosed() { + pusher.Error("push", zap.Error(err)) + pusher.Stop() + } + badPusher = false + if stream.IsClosed() { + pusher.Info("stop push closed") + return + } + } + pusher.Disconnect() + } + pusher.Warn("stop push stop reconnect") +} diff --git a/stream.go b/stream.go index 58385a4..b268891 100644 --- a/stream.go +++ b/stream.go @@ -130,8 +130,9 @@ type StreamTimeoutConfig struct { } type Tracks struct { sync.Map - MainVideo *track.Video - SEI *track.Data[[]byte] + MainVideo *track.Video + SEI *track.Data[[]byte] + marshalLock sync.Mutex } func (tracks *Tracks) Range(f func(name string, t Track)) { @@ -191,6 +192,8 @@ func (tracks *Tracks) AddSEI(t byte, data []byte) bool { func (tracks *Tracks) MarshalJSON() ([]byte, error) { var trackList []Track + tracks.marshalLock.Lock() + defer tracks.marshalLock.Unlock() tracks.Range(func(_ string, t Track) { t.SnapForJson() trackList = append(trackList, t) @@ -238,6 +241,10 @@ func (s *Stream) GetStartTime() time.Time { } func (s *Stream) GetPublisherConfig() *config.Publish { + if s.Publisher == nil { + s.Error("GetPublisherConfig: Publisher is nil") + return nil + } return s.Publisher.GetPublisher().Config } @@ -444,7 +451,7 @@ func (s *Stream) run() { } } if !s.NeverTimeout { - hasTrackTimeout := false + lost := false trackCount := 0 timeout := s.PublishTimeout if s.IsPause { @@ -457,11 +464,20 @@ func (s *Stream) run() { // track 超过一定时间没有更新数据了 if lastWriteTime := t.LastWriteTime(); !lastWriteTime.IsZero() && time.Since(lastWriteTime) > timeout { s.Warn("track timeout", zap.String("name", name), zap.Time("last writetime", lastWriteTime), zap.Duration("timeout", timeout)) - hasTrackTimeout = true + lost = true } } }) - if trackCount == 0 || hasTrackTimeout || (s.Publisher != nil && s.Publisher.IsClosed()) { + if !lost { + if trackCount == 0 { + s.Warn("no tracks") + lost = true + } else if s.Publisher != nil && s.Publisher.IsClosed() { + s.Warn("publish is closed") + lost = true + } + } + if lost { s.action(ACTION_PUBLISHLOST) continue } diff --git a/subscriber.go b/subscriber.go index 8792676..e4bf808 100644 --- a/subscriber.go +++ b/subscriber.go @@ -404,27 +404,10 @@ func (s *Subscriber) PlayBlock(subType byte) { func (s *Subscriber) onStop(reason *zapcore.Field) { if !s.Stream.IsClosed() { - s.Info("stop", *reason) + s.Info("play stop", *reason) if !s.Config.Internal { s.Stream.Receive(s.Spesific) } } } -type IPusher interface { - ISubscriber - Push() error - Connect() error - init(string, string, *config.Push) - Reconnect() bool -} -type Pusher struct { - ClientIO[config.Push] -} - -// 是否需要重连 -func (pub *Pusher) Reconnect() (result bool) { - result = pub.Config.RePush == -1 || pub.ReConnectCount <= pub.Config.RePush - pub.ReConnectCount++ - return -} diff --git a/track/h264.go b/track/h264.go index 5f9e3a1..d152721 100644 --- a/track/h264.go +++ b/track/h264.go @@ -159,11 +159,11 @@ func (vt *H264) CompleteRTP(value *AVFrame) { if value.IFrame { out = append(out, [][]byte{vt.SPS}, [][]byte{vt.PPS}) } - startIndex := len(out) vt.Value.AUList.Range(func(au *util.BLL) bool { if au.ByteLength < RTPMTU { out = append(out, au.ToBuffers()) } else { + startIndex := len(out) var naluType codec.H264NALUType r := au.NewReader() b0, _ := r.ReadByte() diff --git a/track/h265.go b/track/h265.go index 7718d0b..fd2f12b 100644 --- a/track/h265.go +++ b/track/h265.go @@ -189,11 +189,11 @@ func (vt *H265) CompleteRTP(value *AVFrame) { if value.IFrame { out = append(out, [][]byte{vt.VPS}, [][]byte{vt.SPS}, [][]byte{vt.PPS}) } - startIndex := len(out) vt.Value.AUList.Range(func(au *util.BLL) bool { if au.ByteLength < RTPMTU { out = append(out, au.ToBuffers()) } else { + startIndex := len(out) var naluType codec.H265NALUType r := au.NewReader() b0, _ := r.ReadByte() diff --git a/util/socket.go b/util/socket.go index 7985b6b..33d2e7c 100644 --- a/util/socket.go +++ b/util/socket.go @@ -17,43 +17,124 @@ func FetchValue[T any](t T) func() T { } } -func ReturnJson[T any](fetch func() T, tickDur time.Duration, rw http.ResponseWriter, r *http.Request) { - if r.Header.Get("Accept") == "text/event-stream" { - sse := NewSSE(rw, r.Context()) - tick := time.NewTicker(tickDur) - defer tick.Stop() - for range tick.C { - if sse.WriteJSON(fetch()) != nil { - return - } +const ( + APIErrorNone = 0 + APIErrorDecode = iota + 4000 + APIErrorQueryParse + APIErrorNoBody +) + +const ( + APIErrorNotFound = iota + 4040 + APIErrorNoStream + APIErrorNoConfig + APIErrorNoPusher + APIErrorNoSubscriber + APIErrorNoSEI +) + +const ( + APIErrorInternal = iota + 5000 + APIErrorJSONEncode + APIErrorPublish + APIErrorSave + APIErrorOpen +) + +type APIError struct { + Code int `json:"code"` + Message string `json:"msg"` +} + +type APIResult struct { + Code int `json:"code"` + Data any `json:"data"` + Message string `json:"msg"` +} + +func ReturnValue(v any, rw http.ResponseWriter, r *http.Request) { + ReturnFetchValue(FetchValue(v), rw, r) +} + +func ReturnOK(rw http.ResponseWriter, r *http.Request) { + ReturnError(0, "ok", rw, r) +} + +func ReturnError(code int, msg string, rw http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + isJson := query.Get("format") == "json" + if isJson { + if err := json.NewEncoder(rw).Encode(APIError{code, msg}); err != nil { + json.NewEncoder(rw).Encode(APIError{ + Code: APIErrorJSONEncode, + Message: err.Error(), + }) } } else { - rw.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(rw).Encode(fetch()); err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) + switch true { + case code == 0: + http.Error(rw, msg, http.StatusOK) + case code/10 == 404: + http.Error(rw, msg, http.StatusNotFound) + case code > 5000: + http.Error(rw, msg, http.StatusInternalServerError) + default: + http.Error(rw, msg, http.StatusBadRequest) } } } -func ReturnYaml[T any](fetch func() T, tickDur time.Duration, rw http.ResponseWriter, r *http.Request) { +func ReturnFetchValue[T any](fetch func() T, rw http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + isYaml := query.Get("format") == "yaml" + isJson := query.Get("format") == "json" + tickDur, err := time.ParseDuration(query.Get("interval")) + if err != nil { + tickDur = time.Second + } if r.Header.Get("Accept") == "text/event-stream" { sse := NewSSE(rw, r.Context()) tick := time.NewTicker(tickDur) defer tick.Stop() - for range tick.C { - if sse.WriteYAML(fetch()) != nil { - return + if isYaml { + for range tick.C { + if sse.WriteYAML(fetch()) != nil { + return + } + } + } else { + for range tick.C { + if sse.WriteJSON(fetch()) != nil { + return + } } } } else { - rw.Header().Set("Content-Type", "application/yaml") - if err := yaml.NewEncoder(rw).Encode(fetch()); err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) + data := fetch() + rw.Header().Set("Content-Type", Conditoinal(isYaml, "text/yaml", "application/json")) + if isYaml { + if err := yaml.NewEncoder(rw).Encode(data); err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } + } else if isJson { + if err := json.NewEncoder(rw).Encode(APIResult{ + Code: 0, + Data: data, + Message: "ok", + }); err != nil { + json.NewEncoder(rw).Encode(APIError{ + Code: APIErrorJSONEncode, + Message: err.Error(), + }) + } + } else { + if err := json.NewEncoder(rw).Encode(data); err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } } } } - func ListenUDP(address string, networkBuffer int) (*net.UDPConn, error) { addr, err := net.ResolveUDPAddr("udp", address) if err != nil {