diff --git a/pkg/task/channel.go b/pkg/task/channel.go index a91910f..6df5650 100644 --- a/pkg/task/channel.go +++ b/pkg/task/channel.go @@ -54,6 +54,7 @@ func (t *AsyncTickTask) GetSignal() any { } func (t *AsyncTickTask) Go() error { + t.Tick(nil) for { select { case c := <-t.Ticker.C: diff --git a/pkg/task/manager.go b/pkg/task/manager.go index 2905d9d..f941fc7 100644 --- a/pkg/task/manager.go +++ b/pkg/task/manager.go @@ -36,3 +36,28 @@ func (m *Manager[K, T]) Add(ctx T, opt ...any) *Task { }) return m.AddTask(ctx, opt...) } + +// SafeGet 用于不同协程获取元素,防止并发请求 +func (m *Manager[K, T]) SafeGet(key K) (item T, ok bool) { + if m.L == nil { + m.Call(func() error { + item, ok = m.Collection.Get(key) + return nil + }) + } else { + item, ok = m.Collection.Get(key) + } + return +} + +// SafeRange 用于不同协程获取元素,防止并发请求 +func (m *Manager[K, T]) SafeRange(f func(T) bool) { + if m.L == nil { + m.Call(func() error { + m.Collection.Range(f) + return nil + }) + } else { + m.Collection.Range(f) + } +} diff --git a/pull-proxy.go b/pull-proxy.go index 77389cb..8e7a5f8 100644 --- a/pull-proxy.go +++ b/pull-proxy.go @@ -90,11 +90,6 @@ func (d *PullProxy) Start() (err error) { d.Handler = pullTask } if t, ok := pullTask.(task.ITask); ok { - if ticker, ok := t.(task.IChannelTask); ok { - t.OnStart(func() { - ticker.Tick(nil) - }) - } d.AddTask(t) } else { d.ChangeStatus(PullProxyStatusOnline) @@ -237,29 +232,26 @@ func (p *Publisher) processPullProxyOnDispose() { func (s *Server) GetPullProxyList(ctx context.Context, req *emptypb.Empty) (res *pb.PullProxyListResponse, err error) { res = &pb.PullProxyListResponse{} - s.PullProxies.Call(func() error { - for device := range s.PullProxies.Range { - res.Data = append(res.Data, &pb.PullProxyInfo{ - Name: device.Name, - CreateTime: timestamppb.New(device.CreatedAt), - UpdateTime: timestamppb.New(device.UpdatedAt), - Type: device.Type, - PullURL: device.URL, - ParentID: uint32(device.ParentID), - Status: uint32(device.Status), - ID: uint32(device.ID), - PullOnStart: device.PullOnStart, - StopOnIdle: device.StopOnIdle, - Audio: device.Audio, - RecordPath: device.Record.FilePath, - RecordFragment: durationpb.New(device.Record.Fragment), - Description: device.Description, - Rtt: uint32(device.RTT.Milliseconds()), - StreamPath: device.GetStreamPath(), - }) - } - return nil - }) + for device := range s.PullProxies.SafeRange { + res.Data = append(res.Data, &pb.PullProxyInfo{ + Name: device.Name, + CreateTime: timestamppb.New(device.CreatedAt), + UpdateTime: timestamppb.New(device.UpdatedAt), + Type: device.Type, + PullURL: device.URL, + ParentID: uint32(device.ParentID), + Status: uint32(device.Status), + ID: uint32(device.ID), + PullOnStart: device.PullOnStart, + StopOnIdle: device.StopOnIdle, + Audio: device.Audio, + RecordPath: device.Record.FilePath, + RecordFragment: durationpb.New(device.Record.Fragment), + Description: device.Description, + Rtt: uint32(device.RTT.Milliseconds()), + StreamPath: device.GetStreamPath(), + }) + } return } @@ -363,31 +355,23 @@ func (s *Server) UpdatePullProxy(ctx context.Context, req *pb.PullProxyInfo) (re target.StreamPath = req.StreamPath s.DB.Save(target) var needStopOld *PullProxy - s.PullProxies.Call(func() error { - if device, ok := s.PullProxies.Get(uint(req.ID)); ok { - if target.URL != device.URL || device.Audio != target.Audio || device.StreamPath != target.StreamPath || device.Record.FilePath != target.Record.FilePath || device.Record.Fragment != target.Record.Fragment { - device.Stop(task.ErrStopByUser) - needStopOld = device - return nil - } + if device, ok := s.PullProxies.SafeGet(uint(req.ID)); ok { + if target.URL != device.URL || device.Audio != target.Audio || device.StreamPath != target.StreamPath || device.Record.FilePath != target.Record.FilePath || device.Record.Fragment != target.Record.Fragment { + device.Stop(task.ErrStopByUser) + needStopOld = device + } else { device.Name = target.Name device.PullOnStart = target.PullOnStart device.StopOnIdle = target.StopOnIdle device.Description = target.Description } - return nil - }) + } if needStopOld != nil { + if pullJob, ok := s.Pulls.SafeGet(req.StreamPath); ok { + pullJob.Stop(task.ErrStopByUser) + pullJob.WaitStopped() + } s.PullProxies.Add(target).WaitStarted() - s.Pulls.Call(func() error { - if pullJob, ok := s.Pulls.Get(req.StreamPath); ok { - pullJob.Stop(task.ErrStopByUser) - target.Handler.Pull() - } else if target.PullOnStart { - target.Handler.Pull() - } - return nil - }) } res = &pb.SuccessResponse{} return @@ -404,15 +388,12 @@ func (s *Server) RemovePullProxy(ctx context.Context, req *pb.RequestWithId) (re ID: uint(req.Id), }) err = tx.Error - s.PullProxies.Call(func() error { - if device, ok := s.PullProxies.Get(uint(req.Id)); ok { - device.Stop(task.ErrStopByUser) - if pull, ok := device.server.Pulls.Get(device.StreamPath); ok { - pull.Stop(task.ErrStopByUser) - } + if device, ok := s.PullProxies.SafeGet(uint(req.Id)); ok { + device.Stop(task.ErrStopByUser) + if pull, ok := device.server.Pulls.SafeGet(device.StreamPath); ok { + pull.Stop(task.ErrStopByUser) } - return nil - }) + } return } else if req.StreamPath != "" { var deviceList []*PullProxy @@ -421,15 +402,12 @@ func (s *Server) RemovePullProxy(ctx context.Context, req *pb.RequestWithId) (re for _, device := range deviceList { tx := s.DB.Delete(&PullProxy{}, device.ID) err = tx.Error - s.PullProxies.Call(func() error { - if device, ok := s.PullProxies.Get(uint(device.ID)); ok { - device.Stop(task.ErrStopByUser) - if pull, ok := device.server.Pulls.Get(device.StreamPath); ok { - pull.Stop(task.ErrStopByUser) - } + if device, ok := s.PullProxies.SafeGet(uint(device.ID)); ok { + device.Stop(task.ErrStopByUser) + if pull, ok := device.server.Pulls.SafeGet(device.StreamPath); ok { + pull.Stop(task.ErrStopByUser) } - return nil - }) + } } } return