diff --git a/internal/frame_pool.go b/internal/frame_pool.go index c641f39..c7a45ed 100644 --- a/internal/frame_pool.go +++ b/internal/frame_pool.go @@ -4,6 +4,7 @@ import ( "sync" "github.com/asticode/go-astiav" + "github.com/harshabose/tools/buffer/pkg" ) diff --git a/pkg/encoder.go b/pkg/encoder.go index c228ea1..21c527f 100644 --- a/pkg/encoder.go +++ b/pkg/encoder.go @@ -16,7 +16,7 @@ import ( type GeneralEncoder struct { buffer buffer.BufferWithGenerator[astiav.Packet] - filter CanProduceMediaFrame + producer CanProduceMediaFrame codec *astiav.Codec encoderContext *astiav.CodecContext codecFlags *astiav.Dictionary @@ -30,7 +30,7 @@ type GeneralEncoder struct { func CreateGeneralEncoder(ctx context.Context, codecID astiav.CodecID, canProduceMediaFrame CanProduceMediaFrame, options ...EncoderOption) (*GeneralEncoder, error) { ctx2, cancel := context.WithCancel(ctx) encoder := &GeneralEncoder{ - filter: canProduceMediaFrame, + producer: canProduceMediaFrame, codecFlags: astiav.NewDictionary(), ctx: ctx2, cancel: cancel, @@ -69,7 +69,7 @@ func CreateGeneralEncoder(ctx context.Context, codecID astiav.CodecID, canProduc } if encoder.buffer == nil { - encoder.buffer = buffer.CreateChannelBuffer(ctx, 256, internal.CreatePacketPool()) + encoder.buffer = buffer.CreateChannelBuffer(ctx2, 256, internal.CreatePacketPool()) } encoder.findParameterSets(encoder.encoderContext.ExtraData()) @@ -109,7 +109,7 @@ loop1: continue } if err := encoder.encoderContext.SendFrame(frame); err != nil { - encoder.filter.PutBack(frame) + encoder.producer.PutBack(frame) if !errors.Is(err, astiav.ErrEagain) { continue loop1 } @@ -127,7 +127,7 @@ loop1: continue loop2 } } - encoder.filter.PutBack(frame) + encoder.producer.PutBack(frame) } } } @@ -136,7 +136,7 @@ func (encoder *GeneralEncoder) getFrame() (*astiav.Frame, error) { ctx, cancel := context.WithTimeout(encoder.ctx, 50*time.Millisecond) defer cancel() - return encoder.filter.GetFrame(ctx) + return encoder.producer.GetFrame(ctx) } func (encoder *GeneralEncoder) GetPacket(ctx context.Context) (*astiav.Packet, error) { diff --git a/pkg/encoder_builder.go b/pkg/encoder_builder.go index 9fb21b0..8ccf878 100644 --- a/pkg/encoder_builder.go +++ b/pkg/encoder_builder.go @@ -31,6 +31,11 @@ func (b *GeneralEncoderBuilder) UpdateBitrate(bps int64) error { return s.UpdateBitrate(bps) } +func (b *GeneralEncoderBuilder) BuildWithProducer(ctx context.Context, producer CanProduceMediaFrame) (Encoder, error) { + b.producer = producer + return b.Build(ctx) +} + func (b *GeneralEncoderBuilder) Build(ctx context.Context) (Encoder, error) { codec := astiav.FindEncoder(b.codecID) if codec == nil { @@ -39,7 +44,7 @@ func (b *GeneralEncoderBuilder) Build(ctx context.Context) (Encoder, error) { ctx2, cancel := context.WithCancel(ctx) encoder := &GeneralEncoder{ - filter: b.producer, + producer: b.producer, codec: codec, codecFlags: astiav.NewDictionary(), ctx: ctx2, @@ -51,7 +56,7 @@ func (b *GeneralEncoderBuilder) Build(ctx context.Context) (Encoder, error) { return nil, ErrorAllocateCodecContext } - canDescribeMediaFrame, ok := b.producer.(CanDescribeMediaFrame) + canDescribeMediaFrame, ok := encoder.producer.(CanDescribeMediaFrame) if !ok { return nil, ErrorInterfaceMismatch } diff --git a/pkg/multi_encoder.go b/pkg/multi_encoder.go new file mode 100644 index 0000000..bdbdc48 --- /dev/null +++ b/pkg/multi_encoder.go @@ -0,0 +1,310 @@ +package transcode + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/asticode/go-astiav" + + "github.com/harshabose/simple_webrtc_comm/transcode/internal" + "github.com/harshabose/tools/buffer/pkg" +) + +type MultiConfig struct { + Steps uint8 + UpdateConfig +} + +func (c MultiConfig) validate() error { + if c.Steps == 0 { + return fmt.Errorf("steps need be more than 0") + } + return c.UpdateConfig.validate() +} + +func (c MultiConfig) getBitrates() []int64 { + bitrates := make([]int64, c.Steps) + + if c.Steps == 1 { + bitrates[0] = c.MaxBitrate + } else { + step := float64(c.MaxBitrate-c.MinBitrate) / float64(c.Steps-1) + for i := uint8(0); i < c.Steps; i++ { + bitrates[i] = c.MinBitrate + int64(float64(i)*step) + } + } + + return bitrates +} + +func NewMultiConfig(minBitrate, maxBitrate int64, steps uint8) MultiConfig { + c := MultiConfig{ + UpdateConfig: UpdateConfig{ + MaxBitrate: maxBitrate, + MinBitrate: minBitrate, + }, + Steps: steps, + } + + return c +} + +type dummyMediaFrameProducer struct { + buffer buffer.BufferWithGenerator[astiav.Frame] + CanDescribeMediaFrame +} + +func newDummyMediaFrameProducer(buffer buffer.BufferWithGenerator[astiav.Frame], describer CanDescribeMediaFrame) *dummyMediaFrameProducer { + return &dummyMediaFrameProducer{ + buffer: buffer, + CanDescribeMediaFrame: describer, + } +} + +func (p *dummyMediaFrameProducer) pushFrame(ctx context.Context, frame *astiav.Frame) error { + return p.buffer.Push(ctx, frame) +} + +func (p *dummyMediaFrameProducer) GetFrame(ctx context.Context) (*astiav.Frame, error) { + return p.buffer.Pop(ctx) +} + +func (p *dummyMediaFrameProducer) Generate() *astiav.Frame { + return p.buffer.Generate() +} + +func (p *dummyMediaFrameProducer) PutBack(frame *astiav.Frame) { + p.buffer.PutBack(frame) +} + +type splitEncoder struct { + encoder *GeneralEncoder + producer *dummyMediaFrameProducer +} + +func newSplitEncoder(encoder *GeneralEncoder, producer *dummyMediaFrameProducer) *splitEncoder { + return &splitEncoder{ + encoder: encoder, + producer: producer, + } +} + +type MultiUpdateEncoder struct { + encoders []*splitEncoder + active atomic.Pointer[splitEncoder] + config MultiConfig + bitrates []int64 + producer CanProduceMediaFrame + ctx context.Context + cancel context.CancelFunc + + paused atomic.Bool + resume chan struct{} + pauseMux sync.Mutex +} + +func NewMultiUpdateEncoder(ctx context.Context, config MultiConfig, builder *GeneralEncoderBuilder) (*MultiUpdateEncoder, error) { + if err := config.validate(); err != nil { + return nil, err + } + + ctx2, cancel := context.WithCancel(ctx) + encoder := &MultiUpdateEncoder{ + encoders: make([]*splitEncoder, 0), + config: config, + bitrates: config.getBitrates(), + producer: builder.producer, + ctx: ctx2, + cancel: cancel, + resume: make(chan struct{}), + } + + describer, ok := encoder.producer.(CanDescribeMediaFrame) + if !ok { + return nil, ErrorInterfaceMismatch + } + + for _, bitrate := range encoder.bitrates { + producer := newDummyMediaFrameProducer(buffer.CreateChannelBuffer(ctx2, 90, internal.CreateFramePool()), describer) + + if err := builder.UpdateBitrate(bitrate); err != nil { + return nil, err + } + + e, err := builder.BuildWithProducer(ctx2, producer) + if err != nil { + return nil, err + } + + encoder.encoders = append(encoder.encoders, newSplitEncoder(e.(*GeneralEncoder), producer)) + } + + encoder.switchEncoder(0) + + return encoder, nil +} + +func (u *MultiUpdateEncoder) Ctx() context.Context { + return u.ctx +} + +func (u *MultiUpdateEncoder) Start() { + for _, encoder := range u.encoders { + encoder.encoder.Start() + } + + go u.loop() +} + +func (u *MultiUpdateEncoder) GetPacket(ctx context.Context) (*astiav.Packet, error) { + return u.active.Load().encoder.GetPacket(ctx) +} + +func (u *MultiUpdateEncoder) PutBack(packet *astiav.Packet) { + u.active.Load().encoder.PutBack(packet) +} + +func (u *MultiUpdateEncoder) Stop() { + u.cancel() +} + +func (u *MultiUpdateEncoder) UpdateBitrate(bps int64) error { + if err := u.checkPause(bps); err != nil { + return err + } + + bps = u.cutoff(bps) + + bestIndex := u.findBestEncoderIndex(bps) + u.switchEncoder(bestIndex) + + return nil +} + +func (u *MultiUpdateEncoder) findBestEncoderIndex(targetBps int64) int { + bestIndex := 0 + for i, bitrate := range u.bitrates { + if bitrate <= targetBps { + bestIndex = i + } else { + break + } + } + + return bestIndex +} + +func (u *MultiUpdateEncoder) switchEncoder(index int) { + if index < len(u.encoders) { + u.active.Swap(u.encoders[index]) + } +} + +func (u *MultiUpdateEncoder) cutoff(bps int64) int64 { + if bps > u.config.MaxBitrate { + bps = u.config.MaxBitrate + } + + if bps < u.config.MinBitrate { + bps = u.config.MinBitrate + } + + return bps +} + +func (u *MultiUpdateEncoder) shouldPause(bps int64) bool { + return bps <= u.config.MinBitrate && u.config.CutVideoBelowMinBitrate +} + +func (u *MultiUpdateEncoder) checkPause(bps int64) error { + shouldPause := u.shouldPause(bps) + + if shouldPause { + fmt.Println("pausing video...") + return u.PauseEncoding() + } + return u.UnPauseEncoding() +} + +func (u *MultiUpdateEncoder) PauseEncoding() error { + u.paused.Store(true) + return nil +} + +func (u *MultiUpdateEncoder) UnPauseEncoding() error { + u.pauseMux.Lock() + defer u.pauseMux.Unlock() + + if u.paused.Swap(false) { + close(u.resume) + u.resume = make(chan struct{}) + } + return nil +} + +func (u *MultiUpdateEncoder) GetParameterSets() (sps []byte, pps []byte, err error) { + return u.active.Load().encoder.GetParameterSets() +} + +func (u *MultiUpdateEncoder) loop() { + defer u.close() + + for { + select { + case <-u.ctx.Done(): + return + default: + frame, err := u.getFrame() + if err != nil { + continue + } + + for _, encoder := range u.encoders { + if err := u.pushFrame(encoder, frame); err != nil { + continue + } + } + + // NOT PUT BACK AS THEY ARE BEING REF IN THE INDIVIDUAL BUFFERS + // u.producer.PutBack(frame) + } + } +} + +func (u *MultiUpdateEncoder) getFrame() (*astiav.Frame, error) { + ctx, cancel := context.WithTimeout(u.ctx, 50*time.Millisecond) + defer cancel() + + return u.producer.GetFrame(ctx) +} + +func (u *MultiUpdateEncoder) pushFrame(encoder *splitEncoder, frame *astiav.Frame) error { + if frame == nil { + return fmt.Errorf("frame is nil from the producer") + } + + ctx, cancel := context.WithTimeout(u.ctx, 50*time.Millisecond) + defer cancel() + + refFrame := encoder.producer.Generate() + if refFrame == nil { + return fmt.Errorf("failed to generate frame from encoder pool") + } + + // NOTE: THIS IS NEEDED AS Ref NEEDS A NEWLY ALLOCATED OR Unref FRAME + refFrame.Unref() + + if err := refFrame.Ref(frame); err != nil { + return fmt.Errorf("erorr while adding ref to frame; err: %s", err.Error()) + } + + // PUT IN BUFFER + return encoder.producer.pushFrame(ctx, refFrame) +} + +func (u *MultiUpdateEncoder) close() { + +} diff --git a/pkg/transcoder_options.go b/pkg/transcoder_options.go index 3674713..c3850e1 100644 --- a/pkg/transcoder_options.go +++ b/pkg/transcoder_options.go @@ -68,3 +68,16 @@ func WithBitrateControlEncoder(ctx context.Context, codecID astiav.CodecID, bitr return nil } } + +func WithMultiEncoderBitrateControl(ctx context.Context, codecID astiav.CodecID, config MultiConfig, settings codecSettings, bufferSize int) TranscoderOption { + return func(transcoder *Transcoder) error { + builder := NewEncoderBuilder(codecID, settings, bufferSize, transcoder.filter) + multiEncoder, err := NewMultiUpdateEncoder(ctx, config, builder) + if err != nil { + return err + } + + transcoder.encoder = multiEncoder + return nil + } +} diff --git a/pkg/update_encoder_wrapper.go b/pkg/update_encoder_wrapper.go index 268051f..14059c9 100644 --- a/pkg/update_encoder_wrapper.go +++ b/pkg/update_encoder_wrapper.go @@ -19,6 +19,14 @@ type UpdateConfig struct { CutVideoBelowMinBitrate bool } +func (c UpdateConfig) validate() error { + if c.MinBitrate > c.MaxBitrate { + return fmt.Errorf("minimum bitrate is higher than maximum bitrate in the update encoder config") + } + + return nil +} + type UpdateEncoder struct { encoder Encoder config UpdateConfig @@ -41,6 +49,10 @@ func NewUpdateEncoder(ctx context.Context, config UpdateConfig, builder *General ctx: ctx, } + if err := config.validate(); err != nil { + return nil, err + } + encoder, err := builder.Build(ctx) if err != nil { return nil, err @@ -105,7 +117,7 @@ func (u *UpdateEncoder) UpdateBitrate(bps int64) error { return err } - _, change := u.calculateBitrateChange(current, bps) + _, change := calculateBitrateChange(current, bps) if change < 5 { return nil } @@ -200,7 +212,7 @@ func (u *UpdateEncoder) GetParameterSets() (sps []byte, pps []byte, err error) { return p.GetParameterSets() } -func (u *UpdateEncoder) calculateBitrateChange(currentBps, newBps int64) (absoluteChange int64, percentageChange float64) { +func calculateBitrateChange(currentBps, newBps int64) (absoluteChange int64, percentageChange float64) { absoluteChange = newBps - currentBps if absoluteChange < 0 { absoluteChange = -absoluteChange @@ -215,13 +227,12 @@ func (u *UpdateEncoder) calculateBitrateChange(currentBps, newBps int64) (absolu func (u *UpdateEncoder) getPacket() (*astiav.Packet, error) { u.mux.RLock() - encoder := u.encoder // Get reference - u.mux.RUnlock() // Release lock immediately + defer u.mux.RUnlock() - if encoder != nil { + if u.encoder != nil { ctx, cancel := context.WithTimeout(u.ctx, 50*time.Millisecond) defer cancel() - return encoder.GetPacket(ctx) // Don't hold lock during blocking call + return u.encoder.GetPacket(ctx) // Don't hold lock during blocking call } return nil, errors.New("encoder is nil")