From 5cc3fdecf975afd463fdb63f7f247c53376ad8c8 Mon Sep 17 00:00:00 2001 From: langhuihui <178529795@qq.com> Date: Sat, 17 Aug 2024 20:38:10 +0800 Subject: [PATCH] fix: cascade plugin --- example/multiple/config1.yaml | 10 ++- example/multiple/config2.yaml | 25 ++++---- example/multiple/main.go | 2 +- pkg/config/db.go | 2 +- pkg/db/sqlite.go | 7 ++- plugin.go | 115 +++++++++++++++++----------------- plugin/cascade/client.go | 63 +++++++++---------- plugin/cascade/server.go | 14 ++--- server.go | 15 +++-- test/server_test.go | 2 +- 10 files changed, 129 insertions(+), 126 deletions(-) diff --git a/example/multiple/config1.yaml b/example/multiple/config1.yaml index 5f43fac..6a2a203 100644 --- a/example/multiple/config1.yaml +++ b/example/multiple/config1.yaml @@ -5,5 +5,11 @@ global: listenaddrtls: :8555 tcp: listenaddr: :50052 -console: - secret: de2c0bb9fd47684adc07a426e139239b +cascadeclient: + server: localhost:44944 + pull: + enableregexp: true + pullonsub: + .*: m7s://$0 +#console: +# secret: de2c0bb9fd47684adc07a426e139239b diff --git a/example/multiple/config2.yaml b/example/multiple/config2.yaml index dd18ae9..2a0c53b 100644 --- a/example/multiple/config2.yaml +++ b/example/multiple/config2.yaml @@ -1,16 +1,13 @@ global: loglevel: debug - tcp: - listenaddr: :50051 -console: - secret: 00aea3af031f134d6307618b05ec4899 -rtmp: - enable: false -rtsp: - enable: false -webrtc: - enable: false -flv: - pull: - pullonstart: - live/test: /Users/dexter/Movies/jb-demo.flv \ No newline at end of file + disableall: true +#console: +# secret: 00aea3af031f134d6307618b05ec4899 +cascadeserver: + enable: true + quic: + listenaddr: :44944 +#flv: +# pull: +# pullonstart: +# live/test: /Users/dexter/Movies/jb-demo.flv \ No newline at end of file diff --git a/example/multiple/main.go b/example/multiple/main.go index 9fa8c57..0cfebdf 100644 --- a/example/multiple/main.go +++ b/example/multiple/main.go @@ -3,7 +3,7 @@ package main import ( "context" "m7s.live/m7s/v5" - _ "m7s.live/m7s/v5/plugin/console" + _ "m7s.live/m7s/v5/plugin/cascade" _ "m7s.live/m7s/v5/plugin/debug" _ "m7s.live/m7s/v5/plugin/flv" _ "m7s.live/m7s/v5/plugin/logrotate" diff --git a/pkg/config/db.go b/pkg/config/db.go index cadbcd3..5ff288a 100644 --- a/pkg/config/db.go +++ b/pkg/config/db.go @@ -2,5 +2,5 @@ package config type DB struct { DBType string `default:"sqlite" desc:"数据库类型"` - DSN string `default:"cascade.db" desc:"数据库文件路径"` + DSN string `default:"m7s.db" desc:"数据库文件路径"` } diff --git a/pkg/db/sqlite.go b/pkg/db/sqlite.go index a5a29bd..51cf669 100644 --- a/pkg/db/sqlite.go +++ b/pkg/db/sqlite.go @@ -2,10 +2,13 @@ package db -import "github.com/glebarez/sqlite" +import ( + "github.com/glebarez/sqlite" + "gorm.io/gorm" +) func init() { Factory["sqlite"] = func(dsn string) gorm.Dialector { - return gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + return sqlite.Open(dsn) } } diff --git a/plugin.go b/plugin.go index 83c5b18..043bcab 100644 --- a/plugin.go +++ b/plugin.go @@ -23,26 +23,60 @@ import ( "m7s.live/m7s/v5/pkg/util" ) -type DefaultYaml string +type ( + DefaultYaml string + OnExitHandler func() + AuthPublisher = func(*Publisher) *util.Promise + AuthSubscriber = func(*Subscriber) *util.Promise -type OnExitHandler func() -type AuthPublisher = func(*Publisher) *util.Promise -type AuthSubscriber = func(*Subscriber) *util.Promise + PluginMeta struct { + Name string + Version string //插件版本 + Type reflect.Type + defaultYaml DefaultYaml //默认配置 + ServiceDesc *grpc.ServiceDesc + RegisterGRPCHandler func(context.Context, *gatewayRuntime.ServeMux, *grpc.ClientConn) error + Puller Puller + Pusher Pusher + Recorder Recorder + OnExit OnExitHandler + OnAuthPub AuthPublisher + OnAuthSub AuthSubscriber + } -type PluginMeta struct { - Name string - Version string //插件版本 - Type reflect.Type - defaultYaml DefaultYaml //默认配置 - ServiceDesc *grpc.ServiceDesc - RegisterGRPCHandler func(context.Context, *gatewayRuntime.ServeMux, *grpc.ClientConn) error - Puller Puller - Pusher Pusher - Recorder Recorder - OnExit OnExitHandler - OnAuthPub AuthPublisher - OnAuthSub AuthSubscriber -} + iPlugin interface { + nothing() + } + + IPlugin interface { + util.ITask + OnInit() error + OnStop() + Pull(path string, url string) + } + + IRegisterHandler interface { + RegisterHandler() map[string]http.HandlerFunc + } + + IPullerPlugin interface { + GetPullableList() []string + } + + ITCPPlugin interface { + OnTCPConnect(*net.TCPConn) + } + + IUDPPlugin interface { + OnUDPConnect(*net.UDPConn) + } + + IQUICPlugin interface { + OnQUICConnect(quic.Connection) + } +) + +var plugins []PluginMeta func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) (p *Plugin) { instance, ok := reflect.New(plugin.Type).Interface().(IPlugin) @@ -111,39 +145,6 @@ func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) (p *Plugin) return } -type iPlugin interface { - nothing() -} - -type IPlugin interface { - util.ITask - OnInit() error - OnStop() - Pull(path string, url string) -} - -type IRegisterHandler interface { - RegisterHandler() map[string]http.HandlerFunc -} - -type IPullerPlugin interface { - GetPullableList() []string -} - -type ITCPPlugin interface { - OnTCPConnect(*net.TCPConn) -} - -type IUDPPlugin interface { - OnUDPConnect(*net.UDPConn) -} - -type IQUICPlugin interface { - OnQUICConnect(quic.Connection) -} - -var plugins []PluginMeta - // InstallPlugin 安装插件 func InstallPlugin[C iPlugin](options ...any) error { var c *C @@ -347,11 +348,13 @@ func (p *Plugin) listen() (err error) { quicConf := &p.config.Quic if quicConf.ListenAddr != "" && quicConf.AutoListen { p.Info("listen quic", "addr", quicConf.ListenAddr) - err = quicConf.ListenQuic(p, quicHandler.OnQUICConnect) - if err != nil { - p.Error("listen quic", "addr", quicConf.ListenAddr, "error", err) - return - } + go func() { + p.Stop(quicConf.ListenQuic(p, quicHandler.OnQUICConnect)) + }() + //if err != nil { + // p.Error("listen quic", "addr", quicConf.ListenAddr, "error", err) + // return + //} } } return diff --git a/plugin/cascade/client.go b/plugin/cascade/client.go index bc6971c..cc7a4db 100644 --- a/plugin/cascade/client.go +++ b/plugin/cascade/client.go @@ -11,7 +11,7 @@ import ( "github.com/quic-go/quic-go" ) -type CascadeClientConfig struct { +type CascadeClientPlugin struct { m7s.Plugin RelayAPI cascade.RelayAPIConfig `desc:"访问控制"` AutoPush bool `desc:"自动推流到上级"` //自动推流到上级 @@ -20,16 +20,11 @@ type CascadeClientConfig struct { conn quic.Connection } -var _ = m7s.InstallPlugin[CascadeClientConfig](m7s.DefaultYaml(` -cascadeclient: - relayapi: - allow: - - / -`), cascade.NewCascadePuller) +var _ = m7s.InstallPlugin[CascadeClientPlugin](cascade.NewCascadePuller) type ConnectServerTask struct { util.Task - cfg *CascadeClientConfig + cfg *CascadeClientPlugin quic.Connection } @@ -43,23 +38,25 @@ func (task *ConnectServerTask) Start() (err error) { KeepAlivePeriod: time.Second * 10, EnableDatagrams: true, }) - if stream := quic.Stream(nil); err == nil { - if stream, err = task.OpenStreamSync(task.cfg); err == nil { - res := []byte{0} - fmt.Fprintf(stream, "%s", task.cfg.Secret) - stream.Write([]byte{0}) - _, err = stream.Read(res) - if err == nil && res[0] == 0 { - task.Info("connected to cascade server", "server", task.cfg.Server) - stream.Close() - } else { - var zapErr any = err - if err == nil { - zapErr = res[0] - } - task.Error("connect to cascade server", "server", task.cfg.Server, "err", zapErr) - return nil + if err != nil { + return + } + var stream quic.Stream + if stream, err = task.OpenStreamSync(task.cfg); err == nil { + res := []byte{0} + fmt.Fprintf(stream, "%s", task.cfg.Secret) + stream.Write([]byte{0}) + _, err = stream.Read(res) + if err == nil && res[0] == 0 { + task.Info("connected to cascade server", "server", task.cfg.Server) + stream.Close() + } else { + var zapErr any = err + if err == nil { + zapErr = res[0] } + task.Error("connect to cascade server", "server", task.cfg.Server, "err", zapErr) + return nil } } return @@ -80,8 +77,8 @@ func (task *ConnectServerTask) Run() (err error) { return } -func (c *CascadeClientConfig) OnInit() (err error) { - if c.Secret == "" || c.Server == "" { +func (c *CascadeClientPlugin) OnInit() (err error) { + if c.Secret == "" && c.Server == "" { return nil } connectTask := ConnectServerTask{ @@ -92,14 +89,14 @@ func (c *CascadeClientConfig) OnInit() (err error) { return } -func (c *CascadeClientConfig) Pull(streamPath, url string) { - puller := cascade.NewCascadePuller().(*cascade.Puller) - puller.Connection = c.conn - puller.GetPullContext().Init(puller, &c.Plugin, streamPath, url) - c.Plugin.Server.AddPullTask(puller) +func (c *CascadeClientPlugin) Pull(streamPath, url string) { + puller := &cascade.Puller{ + Connection: c.conn, + } + c.Plugin.Server.AddPullTask(puller.GetPullContext().Init(puller, &c.Plugin, streamPath, url)) } -//func (c *CascadeClientConfig) Start() { +//func (c *CascadeClientPlugin) Start() { // retryDelay := [...]int{2, 3, 5, 8, 13} // for i := 0; c.Err() == nil; i++ { // connected, err := c.Remote() @@ -117,7 +114,7 @@ func (c *CascadeClientConfig) Pull(streamPath, url string) { // } //} -//func (c *CascadeClientConfig) Remote() (wasConnected bool, err error) { +//func (c *CascadeClientPlugin) Remote() (wasConnected bool, err error) { // tlsConf := &tls.Config{ // InsecureSkipVerify: true, // NextProtos: []string{"monibuca"}, diff --git a/plugin/cascade/server.go b/plugin/cascade/server.go index c01402d..677a217 100644 --- a/plugin/cascade/server.go +++ b/plugin/cascade/server.go @@ -11,19 +11,18 @@ import ( "m7s.live/m7s/v5/plugin/cascade/pkg" ) -type CascadeServerConfig struct { +type CascadeServerPlugin struct { m7s.Plugin AutoRegister bool `default:"true" desc:"下级自动注册"` RelayAPI cascade.RelayAPIConfig `desc:"访问控制"` } -var _ = m7s.InstallPlugin[CascadeServerConfig]() +var _ = m7s.InstallPlugin[CascadeServerPlugin]() -func (c *CascadeServerConfig) OnQUICConnect(conn quic.Connection) (err error) { +func (c *CascadeServerPlugin) OnQUICConnect(conn quic.Connection) { remoteAddr := conn.RemoteAddr().String() c.Info("client connected:", "remoteAddr", remoteAddr) - var stream quic.Stream - stream, err = conn.AcceptStream(c) + stream, err := conn.AcceptStream(c) if err != nil { c.Error("AcceptStream", "err", err) return @@ -76,11 +75,10 @@ func (c *CascadeServerConfig) OnQUICConnect(conn quic.Connection) (err error) { c.AddTask(&receiveRequestTask) } } - return } // API_relay_ 用于转发请求, api/relay/:instanceId/* -func (c *CascadeServerConfig) API_relay_(w http.ResponseWriter, r *http.Request) { +func (c *CascadeServerPlugin) API_relay_(w http.ResponseWriter, r *http.Request) { paths := strings.Split(r.URL.Path, "/") instanceId, err := strconv.ParseUint(paths[3], 10, 32) instance, ok := cascade.SubordinateMap.Get(uint(instanceId)) @@ -105,6 +103,6 @@ func (c *CascadeServerConfig) API_relay_(w http.ResponseWriter, r *http.Request) } // API_list 用于获取所有下级, api/list -func (c *CascadeServerConfig) API_list(w http.ResponseWriter, r *http.Request) { +func (c *CascadeServerPlugin) API_list(w http.ResponseWriter, r *http.Request) { //util.ReturnFetchList(SubordinateMap.ToList, w, r) } diff --git a/server.go b/server.go index ade3a54..4d4cab7 100644 --- a/server.go +++ b/server.go @@ -2,6 +2,7 @@ package m7s import ( "context" + "errors" "fmt" "log/slog" "net" @@ -90,12 +91,10 @@ func NewServer(conf any) (s *Server) { return } -func Run(ctx context.Context, conf any) error { - for { - if err := util.RootTask.AddTaskWithContext(ctx, NewServer(conf)).WaitStopped(); err != ErrRestart { - return err - } +func Run(ctx context.Context, conf any) (err error) { + for err = ErrRestart; errors.Is(err, ErrRestart); err = util.RootTask.AddTaskWithContext(ctx, NewServer(conf)).WaitStopped() { } + return } func AddRootTask[T util.ITask](task T) T { @@ -108,7 +107,7 @@ func AddRootTaskWithContext[T util.ITask](ctx context.Context, task T) T { return task } -type rawconfig = map[string]map[string]any +type RawConfig = map[string]map[string]any func init() { signalChan := make(chan os.Signal, 1) @@ -151,7 +150,7 @@ func (s *Server) Start() (err error) { httpConf.GetHttpMux().ServeHTTP(w, r) })) httpConf.SetMux(mux) - var cg rawconfig + var cg RawConfig var configYaml []byte switch v := s.conf.(type) { case string: @@ -163,7 +162,7 @@ func (s *Server) Start() (err error) { } case []byte: configYaml = v - case rawconfig: + case RawConfig: cg = v } if configYaml != nil { diff --git a/test/server_test.go b/test/server_test.go index 0515527..7cda79e 100644 --- a/test/server_test.go +++ b/test/server_test.go @@ -8,7 +8,7 @@ import ( ) func TestRestart(b *testing.T) { - conf := map[string]map[string]any{"global": {"loglevel": "debug"}} + conf := m7s.RawConfig{"global": {"loglevel": "debug"}} var server *m7s.Server go func() { time.Sleep(time.Second * 2)