From 523baa0f92ea67bc27ed23ec31388d4be1b053de Mon Sep 17 00:00:00 2001 From: langhuihui <178529795@qq.com> Date: Fri, 7 Jun 2024 18:16:28 +0800 Subject: [PATCH] fix: memory leadk --- api.go | 12 ++-------- pkg/av-reader.go | 2 +- pkg/error.go | 1 + pkg/util/buffers.go | 7 ++++++ pkg/util/mem.go | 40 ++++++++++++++++++------------- pkg/util/promise.go | 22 +++++++++++------ plugin/rtmp/index.go | 3 +-- plugin/rtmp/pkg/audio.go | 2 +- plugin/rtmp/pkg/handshake.go | 14 +++++------ plugin/rtmp/pkg/net-connection.go | 32 +++++++++++++++---------- plugin/rtmp/pkg/video.go | 32 ++++++++----------------- publisher.go | 33 ++++++++++++++----------- 12 files changed, 108 insertions(+), 92 deletions(-) diff --git a/api.go b/api.go index 0c748eb..9599546 100644 --- a/api.go +++ b/api.go @@ -182,12 +182,8 @@ func (s *Server) GetSubscribers(ctx context.Context, req *pb.SubscribersRequest) } func (s *Server) AudioTrackSnap(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.TrackSnapShotResponse, err error) { s.Call(func() { - if pub, ok := s.Streams.Get(req.StreamPath); ok { + if pub, ok := s.Streams.Get(req.StreamPath); ok && !pub.AudioTrack.IsEmpty() { res = &pb.TrackSnapShotResponse{} - _, err = pub.AudioTrack.Ready.Await() - if err != nil { - return - } for _, memlist := range pub.AudioTrack.Allocator.GetChildren() { var list []*pb.MemoryBlock for _, block := range memlist.GetBlocks() { @@ -265,12 +261,8 @@ func (s *Server) api_VideoTrack_SSE(rw http.ResponseWriter, r *http.Request) { } func (s *Server) VideoTrackSnap(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.TrackSnapShotResponse, err error) { s.Call(func() { - if pub, ok := s.Streams.Get(req.StreamPath); ok { + if pub, ok := s.Streams.Get(req.StreamPath); ok && !pub.VideoTrack.IsEmpty() { res = &pb.TrackSnapShotResponse{} - _, err = pub.VideoTrack.Ready.Await() - if err != nil { - return - } for _, memlist := range pub.VideoTrack.Allocator.GetChildren() { var list []*pb.MemoryBlock for _, block := range memlist.GetBlocks() { diff --git a/pkg/av-reader.go b/pkg/av-reader.go index 5a5ddef..2d1b2a6 100644 --- a/pkg/av-reader.go +++ b/pkg/av-reader.go @@ -40,7 +40,7 @@ func (r *AVRingReader) DecConfChanged() bool { func NewAVRingReader(t *AVTrack) *AVRingReader { t.Debug("create reader") - <-t.Ready.Done() + t.Ready.Await() t.Info("reader +1", "count", t.ReaderCount.Add(1)) return &AVRingReader{ Track: t, diff --git a/pkg/error.go b/pkg/error.go index 30b2650..4f02cf6 100644 --- a/pkg/error.go +++ b/pkg/error.go @@ -15,4 +15,5 @@ var ( ErrRestart = errors.New("restart") ErrInterrupt = errors.New("interrupt") ErrUnsupportCodec = errors.New("unsupport codec") + ErrMuted = errors.New("muted") ) diff --git a/pkg/util/buffers.go b/pkg/util/buffers.go index d7f2d65..32d32bf 100644 --- a/pkg/util/buffers.go +++ b/pkg/util/buffers.go @@ -43,6 +43,13 @@ func (buffers *Memory) UpdateBuffer(index int, buf []byte) { buffers.Buffers[index] = buf } +func (buffers *Memory) CopyFrom(b Memory) { + buf := make([]byte, b.Size) + bufs := slices.Clone(b.Buffers) + bufs.Read(buf) + buffers.ReadFromBytes(buf) +} + func (buffers *Memory) ReadFromBytes(b ...[]byte) { buffers.Buffers = append(buffers.Buffers, b...) for _, level0 := range b { diff --git a/pkg/util/mem.go b/pkg/util/mem.go index 653d5b8..85da980 100644 --- a/pkg/util/mem.go +++ b/pkg/util/mem.go @@ -80,12 +80,7 @@ type ScalableMemoryAllocator struct { } func NewScalableMemoryAllocator(size int) (ret *ScalableMemoryAllocator) { - if value, ok := pools.Load(size); ok { - ret = value.(*sync.Pool).Get().(*ScalableMemoryAllocator) - } else { - ret = &ScalableMemoryAllocator{children: []*MemoryAllocator{NewMemoryAllocator(size)}, size: size} - } - return + return &ScalableMemoryAllocator{children: []*MemoryAllocator{GetMemoryAllocator(size)}, size: size} } func (sma *ScalableMemoryAllocator) checkSize() { @@ -138,7 +133,7 @@ func (sma *ScalableMemoryAllocator) Malloc(size int) (memory []byte) { return } } - child = NewMemoryAllocator(max(min(MaxBlockSize, child.Size*2), size)) + child = GetMemoryAllocator(max(min(MaxBlockSize, child.Size*2), size)) sma.size += child.Size memory = child.Malloc(size) sma.children = append(sma.children, child) @@ -175,20 +170,24 @@ func (sma *ScalableMemoryAllocator) Free(mem []byte) bool { type RecyclableMemory struct { *ScalableMemoryAllocator Memory - mallocIndexes []int + RecycleIndexes []int } func (r *RecyclableMemory) NextN(size int) (memory []byte) { memory = r.ScalableMemoryAllocator.Malloc(size) - r.mallocIndexes = append(r.mallocIndexes, len(r.Buffers)) + if r.RecycleIndexes != nil { + r.RecycleIndexes = append(r.RecycleIndexes, len(r.Buffers)) + } r.ReadFromBytes(memory) return } func (r *RecyclableMemory) AddRecycleBytes(b ...[]byte) { - start := len(r.Buffers) - for i := range b { - r.mallocIndexes = append(r.mallocIndexes, start+i) + if r.RecycleIndexes != nil { + start := len(r.Buffers) + for i := range b { + r.RecycleIndexes = append(r.RecycleIndexes, start+i) + } } r.ReadFromBytes(b...) } @@ -198,15 +197,24 @@ func (r *RecyclableMemory) RemoveRecycleBytes(index int) (buf []byte) { index = len(r.Buffers) + index } buf = r.Buffers[index] - i := slices.Index(r.mallocIndexes, index) - r.mallocIndexes = slices.Delete(r.mallocIndexes, i, i+1) + if r.RecycleIndexes != nil { + i := slices.Index(r.RecycleIndexes, index) + r.RecycleIndexes = slices.Delete(r.RecycleIndexes, i, i+1) + } r.Buffers = slices.Delete(r.Buffers, index, index+1) r.Size -= len(buf) return } func (r *RecyclableMemory) Recycle() { - for _, index := range r.mallocIndexes { - r.Free(r.Buffers[index]) + if r.RecycleIndexes != nil { + for _, index := range r.RecycleIndexes { + r.Free(r.Buffers[index]) + } + r.RecycleIndexes = r.RecycleIndexes[:0] + } else { + for _, buf := range r.Buffers { + r.Free(buf) + } } } diff --git a/pkg/util/promise.go b/pkg/util/promise.go index 8e770c9..4e2308f 100644 --- a/pkg/util/promise.go +++ b/pkg/util/promise.go @@ -3,21 +3,27 @@ package util import ( "context" "errors" + "time" ) type Promise[T any] struct { context.Context context.CancelCauseFunc Value T - // timer *time.Timer + timer *time.Timer } func NewPromise[T any](v T) *Promise[T] { p := &Promise[T]{Value: v} p.Context, p.CancelCauseFunc = context.WithCancelCause(context.Background()) - // p.timer = time.AfterFunc(time.Second, func() { - // p.CancelCauseFunc(ErrTimeout) - // }) + return p +} +func NewPromiseWithTimeout[T any](v T, timeout time.Duration) *Promise[T] { + p := &Promise[T]{Value: v} + p.Context, p.CancelCauseFunc = context.WithCancelCause(context.Background()) + p.timer = time.AfterFunc(timeout, func() { + p.CancelCauseFunc(ErrTimeout) + }) return p } @@ -32,17 +38,19 @@ func (p *Promise[T]) Resolve(v T) { func (p *Promise[T]) Await() (T, error) { <-p.Done() err := context.Cause(p.Context) - if err == ErrResolve { + if errors.Is(err, ErrResolve) { err = nil } return p.Value, err } func (p *Promise[T]) Fulfill(err error) { - // p.timer.Stop() + if p.timer != nil { + p.timer.Stop() + } p.CancelCauseFunc(Conditoinal(err == nil, ErrResolve, err)) } -func (p *Promise[T]) Pendding() bool { +func (p *Promise[T]) IsPending() bool { return context.Cause(p.Context) == nil } diff --git a/plugin/rtmp/index.go b/plugin/rtmp/index.go index 09def95..fb7df33 100644 --- a/plugin/rtmp/index.go +++ b/plugin/rtmp/index.go @@ -44,8 +44,7 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) { var err error logger.Info("conn") nc := NewNetConnection(conn, logger) - defer nc.BufReader.Recycle() - defer conn.Close() + defer nc.Destroy() ctx, cancel := context.WithCancelCause(p) defer func() { logger.Info("conn close") diff --git a/plugin/rtmp/pkg/audio.go b/plugin/rtmp/pkg/audio.go index 1b4f968..66b38c7 100644 --- a/plugin/rtmp/pkg/audio.go +++ b/plugin/rtmp/pkg/audio.go @@ -50,7 +50,7 @@ func (avcc *RTMPAudio) Parse(t *AVTrack) (isIDR, isSeq bool, raw any, err error) return } var cloneFrame RTMPAudio - cloneFrame.ReadFromBytes(avcc.ToBytes()) + cloneFrame.CopyFrom(avcc.Memory) ctx.AudioObjectType = b0 >> 3 ctx.SamplingFrequencyIndex = (b0 & 0x07 << 1) | (b1 >> 7) ctx.ChannelConfiguration = (b1 >> 3) & 0x0F diff --git a/plugin/rtmp/pkg/handshake.go b/plugin/rtmp/pkg/handshake.go index 700e29c..996850e 100644 --- a/plugin/rtmp/pkg/handshake.go +++ b/plugin/rtmp/pkg/handshake.go @@ -67,8 +67,8 @@ var ( // C2 S2 : 参考C1 S1 func (nc *NetConnection) Handshake(checkC2 bool) (err error) { - C0C1 := nc.writePool.NextN(C1S1_SIZE + 1) - defer nc.writePool.Recycle() + C0C1 := nc.mediaDataPool.NextN(C1S1_SIZE + 1) + defer nc.mediaDataPool.Recycle() if _, err = io.ReadFull(nc.Conn, C0C1); err != nil { return err } @@ -90,8 +90,8 @@ func (nc *NetConnection) Handshake(checkC2 bool) (err error) { } func (client *NetConnection) ClientHandshake() (err error) { - C0C1 := client.writePool.NextN(C1S1_SIZE + 1) - defer client.writePool.Recycle() + C0C1 := client.mediaDataPool.NextN(C1S1_SIZE + 1) + defer client.mediaDataPool.Recycle() C0C1[0] = RTMP_HANDSHAKE_VERSION if _, err = client.Write(C0C1); err == nil { // read S0 S1 @@ -108,15 +108,15 @@ func (client *NetConnection) ClientHandshake() (err error) { } func (nc *NetConnection) simple_handshake(C1 []byte, checkC2 bool) error { - S0S1 := nc.writePool.NextN(C1S1_SIZE + 1) - defer nc.writePool.Recycle() + S0S1 := nc.mediaDataPool.NextN(C1S1_SIZE + 1) + defer nc.mediaDataPool.Recycle() S0S1[0] = RTMP_HANDSHAKE_VERSION util.PutBE(S0S1[1:5], time.Now().Unix()&0xFFFFFFFF) copy(S0S1[5:], "Monibuca") nc.Write(S0S1) nc.Write(C1) // S2 C2, err := nc.ReadBytes(C1S1_SIZE) - defer C2.Recycle() + C2.Recycle() if err != nil { return err } diff --git a/plugin/rtmp/pkg/net-connection.go b/plugin/rtmp/pkg/net-connection.go index 789a626..df2a588 100644 --- a/plugin/rtmp/pkg/net-connection.go +++ b/plugin/rtmp/pkg/net-connection.go @@ -58,7 +58,7 @@ type NetConnection struct { AppName string tmpBuf util.Buffer //用来接收/发送小数据,复用内存 chunkHeaderBuf util.Buffer - writePool util.RecyclableMemory + mediaDataPool util.RecyclableMemory writing atomic.Bool // false 可写,true 不可写 } @@ -74,10 +74,14 @@ func NewNetConnection(conn net.Conn, logger *slog.Logger) (ret *NetConnection) { tmpBuf: make(util.Buffer, 4), chunkHeaderBuf: make(util.Buffer, 0, 20), } - ret.writePool.ScalableMemoryAllocator = util.NewScalableMemoryAllocator(1024) + ret.mediaDataPool.ScalableMemoryAllocator = util.NewScalableMemoryAllocator(1024) return } - +func (conn *NetConnection) Destroy() { + conn.Conn.Close() + conn.BufReader.Recycle() + conn.mediaDataPool.Recycle() +} func (conn *NetConnection) SendStreamID(eventType uint16, streamID uint32) (err error) { return conn.SendMessage(RTMP_MSG_USER_CONTROL, &StreamIDMessage{UserControlMessage{EventType: eventType}, streamID}) } @@ -143,26 +147,30 @@ func (conn *NetConnection) readChunk() (msg *Chunk, err error) { } else { mem, err = conn.ReadBytes(conn.readChunkSize) } + mem.Recycle() if err != nil { - mem.Recycle() return nil, err } conn.readSeqNum += uint32(mem.Size) if chunk.bufLen == 0 { - chunk.AVData.RecyclableMemory = mem - } else { - chunk.AVData.AddRecycleBytes(mem.Buffers...) + chunk.AVData.RecyclableMemory = util.RecyclableMemory{ + ScalableMemoryAllocator: conn.mediaDataPool.ScalableMemoryAllocator, + } + chunk.AVData.NextN(msgLen) } - - chunk.bufLen += mem.Size - if chunk.AVData.Size == msgLen { + buffer := chunk.AVData.Buffers[0] + for _, b := range mem.Buffers { + copy(buffer[chunk.bufLen:], b) + chunk.bufLen += len(b) + } + if chunk.bufLen == msgLen { msg = chunk switch chunk.MessageTypeID { case RTMP_MSG_AUDIO, RTMP_MSG_VIDEO: msg.AVData.Timestamp = chunk.ChunkHeader.ExtendTimestamp default: - msg.AVData.Recycle() - err = GetRtmpMessage(msg, msg.AVData.ToBytes()) + chunk.AVData.Recycle() + err = GetRtmpMessage(msg, buffer) } msg.bufLen = 0 } diff --git a/plugin/rtmp/pkg/video.go b/plugin/rtmp/pkg/video.go index 9e32b2f..4c72d7f 100644 --- a/plugin/rtmp/pkg/video.go +++ b/plugin/rtmp/pkg/video.go @@ -31,7 +31,7 @@ func (avcc *RTMPVideo) Parse(t *AVTrack) (isIDR, isSeq bool, raw any, err error) isSeq = true isIDR = false var cloneFrame RTMPVideo - cloneFrame.ReadFromBytes(avcc.ToBytes()) + cloneFrame.CopyFrom(avcc.Memory) switch fourCC { case codec.FourCC_H264: var ctx H264Ctx @@ -128,10 +128,10 @@ func (avcc *RTMPVideo) DecodeConfig(t *AVTrack, from ICodecCtx) (err error) { seqFrame.ReadFromBytes(b) t.SequenceFrame = seqFrame.WrapVideo() if t.Enabled(context.TODO(), TraceLevel) { - codec := t.FourCC().String() + c := t.FourCC().String() size := seqFrame.GetSize() data := seqFrame.String() - t.Trace("decConfig", "codec", codec, "size", size, "data", data) + t.Trace("decConfig", "codec", c, "size", size, "data", data) } } @@ -227,14 +227,14 @@ func (avcc *RTMPVideo) ToRaw(codecCtx ICodecCtx) (any, error) { } return nil, nil } - -func (h264 *H264Ctx) CreateFrame(from *AVFrame) (frame IAVFrame, err error) { +func createH26xFrame(from *AVFrame, codecID VideoCodecID) (frame IAVFrame, err error) { var rtmpVideo RTMPVideo rtmpVideo.Timestamp = uint32(from.Timestamp / time.Millisecond) rtmpVideo.ScalableMemoryAllocator = from.Wraps[0].GetScalableMemoryAllocator() nalus := from.Raw.(Nalus) + rtmpVideo.RecycleIndexes = make([]int, len(nalus.Nalus)) // Recycle partial data head := rtmpVideo.NextN(5) - head[0] = util.Conditoinal[byte](from.IDR, 0x10, 0x20) | byte(ParseVideoCodec(h264.FourCC())) + head[0] = util.Conditoinal[byte](from.IDR, 0x10, 0x20) | byte(codecID) head[1] = 1 util.PutBE(head[2:5], (nalus.PTS-nalus.DTS)/90) // cts for _, nalu := range nalus.Nalus { @@ -246,24 +246,12 @@ func (h264 *H264Ctx) CreateFrame(from *AVFrame) (frame IAVFrame, err error) { frame = &rtmpVideo return } +func (h264 *H264Ctx) CreateFrame(from *AVFrame) (frame IAVFrame, err error) { + return createH26xFrame(from, ParseVideoCodec(h264.FourCC())) +} func (h265 *H265Ctx) CreateFrame(from *AVFrame) (frame IAVFrame, err error) { - var rtmpVideo RTMPVideo - rtmpVideo.Timestamp = uint32(from.Timestamp / time.Millisecond) - rtmpVideo.ScalableMemoryAllocator = from.Wraps[0].GetScalableMemoryAllocator() - nalus := from.Raw.(Nalus) - head := rtmpVideo.NextN(5) - head[0] = util.Conditoinal[byte](from.IDR, 0x10, 0x20) | byte(ParseVideoCodec(h265.FourCC())) - head[1] = 1 - util.PutBE(head[2:5], (nalus.PTS-nalus.DTS)/90) // cts - for _, nalu := range nalus.Nalus { - naluLenM := rtmpVideo.NextN(4) - naluLen := uint32(util.LenOfBuffers(nalu)) - binary.BigEndian.PutUint32(naluLenM, naluLen) - rtmpVideo.ReadFromBytes(nalu...) - } - frame = &rtmpVideo - return + return createH26xFrame(from, ParseVideoCodec(h265.FourCC())) } func (av1 *AV1Ctx) CreateFrame(*AVFrame) (frame IAVFrame, err error) { diff --git a/publisher.go b/publisher.go index 310f066..e6591f3 100644 --- a/publisher.go +++ b/publisher.go @@ -169,13 +169,17 @@ func (p *Publisher) writeAV(t *AVTrack, data IAVFrame) { } func (p *Publisher) WriteVideo(data IAVFrame) (err error) { + defer func() { + if err != nil { + data.Recycle() + } + }() if !p.PubVideo || p.IsStopped() { - data.Recycle() - return + return ErrMuted } t := p.VideoTrack.AVTrack if t == nil { - t = NewAVTrack(data, p.Logger.With("track", "video"), 256) + t = NewAVTrack(data, p.Logger.With("track", "video"), 50) p.Lock() p.VideoTrack.AVTrack = t p.VideoTrack.Add(t) @@ -187,19 +191,17 @@ func (p *Publisher) WriteVideo(data IAVFrame) (err error) { p.Unlock() } oldCodecCtx := t.ICodecCtx - isIDR, _, raw, err := data.Parse(t) + t.Value.IDR, _, t.Value.Raw, err = data.Parse(t) codecCtxChanged := oldCodecCtx != t.ICodecCtx if err != nil { p.Error("parse", "err", err) return err } if t.ICodecCtx == nil { - return + return ErrUnsupportCodec } - t.Value.Raw = raw - t.Value.IDR = isIDR idr, hidr := t.IDRing.Load(), t.HistoryRing.Load() - if isIDR { + if t.Value.IDR { if idr != nil { p.GOP = int(t.Value.Sequence - idr.Value.Sequence) if hidr == nil { @@ -230,7 +232,7 @@ func (p *Publisher) WriteVideo(data IAVFrame) (err error) { } } p.writeAV(t, data) - if p.VideoTrack.Length > 1 && !p.VideoTrack.AVTrack.Ready.Pendding() { + if p.VideoTrack.Length > 1 && !p.VideoTrack.AVTrack.Ready.IsPending() { if t.Value.Raw == nil { t.Value.Raw, err = t.Value.Wraps[0].ToRaw(t.ICodecCtx) if err != nil { @@ -278,9 +280,13 @@ func (p *Publisher) WriteVideo(data IAVFrame) (err error) { } func (p *Publisher) WriteAudio(data IAVFrame) (err error) { + defer func() { + if err != nil { + data.Recycle() + } + }() if !p.PubAudio || p.IsStopped() { - data.Recycle() - return + return ErrMuted } t := p.AudioTrack.AVTrack if t == nil { @@ -297,11 +303,10 @@ func (p *Publisher) WriteAudio(data IAVFrame) (err error) { } _, _, _, err = data.Parse(t) if t.ICodecCtx == nil { - return + return ErrUnsupportCodec } - if t.Ready.Pendding() { + if t.Ready.IsPending() { t.Ready.Fulfill(err) - return } p.writeAV(t, data) t.Step()