added multi_encoder.go

This commit is contained in:
harshabose
2025-05-30 09:34:36 +05:30
parent 6a277a1e78
commit 9c38b34fb4
6 changed files with 354 additions and 14 deletions
+1
View File
@@ -4,6 +4,7 @@ import (
"sync"
"github.com/asticode/go-astiav"
"github.com/harshabose/tools/buffer/pkg"
)
+6 -6
View File
@@ -16,7 +16,7 @@ import (
type GeneralEncoder struct {
buffer buffer.BufferWithGenerator[astiav.Packet]
filter CanProduceMediaFrame
producer CanProduceMediaFrame
codec *astiav.Codec
encoderContext *astiav.CodecContext
codecFlags *astiav.Dictionary
@@ -30,7 +30,7 @@ type GeneralEncoder struct {
func CreateGeneralEncoder(ctx context.Context, codecID astiav.CodecID, canProduceMediaFrame CanProduceMediaFrame, options ...EncoderOption) (*GeneralEncoder, error) {
ctx2, cancel := context.WithCancel(ctx)
encoder := &GeneralEncoder{
filter: canProduceMediaFrame,
producer: canProduceMediaFrame,
codecFlags: astiav.NewDictionary(),
ctx: ctx2,
cancel: cancel,
@@ -69,7 +69,7 @@ func CreateGeneralEncoder(ctx context.Context, codecID astiav.CodecID, canProduc
}
if encoder.buffer == nil {
encoder.buffer = buffer.CreateChannelBuffer(ctx, 256, internal.CreatePacketPool())
encoder.buffer = buffer.CreateChannelBuffer(ctx2, 256, internal.CreatePacketPool())
}
encoder.findParameterSets(encoder.encoderContext.ExtraData())
@@ -109,7 +109,7 @@ loop1:
continue
}
if err := encoder.encoderContext.SendFrame(frame); err != nil {
encoder.filter.PutBack(frame)
encoder.producer.PutBack(frame)
if !errors.Is(err, astiav.ErrEagain) {
continue loop1
}
@@ -127,7 +127,7 @@ loop1:
continue loop2
}
}
encoder.filter.PutBack(frame)
encoder.producer.PutBack(frame)
}
}
}
@@ -136,7 +136,7 @@ func (encoder *GeneralEncoder) getFrame() (*astiav.Frame, error) {
ctx, cancel := context.WithTimeout(encoder.ctx, 50*time.Millisecond)
defer cancel()
return encoder.filter.GetFrame(ctx)
return encoder.producer.GetFrame(ctx)
}
func (encoder *GeneralEncoder) GetPacket(ctx context.Context) (*astiav.Packet, error) {
+7 -2
View File
@@ -31,6 +31,11 @@ func (b *GeneralEncoderBuilder) UpdateBitrate(bps int64) error {
return s.UpdateBitrate(bps)
}
func (b *GeneralEncoderBuilder) BuildWithProducer(ctx context.Context, producer CanProduceMediaFrame) (Encoder, error) {
b.producer = producer
return b.Build(ctx)
}
func (b *GeneralEncoderBuilder) Build(ctx context.Context) (Encoder, error) {
codec := astiav.FindEncoder(b.codecID)
if codec == nil {
@@ -39,7 +44,7 @@ func (b *GeneralEncoderBuilder) Build(ctx context.Context) (Encoder, error) {
ctx2, cancel := context.WithCancel(ctx)
encoder := &GeneralEncoder{
filter: b.producer,
producer: b.producer,
codec: codec,
codecFlags: astiav.NewDictionary(),
ctx: ctx2,
@@ -51,7 +56,7 @@ func (b *GeneralEncoderBuilder) Build(ctx context.Context) (Encoder, error) {
return nil, ErrorAllocateCodecContext
}
canDescribeMediaFrame, ok := b.producer.(CanDescribeMediaFrame)
canDescribeMediaFrame, ok := encoder.producer.(CanDescribeMediaFrame)
if !ok {
return nil, ErrorInterfaceMismatch
}
+310
View File
@@ -0,0 +1,310 @@
package transcode
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/asticode/go-astiav"
"github.com/harshabose/simple_webrtc_comm/transcode/internal"
"github.com/harshabose/tools/buffer/pkg"
)
type MultiConfig struct {
Steps uint8
UpdateConfig
}
func (c MultiConfig) validate() error {
if c.Steps == 0 {
return fmt.Errorf("steps need be more than 0")
}
return c.UpdateConfig.validate()
}
func (c MultiConfig) getBitrates() []int64 {
bitrates := make([]int64, c.Steps)
if c.Steps == 1 {
bitrates[0] = c.MaxBitrate
} else {
step := float64(c.MaxBitrate-c.MinBitrate) / float64(c.Steps-1)
for i := uint8(0); i < c.Steps; i++ {
bitrates[i] = c.MinBitrate + int64(float64(i)*step)
}
}
return bitrates
}
func NewMultiConfig(minBitrate, maxBitrate int64, steps uint8) MultiConfig {
c := MultiConfig{
UpdateConfig: UpdateConfig{
MaxBitrate: maxBitrate,
MinBitrate: minBitrate,
},
Steps: steps,
}
return c
}
type dummyMediaFrameProducer struct {
buffer buffer.BufferWithGenerator[astiav.Frame]
CanDescribeMediaFrame
}
func newDummyMediaFrameProducer(buffer buffer.BufferWithGenerator[astiav.Frame], describer CanDescribeMediaFrame) *dummyMediaFrameProducer {
return &dummyMediaFrameProducer{
buffer: buffer,
CanDescribeMediaFrame: describer,
}
}
func (p *dummyMediaFrameProducer) pushFrame(ctx context.Context, frame *astiav.Frame) error {
return p.buffer.Push(ctx, frame)
}
func (p *dummyMediaFrameProducer) GetFrame(ctx context.Context) (*astiav.Frame, error) {
return p.buffer.Pop(ctx)
}
func (p *dummyMediaFrameProducer) Generate() *astiav.Frame {
return p.buffer.Generate()
}
func (p *dummyMediaFrameProducer) PutBack(frame *astiav.Frame) {
p.buffer.PutBack(frame)
}
type splitEncoder struct {
encoder *GeneralEncoder
producer *dummyMediaFrameProducer
}
func newSplitEncoder(encoder *GeneralEncoder, producer *dummyMediaFrameProducer) *splitEncoder {
return &splitEncoder{
encoder: encoder,
producer: producer,
}
}
type MultiUpdateEncoder struct {
encoders []*splitEncoder
active atomic.Pointer[splitEncoder]
config MultiConfig
bitrates []int64
producer CanProduceMediaFrame
ctx context.Context
cancel context.CancelFunc
paused atomic.Bool
resume chan struct{}
pauseMux sync.Mutex
}
func NewMultiUpdateEncoder(ctx context.Context, config MultiConfig, builder *GeneralEncoderBuilder) (*MultiUpdateEncoder, error) {
if err := config.validate(); err != nil {
return nil, err
}
ctx2, cancel := context.WithCancel(ctx)
encoder := &MultiUpdateEncoder{
encoders: make([]*splitEncoder, 0),
config: config,
bitrates: config.getBitrates(),
producer: builder.producer,
ctx: ctx2,
cancel: cancel,
resume: make(chan struct{}),
}
describer, ok := encoder.producer.(CanDescribeMediaFrame)
if !ok {
return nil, ErrorInterfaceMismatch
}
for _, bitrate := range encoder.bitrates {
producer := newDummyMediaFrameProducer(buffer.CreateChannelBuffer(ctx2, 90, internal.CreateFramePool()), describer)
if err := builder.UpdateBitrate(bitrate); err != nil {
return nil, err
}
e, err := builder.BuildWithProducer(ctx2, producer)
if err != nil {
return nil, err
}
encoder.encoders = append(encoder.encoders, newSplitEncoder(e.(*GeneralEncoder), producer))
}
encoder.switchEncoder(0)
return encoder, nil
}
func (u *MultiUpdateEncoder) Ctx() context.Context {
return u.ctx
}
func (u *MultiUpdateEncoder) Start() {
for _, encoder := range u.encoders {
encoder.encoder.Start()
}
go u.loop()
}
func (u *MultiUpdateEncoder) GetPacket(ctx context.Context) (*astiav.Packet, error) {
return u.active.Load().encoder.GetPacket(ctx)
}
func (u *MultiUpdateEncoder) PutBack(packet *astiav.Packet) {
u.active.Load().encoder.PutBack(packet)
}
func (u *MultiUpdateEncoder) Stop() {
u.cancel()
}
func (u *MultiUpdateEncoder) UpdateBitrate(bps int64) error {
if err := u.checkPause(bps); err != nil {
return err
}
bps = u.cutoff(bps)
bestIndex := u.findBestEncoderIndex(bps)
u.switchEncoder(bestIndex)
return nil
}
func (u *MultiUpdateEncoder) findBestEncoderIndex(targetBps int64) int {
bestIndex := 0
for i, bitrate := range u.bitrates {
if bitrate <= targetBps {
bestIndex = i
} else {
break
}
}
return bestIndex
}
func (u *MultiUpdateEncoder) switchEncoder(index int) {
if index < len(u.encoders) {
u.active.Swap(u.encoders[index])
}
}
func (u *MultiUpdateEncoder) cutoff(bps int64) int64 {
if bps > u.config.MaxBitrate {
bps = u.config.MaxBitrate
}
if bps < u.config.MinBitrate {
bps = u.config.MinBitrate
}
return bps
}
func (u *MultiUpdateEncoder) shouldPause(bps int64) bool {
return bps <= u.config.MinBitrate && u.config.CutVideoBelowMinBitrate
}
func (u *MultiUpdateEncoder) checkPause(bps int64) error {
shouldPause := u.shouldPause(bps)
if shouldPause {
fmt.Println("pausing video...")
return u.PauseEncoding()
}
return u.UnPauseEncoding()
}
func (u *MultiUpdateEncoder) PauseEncoding() error {
u.paused.Store(true)
return nil
}
func (u *MultiUpdateEncoder) UnPauseEncoding() error {
u.pauseMux.Lock()
defer u.pauseMux.Unlock()
if u.paused.Swap(false) {
close(u.resume)
u.resume = make(chan struct{})
}
return nil
}
func (u *MultiUpdateEncoder) GetParameterSets() (sps []byte, pps []byte, err error) {
return u.active.Load().encoder.GetParameterSets()
}
func (u *MultiUpdateEncoder) loop() {
defer u.close()
for {
select {
case <-u.ctx.Done():
return
default:
frame, err := u.getFrame()
if err != nil {
continue
}
for _, encoder := range u.encoders {
if err := u.pushFrame(encoder, frame); err != nil {
continue
}
}
// NOT PUT BACK AS THEY ARE BEING REF IN THE INDIVIDUAL BUFFERS
// u.producer.PutBack(frame)
}
}
}
func (u *MultiUpdateEncoder) getFrame() (*astiav.Frame, error) {
ctx, cancel := context.WithTimeout(u.ctx, 50*time.Millisecond)
defer cancel()
return u.producer.GetFrame(ctx)
}
func (u *MultiUpdateEncoder) pushFrame(encoder *splitEncoder, frame *astiav.Frame) error {
if frame == nil {
return fmt.Errorf("frame is nil from the producer")
}
ctx, cancel := context.WithTimeout(u.ctx, 50*time.Millisecond)
defer cancel()
refFrame := encoder.producer.Generate()
if refFrame == nil {
return fmt.Errorf("failed to generate frame from encoder pool")
}
// NOTE: THIS IS NEEDED AS Ref NEEDS A NEWLY ALLOCATED OR Unref FRAME
refFrame.Unref()
if err := refFrame.Ref(frame); err != nil {
return fmt.Errorf("erorr while adding ref to frame; err: %s", err.Error())
}
// PUT IN BUFFER
return encoder.producer.pushFrame(ctx, refFrame)
}
func (u *MultiUpdateEncoder) close() {
}
+13
View File
@@ -68,3 +68,16 @@ func WithBitrateControlEncoder(ctx context.Context, codecID astiav.CodecID, bitr
return nil
}
}
func WithMultiEncoderBitrateControl(ctx context.Context, codecID astiav.CodecID, config MultiConfig, settings codecSettings, bufferSize int) TranscoderOption {
return func(transcoder *Transcoder) error {
builder := NewEncoderBuilder(codecID, settings, bufferSize, transcoder.filter)
multiEncoder, err := NewMultiUpdateEncoder(ctx, config, builder)
if err != nil {
return err
}
transcoder.encoder = multiEncoder
return nil
}
}
+17 -6
View File
@@ -19,6 +19,14 @@ type UpdateConfig struct {
CutVideoBelowMinBitrate bool
}
func (c UpdateConfig) validate() error {
if c.MinBitrate > c.MaxBitrate {
return fmt.Errorf("minimum bitrate is higher than maximum bitrate in the update encoder config")
}
return nil
}
type UpdateEncoder struct {
encoder Encoder
config UpdateConfig
@@ -41,6 +49,10 @@ func NewUpdateEncoder(ctx context.Context, config UpdateConfig, builder *General
ctx: ctx,
}
if err := config.validate(); err != nil {
return nil, err
}
encoder, err := builder.Build(ctx)
if err != nil {
return nil, err
@@ -105,7 +117,7 @@ func (u *UpdateEncoder) UpdateBitrate(bps int64) error {
return err
}
_, change := u.calculateBitrateChange(current, bps)
_, change := calculateBitrateChange(current, bps)
if change < 5 {
return nil
}
@@ -200,7 +212,7 @@ func (u *UpdateEncoder) GetParameterSets() (sps []byte, pps []byte, err error) {
return p.GetParameterSets()
}
func (u *UpdateEncoder) calculateBitrateChange(currentBps, newBps int64) (absoluteChange int64, percentageChange float64) {
func calculateBitrateChange(currentBps, newBps int64) (absoluteChange int64, percentageChange float64) {
absoluteChange = newBps - currentBps
if absoluteChange < 0 {
absoluteChange = -absoluteChange
@@ -215,13 +227,12 @@ func (u *UpdateEncoder) calculateBitrateChange(currentBps, newBps int64) (absolu
func (u *UpdateEncoder) getPacket() (*astiav.Packet, error) {
u.mux.RLock()
encoder := u.encoder // Get reference
u.mux.RUnlock() // Release lock immediately
defer u.mux.RUnlock()
if encoder != nil {
if u.encoder != nil {
ctx, cancel := context.WithTimeout(u.ctx, 50*time.Millisecond)
defer cancel()
return encoder.GetPacket(ctx) // Don't hold lock during blocking call
return u.encoder.GetPacket(ctx) // Don't hold lock during blocking call
}
return nil, errors.New("encoder is nil")