feat: add safe get in manager

This commit is contained in:
langhuihui
2025-04-11 13:46:22 +08:00
parent da338c05c1
commit 74dd4d7235
3 changed files with 66 additions and 62 deletions
+1
View File
@@ -54,6 +54,7 @@ func (t *AsyncTickTask) GetSignal() any {
}
func (t *AsyncTickTask) Go() error {
t.Tick(nil)
for {
select {
case c := <-t.Ticker.C:
+25
View File
@@ -36,3 +36,28 @@ func (m *Manager[K, T]) Add(ctx T, opt ...any) *Task {
})
return m.AddTask(ctx, opt...)
}
// SafeGet 用于不同协程获取元素,防止并发请求
func (m *Manager[K, T]) SafeGet(key K) (item T, ok bool) {
if m.L == nil {
m.Call(func() error {
item, ok = m.Collection.Get(key)
return nil
})
} else {
item, ok = m.Collection.Get(key)
}
return
}
// SafeRange 用于不同协程获取元素,防止并发请求
func (m *Manager[K, T]) SafeRange(f func(T) bool) {
if m.L == nil {
m.Call(func() error {
m.Collection.Range(f)
return nil
})
} else {
m.Collection.Range(f)
}
}
+40 -62
View File
@@ -90,11 +90,6 @@ func (d *PullProxy) Start() (err error) {
d.Handler = pullTask
}
if t, ok := pullTask.(task.ITask); ok {
if ticker, ok := t.(task.IChannelTask); ok {
t.OnStart(func() {
ticker.Tick(nil)
})
}
d.AddTask(t)
} else {
d.ChangeStatus(PullProxyStatusOnline)
@@ -237,29 +232,26 @@ func (p *Publisher) processPullProxyOnDispose() {
func (s *Server) GetPullProxyList(ctx context.Context, req *emptypb.Empty) (res *pb.PullProxyListResponse, err error) {
res = &pb.PullProxyListResponse{}
s.PullProxies.Call(func() error {
for device := range s.PullProxies.Range {
res.Data = append(res.Data, &pb.PullProxyInfo{
Name: device.Name,
CreateTime: timestamppb.New(device.CreatedAt),
UpdateTime: timestamppb.New(device.UpdatedAt),
Type: device.Type,
PullURL: device.URL,
ParentID: uint32(device.ParentID),
Status: uint32(device.Status),
ID: uint32(device.ID),
PullOnStart: device.PullOnStart,
StopOnIdle: device.StopOnIdle,
Audio: device.Audio,
RecordPath: device.Record.FilePath,
RecordFragment: durationpb.New(device.Record.Fragment),
Description: device.Description,
Rtt: uint32(device.RTT.Milliseconds()),
StreamPath: device.GetStreamPath(),
})
}
return nil
})
for device := range s.PullProxies.SafeRange {
res.Data = append(res.Data, &pb.PullProxyInfo{
Name: device.Name,
CreateTime: timestamppb.New(device.CreatedAt),
UpdateTime: timestamppb.New(device.UpdatedAt),
Type: device.Type,
PullURL: device.URL,
ParentID: uint32(device.ParentID),
Status: uint32(device.Status),
ID: uint32(device.ID),
PullOnStart: device.PullOnStart,
StopOnIdle: device.StopOnIdle,
Audio: device.Audio,
RecordPath: device.Record.FilePath,
RecordFragment: durationpb.New(device.Record.Fragment),
Description: device.Description,
Rtt: uint32(device.RTT.Milliseconds()),
StreamPath: device.GetStreamPath(),
})
}
return
}
@@ -363,31 +355,23 @@ func (s *Server) UpdatePullProxy(ctx context.Context, req *pb.PullProxyInfo) (re
target.StreamPath = req.StreamPath
s.DB.Save(target)
var needStopOld *PullProxy
s.PullProxies.Call(func() error {
if device, ok := s.PullProxies.Get(uint(req.ID)); ok {
if target.URL != device.URL || device.Audio != target.Audio || device.StreamPath != target.StreamPath || device.Record.FilePath != target.Record.FilePath || device.Record.Fragment != target.Record.Fragment {
device.Stop(task.ErrStopByUser)
needStopOld = device
return nil
}
if device, ok := s.PullProxies.SafeGet(uint(req.ID)); ok {
if target.URL != device.URL || device.Audio != target.Audio || device.StreamPath != target.StreamPath || device.Record.FilePath != target.Record.FilePath || device.Record.Fragment != target.Record.Fragment {
device.Stop(task.ErrStopByUser)
needStopOld = device
} else {
device.Name = target.Name
device.PullOnStart = target.PullOnStart
device.StopOnIdle = target.StopOnIdle
device.Description = target.Description
}
return nil
})
}
if needStopOld != nil {
if pullJob, ok := s.Pulls.SafeGet(req.StreamPath); ok {
pullJob.Stop(task.ErrStopByUser)
pullJob.WaitStopped()
}
s.PullProxies.Add(target).WaitStarted()
s.Pulls.Call(func() error {
if pullJob, ok := s.Pulls.Get(req.StreamPath); ok {
pullJob.Stop(task.ErrStopByUser)
target.Handler.Pull()
} else if target.PullOnStart {
target.Handler.Pull()
}
return nil
})
}
res = &pb.SuccessResponse{}
return
@@ -404,15 +388,12 @@ func (s *Server) RemovePullProxy(ctx context.Context, req *pb.RequestWithId) (re
ID: uint(req.Id),
})
err = tx.Error
s.PullProxies.Call(func() error {
if device, ok := s.PullProxies.Get(uint(req.Id)); ok {
device.Stop(task.ErrStopByUser)
if pull, ok := device.server.Pulls.Get(device.StreamPath); ok {
pull.Stop(task.ErrStopByUser)
}
if device, ok := s.PullProxies.SafeGet(uint(req.Id)); ok {
device.Stop(task.ErrStopByUser)
if pull, ok := device.server.Pulls.SafeGet(device.StreamPath); ok {
pull.Stop(task.ErrStopByUser)
}
return nil
})
}
return
} else if req.StreamPath != "" {
var deviceList []*PullProxy
@@ -421,15 +402,12 @@ func (s *Server) RemovePullProxy(ctx context.Context, req *pb.RequestWithId) (re
for _, device := range deviceList {
tx := s.DB.Delete(&PullProxy{}, device.ID)
err = tx.Error
s.PullProxies.Call(func() error {
if device, ok := s.PullProxies.Get(uint(device.ID)); ok {
device.Stop(task.ErrStopByUser)
if pull, ok := device.server.Pulls.Get(device.StreamPath); ok {
pull.Stop(task.ErrStopByUser)
}
if device, ok := s.PullProxies.SafeGet(uint(device.ID)); ok {
device.Stop(task.ErrStopByUser)
if pull, ok := device.server.Pulls.SafeGet(device.StreamPath); ok {
pull.Stop(task.ErrStopByUser)
}
return nil
})
}
}
}
return