From bcbabc719ff942efa3a6f2d645f1bda7d63f0e81 Mon Sep 17 00:00:00 2001 From: langhuihui <178529795@qq.com> Date: Thu, 6 Jun 2024 19:48:13 +0800 Subject: [PATCH] fix: pull count --- api.go | 58 ++++++++++++++++++++++++++----- pkg/avframe.go | 4 ++- pkg/util/allocator.go | 17 ++------- pkg/util/buf-reader.go | 11 +++--- pkg/util/buffers.go | 15 +++++--- pkg/util/collection.go | 2 +- pkg/util/mem.go | 41 ++++++++++++++-------- plugin/rtmp/index.go | 2 ++ plugin/rtmp/pkg/net-connection.go | 2 +- plugin/rtmp/pkg/video.go | 2 +- plugin/webrtc/index.go | 2 ++ publisher.go | 12 +++++++ 12 files changed, 118 insertions(+), 50 deletions(-) diff --git a/api.go b/api.go index 77dc125..0c748eb 100644 --- a/api.go +++ b/api.go @@ -181,13 +181,56 @@ func (s *Server) GetSubscribers(ctx context.Context, req *pb.SubscribersRequest) return } func (s *Server) AudioTrackSnap(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.TrackSnapShotResponse, err error) { - // s.Call(func() { - // if pub, ok := s.Streams.Get(req.StreamPath); ok { - // res = pub.AudioSnapShot() - // } else { - // err = pkg.ErrNotFound - // } - // }) + s.Call(func() { + if pub, ok := s.Streams.Get(req.StreamPath); ok { + res = &pb.TrackSnapShotResponse{} + _, err = pub.AudioTrack.Ready.Await() + if err != nil { + return + } + for _, memlist := range pub.AudioTrack.Allocator.GetChildren() { + var list []*pb.MemoryBlock + for _, block := range memlist.GetBlocks() { + list = append(list, &pb.MemoryBlock{ + S: uint32(block.Start), + E: uint32(block.End), + }) + } + res.Memory = append(res.Memory, &pb.MemoryBlockGroup{List: list, Size: uint32(memlist.Size)}) + } + res.Reader = make(map[uint32]uint32) + for sub := range pub.Subscribers { + if sub.AudioReader == nil { + continue + } + res.Reader[uint32(sub.ID)] = sub.AudioReader.Value.Sequence + } + pub.AudioTrack.Ring.Do(func(v *pkg.AVFrame) { + if v.TryRLock() { + if len(v.Wraps) > 0 { + var snap pb.TrackSnapShot + snap.Sequence = v.Sequence + snap.Timestamp = uint32(v.Timestamp / time.Millisecond) + snap.WriteTime = timestamppb.New(v.WriteTime) + snap.Wrap = make([]*pb.Wrap, len(v.Wraps)) + snap.KeyFrame = v.IDR + res.RingDataSize += uint32(v.Wraps[0].GetSize()) + for i, wrap := range v.Wraps { + snap.Wrap[i] = &pb.Wrap{ + Timestamp: uint32(wrap.GetTimestamp() / time.Millisecond), + Size: uint32(wrap.GetSize()), + Data: wrap.String(), + } + } + res.Ring = append(res.Ring, &snap) + } + v.RUnlock() + } + }) + } else { + err = pkg.ErrNotFound + } + }) return } func (s *Server) api_VideoTrack_SSE(rw http.ResponseWriter, r *http.Request) { @@ -228,7 +271,6 @@ func (s *Server) VideoTrackSnap(ctx context.Context, req *pb.StreamSnapRequest) if err != nil { return } - // vcc := pub.VideoTrack.AVTrack.ICodecCtx.(pkg.IVideoCodecCtx) for _, memlist := range pub.VideoTrack.Allocator.GetChildren() { var list []*pb.MemoryBlock for _, block := range memlist.GetBlocks() { diff --git a/pkg/avframe.go b/pkg/avframe.go index 0892685..6516f3e 100644 --- a/pkg/avframe.go +++ b/pkg/avframe.go @@ -68,7 +68,9 @@ var _ IAVFrame = (*AnnexB)(nil) func (frame *AVFrame) Reset() { frame.Timestamp = 0 if len(frame.Wraps) > 0 { - frame.Wraps[0].Recycle() + for _, wrap := range frame.Wraps { + wrap.Recycle() + } frame.Wraps = frame.Wraps[:0] } } diff --git a/pkg/util/allocator.go b/pkg/util/allocator.go index f3e25e0..8107801 100644 --- a/pkg/util/allocator.go +++ b/pkg/util/allocator.go @@ -8,7 +8,6 @@ type ( Block struct { Start, End int trees [2]Tree - allocator *Allocator } // History struct { // Malloc bool @@ -39,7 +38,6 @@ func NewAllocator(size int) (result *Allocator) { offsetTree: root, Size: size, } - root.allocator = result return } @@ -57,10 +55,6 @@ func compareByOffset(a, b *Block) bool { var compares = [...]func(a, b *Block) bool{compareBySize, compareByOffset} var emptyTrees = [2]Tree{} -func (b *Block) recycle() { - b.allocator.putBlock(b) -} - func (b *Block) insert(block *Block, treeIndex int) *Block { if b == nil { return block @@ -198,22 +192,17 @@ func (a *Allocator) findAvailableBlock(block *Block, size int) *Block { func (a *Allocator) getBlock(start, end int) *Block { if l := len(a.pool); l == 0 { - return &Block{Start: start, End: end, allocator: a} + return &Block{Start: start, End: end} } else { block := a.pool[l-1] a.pool = a.pool[:l-1] block.Start, block.End = start, end - block.allocator = a return block } } func (a *Allocator) putBlock(b *Block) { - if b.allocator == nil { - return - } b.trees = emptyTrees - b.allocator = nil a.pool = append(a.pool, b) } @@ -255,7 +244,7 @@ func (a *Allocator) mergeAdjacentBlocks(block *Block) { a.deleteSizeTree(block) block.Start = leftAdjacent.Start a.sizeTree = a.sizeTree.insert(block, 0) - leftAdjacent.recycle() + a.putBlock(leftAdjacent) } if rightAdjacent := a.offsetTree.findRightAdjacentBlock(block.End); rightAdjacent != nil { a.deleteSizeTree(rightAdjacent) @@ -263,7 +252,7 @@ func (a *Allocator) mergeAdjacentBlocks(block *Block) { a.deleteSizeTree(block) block.End = rightAdjacent.End a.sizeTree = a.sizeTree.insert(block, 0) - rightAdjacent.recycle() + a.putBlock(rightAdjacent) } } diff --git a/pkg/util/buf-reader.go b/pkg/util/buf-reader.go index 572934f..6519118 100644 --- a/pkg/util/buf-reader.go +++ b/pkg/util/buf-reader.go @@ -40,10 +40,11 @@ func (r *BufReader) eat() error { if n, err := r.reader.Read(buf); err != nil { r.allocator.Free(buf) return err - } else if n < r.BufLen { - r.buf.ReadFromBytes(buf[:n]) - r.allocator.Free(buf[n:]) - } else if n == r.BufLen { + } else { + if n < r.BufLen { + r.allocator.Free(buf[n:]) + buf = buf[:n] + } r.buf.ReadFromBytes(buf) } return nil @@ -101,7 +102,7 @@ func (r *BufReader) ReadBytes(n int) (mem RecyclableMemory, err error) { return } n -= r.buf.Length - mem.AddRecycleBytes(r.buf.Memory.Buffers...) + mem.AddRecycleBytes(r.buf.Buffers...) r.buf = MemoryReader{} } } diff --git a/pkg/util/buffers.go b/pkg/util/buffers.go index 6d80c64..d7f2d65 100644 --- a/pkg/util/buffers.go +++ b/pkg/util/buffers.go @@ -266,13 +266,11 @@ func (reader *MemoryReader) ClipFront() (r net.Buffers) { } buffers := &reader.Memory if reader.Length == 0 { - r = buffers.Buffers + r = slices.Clone(buffers.Buffers) buffers.Buffers = buffers.Buffers[:0] } else { - for i := range reader.offset0 { - r = append(r, buffers.Buffers[i]) - } - if reader.getCurrentBufLen() > 0 { + r = slices.Clone(buffers.Buffers[:reader.offset0]) + if reader.offset1 > 0 { r = append(r, buffers.Buffers[reader.offset0][:reader.offset1]) buffers.Buffers[reader.offset0] = reader.GetCurrent() } @@ -280,6 +278,13 @@ func (reader *MemoryReader) ClipFront() (r net.Buffers) { buffers.Buffers = slices.Delete(buffers.Buffers, 0, reader.offset0) } } + // bs := 0 + // for _, b := range r { + // bs += len(b) + // } + // if bs != offset { + // panic("ClipFront error") + // } reader.Size -= offset reader.offset0 = 0 reader.offset1 = 0 diff --git a/pkg/util/collection.go b/pkg/util/collection.go index cdcdc84..25cb396 100644 --- a/pkg/util/collection.go +++ b/pkg/util/collection.go @@ -63,10 +63,10 @@ func (c *Collection[K, T]) RemoveByKey(key K) { for i := range c.Length { if c.Items[i].GetKey() == key { c.Items = slices.Delete(c.Items, i, i+1) + c.Length-- break } } - c.Length-- } func (c *Collection[K, T]) Get(key K) (item T, ok bool) { diff --git a/pkg/util/mem.go b/pkg/util/mem.go index 30eb3ca..653d5b8 100644 --- a/pkg/util/mem.go +++ b/pkg/util/mem.go @@ -8,6 +8,9 @@ import ( const MaxBlockSize = 4 * 1024 * 1024 +var pools sync.Map +var EnableCheckSize bool = false + type MemoryAllocator struct { allocator *Allocator start int64 @@ -15,6 +18,15 @@ type MemoryAllocator struct { Size int } +func GetMemoryAllocator(size int) (ret *MemoryAllocator) { + if value, ok := pools.Load(size); ok { + ret = value.(*sync.Pool).Get().(*MemoryAllocator) + } else { + ret = NewMemoryAllocator(size) + } + return +} + func NewMemoryAllocator(size int) (ret *MemoryAllocator) { ret = &MemoryAllocator{ Size: size, @@ -25,8 +37,15 @@ func NewMemoryAllocator(size int) (ret *MemoryAllocator) { return } -func (ma *MemoryAllocator) Reset() { +func (ma *MemoryAllocator) Recycle() { ma.allocator = NewAllocator(ma.Size) + size := ma.Size + pool, _ := pools.LoadOrStore(size, &sync.Pool{ + New: func() any { + return NewMemoryAllocator(size) + }, + }) + pool.(*sync.Pool).Put(ma) } func (ma *MemoryAllocator) Malloc(size int) (memory []byte) { @@ -53,10 +72,6 @@ func (ma *MemoryAllocator) GetBlocks() (blocks []*Block) { return ma.allocator.GetBlocks() } -var EnableCheckSize bool = false - -var pools sync.Map - type ScalableMemoryAllocator struct { children []*MemoryAllocator totalMalloc int64 @@ -105,15 +120,8 @@ func (sma *ScalableMemoryAllocator) GetChildren() []*MemoryAllocator { func (sma *ScalableMemoryAllocator) Recycle() { for _, child := range sma.children { - child.Reset() + child.Recycle() } - size := sma.children[0].Size - pool, _ := pools.LoadOrStore(size, &sync.Pool{ - New: func() interface{} { - return &ScalableMemoryAllocator{children: []*MemoryAllocator{NewMemoryAllocator(size)}, size: size} - }, - }) - pool.(*sync.Pool).Put(sma) } func (sma *ScalableMemoryAllocator) Malloc(size int) (memory []byte) { @@ -150,9 +158,14 @@ func (sma *ScalableMemoryAllocator) Free(mem []byte) bool { } ptr := int64(uintptr(unsafe.Pointer(&mem[0]))) size := len(mem) - for _, child := range sma.children { + for i, child := range sma.children { if start := int(ptr - child.start); start >= 0 && start < child.Size && child.free(start, size) { sma.addFreeCount(size) + if len(sma.children) > 1 && child.allocator.sizeTree.End-child.allocator.sizeTree.Start == child.Size { + child.Recycle() + sma.children = slices.Delete(sma.children, i, i+1) + sma.size -= child.Size + } return true } } diff --git a/plugin/rtmp/index.go b/plugin/rtmp/index.go index 616bb43..09def95 100644 --- a/plugin/rtmp/index.go +++ b/plugin/rtmp/index.go @@ -199,12 +199,14 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) { if r, ok := receivers[msg.MessageStreamID]; ok { r.WriteAudio(msg.AVData.WrapAudio()) } else { + msg.AVData.Recycle() logger.Warn("ReceiveAudio", "MessageStreamID", msg.MessageStreamID) } case RTMP_MSG_VIDEO: if r, ok := receivers[msg.MessageStreamID]; ok { r.WriteVideo(msg.AVData.WrapVideo()) } else { + msg.AVData.Recycle() logger.Warn("ReceiveVideo", "MessageStreamID", msg.MessageStreamID) } } diff --git a/plugin/rtmp/pkg/net-connection.go b/plugin/rtmp/pkg/net-connection.go index 2d81732..789a626 100644 --- a/plugin/rtmp/pkg/net-connection.go +++ b/plugin/rtmp/pkg/net-connection.go @@ -151,7 +151,7 @@ func (conn *NetConnection) readChunk() (msg *Chunk, err error) { if chunk.bufLen == 0 { chunk.AVData.RecyclableMemory = mem } else { - chunk.AVData.ReadFromBytes(mem.Buffers...) + chunk.AVData.AddRecycleBytes(mem.Buffers...) } chunk.bufLen += mem.Size diff --git a/plugin/rtmp/pkg/video.go b/plugin/rtmp/pkg/video.go index c9528bd..9e32b2f 100644 --- a/plugin/rtmp/pkg/video.go +++ b/plugin/rtmp/pkg/video.go @@ -125,7 +125,7 @@ func (avcc *RTMPVideo) DecodeConfig(t *AVTrack, from ICodecCtx) (err error) { b.Write(h264ctx.PPS[0]) t.ICodecCtx = &ctx var seqFrame RTMPData - seqFrame.Memory.ReadFromBytes(b) + seqFrame.ReadFromBytes(b) t.SequenceFrame = seqFrame.WrapVideo() if t.Enabled(context.TODO(), TraceLevel) { codec := t.FourCC().String() diff --git a/plugin/webrtc/index.go b/plugin/webrtc/index.go index be5eb0f..3599de6 100644 --- a/plugin/webrtc/index.go +++ b/plugin/webrtc/index.go @@ -183,6 +183,7 @@ func (conf *WebRTCPlugin) Push_(w http.ResponseWriter, r *http.Request) { return } mem := util.NewScalableMemoryAllocator(1460 * 100) + defer mem.Recycle() frame := &mrtp.RTPAudio{} frame.RTPCodecParameters = &codecP frame.ScalableMemoryAllocator = mem @@ -222,6 +223,7 @@ func (conf *WebRTCPlugin) Push_(w http.ResponseWriter, r *http.Request) { } var lastPLISent time.Time mem := util.NewScalableMemoryAllocator(1460 * 100) + defer mem.Recycle() frame := &mrtp.RTPVideo{} frame.RTPCodecParameters = &codecP frame.ScalableMemoryAllocator = mem diff --git a/publisher.go b/publisher.go index 3f28250..310f066 100644 --- a/publisher.go +++ b/publisher.go @@ -310,6 +310,18 @@ func (p *Publisher) WriteAudio(data IAVFrame) (err error) { } func (p *Publisher) WriteData(data IDataFrame) (err error) { + if p.DataTrack == nil { + p.DataTrack = &DataTrack{} + p.DataTrack.Logger = p.Logger.With("track", "data") + p.Lock() + if len(p.Subscribers) > 0 { + p.State = PublisherStateSubscribed + } else { + p.State = PublisherStateTrackAdded + } + p.Unlock() + } + // TODO: Implement this function return }