diff --git a/pkg/decoder.go b/pkg/decoder.go index 106b110..2f135c7 100644 --- a/pkg/decoder.go +++ b/pkg/decoder.go @@ -79,12 +79,6 @@ func (decoder *GeneralDecoder) Stop() { } func (decoder *GeneralDecoder) loop() { - var ( - packet *astiav.Packet - frame *astiav.Frame - err error - ) - defer decoder.close() loop1: @@ -92,7 +86,12 @@ loop1: select { case <-decoder.ctx.Done(): return - case packet = <-decoder.demuxer.WaitForPacket(): + default: + packet, err := decoder.getPacket() + if err != nil { + // fmt.Println("unable to get packet from demuxer; err:", err.Error()) + continue + } if err := decoder.decoderContext.SendPacket(packet); err != nil { decoder.demuxer.PutBack(packet) if !errors.Is(err, astiav.ErrEagain) { @@ -101,7 +100,7 @@ loop1: } loop2: for { - frame = decoder.buffer.Generate() + frame := decoder.buffer.Generate() if err := decoder.decoderContext.ReceiveFrame(frame); err != nil { decoder.buffer.PutBack(frame) break loop2 @@ -109,7 +108,7 @@ loop1: frame.SetPictureType(astiav.PictureTypeNone) - if err = decoder.pushFrame(frame); err != nil { + if err := decoder.pushFrame(frame); err != nil { decoder.buffer.PutBack(frame) continue loop2 } @@ -120,14 +119,21 @@ loop1: } func (decoder *GeneralDecoder) pushFrame(frame *astiav.Frame) error { - ctx, cancel := context.WithTimeout(decoder.ctx, time.Second) + ctx, cancel := context.WithTimeout(decoder.ctx, 50*time.Millisecond) defer cancel() return decoder.buffer.Push(ctx, frame) } -func (decoder *GeneralDecoder) WaitForFrame() chan *astiav.Frame { - return decoder.buffer.GetChannel() +func (decoder *GeneralDecoder) getPacket() (*astiav.Packet, error) { + ctx, cancel := context.WithTimeout(decoder.ctx, 50*time.Millisecond) + defer cancel() + + return decoder.demuxer.GetPacket(ctx) +} + +func (decoder *GeneralDecoder) GetFrame(ctx context.Context) (*astiav.Frame, error) { + return decoder.buffer.Pop(ctx) } func (decoder *GeneralDecoder) PutBack(frame *astiav.Frame) { diff --git a/pkg/demuxer.go b/pkg/demuxer.go index 5ad2485..1efb1a0 100644 --- a/pkg/demuxer.go +++ b/pkg/demuxer.go @@ -116,20 +116,13 @@ loop1: } func (demuxer *GeneralDemuxer) pushPacket(packet *astiav.Packet) error { - ctx, cancel := context.WithTimeout(demuxer.ctx, time.Second) // TODO: NEEDS TO BE BASED ON FPS ON INPUT_FORMAT + ctx, cancel := context.WithTimeout(demuxer.ctx, 50*time.Millisecond) // TODO: NEEDS TO BE BASED ON FPS ON INPUT_FORMAT defer cancel() return demuxer.buffer.Push(ctx, packet) } -func (demuxer *GeneralDemuxer) WaitForPacket() chan *astiav.Packet { - return demuxer.buffer.GetChannel() -} - -func (demuxer *GeneralDemuxer) GetPacket() (*astiav.Packet, error) { - ctx, cancel := context.WithTimeout(demuxer.ctx, time.Second) - defer cancel() - +func (demuxer *GeneralDemuxer) GetPacket(ctx context.Context) (*astiav.Packet, error) { return demuxer.buffer.Pop(ctx) } diff --git a/pkg/encoder.go b/pkg/encoder.go index bfd8d51..c228ea1 100644 --- a/pkg/encoder.go +++ b/pkg/encoder.go @@ -95,11 +95,6 @@ func (encoder *GeneralEncoder) TimeBase() astiav.Rational { } func (encoder *GeneralEncoder) loop() { - var ( - frame *astiav.Frame - packet *astiav.Packet - err error - ) defer encoder.close() loop1: @@ -107,8 +102,13 @@ loop1: select { case <-encoder.ctx.Done(): return - case frame = <-encoder.filter.WaitForFrame(): - if err = encoder.encoderContext.SendFrame(frame); err != nil { + default: + frame, err := encoder.getFrame() + if err != nil { + // fmt.Println("unable to get packet from encoder; err:", err.Error()) + continue + } + if err := encoder.encoderContext.SendFrame(frame); err != nil { encoder.filter.PutBack(frame) if !errors.Is(err, astiav.ErrEagain) { continue loop1 @@ -116,13 +116,13 @@ loop1: } loop2: for { - packet = encoder.buffer.Generate() + packet := encoder.buffer.Generate() if err = encoder.encoderContext.ReceivePacket(packet); err != nil { encoder.buffer.PutBack(packet) break loop2 } - if err = encoder.pushPacket(packet); err != nil { + if err := encoder.pushPacket(packet); err != nil { encoder.buffer.PutBack(packet) continue loop2 } @@ -132,12 +132,19 @@ loop1: } } -func (encoder *GeneralEncoder) WaitForPacket() chan *astiav.Packet { - return encoder.buffer.GetChannel() +func (encoder *GeneralEncoder) getFrame() (*astiav.Frame, error) { + ctx, cancel := context.WithTimeout(encoder.ctx, 50*time.Millisecond) + defer cancel() + + return encoder.filter.GetFrame(ctx) +} + +func (encoder *GeneralEncoder) GetPacket(ctx context.Context) (*astiav.Packet, error) { + return encoder.buffer.Pop(ctx) } func (encoder *GeneralEncoder) pushPacket(packet *astiav.Packet) error { - ctx, cancel := context.WithTimeout(encoder.ctx, time.Second) + ctx, cancel := context.WithTimeout(encoder.ctx, 50*time.Millisecond) defer cancel() return encoder.buffer.Push(ctx, packet) diff --git a/pkg/filter.go b/pkg/filter.go index 9e9b25e..e4e86e5 100644 --- a/pkg/filter.go +++ b/pkg/filter.go @@ -3,7 +3,6 @@ package transcode import ( "context" "fmt" - "sync" "time" "github.com/asticode/go-astiav" @@ -23,7 +22,6 @@ type GeneralFilter struct { srcContext *astiav.BuffersrcFilterContext sinkContext *astiav.BuffersinkFilterContext srcContextParams *astiav.BuffersrcFilterContextParameters // NOTE: THIS BECOMES NIL AFTER INITIALISATION - mux sync.RWMutex ctx context.Context cancel context.CancelFunc } @@ -136,11 +134,6 @@ func (filter *GeneralFilter) Stop() { } func (filter *GeneralFilter) loop() { - var ( - err error = nil - srcFrame *astiav.Frame - sinkFrame *astiav.Frame - ) defer filter.close() loop1: @@ -148,44 +141,54 @@ loop1: select { case <-filter.ctx.Done(): return - case srcFrame = <-filter.decoder.WaitForFrame(): - filter.mux.Lock() - if err = filter.srcContext.AddFrame(srcFrame, astiav.NewBuffersrcFlags(astiav.BuffersrcFlagKeepRef)); err != nil { + default: + srcFrame, err := filter.getFrame() + if err != nil { + // fmt.Println("unable to get frame from decoder; err:", err.Error()) + continue + } + if err := filter.srcContext.AddFrame(srcFrame, astiav.NewBuffersrcFlags(astiav.BuffersrcFlagKeepRef)); err != nil { filter.buffer.PutBack(srcFrame) continue loop1 } loop2: for { - sinkFrame = filter.buffer.Generate() + sinkFrame := filter.buffer.Generate() if err = filter.sinkContext.GetFrame(sinkFrame, astiav.NewBuffersinkFlags()); err != nil { filter.buffer.PutBack(sinkFrame) break loop2 } - if err = filter.pushFrame(sinkFrame); err != nil { + if err := filter.pushFrame(sinkFrame); err != nil { filter.buffer.PutBack(sinkFrame) continue loop2 } } filter.decoder.PutBack(srcFrame) - filter.mux.Unlock() } } } func (filter *GeneralFilter) pushFrame(frame *astiav.Frame) error { - ctx, cancel := context.WithTimeout(filter.ctx, time.Second) + ctx, cancel := context.WithTimeout(filter.ctx, 50*time.Millisecond) defer cancel() return filter.buffer.Push(ctx, frame) } +func (filter *GeneralFilter) getFrame() (*astiav.Frame, error) { + ctx, cancel := context.WithTimeout(filter.ctx, 50*time.Millisecond) + defer cancel() + + return filter.decoder.GetFrame(ctx) +} + func (filter *GeneralFilter) PutBack(frame *astiav.Frame) { filter.buffer.PutBack(frame) } -func (filter *GeneralFilter) WaitForFrame() chan *astiav.Frame { - return filter.buffer.GetChannel() +func (filter *GeneralFilter) GetFrame(ctx context.Context) (*astiav.Frame, error) { + return filter.buffer.Pop(ctx) } func (filter *GeneralFilter) close() { diff --git a/pkg/interfaces.go b/pkg/interfaces.go index 2688b4e..99d2187 100644 --- a/pkg/interfaces.go +++ b/pkg/interfaces.go @@ -45,12 +45,12 @@ type CanDescribeMediaPacket interface { } type CanProduceMediaPacket interface { - WaitForPacket() chan *astiav.Packet + GetPacket(ctx context.Context) (*astiav.Packet, error) PutBack(*astiav.Packet) } type CanProduceMediaFrame interface { - WaitForFrame() chan *astiav.Frame + GetFrame(ctx context.Context) (*astiav.Frame, error) PutBack(*astiav.Frame) } diff --git a/pkg/transcoder.go b/pkg/transcoder.go index a6951d5..afd9edd 100644 --- a/pkg/transcoder.go +++ b/pkg/transcoder.go @@ -1,6 +1,8 @@ package transcode import ( + "context" + "github.com/asticode/go-astiav" ) @@ -45,8 +47,8 @@ func (t *Transcoder) Stop() { t.demuxer.Stop() } -func (t *Transcoder) WaitForPacket() chan *astiav.Packet { - return t.encoder.WaitForPacket() +func (t *Transcoder) GetPacket(ctx context.Context) (*astiav.Packet, error) { + return t.encoder.GetPacket(ctx) } func (t *Transcoder) PutBack(packet *astiav.Packet) { diff --git a/pkg/update_encoder_wrapper.go b/pkg/update_encoder_wrapper.go index f3287e4..268051f 100644 --- a/pkg/update_encoder_wrapper.go +++ b/pkg/update_encoder_wrapper.go @@ -2,11 +2,16 @@ package transcode import ( "context" + "errors" "fmt" "sync" "sync/atomic" + "time" "github.com/asticode/go-astiav" + + "github.com/harshabose/simple_webrtc_comm/transcode/internal" + "github.com/harshabose/tools/buffer/pkg" ) type UpdateConfig struct { @@ -18,6 +23,7 @@ type UpdateEncoder struct { encoder Encoder config UpdateConfig builder *GeneralEncoderBuilder + buffer buffer.BufferWithGenerator[astiav.Packet] mux sync.RWMutex ctx context.Context @@ -31,6 +37,7 @@ func NewUpdateEncoder(ctx context.Context, config UpdateConfig, builder *General config: config, builder: builder, resume: make(chan struct{}), + buffer: buffer.CreateChannelBuffer(ctx, 30, internal.CreatePacketPool()), ctx: ctx, } @@ -41,6 +48,8 @@ func NewUpdateEncoder(ctx context.Context, config UpdateConfig, builder *General updater.encoder = encoder + go updater.loop() + return updater, nil } @@ -58,12 +67,8 @@ func (u *UpdateEncoder) Start() { u.encoder.Start() } -func (u *UpdateEncoder) WaitForPacket() chan *astiav.Packet { - if u.paused.Load() { - <-u.resume - } - - return u.encoder.WaitForPacket() +func (u *UpdateEncoder) GetPacket(ctx context.Context) (*astiav.Packet, error) { + return u.buffer.Pop(ctx) } func (u *UpdateEncoder) PutBack(packet *astiav.Packet) { @@ -83,6 +88,7 @@ func (u *UpdateEncoder) Stop() { // UpdateBitrate modifies the encoder's target bitrate to the specified value in bits per second. // Returns an error if the update fails. func (u *UpdateEncoder) UpdateBitrate(bps int64) error { + // return nil if err := u.checkPause(bps); err != nil { return err } @@ -98,13 +104,14 @@ func (u *UpdateEncoder) UpdateBitrate(bps int64) error { if err != nil { return err } - fmt.Printf("got bitrate update request (%d -> %d)\n", current, bps) _, change := u.calculateBitrateChange(current, bps) if change < 5 { return nil } + fmt.Printf("got bitrate update request (%d -> %d)\n", current, bps) + start := time.Now() if err := u.builder.UpdateBitrate(bps); err != nil { return err } @@ -118,27 +125,27 @@ func (u *UpdateEncoder) UpdateBitrate(bps int64) error { // Wait for the first packet from the new encoder // firstPacket := <-newEncoder.WaitForPacket() + // newEncoder.PutBack(firstPacket) u.mux.Lock() oldEncoder := u.encoder u.encoder = newEncoder u.mux.Unlock() - // Put the first packet back for next WaitForPacket() - // newEncoder.PutBack(firstPacket) - - if oldEncoder != nil { - oldEncoder.Stop() - } - // Print encoder update notification fmt.Println() fmt.Println("╔═══════════════════════════════════════╗") fmt.Println("║ 🎥 ENCODER UPDATED 🎥 ║") - fmt.Printf("║ New Bitrate: %6d kbps ║\n", bps/1000) + fmt.Printf("║ New Bitrate: %6d kbps ║\n", bps/1000) + fmt.Printf("║ Change: %6.2f ║\n", change) + fmt.Printf("║ Update time: %6d ms ║\n", time.Since(start).Milliseconds()) fmt.Println("╚═══════════════════════════════════════╝") fmt.Println() + if oldEncoder != nil { + oldEncoder.Stop() + } + return nil } @@ -206,6 +213,44 @@ func (u *UpdateEncoder) calculateBitrateChange(currentBps, newBps int64) (absolu return absoluteChange, percentageChange } -func (u *UpdateEncoder) swapSoon() { +func (u *UpdateEncoder) getPacket() (*astiav.Packet, error) { + u.mux.RLock() + encoder := u.encoder // Get reference + u.mux.RUnlock() // Release lock immediately + if encoder != nil { + ctx, cancel := context.WithTimeout(u.ctx, 50*time.Millisecond) + defer cancel() + return encoder.GetPacket(ctx) // Don't hold lock during blocking call + } + + return nil, errors.New("encoder is nil") +} + +func (u *UpdateEncoder) pushPacket(p *astiav.Packet) error { + if p == nil { + return nil + } + ctx, cancel := context.WithTimeout(u.ctx, 50*time.Millisecond) + defer cancel() + return u.buffer.Push(ctx, p) +} + +func (u *UpdateEncoder) loop() { + for { + select { + case <-u.ctx.Done(): + return + default: + p, err := u.getPacket() + if err != nil { + // fmt.Println("error getting packet from encoder; err:", err.Error()) + } + + if err := u.pushPacket(p); err != nil { + fmt.Println(err.Error()) + } + time.Sleep(10 * time.Millisecond) + } + } } diff --git a/pkg/x264options.go b/pkg/x264options.go index af9dc15..5b8402e 100644 --- a/pkg/x264options.go +++ b/pkg/x264options.go @@ -119,9 +119,9 @@ var LowLatencyBitrateControlled = &X264Options{ Tune: "zerolatency", X264AdvancedOptions: &X264AdvancedOptions{ - Bitrate: "800", // 800kbps - VBVMaxBitrate: "800", - VBVBuffer: "400", + Bitrate: "500", // 800kbps + VBVMaxBitrate: "500", + VBVBuffer: "250", RateTolerance: "1", // 1% rate tolerance MaxGOP: "25", MinGOP: "13",