diff --git a/internal/packet_pool.go b/internal/packet_pool.go index bc47053..b9f6640 100644 --- a/internal/packet_pool.go +++ b/internal/packet_pool.go @@ -4,6 +4,7 @@ import ( "sync" "github.com/asticode/go-astiav" + "github.com/harshabose/tools/buffer/pkg" ) @@ -48,7 +49,7 @@ func (pool *packetPool) Release() { if !ok { continue } - + // fmt.Printf("๐Ÿ—‘๏ธ Releasing packet: ptr=%p\n", packet) packet.Free() } } diff --git a/pkg/VP8Encoder.go b/pkg/VP8Encoder.go deleted file mode 100644 index e460f22..0000000 --- a/pkg/VP8Encoder.go +++ /dev/null @@ -1,352 +0,0 @@ -package transcode - -// -// import ( -// "context" -// "errors" -// "fmt" -// "math" -// "sync" -// "time" -// -// "github.com/asticode/go-astiav" -// -// "github.com/harshabose/simple_webrtc_comm/transcode/internal" -// "github.com/harshabose/tools/buffer/pkg" -// ) -// -// type VP8Encoder struct { -// buffer buffer.BufferWithGenerator[astiav.Packet] -// filter *Filter -// codec *astiav.Codec -// codecFlags *astiav.Dictionary -// copyCodecFlags *astiav.Dictionary -// codecSettings codecSettings -// bandwidthChan chan int64 -// options []EncoderOption -// -// encoderContext *astiav.CodecContext -// fallbackEncoderContext *astiav.CodecContext -// -// ctx context.Context -// mux sync.Mutex -// } -// -// func NewVP8Encoder(ctx context.Context, filter *Filter, options ...EncoderOption) (*VP8Encoder, error) { -// encoder := &VP8Encoder{ -// filter: filter, -// codecFlags: astiav.NewDictionary(), -// ctx: ctx, -// } -// -// if encoder.codec = astiav.FindEncoder(astiav.CodecIDVp8); encoder.codec == nil { -// return nil, errors.New("VP8 encoder not found") -// } -// -// encoderContext, err := createNewVP8Encoder(encoder.codec, filter) -// if err != nil { -// return nil, err -// } -// encoder.encoderContext = encoderContext -// -// for _, option := range options { -// if err := option(encoder); err != nil { -// return nil, err -// } -// } -// -// if encoder.codecSettings == nil { -// fmt.Println("warn: no VP8 encoder settings were provided") -// } -// -// copyDict, err := copyDictionary(encoder.codecFlags) -// if err != nil { -// return nil, err -// } -// encoder.copyCodecFlags = copyDict -// -// if err := openVP8Encoder(encoder.encoderContext, encoder.codec, encoder.codecFlags); err != nil { -// return nil, err -// } -// -// if encoder.buffer == nil { -// encoder.buffer = buffer.CreateChannelBuffer(ctx, 256, internal.CreatePacketPool()) -// } -// -// return encoder, nil -// } -// -// func (e *VP8Encoder) Start() { -// go e.loop() -// } -// -// func (e *VP8Encoder) GetPacket() (*astiav.Packet, error) { -// ctx, cancel := context.WithTimeout(e.ctx, time.Second) -// defer cancel() -// return e.buffer.Pop(ctx) -// } -// -// func (e *VP8Encoder) WaitForPacket() chan *astiav.Packet { -// return e.buffer.GetChannel() -// } -// -// func (e *VP8Encoder) PutBack(packet *astiav.Packet) { -// e.buffer.PutBack(packet) -// } -// -// func (e *VP8Encoder) GetTimeBase() astiav.Rational { -// e.mux.Lock() -// defer e.mux.Unlock() -// -// if e.encoderContext != nil { -// return e.encoderContext.TimeBase() -// } -// if e.fallbackEncoderContext != nil { -// return e.fallbackEncoderContext.TimeBase() -// } -// return astiav.Rational{} -// } -// -// func (e *VP8Encoder) GetDuration() time.Duration { -// e.mux.Lock() -// defer e.mux.Unlock() -// -// if e.encoderContext != nil { -// return time.Duration(float64(time.Second) / e.encoderContext.Framerate().Float64()) -// } -// if e.fallbackEncoderContext != nil { -// return time.Duration(float64(time.Second) / e.fallbackEncoderContext.Framerate().Float64()) -// } -// return time.Second / 30 -// } -// -// func (e *VP8Encoder) SetBitrateChannel(channel chan int64) { -// e.mux.Lock() -// defer e.mux.Unlock() -// e.bandwidthChan = channel -// } -// -// // Get current VP8 bitrate from encoder context -// func (e *VP8Encoder) getCurrentBitrate() (int64, error) { -// e.mux.Lock() -// defer e.mux.Unlock() -// -// if e.encoderContext != nil { -// return e.encoderContext.BitRate() / 1000, nil // Convert to kbps -// } -// return 0, errors.New("no encoder context available") -// } -// -// // Update VP8 bitrate (simpler than x264) -// func (e *VP8Encoder) updateBitrate(bitrate int64) error { -// start := time.Now() -// -// e.mux.Lock() -// current, err := e.getCurrentBitrate() -// if err != nil { -// e.mux.Unlock() -// fmt.Printf("error getting current bitrate; err: %s\n", err.Error()) -// return err -// } -// -// // Same change logic as your x264 version -// change := math.Abs(float64(current)-float64(bitrate)) / math.Abs(float64(current)) -// -// if change < 0.1 || change > 2.0 { -// e.mux.Unlock() -// fmt.Printf("change not appropriate; current: %d; new: %d; change:%f\n", current, bitrate, change) -// return nil -// } -// -// fmt.Printf("VP8 bitrate change approved; change: %f\n", change) -// -// // Set VP8 bitrate parameters -// if err := e.updateVP8Options(bitrate); err != nil { -// e.mux.Unlock() -// fmt.Printf("error while updating VP8 options; err: %s\n", err.Error()) -// return err -// } -// -// e.mux.Unlock() -// if err := e.createNewEncoderContext(); err != nil { -// return err -// } -// -// duration := time.Since(start) -// fmt.Printf("๐Ÿ”„ VP8 Bitrate updated: %d โ†’ %d (%.1f%%) in %v\n", -// current, bitrate, change*100, duration) -// -// return nil -// } -// -// // Update VP8-specific options -// func (e *VP8Encoder) updateVP8Options(bitrate int64) error { -// // VP8 uses simpler parameter names -// paramsToUpdate := map[string]string{ -// "deadline": "1", // Real-time encoding -// "b:v": fmt.Sprintf("%dk", bitrate), // Target bitrate -// "minrate": fmt.Sprintf("%dk", bitrate*80/100), // Min bitrate (80% of target) -// "maxrate": fmt.Sprintf("%dk", bitrate*120/100), // Max bitrate (120% of target) -// "bufsize": fmt.Sprintf("%dk", bitrate/5), // Buffer size -// "crf": "10", // Good quality balance -// "cpu-used": "8", // Fastest preset for real-time -// } -// -// for param, value := range paramsToUpdate { -// if err := e.copyCodecFlags.Set(param, value, 0); err != nil { -// return err -// } -// } -// -// return nil -// } -// -// // Rest of your encoder methods... -// func (e *VP8Encoder) createNewEncoderContext() error { -// e.mux.Lock() -// e.fallbackEncoderContext = e.encoderContext -// e.encoderContext = nil -// -// copyDict, err := copyDictionary(e.copyCodecFlags) -// if err != nil { -// e.mux.Unlock() -// return err -// } -// -// e.codecFlags.Free() -// e.codecFlags = copyDict -// e.mux.Unlock() -// -// encoderContext, err := createNewOpenVP8Encoder(e.codec, e.filter, e.codecFlags) -// if err != nil { -// e.mux.Lock() -// e.encoderContext = e.fallbackEncoderContext -// e.fallbackEncoderContext = nil -// e.mux.Unlock() -// fmt.Printf("New VP8 encoder creation failed, reverted: %v\n", err) -// return err -// } -// -// e.mux.Lock() -// oldFallback := e.fallbackEncoderContext -// e.encoderContext = encoderContext -// e.fallbackEncoderContext = nil -// e.mux.Unlock() -// -// if oldFallback != nil { -// oldFallback.Free() -// fmt.Printf("๐Ÿงน Cleaned up fallback VP8 encoder context\n") -// } -// -// return nil -// } -// -// func (e *VP8Encoder) pickContextAndProcess(frame *astiav.Frame) error { -// e.mux.Lock() -// defer e.mux.Unlock() -// -// if e.encoderContext != nil { -// return e.sendFrameAndPutPackets(e.encoderContext, frame) -// } -// if e.fallbackEncoderContext != nil { -// return e.sendFrameAndPutPackets(e.fallbackEncoderContext, frame) -// } -// return errors.New("invalid VP8 encoder context state") -// } -// -// func (e *VP8Encoder) sendFrameAndPutPackets(encoderContext *astiav.CodecContext, frame *astiav.Frame) error { -// defer e.filter.PutBack(frame) -// -// if err := encoderContext.SendFrame(frame); err != nil { -// return err -// } -// -// for { -// packet := e.buffer.Generate() -// if err := encoderContext.ReceivePacket(packet); err != nil { -// e.buffer.PutBack(packet) -// break -// } -// if err := e.pushPacket(packet); err != nil { -// e.buffer.PutBack(packet) -// continue -// } -// } -// return nil -// } -// -// func (e *VP8Encoder) pushPacket(packet *astiav.Packet) error { -// ctx, cancel := context.WithTimeout(e.ctx, time.Second) -// defer cancel() -// return e.buffer.Push(ctx, packet) -// } -// -// func (e *VP8Encoder) loop() { -// e.encoderContext.SetBitRate(2_000_000) -// fmt.Println("VP8 loop started") -// defer e.Close() -// -// for { -// select { -// case <-e.ctx.Done(): -// return -// case bitrate := <-e.bandwidthChan: -// fmt.Println("bitrate recommended:", bitrate) -// // if err := e.updateBitrate(bitrate); err != nil { -// // fmt.Printf("error while updating VP8 bitrate; err: %s\n", err.Error()) -// // } -// case frame := <-e.filter.WaitForFrame(): -// if err := e.pickContextAndProcess(frame); err != nil { -// if !errors.Is(err, astiav.ErrEagain) { -// continue -// } -// } -// } -// } -// } -// -// func (e *VP8Encoder) Close() { -// e.mux.Lock() -// defer e.mux.Unlock() -// -// if e.encoderContext != nil { -// e.encoderContext.Free() -// e.encoderContext = nil -// } -// if e.fallbackEncoderContext != nil { -// e.fallbackEncoderContext.Free() -// e.fallbackEncoderContext = nil -// } -// } -// -// // Helper functions for VP8 encoder creation -// func createNewVP8Encoder(codec *astiav.Codec, filter *Filter) (*astiav.CodecContext, error) { -// encoderContext := astiav.AllocCodecContext(codec) -// if encoderContext == nil { -// return nil, ErrorAllocateCodecContext -// } -// -// // Set VP8-specific context parameters -// withVideoSetEncoderContextParameter(filter, encoderContext) -// -// return encoderContext, nil -// } -// -// func createNewOpenVP8Encoder(codec *astiav.Codec, filter *Filter, settings *astiav.Dictionary) (*astiav.CodecContext, error) { -// encoderContext, err := createNewVP8Encoder(codec, filter) -// if err != nil { -// return nil, err -// } -// -// if err := openVP8Encoder(encoderContext, codec, settings); err != nil { -// encoderContext.Free() -// return nil, err -// } -// -// return encoderContext, nil -// } -// -// func openVP8Encoder(encoderContext *astiav.CodecContext, codec *astiav.Codec, settings *astiav.Dictionary) error { -// encoderContext.SetFlags(astiav.NewCodecContextFlags(astiav.CodecContextFlagGlobalHeader)) -// return encoderContext.Open(codec, settings) -// } diff --git a/pkg/decoder.go b/pkg/decoder.go index 527bf56..106b110 100644 --- a/pkg/decoder.go +++ b/pkg/decoder.go @@ -3,7 +3,6 @@ package transcode import ( "context" "errors" - "fmt" "time" "github.com/asticode/go-astiav" @@ -13,31 +12,39 @@ import ( "github.com/harshabose/simple_webrtc_comm/transcode/internal" ) -type Decoder struct { - demuxer *Demuxer +type GeneralDecoder struct { + demuxer CanProduceMediaPacket decoderContext *astiav.CodecContext codec *astiav.Codec buffer buffer.BufferWithGenerator[astiav.Frame] ctx context.Context + cancel context.CancelFunc } -func CreateDecoder(ctx context.Context, demuxer *Demuxer, options ...DecoderOption) (*Decoder, error) { +func CreateGeneralDecoder(ctx context.Context, canProduceMediaType CanProduceMediaPacket, options ...DecoderOption) (*GeneralDecoder, error) { var ( err error contextOption DecoderOption - decoder *Decoder + decoder *GeneralDecoder ) - decoder = &Decoder{ - demuxer: demuxer, - ctx: ctx, + ctx2, cancel := context.WithCancel(ctx) + decoder = &GeneralDecoder{ + demuxer: canProduceMediaType, + ctx: ctx2, + cancel: cancel, } - if demuxer.stream.CodecParameters().MediaType() == astiav.MediaTypeVideo { - contextOption = withVideoSetDecoderContext(demuxer) + canDescribeMediaPacket, ok := canProduceMediaType.(CanDescribeMediaPacket) + if !ok { + return nil, ErrorInterfaceMismatch } - if demuxer.stream.CodecParameters().MediaType() == astiav.MediaTypeAudio { - contextOption = withAudioSetDecoderContext(demuxer) + + if canDescribeMediaPacket.MediaType() == astiav.MediaTypeVideo { + contextOption = withVideoSetDecoderContext(canDescribeMediaPacket) + } + if canDescribeMediaPacket.MediaType() == astiav.MediaTypeAudio { + contextOption = withAudioSetDecoderContext(canDescribeMediaPacket) } options = append([]DecoderOption{contextOption}, options...) @@ -59,11 +66,19 @@ func CreateDecoder(ctx context.Context, demuxer *Demuxer, options ...DecoderOpti return decoder, nil } -func (decoder *Decoder) Start() { +func (decoder *GeneralDecoder) Ctx() context.Context { + return decoder.ctx +} + +func (decoder *GeneralDecoder) Start() { go decoder.loop() } -func (decoder *Decoder) loop() { +func (decoder *GeneralDecoder) Stop() { + decoder.cancel() +} + +func (decoder *GeneralDecoder) loop() { var ( packet *astiav.Packet frame *astiav.Frame @@ -95,7 +110,6 @@ loop1: frame.SetPictureType(astiav.PictureTypeNone) if err = decoder.pushFrame(frame); err != nil { - fmt.Println("warning: frame dropped!") decoder.buffer.PutBack(frame) continue loop2 } @@ -105,30 +119,105 @@ loop1: } } -func (decoder *Decoder) pushFrame(frame *astiav.Frame) error { +func (decoder *GeneralDecoder) pushFrame(frame *astiav.Frame) error { ctx, cancel := context.WithTimeout(decoder.ctx, time.Second) defer cancel() return decoder.buffer.Push(ctx, frame) } -func (decoder *Decoder) GetFrame() (*astiav.Frame, error) { - ctx, cancel := context.WithTimeout(decoder.ctx, time.Second) - defer cancel() - - return decoder.buffer.Pop(ctx) -} - -func (decoder *Decoder) WaitForFrame() chan *astiav.Frame { +func (decoder *GeneralDecoder) WaitForFrame() chan *astiav.Frame { return decoder.buffer.GetChannel() } -func (decoder *Decoder) PutBack(frame *astiav.Frame) { +func (decoder *GeneralDecoder) PutBack(frame *astiav.Frame) { decoder.buffer.PutBack(frame) } -func (decoder *Decoder) close() { +func (decoder *GeneralDecoder) close() { if decoder.decoderContext != nil { decoder.decoderContext.Free() } } + +func (decoder *GeneralDecoder) SetBuffer(buffer buffer.BufferWithGenerator[astiav.Frame]) { + decoder.buffer = buffer +} + +func (decoder *GeneralDecoder) SetCodec(producer CanDescribeMediaPacket) error { + if decoder.codec = astiav.FindDecoder(producer.CodecID()); decoder.codec == nil { + return ErrorNoCodecFound + } + decoder.decoderContext = astiav.AllocCodecContext(decoder.codec) + if decoder.decoderContext == nil { + return ErrorAllocateCodecContext + } + + return nil +} + +func (decoder *GeneralDecoder) FillContextContent(producer CanDescribeMediaPacket) error { + return producer.GetCodecParameters().ToCodecContext(decoder.decoderContext) +} + +func (decoder *GeneralDecoder) SetFrameRate(producer CanDescribeFrameRate) { + decoder.decoderContext.SetFramerate(producer.FrameRate()) +} + +func (decoder *GeneralDecoder) SetTimeBase(producer CanDescribeTimeBase) { + decoder.decoderContext.SetTimeBase(producer.TimeBase()) +} + +// ### IMPLEMENTS CanDescribeMediaVideoFrame + +func (decoder *GeneralDecoder) FrameRate() astiav.Rational { + return decoder.decoderContext.Framerate() +} + +func (decoder *GeneralDecoder) TimeBase() astiav.Rational { + return decoder.decoderContext.TimeBase() +} + +func (decoder *GeneralDecoder) Height() int { + return decoder.decoderContext.Height() +} + +func (decoder *GeneralDecoder) Width() int { + return decoder.decoderContext.Width() +} + +func (decoder *GeneralDecoder) PixelFormat() astiav.PixelFormat { + return decoder.decoderContext.PixelFormat() +} + +func (decoder *GeneralDecoder) SampleAspectRatio() astiav.Rational { + return decoder.decoderContext.SampleAspectRatio() +} + +func (decoder *GeneralDecoder) ColorSpace() astiav.ColorSpace { + return decoder.decoderContext.ColorSpace() +} + +func (decoder *GeneralDecoder) ColorRange() astiav.ColorRange { + return decoder.decoderContext.ColorRange() +} + +// ## CanDescribeMediaAudioFrame + +func (decoder *GeneralDecoder) SampleRate() int { + return decoder.decoderContext.SampleRate() +} + +func (decoder *GeneralDecoder) SampleFormat() astiav.SampleFormat { + return decoder.decoderContext.SampleFormat() +} + +func (decoder *GeneralDecoder) ChannelLayout() astiav.ChannelLayout { + return decoder.decoderContext.ChannelLayout() +} + +// ## CanDescribeMediaFrame + +func (decoder *GeneralDecoder) MediaType() astiav.MediaType { + return decoder.decoderContext.MediaType() +} diff --git a/pkg/decoder_options.go b/pkg/decoder_options.go index 6a99d46..e1d1a6e 100644 --- a/pkg/decoder_options.go +++ b/pkg/decoder_options.go @@ -2,63 +2,62 @@ package transcode import ( "github.com/asticode/go-astiav" - buffer "github.com/harshabose/tools/buffer/pkg" + + "github.com/harshabose/tools/buffer/pkg" "github.com/harshabose/simple_webrtc_comm/transcode/internal" ) -type DecoderOption = func(*Decoder) error +type DecoderOption = func(decoder Decoder) error -func withVideoSetDecoderContext(demuxer *Demuxer) func(*Decoder) error { - return func(decoder *Decoder) error { - var ( - err error - ) - - if decoder.codec = astiav.FindDecoder(demuxer.codecParameters.CodecID()); decoder.codec == nil { - return ErrorNoCodecFound +func withVideoSetDecoderContext(demuxer CanDescribeMediaPacket) DecoderOption { + return func(decoder Decoder) error { + consumer, ok := decoder.(CanSetMediaPacket) + if !ok { + return ErrorInterfaceMismatch } - if decoder.decoderContext = astiav.AllocCodecContext(decoder.codec); decoder.decoderContext == nil { - return ErrorAllocateCodecContext + if err := consumer.SetCodec(demuxer); err != nil { + return err } - if err = demuxer.stream.CodecParameters().ToCodecContext(decoder.decoderContext); err != nil { - return ErrorFillCodecContext + if err := consumer.FillContextContent(demuxer); err != nil { + return err } - decoder.decoderContext.SetFramerate(demuxer.formatContext.GuessFrameRate(demuxer.stream, nil)) - decoder.decoderContext.SetTimeBase(demuxer.stream.TimeBase()) + consumer.SetFrameRate(demuxer) + consumer.SetTimeBase(demuxer) return nil } } -func withAudioSetDecoderContext(demuxer *Demuxer) func(*Decoder) error { - return func(decoder *Decoder) error { - var ( - err error - ) - - if decoder.codec = astiav.FindDecoder(demuxer.codecParameters.CodecID()); decoder.codec == nil { - return ErrorNoCodecFound +func withAudioSetDecoderContext(demuxer CanDescribeMediaPacket) DecoderOption { + return func(decoder Decoder) error { + consumer, ok := decoder.(CanSetMediaPacket) + if !ok { + return ErrorInterfaceMismatch } - if decoder.decoderContext = astiav.AllocCodecContext(decoder.codec); decoder.decoderContext == nil { - return ErrorAllocateCodecContext + if err := consumer.SetCodec(demuxer); err != nil { + return err } - if err = demuxer.stream.CodecParameters().ToCodecContext(decoder.decoderContext); err != nil { - return ErrorFillCodecContext + if err := consumer.FillContextContent(demuxer); err != nil { + return err } - decoder.decoderContext.SetTimeBase(demuxer.stream.TimeBase()) + consumer.SetTimeBase(demuxer) return nil } } func WithDecoderBufferSize(size int) DecoderOption { - return func(decoder *Decoder) error { - decoder.buffer = buffer.CreateChannelBuffer(decoder.ctx, size, internal.CreateFramePool()) + return func(decoder Decoder) error { + s, ok := decoder.(CanSetBuffer[astiav.Frame]) + if !ok { + return ErrorInterfaceMismatch + } + s.SetBuffer(buffer.CreateChannelBuffer(decoder.Ctx(), size, internal.CreateFramePool())) return nil } } diff --git a/pkg/demuxer.go b/pkg/demuxer.go index 22e2b5a..5ad2485 100644 --- a/pkg/demuxer.go +++ b/pkg/demuxer.go @@ -2,17 +2,15 @@ package transcode import ( "context" - "fmt" "time" "github.com/asticode/go-astiav" - "github.com/harshabose/tools/buffer/pkg" - "github.com/harshabose/simple_webrtc_comm/transcode/internal" + "github.com/harshabose/tools/buffer/pkg" ) -type Demuxer struct { +type GeneralDemuxer struct { formatContext *astiav.FormatContext inputOptions *astiav.Dictionary inputFormat *astiav.InputFormat @@ -20,20 +18,27 @@ type Demuxer struct { codecParameters *astiav.CodecParameters buffer buffer.BufferWithGenerator[astiav.Packet] ctx context.Context + cancel context.CancelFunc } -func CreateDemuxer(ctx context.Context, containerAddress string, options ...DemuxerOption) (*Demuxer, error) { +func CreateGeneralDemuxer(ctx context.Context, containerAddress string, options ...DemuxerOption) (*GeneralDemuxer, error) { + ctx2, cancel := context.WithCancel(ctx) astiav.RegisterAllDevices() - demuxer := &Demuxer{ + demuxer := &GeneralDemuxer{ formatContext: astiav.AllocFormatContext(), inputOptions: astiav.NewDictionary(), - ctx: ctx, + ctx: ctx2, + cancel: cancel, } if demuxer.formatContext == nil { return nil, ErrorAllocateFormatContext } + if demuxer.inputOptions == nil { + return nil, ErrorGeneralAllocate + } + for _, option := range options { if err := option(demuxer); err != nil { return nil, err @@ -65,22 +70,27 @@ func CreateDemuxer(ctx context.Context, containerAddress string, options ...Demu return demuxer, nil } -func (demuxer *Demuxer) Start() { +func (demuxer *GeneralDemuxer) Ctx() context.Context { + return demuxer.ctx +} + +func (demuxer *GeneralDemuxer) Start() { go demuxer.loop() } -func (demuxer *Demuxer) loop() { - defer demuxer.close() +func (demuxer *GeneralDemuxer) Stop() { + demuxer.cancel() +} - ticker := time.NewTicker(time.Millisecond) - defer ticker.Stop() +func (demuxer *GeneralDemuxer) loop() { + defer demuxer.close() loop1: for { select { case <-demuxer.ctx.Done(): return - case <-ticker.C: + default: loop2: for { packet := demuxer.buffer.Generate() @@ -105,32 +115,63 @@ loop1: } } -func (demuxer *Demuxer) pushPacket(packet *astiav.Packet) error { +func (demuxer *GeneralDemuxer) pushPacket(packet *astiav.Packet) error { ctx, cancel := context.WithTimeout(demuxer.ctx, time.Second) // TODO: NEEDS TO BE BASED ON FPS ON INPUT_FORMAT defer cancel() return demuxer.buffer.Push(ctx, packet) } -func (demuxer *Demuxer) WaitForPacket() chan *astiav.Packet { +func (demuxer *GeneralDemuxer) WaitForPacket() chan *astiav.Packet { return demuxer.buffer.GetChannel() } -func (demuxer *Demuxer) GetPacket() (*astiav.Packet, error) { +func (demuxer *GeneralDemuxer) GetPacket() (*astiav.Packet, error) { ctx, cancel := context.WithTimeout(demuxer.ctx, time.Second) defer cancel() return demuxer.buffer.Pop(ctx) } -func (demuxer *Demuxer) PutBack(packet *astiav.Packet) { +func (demuxer *GeneralDemuxer) PutBack(packet *astiav.Packet) { demuxer.buffer.PutBack(packet) } -func (demuxer *Demuxer) close() { +func (demuxer *GeneralDemuxer) close() { if demuxer.formatContext != nil { demuxer.formatContext.CloseInput() - fmt.Println("closed container") demuxer.formatContext.Free() } } + +func (demuxer *GeneralDemuxer) SetInputOption(key, value string, flags astiav.DictionaryFlags) error { + return demuxer.inputOptions.Set(key, value, flags) +} + +func (demuxer *GeneralDemuxer) SetInputFormat(format *astiav.InputFormat) { + demuxer.inputFormat = format +} + +func (demuxer *GeneralDemuxer) SetBuffer(buffer buffer.BufferWithGenerator[astiav.Packet]) { + demuxer.buffer = buffer +} + +func (demuxer *GeneralDemuxer) GetCodecParameters() *astiav.CodecParameters { + return demuxer.codecParameters +} + +func (demuxer *GeneralDemuxer) MediaType() astiav.MediaType { + return demuxer.codecParameters.MediaType() +} + +func (demuxer *GeneralDemuxer) CodecID() astiav.CodecID { + return demuxer.codecParameters.CodecID() +} + +func (demuxer *GeneralDemuxer) FrameRate() astiav.Rational { + return demuxer.formatContext.GuessFrameRate(demuxer.stream, nil) +} + +func (demuxer *GeneralDemuxer) TimeBase() astiav.Rational { + return demuxer.stream.TimeBase() +} diff --git a/pkg/demuxer_options.go b/pkg/demuxer_options.go index 8ec268d..c9206cd 100644 --- a/pkg/demuxer_options.go +++ b/pkg/demuxer_options.go @@ -8,59 +8,78 @@ import ( "github.com/harshabose/simple_webrtc_comm/transcode/internal" ) -type DemuxerOption = func(*Demuxer) error +type DemuxerOption = func(demuxer Demuxer) error -func WithRTSPInputOption(demuxer *Demuxer) error { - var err error = nil - - if err = demuxer.inputOptions.Set("rtsp_transport", "tcp", 0); err != nil { +func WithRTSPInputOption(demuxer Demuxer) error { + s, ok := demuxer.(CanSetDemuxerInputOption) + if !ok { + return ErrorInterfaceMismatch + } + if err := s.SetInputOption("rtsp_transport", "tcp", 0); err != nil { return err } - if err = demuxer.inputOptions.Set("stimeout", "5000000", 0); err != nil { + if err := s.SetInputOption("stimeout", "5000000", 0); err != nil { return err } - if err = demuxer.inputOptions.Set("fflags", "nobuffer", 0); err != nil { + if err := s.SetInputOption("fflags", "nobuffer", 0); err != nil { return err } - if err = demuxer.inputOptions.Set("flags", "low_delay", 0); err != nil { + if err := s.SetInputOption("flags", "low_delay", 0); err != nil { return err } - if err = demuxer.inputOptions.Set("reorder_queue_size", "0", 0); err != nil { + if err := s.SetInputOption("reorder_queue_size", "0", 0); err != nil { return err } return nil } -func WithFileInputOption(demuxer *Demuxer) error { - if err := demuxer.inputOptions.Set("re", "", 0); err != nil { +func WithFileInputOption(demuxer Demuxer) error { + s, ok := demuxer.(CanSetDemuxerInputOption) + if !ok { + return ErrorInterfaceMismatch + } + if err := s.SetInputOption("re", "", 0); err != nil { return err } // // Additional options for smooth playback - // if err := demuxer.inputOptions.Set("fflags", "+genpts", 0); err != nil { + // if err := demuxer.inputOptions.SetInputOption("fflags", "+genpts", 0); err != nil { // return err // } return nil } -func WithAlsaInputFormatOption(demuxer *Demuxer) error { - demuxer.inputFormat = astiav.FindInputFormat("alsa") +func WithAlsaInputFormatOption(demuxer Demuxer) error { + s, ok := demuxer.(CanSetDemuxerInputFormat) + if !ok { + return ErrorInterfaceMismatch + } + s.SetInputFormat(astiav.FindInputFormat("alsa")) return nil } -func WithAvFoundationInputFormatOption(demuxer *Demuxer) error { - demuxer.inputFormat = astiav.FindInputFormat("avfoundation") +func WithAvFoundationInputFormatOption(demuxer Demuxer) error { + setInputFormat, ok := demuxer.(CanSetDemuxerInputFormat) + if !ok { + return ErrorInterfaceMismatch + } + setInputFormat.SetInputFormat(astiav.FindInputFormat("avfoundation")) - if err := demuxer.inputOptions.Set("video_size", "1280x720", 0); err != nil { + setInputOption, ok := demuxer.(CanSetDemuxerInputOption) + if !ok { + return ErrorInterfaceMismatch + } + + if err := setInputOption.SetInputOption("video_size", "1280x720", 0); err != nil { return err } - if err := demuxer.inputOptions.Set("framerate", "30", 0); err != nil { + if err := setInputOption.SetInputOption("framerate", "30", 0); err != nil { return err } - if err := demuxer.inputOptions.Set("pixel_format", "uyvy422", 0); err != nil { + if err := setInputOption.SetInputOption("pixel_format", "uyvy422", 0); err != nil { return err } @@ -68,8 +87,12 @@ func WithAvFoundationInputFormatOption(demuxer *Demuxer) error { } func WithDemuxerBufferSize(size int) DemuxerOption { - return func(demuxer *Demuxer) error { - demuxer.buffer = buffer.CreateChannelBuffer(demuxer.ctx, size, internal.CreatePacketPool()) + return func(demuxer Demuxer) error { + s, ok := demuxer.(CanSetBuffer[astiav.Packet]) + if !ok { + return ErrorInterfaceMismatch + } + s.SetBuffer(buffer.CreateChannelBuffer(demuxer.Ctx(), size, internal.CreatePacketPool())) return nil } } diff --git a/pkg/encoder.go b/pkg/encoder.go index fab998b..bfd8d51 100644 --- a/pkg/encoder.go +++ b/pkg/encoder.go @@ -1,279 +1,220 @@ package transcode -// -// import ( -// "context" -// "errors" -// "fmt" -// "math" -// "time" -// -// "github.com/asticode/go-astiav" -// -// "github.com/harshabose/tools/buffer/pkg" -// -// "github.com/harshabose/simple_webrtc_comm/transcode/internal" -// ) -// -// type Encoder struct { -// buffer buffer.BufferWithGenerator[astiav.Packet] -// filter *Filter -// ctx context.Context -// codec *astiav.Codec -// encoderContext *astiav.CodecContext -// codecFlags *astiav.Dictionary -// encoderSettings codecSettings -// bandwidthChan chan int64 -// previousBitrate int64 -// timer *time.Timer -// testMode bool // Add this flag -// sps []byte -// pps []byte -// } -// -// func CreateEncoder(ctx context.Context, codecID astiav.CodecID, filter *Filter, options ...EncoderOption) (*Encoder, error) { -// encoder := &Encoder{ -// filter: filter, -// codecFlags: astiav.NewDictionary(), -// ctx: ctx, -// } -// -// encoder.codec = astiav.FindEncoder(codecID) -// if encoder.encoderContext = astiav.AllocCodecContext(encoder.codec); encoder.encoderContext == nil { -// return nil, ErrorAllocateCodecContext -// } -// -// var contextOption EncoderOption -// if filter.sinkContext.MediaType() == astiav.MediaTypeAudio { -// contextOption = withAudioSetEncoderParameters(filter) -// } -// if filter.sinkContext.MediaType() == astiav.MediaTypeVideo { -// contextOption = withVideoSetEncoderParameters(filter) -// } -// -// options = append([]EncoderOption{contextOption}, options...) -// -// for _, option := range options { -// if err := option(encoder); err != nil { -// return nil, err -// } -// } -// -// if encoder.encoderSettings == nil { -// fmt.Println("warn: no encoder settings are provided") -// } -// -// encoder.encoderContext.SetFlags(astiav.NewCodecContextFlags(astiav.CodecContextFlagGlobalHeader)) -// -// if err := encoder.encoderContext.Open(encoder.codec, encoder.codecFlags); err != nil { -// return nil, err -// } -// -// if encoder.buffer == nil { -// encoder.buffer = buffer.CreateChannelBuffer(ctx, 256, internal.CreatePacketPool()) -// } -// -// encoder.findParameterSets(encoder.encoderContext.ExtraData()) -// -// return encoder, nil -// } -// -// func (encoder *Encoder) Start() { -// encoder.timer = time.NewTimer(10 * time.Second) -// go encoder.loop() -// } -// -// func (encoder *Encoder) GetParameterSets() ([]byte, []byte) { -// return encoder.sps, encoder.pps -// } -// -// func (encoder *Encoder) GetDuration() time.Duration { -// if encoder.encoderContext.MediaType() == astiav.MediaTypeAudio { -// return time.Duration(float64(time.Second) * float64(encoder.encoderContext.FrameSize()) / float64(encoder.encoderContext.SampleRate())) -// } -// return time.Duration(float64(time.Second) / encoder.encoderContext.Framerate().Float64()) -// } -// -// func (encoder *Encoder) GetTimeBase() astiav.Rational { -// return encoder.encoderContext.TimeBase() -// } -// -// func (encoder *Encoder) loop() { -// var ( -// frame *astiav.Frame -// packet *astiav.Packet -// err error -// ) -// defer encoder.close() -// -// loop1: -// for { -// select { -// case <-encoder.ctx.Done(): -// return -// case bitrate := <-encoder.bandwidthChan: // TODO: MIGHT NEED A MUTEX FOR THIS ONE CASE -// encoder.UpdateBitrate(bitrate) -// case frame = <-encoder.filter.WaitForFrame(): -// if err = encoder.encoderContext.SendFrame(frame); err != nil { -// encoder.filter.PutBack(frame) -// if !errors.Is(err, astiav.ErrEagain) { -// continue loop1 -// } -// } -// loop2: -// for { -// packet = encoder.buffer.Generate() -// if err = encoder.encoderContext.ReceivePacket(packet); err != nil { -// encoder.buffer.PutBack(packet) -// break loop2 -// } -// -// if err = encoder.pushPacket(packet); err != nil { -// encoder.buffer.PutBack(packet) -// continue loop2 -// } -// } -// encoder.filter.PutBack(frame) -// } -// } -// } -// -// func (encoder *Encoder) WaitForPacket() chan *astiav.Packet { -// return encoder.buffer.GetChannel() -// } -// -// func (encoder *Encoder) pushPacket(packet *astiav.Packet) error { -// ctx, cancel := context.WithTimeout(encoder.ctx, time.Second) -// defer cancel() -// -// return encoder.buffer.Push(ctx, packet) -// } -// -// func (encoder *Encoder) GetPacket() (*astiav.Packet, error) { -// ctx, cancel := context.WithTimeout(encoder.ctx, time.Second) -// defer cancel() -// -// return encoder.buffer.Pop(ctx) -// } -// -// func (encoder *Encoder) PutBack(packet *astiav.Packet) { -// encoder.buffer.PutBack(packet) -// } -// -// func (encoder *Encoder) SetBitrateChannel(channel chan int64) { -// encoder.bandwidthChan = channel -// } -// -// func (encoder *Encoder) close() { -// if encoder.encoderContext != nil { -// encoder.encoderContext.Free() -// } -// } -// -// func (encoder *Encoder) findParameterSets(extraData []byte) { -// if len(extraData) > 0 { -// // Find first start code (0x00000001) -// for i := 0; i < len(extraData)-4; i++ { -// if extraData[i] == 0 && extraData[i+1] == 0 && extraData[i+2] == 0 && extraData[i+3] == 1 { -// // Skip start code to get NAL type -// nalType := extraData[i+4] & 0x1F -// -// // Find next start code or end -// nextStart := len(extraData) -// for j := i + 4; j < len(extraData)-4; j++ { -// if extraData[j] == 0 && extraData[j+1] == 0 && extraData[j+2] == 0 && extraData[j+3] == 1 { -// nextStart = j -// break -// } -// } -// -// if nalType == 7 { // SPS -// encoder.sps = make([]byte, nextStart-i) -// copy(encoder.sps, extraData[i:nextStart]) -// } else if nalType == 8 { // PPS -// encoder.pps = make([]byte, len(extraData)-i) -// copy(encoder.pps, extraData[i:]) -// } -// -// i = nextStart - 1 -// } -// } -// fmt.Println("SPS for current encoder: ", encoder.sps) -// fmt.Println("PPS for current encoder: ", encoder.pps) -// } -// } -// -// func (encoder *Encoder) UpdateBitrate(bitrate int64) { -// // Show current encoder state -// currentEncoderBitrate := encoder.encoderContext.BitRate() -// fmt.Printf("recommended bitrate update to: %d (previous: %d, encoder actual: %d)\n", -// bitrate, encoder.previousBitrate, currentEncoderBitrate) -// -// if encoder.previousBitrate == 0 { -// encoder.SetBitrate(bitrate) -// newEncoderBitrate := encoder.encoderContext.BitRate() -// encoder.previousBitrate = bitrate -// fmt.Printf("initial bitrate set to: %d (encoder confirms: %d)\n", bitrate, newEncoderBitrate) -// return -// } -// -// change := math.Abs(float64(encoder.previousBitrate - bitrate)) -// changePercent := change / float64(encoder.previousBitrate) * 100 -// -// fmt.Printf("bitrate change: %.1f%% (%.0f -> %.0f)\n", -// changePercent, float64(encoder.previousBitrate), float64(bitrate)) -// -// shouldUpdate := false -// -// // Much more lenient thresholds for BWE: -// if changePercent >= 2.0 { // Was 5.0 - now accepts smaller changes -// if changePercent <= 300.0 { // Was 90.0 - now allows recovery -// shouldUpdate = true -// } else if encoder.previousBitrate <= 200 && bitrate > encoder.previousBitrate { -// // Special recovery case: allow any increase from very low bitrates -// shouldUpdate = true -// fmt.Printf("๐Ÿ”„ recovery mode: very low bitrate, allowing large increase\n") -// } -// } -// -// if shouldUpdate { -// oldEncoderBitrate := encoder.encoderContext.BitRate() -// encoder.SetBitrate(bitrate) -// newEncoderBitrate := encoder.encoderContext.BitRate() -// encoder.previousBitrate = bitrate -// fmt.Printf("โœ“ updated encoder bitrate: %d โ†’ %d โ†’ %d (target โ†’ old โ†’ new)\n", -// bitrate, oldEncoderBitrate, newEncoderBitrate) -// } else { -// fmt.Printf("โœ— bitrate change ignored (%.1f%% change), encoder remains at: %d\n", -// changePercent, currentEncoderBitrate) -// } -// } -// -// func (encoder *Encoder) SetBitrate(bitrate int64) { -// encoder.encoderContext.SetBitRate(bitrate) -// // if err := encoder.codecFlags.Set("bitrate", strconv.Itoa(int(bitrate/1000)), 0); err != nil { -// // fmt.Println("error while setting bitrate; err:", err.Error()) -// // } -// } -// -// // func (encoder *Encoder) UpdateBitrate(b int64) { -// // // Check if timer fired (only happens once) -// // select { -// // case <-encoder.timer.C: -// // fmt.Println("timer hit!!!!!!!!!!!") -// // encoder.testMode = true // Set flag -// // encoder.SetBitrate(250000000) -// // return -// // default: -// // // Continue to check flag -// // } -// // -// // // After timer fires, always use 10000 -// // if encoder.testMode { -// // fmt.Println("test mode active - forcing 10000") -// // encoder.SetBitrate(250000000) -// // } -// // -// // // Before timer fires, show current bitrate -// // fmt.Println("current bitrate:", encoder.encoderContext.BitRate()) -// // } +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "time" + + "github.com/asticode/go-astiav" + + "github.com/harshabose/tools/buffer/pkg" + + "github.com/harshabose/simple_webrtc_comm/transcode/internal" +) + +type GeneralEncoder struct { + buffer buffer.BufferWithGenerator[astiav.Packet] + filter CanProduceMediaFrame + codec *astiav.Codec + encoderContext *astiav.CodecContext + codecFlags *astiav.Dictionary + encoderSettings codecSettings + sps []byte + pps []byte + ctx context.Context + cancel context.CancelFunc +} + +func CreateGeneralEncoder(ctx context.Context, codecID astiav.CodecID, canProduceMediaFrame CanProduceMediaFrame, options ...EncoderOption) (*GeneralEncoder, error) { + ctx2, cancel := context.WithCancel(ctx) + encoder := &GeneralEncoder{ + filter: canProduceMediaFrame, + codecFlags: astiav.NewDictionary(), + ctx: ctx2, + cancel: cancel, + } + + encoder.codec = astiav.FindEncoder(codecID) + if encoder.encoderContext = astiav.AllocCodecContext(encoder.codec); encoder.encoderContext == nil { + return nil, ErrorAllocateCodecContext + } + + canDescribeMediaFrame, ok := canProduceMediaFrame.(CanDescribeMediaFrame) + if !ok { + return nil, ErrorInterfaceMismatch + } + if canDescribeMediaFrame.MediaType() == astiav.MediaTypeAudio { + withAudioSetEncoderContextParameters(canDescribeMediaFrame, encoder.encoderContext) + } + if canDescribeMediaFrame.MediaType() == astiav.MediaTypeVideo { + withVideoSetEncoderContextParameter(canDescribeMediaFrame, encoder.encoderContext) + } + + for _, option := range options { + if err := option(encoder); err != nil { + return nil, err + } + } + + if encoder.encoderSettings == nil { + fmt.Println("warn: no encoder settings are provided") + } + + encoder.encoderContext.SetFlags(astiav.NewCodecContextFlags(astiav.CodecContextFlagGlobalHeader)) + + if err := encoder.encoderContext.Open(encoder.codec, encoder.codecFlags); err != nil { + return nil, err + } + + if encoder.buffer == nil { + encoder.buffer = buffer.CreateChannelBuffer(ctx, 256, internal.CreatePacketPool()) + } + + encoder.findParameterSets(encoder.encoderContext.ExtraData()) + + return encoder, nil +} + +func (encoder *GeneralEncoder) Ctx() context.Context { + return encoder.ctx +} + +func (encoder *GeneralEncoder) Start() { + go encoder.loop() +} + +func (encoder *GeneralEncoder) GetParameterSets() ([]byte, []byte, error) { + encoder.findParameterSets(encoder.encoderContext.ExtraData()) + return encoder.sps, encoder.pps, nil +} + +func (encoder *GeneralEncoder) TimeBase() astiav.Rational { + return encoder.encoderContext.TimeBase() +} + +func (encoder *GeneralEncoder) loop() { + var ( + frame *astiav.Frame + packet *astiav.Packet + err error + ) + defer encoder.close() + +loop1: + for { + select { + case <-encoder.ctx.Done(): + return + case frame = <-encoder.filter.WaitForFrame(): + if err = encoder.encoderContext.SendFrame(frame); err != nil { + encoder.filter.PutBack(frame) + if !errors.Is(err, astiav.ErrEagain) { + continue loop1 + } + } + loop2: + for { + packet = encoder.buffer.Generate() + if err = encoder.encoderContext.ReceivePacket(packet); err != nil { + encoder.buffer.PutBack(packet) + break loop2 + } + + if err = encoder.pushPacket(packet); err != nil { + encoder.buffer.PutBack(packet) + continue loop2 + } + } + encoder.filter.PutBack(frame) + } + } +} + +func (encoder *GeneralEncoder) WaitForPacket() chan *astiav.Packet { + return encoder.buffer.GetChannel() +} + +func (encoder *GeneralEncoder) pushPacket(packet *astiav.Packet) error { + ctx, cancel := context.WithTimeout(encoder.ctx, time.Second) + defer cancel() + + return encoder.buffer.Push(ctx, packet) +} + +func (encoder *GeneralEncoder) PutBack(packet *astiav.Packet) { + encoder.buffer.PutBack(packet) +} + +func (encoder *GeneralEncoder) Stop() { + encoder.cancel() +} + +func (encoder *GeneralEncoder) close() { + if encoder.encoderContext != nil { + encoder.encoderContext.Free() + } + + if encoder.codecFlags != nil { + encoder.codecFlags.Free() + } +} + +func (encoder *GeneralEncoder) findParameterSets(extraData []byte) { + if len(extraData) > 0 { + // Find the first start code (0x00000001) + for i := 0; i < len(extraData)-4; i++ { + if extraData[i] == 0 && extraData[i+1] == 0 && extraData[i+2] == 0 && extraData[i+3] == 1 { + // Skip start code to get the NAL type + nalType := extraData[i+4] & 0x1F + + // Find the next start code or end + nextStart := len(extraData) + for j := i + 4; j < len(extraData)-4; j++ { + if extraData[j] == 0 && extraData[j+1] == 0 && extraData[j+2] == 0 && extraData[j+3] == 1 { + nextStart = j + break + } + } + + if nalType == 7 { // SPS + encoder.sps = make([]byte, nextStart-i) + copy(encoder.sps, extraData[i:nextStart]) + } else if nalType == 8 { // PPS + encoder.pps = make([]byte, len(extraData)-i) + copy(encoder.pps, extraData[i:]) + } + + i = nextStart - 1 + } + } + fmt.Println("SPS for current encoder: ", encoder.sps) + fmt.Println("\tSPS for current encoder in Base64:", base64.StdEncoding.EncodeToString(encoder.sps)) + fmt.Println("PPS for current encoder: ", encoder.pps) + fmt.Println("\tPPS for current encoder in Base64:", base64.StdEncoding.EncodeToString(encoder.pps)) + } +} + +func (encoder *GeneralEncoder) SetBuffer(buffer buffer.BufferWithGenerator[astiav.Packet]) { + encoder.buffer = buffer +} + +func (encoder *GeneralEncoder) SetEncoderCodecSettings(settings codecSettings) error { + encoder.encoderSettings = settings + return encoder.encoderSettings.ForEach(func(key string, value string) error { + if value == "" { + return nil + } + return encoder.codecFlags.Set(key, value, 0) + }) +} + +func (encoder *GeneralEncoder) GetCurrentBitrate() (int64, error) { + g, ok := encoder.encoderSettings.(CanGetCurrentBitrate) + if !ok { + return 0, ErrorInterfaceMismatch + } + + return g.GetCurrentBitrate() +} diff --git a/pkg/encoder_builder.go b/pkg/encoder_builder.go new file mode 100644 index 0000000..9fb21b0 --- /dev/null +++ b/pkg/encoder_builder.go @@ -0,0 +1,81 @@ +package transcode + +import ( + "context" + + "github.com/asticode/go-astiav" +) + +type GeneralEncoderBuilder struct { + codecID astiav.CodecID + bufferSize int + settings codecSettings + producer CanProduceMediaFrame +} + +func NewEncoderBuilder(codecID astiav.CodecID, settings codecSettings, bufferSize int, producer CanProduceMediaFrame) *GeneralEncoderBuilder { + return &GeneralEncoderBuilder{ + bufferSize: bufferSize, + codecID: codecID, + settings: settings, + producer: producer, + } +} + +func (b *GeneralEncoderBuilder) UpdateBitrate(bps int64) error { + s, ok := b.settings.(CanUpdateBitrate) + if !ok { + return ErrorInterfaceMismatch + } + + return s.UpdateBitrate(bps) +} + +func (b *GeneralEncoderBuilder) Build(ctx context.Context) (Encoder, error) { + codec := astiav.FindEncoder(b.codecID) + if codec == nil { + return nil, ErrorNoCodecFound + } + + ctx2, cancel := context.WithCancel(ctx) + encoder := &GeneralEncoder{ + filter: b.producer, + codec: codec, + codecFlags: astiav.NewDictionary(), + ctx: ctx2, + cancel: cancel, + } + + encoder.encoderContext = astiav.AllocCodecContext(codec) + if encoder.encoderContext == nil { + return nil, ErrorAllocateCodecContext + } + + canDescribeMediaFrame, ok := b.producer.(CanDescribeMediaFrame) + if !ok { + return nil, ErrorInterfaceMismatch + } + if canDescribeMediaFrame.MediaType() == astiav.MediaTypeAudio { + withAudioSetEncoderContextParameters(canDescribeMediaFrame, encoder.encoderContext) + } + if canDescribeMediaFrame.MediaType() == astiav.MediaTypeVideo { + withVideoSetEncoderContextParameter(canDescribeMediaFrame, encoder.encoderContext) + } + + if err := encoder.SetEncoderCodecSettings(b.settings); err != nil { + return nil, err + } + + if err := WithEncoderBufferSize(b.bufferSize)(encoder); err != nil { + return nil, err + } + encoder.encoderContext.SetFlags(astiav.NewCodecContextFlags(astiav.CodecContextFlagGlobalHeader)) + + if err := encoder.encoderContext.Open(encoder.codec, encoder.codecFlags); err != nil { + return nil, err + } + + encoder.findParameterSets(encoder.encoderContext.ExtraData()) + + return encoder, nil +} diff --git a/pkg/encoder_options.go b/pkg/encoder_options.go index 9274c60..c0c3b72 100644 --- a/pkg/encoder_options.go +++ b/pkg/encoder_options.go @@ -13,7 +13,7 @@ import ( ) type ( - EncoderOption = func(encoder *Encoder) error + EncoderOption = func(encoder Encoder) error ) type codecSettings interface { @@ -21,7 +21,6 @@ type codecSettings interface { } type X264Opts struct { - // RateControl string `x264-opts:"rate-control"` Bitrate string `x264-opts:"bitrate"` VBVMaxBitrate string `x264-opts:"vbv-maxrate"` VBVBuffer string `x264-opts:"vbv-bufsize"` @@ -30,9 +29,9 @@ type X264Opts struct { AnnexB string `x264-opts:"annexb"` } -func (x264 X264Opts) ForEach(fn func(string, string) error) error { - t := reflect.TypeOf(x264) - v := reflect.ValueOf(x264) +func (x264 *X264Opts) ForEach(fn func(string, string) error) error { + t := reflect.TypeOf(*x264) + v := reflect.ValueOf(*x264) // Build a single x264opts string var optParts []string @@ -67,9 +66,17 @@ func (x264 X264Opts) ForEach(fn func(string, string) error) error { return nil } +func (x264 *X264Opts) UpdateBitrate(bps int64) error { + x264.Bitrate = fmt.Sprintf("%d", bps/1000) + x264.VBVMaxBitrate = fmt.Sprintf("%d", (bps/1000)+200) + x264.VBVBuffer = fmt.Sprintf("%d", bps/2000) + + return nil +} + type X264OpenSettings struct { - X264Opts - RateControl string `x264:"rc"` + *X264Opts + // RateControl string `x264:"rc"` // not sure; fuck Preset string `x264:"preset"` // exists Tune string `x264:"tune"` // exists Refs string `x264:"refs"` // exists @@ -82,7 +89,7 @@ type X264OpenSettings struct { NGOP string `x264:"g"` // exists NGOPMin string `x264:"keyint_min"` // exists Scenecut string `x264:"sc_threshold"` // exists - InfraRefresh string `x264:"intra-refresh"` // exists + IntraRefresh string `x264:"intra-refresh"` // exists LookAhead string `x264:"rc-lookahead"` // exists SlicedThreads string `x264:"slice"` // exists ForceIDR string `x264:"force-idr"` // exists @@ -93,9 +100,9 @@ type X264OpenSettings struct { Aud string `x264:"aud"` // exists } -func (s X264OpenSettings) ForEach(fn func(key, value string) error) error { - t := reflect.TypeOf(s) - v := reflect.ValueOf(s) +func (s *X264OpenSettings) ForEach(fn func(key, value string) error) error { + t := reflect.TypeOf(*s) + v := reflect.ValueOf(*s) for i := 0; i < t.NumField(); i++ { field := t.Field(i) @@ -110,8 +117,12 @@ func (s X264OpenSettings) ForEach(fn func(key, value string) error) error { return s.X264Opts.ForEach(fn) } +func (s *X264OpenSettings) UpdateBitrate(bps int64) error { + return s.X264Opts.UpdateBitrate(bps) +} + var DefaultX264Settings = X264OpenSettings{ - X264Opts: X264Opts{ + X264Opts: &X264Opts{ // RateControl: "abr", Bitrate: "4000", VBVMaxBitrate: "5000", @@ -132,7 +143,7 @@ var DefaultX264Settings = X264OpenSettings{ NGOP: "250", NGOPMin: "25", Scenecut: "40", - InfraRefresh: "0", + IntraRefresh: "0", LookAhead: "40", SlicedThreads: "0", ForceIDR: "0", @@ -144,7 +155,7 @@ var DefaultX264Settings = X264OpenSettings{ } var LowBandwidthX264Settings = X264OpenSettings{ - X264Opts: X264Opts{ + X264Opts: &X264Opts{ // RateControl: "abr", Bitrate: "1500", VBVMaxBitrate: "1800", @@ -165,7 +176,7 @@ var LowBandwidthX264Settings = X264OpenSettings{ NGOP: "60", NGOPMin: "30", Scenecut: "30", - InfraRefresh: "0", + IntraRefresh: "0", LookAhead: "20", SlicedThreads: "1", ForceIDR: "0", @@ -177,7 +188,7 @@ var LowBandwidthX264Settings = X264OpenSettings{ } var LowLatencyX264Settings = X264OpenSettings{ - X264Opts: X264Opts{ + X264Opts: &X264Opts{ // RateControl: "abr", Bitrate: "2500", VBVMaxBitrate: "12000", @@ -198,7 +209,7 @@ var LowLatencyX264Settings = X264OpenSettings{ NGOP: "30", NGOPMin: "15", Scenecut: "0", - InfraRefresh: "1", + IntraRefresh: "1", LookAhead: "10", SlicedThreads: "1", ForceIDR: "1", @@ -211,7 +222,7 @@ var LowLatencyX264Settings = X264OpenSettings{ } var HighQualityX264Settings = X264OpenSettings{ - X264Opts: X264Opts{ + X264Opts: &X264Opts{ // RateControl: "abr", Bitrate: "15000", VBVMaxBitrate: "20000", @@ -232,7 +243,7 @@ var HighQualityX264Settings = X264OpenSettings{ NGOP: "250", NGOPMin: "30", Scenecut: "80", - InfraRefresh: "0", + IntraRefresh: "0", LookAhead: "60", SlicedThreads: "0", ForceIDR: "0", @@ -245,163 +256,123 @@ var HighQualityX264Settings = X264OpenSettings{ } var WebRTCOptimisedX264Settings = X264OpenSettings{ - X264Opts: X264Opts{ - // RateControl: "cbr", + X264Opts: &X264Opts{ Bitrate: "800", // Keep your current target VBVMaxBitrate: "900", // Same as target! VBVBuffer: "300", // 2500/30fps โ‰ˆ 83 kbits (single frame) RateTol: "0.1", // More tolerance - SyncLookAhead: "0", // Already correct AnnexB: "1", // Already correct }, - LookAhead: "0", // Critical fix! Qmin: "26", // Wider range Qmax: "42", // Much wider range Level: "3.1", // Better compatibility Preset: "ultrafast", Tune: "zerolatency", - Refs: "1", Profile: "baseline", - BFrames: "0", - BAdapt: "0", NGOP: "50", NGOPMin: "25", - Scenecut: "0", - InfraRefresh: "1", - SlicedThreads: "1", - ForceIDR: "1", - AQMode: "1", - AQStrength: "0.5", - MBTree: "0", + IntraRefresh: "1", + SlicedThreads: "1", // TODO: CHECK THIS + // ForceIDR: "1", // TODO: CHECK THIS; MIGHT BE IN CONFLICT WITH IntraRefresh + AQMode: "1", // RE-ENABLED AS zerolatency disables this + AQStrength: "0.5", Threads: "0", Aud: "1", } -func WithX264DefaultOptions(encoder *Encoder) error { - encoder.codecSettings = DefaultX264Settings - - return encoder.codecSettings.ForEach(func(key, value string) error { - return encoder.codecFlags.Set(key, value, 0) - }) +func WithX264DefaultOptions(encoder Encoder) error { + return WithCodecSettings(&DefaultX264Settings)(encoder) } -func WithX264HighQualityOptions(encoder *Encoder) error { - encoder.codecSettings = HighQualityX264Settings - - return encoder.codecSettings.ForEach(func(key, value string) error { - return encoder.codecFlags.Set(key, value, 0) - }) +func WithX264HighQualityOptions(encoder Encoder) error { + return WithCodecSettings(&HighQualityX264Settings)(encoder) } -func WithX264LowLatencyOptions(encoder *Encoder) error { - encoder.codecSettings = LowLatencyX264Settings - - return encoder.codecSettings.ForEach(func(key, value string) error { - return encoder.codecFlags.Set(key, value, 0) - }) +func WithX264LowLatencyOptions(encoder Encoder) error { + return WithCodecSettings(&LowLatencyX264Settings)(encoder) } -func WithWebRTCOptimisedOptions(encoder *Encoder) error { - encoder.codecSettings = WebRTCOptimisedX264Settings - - return encoder.codecSettings.ForEach(func(key, value string) error { - fmt.Printf("setting key (%s): value(%s)\n", key, value) - return encoder.codecFlags.Set(key, value, 0) - }) +func WithWebRTCOptimisedOptions(encoder Encoder) error { + return WithCodecSettings(&WebRTCOptimisedX264Settings)(encoder) } -func WithX264LowBandwidthOptions(encoder *Encoder) error { - encoder.codecSettings = LowBandwidthX264Settings +func WithCodecSettings(settings codecSettings) EncoderOption { + return func(encoder Encoder) error { + s, ok := encoder.(CanSetEncoderCodecSettings) + if !ok { + return ErrorInterfaceMismatch + } - return encoder.codecSettings.ForEach(func(key, value string) error { - return encoder.codecFlags.Set(key, value, 0) - }) + return s.SetEncoderCodecSettings(settings) + } } -// -// -// func WithDefaultVP8Options(encoder *VP8Encoder) error { -// encoder.codecSettings = DefaultVP8Settings -// -// return encoder.codecSettings.ForEach(func(key, value string) error { -// if value == "" { -// return nil -// } -// return encoder.codecFlags.Set(key, value, 0) -// }) -// } -// -// func withVideoSetEncoderParameters(filter *Filter) EncoderOption { -// return func(encoder *VP8Encoder) error { -// withVideoSetEncoderContextParameter(filter, encoder.encoderContext) -// return nil -// } -// } -// -// func withAudioSetEncoderParameters(filter *Filter) EncoderOption { -// return func(encoder *VP8Encoder) error { -// withAudioSetEncoderContextParameters(filter, encoder.encoderContext) -// return nil -// } -// } +func WithX264LowBandwidthOptions(encoder Encoder) error { + return WithCodecSettings(&LowBandwidthX264Settings)(encoder) +} -func withAudioSetEncoderContextParameters(filter *Filter, eCtx *astiav.CodecContext) { - eCtx.SetTimeBase(filter.sinkContext.TimeBase()) - eCtx.SetSampleRate(filter.sinkContext.SampleRate()) - eCtx.SetSampleFormat(filter.sinkContext.SampleFormat()) - eCtx.SetChannelLayout(filter.sinkContext.ChannelLayout()) +func withAudioSetEncoderContextParameters(filter CanDescribeMediaAudioFrame, eCtx *astiav.CodecContext) { + eCtx.SetTimeBase(filter.TimeBase()) + eCtx.SetSampleRate(filter.SampleRate()) + eCtx.SetSampleFormat(filter.SampleFormat()) + eCtx.SetChannelLayout(filter.ChannelLayout()) eCtx.SetStrictStdCompliance(-2) } -func withVideoSetEncoderContextParameter(filter *Filter, eCtx *astiav.CodecContext) { - eCtx.SetHeight(filter.sinkContext.Height()) - eCtx.SetWidth(filter.sinkContext.Width()) - eCtx.SetTimeBase(filter.sinkContext.TimeBase()) - eCtx.SetPixelFormat(filter.sinkContext.PixelFormat()) - eCtx.SetFramerate(filter.sinkContext.FrameRate()) +func withVideoSetEncoderContextParameter(filter CanDescribeMediaVideoFrame, eCtx *astiav.CodecContext) { + eCtx.SetHeight(filter.Height()) + eCtx.SetWidth(filter.Width()) + eCtx.SetTimeBase(filter.TimeBase()) + eCtx.SetPixelFormat(filter.PixelFormat()) + eCtx.SetFramerate(filter.FrameRate()) } func WithEncoderBufferSize(size int) EncoderOption { - return func(encoder *Encoder) error { - encoder.buffer = buffer.CreateChannelBuffer(encoder.ctx, size, internal.CreatePacketPool()) + return func(encoder Encoder) error { + s, ok := encoder.(CanSetBuffer[astiav.Packet]) + if !ok { + return ErrorInterfaceMismatch + } + s.SetBuffer(buffer.CreateChannelBuffer(encoder.Ctx(), size, internal.CreatePacketPool())) return nil } } -type VP8Settings struct { - Deadline string `vp8:"deadline"` // Real-time encoding - Bitrate string `vp8:"b"` // Target bitrate - MinRate string `vp8:"minrate"` // Minimum bitrate - MaxRate string `vp8:"maxrate"` // Maximum bitrate - BufSize string `vp8:"bufsize"` // Buffer size - CRF string `vp8:"crf"` // Quality setting - CPUUsed string `vp8:"cpu-used"` // Speed preset -} - -var DefaultVP8Settings = VP8Settings{ - Deadline: "1", // Real-time - Bitrate: "2500k", // 2.5 Mbps - MinRate: "2000k", // Min 2 Mbps - MaxRate: "3000k", // Max 3 Mbps - BufSize: "500k", // 500kb buffer - CRF: "10", // Good quality - CPUUsed: "8", // Fastest -} - -func (s VP8Settings) ForEach(fn func(key, value string) error) error { - t := reflect.TypeOf(s) - v := reflect.ValueOf(s) - - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - tag := field.Tag.Get("vp8") - if tag != "" { - if err := fn(tag, v.Field(i).String()); err != nil { - return err - } - } - } - - return nil -} +// +// type VP8Settings struct { +// Deadline string `vp8:"deadline"` // Real-time encoding +// Bitrate string `vp8:"b"` // Target bitrate +// MinRate string `vp8:"minrate"` // Minimum bitrate +// MaxRate string `vp8:"maxrate"` // Maximum bitrate +// BufSize string `vp8:"bufsize"` // Buffer size +// CRF string `vp8:"crf"` // Quality setting +// CPUUsed string `vp8:"cpu-used"` // Speed preset +// } +// +// var DefaultVP8Settings = VP8Settings{ +// Deadline: "1", // Real-time +// Bitrate: "2500k", // 2.5 Mbps +// MinRate: "2000k", // Min 2 Mbps +// MaxRate: "3000k", // Max 3 Mbps +// BufSize: "500k", // 500kb buffer +// CRF: "10", // Good quality +// CPUUsed: "8", // Fastest +// } +// +// func (s VP8Settings) ForEach(fn func(key, value string) error) error { +// t := reflect.TypeOf(s) +// v := reflect.ValueOf(s) +// +// for i := 0; i < t.NumField(); i++ { +// field := t.Field(i) +// tag := field.Tag.Get("vp8") +// if tag != "" { +// if err := fn(tag, v.Field(i).String()); err != nil { +// return err +// } +// } +// } +// +// return nil +// } diff --git a/pkg/errors.go b/pkg/errors.go index 63d614d..822d42b 100644 --- a/pkg/errors.go +++ b/pkg/errors.go @@ -6,7 +6,9 @@ var ( ErrorAllocateFormatContext = errors.New("error allocate format context") ErrorOpenInputContainer = errors.New("error opening container") ErrorNoStreamFound = errors.New("error no stream found") + ErrorGeneralAllocate = errors.New("error allocating general object") ErrorNoVideoStreamFound = errors.New("no video stream found") + ErrorInterfaceMismatch = errors.New("interface mismatch") ErrorNoCodecFound = errors.New("error no codec found") ErrorAllocateCodecContext = errors.New("error allocating codec context") diff --git a/pkg/filter.go b/pkg/filter.go index f93bd49..9e9b25e 100644 --- a/pkg/filter.go +++ b/pkg/filter.go @@ -13,9 +13,9 @@ import ( "github.com/harshabose/simple_webrtc_comm/transcode/internal" ) -type Filter struct { +type GeneralFilter struct { content string - decoder *Decoder + decoder CanProduceMediaFrame buffer buffer.BufferWithGenerator[astiav.Frame] graph *astiav.FilterGraph input *astiav.FilterInOut @@ -25,52 +25,55 @@ type Filter struct { srcContextParams *astiav.BuffersrcFilterContextParameters // NOTE: THIS BECOMES NIL AFTER INITIALISATION mux sync.RWMutex ctx context.Context + cancel context.CancelFunc } -func CreateFilter(ctx context.Context, decoder *Decoder, filterConfig *FilterConfig, options ...FilterOption) (*Filter, error) { - var ( - filter *Filter - filterSrc *astiav.Filter - filterSink *astiav.Filter - contextOption FilterOption - err error - ) - filter = &Filter{ +func CreateGeneralFilter(ctx context.Context, canProduceMediaFrame CanProduceMediaFrame, filterConfig FilterConfig, options ...FilterOption) (*GeneralFilter, error) { + ctx2, cancel := context.WithCancel(ctx) + filter := &GeneralFilter{ graph: astiav.AllocFilterGraph(), - decoder: decoder, + decoder: canProduceMediaFrame, input: astiav.AllocFilterInOut(), output: astiav.AllocFilterInOut(), srcContextParams: astiav.AllocBuffersrcFilterContextParameters(), - ctx: ctx, + ctx: ctx2, + cancel: cancel, } // TODO: CHECK IF ALL ATTRIBUTES ARE ALLOCATED PROPERLY - if filterSrc = astiav.FindFilterByName(filterConfig.Source.String()); filterSrc == nil { - return nil, ErrorNoFilterName - } - if filterSink = astiav.FindFilterByName(filterConfig.Sink.String()); filterSink == nil { + filterSrc := astiav.FindFilterByName(filterConfig.Source.String()) + if filterSrc == nil { return nil, ErrorNoFilterName } - if filter.srcContext, err = filter.graph.NewBuffersrcFilterContext(filterSrc, "in"); err != nil { + filterSink := astiav.FindFilterByName(filterConfig.Sink.String()) + if filterSink == nil { + return nil, ErrorNoFilterName + } + + srcContext, err := filter.graph.NewBuffersrcFilterContext(filterSrc, "in") + if err != nil { return nil, ErrorAllocSrcContext } + filter.srcContext = srcContext - if filter.sinkContext, err = filter.graph.NewBuffersinkFilterContext(filterSink, "out"); err != nil { + sinkContext, err := filter.graph.NewBuffersinkFilterContext(filterSink, "out") + if err != nil { return nil, ErrorAllocSinkContext } + filter.sinkContext = sinkContext - if decoder.decoderContext.MediaType() == astiav.MediaTypeVideo { - fmt.Println("video media type detected") - contextOption = withVideoSetFilterContextParameters(decoder) + canDescribeMediaFrame, ok := canProduceMediaFrame.(CanDescribeMediaFrame) + if !ok { + return nil, ErrorInterfaceMismatch } - if decoder.decoderContext.MediaType() == astiav.MediaTypeAudio { - fmt.Println("audio media type detected") - contextOption = withAudioSetFilterContextParameters(decoder) + if canDescribeMediaFrame.MediaType() == astiav.MediaTypeVideo { + options = append([]FilterOption{withVideoSetFilterContextParameters(canDescribeMediaFrame)}, options...) + } + if canDescribeMediaFrame.MediaType() == astiav.MediaTypeAudio { + options = append([]FilterOption{withAudioSetFilterContextParameters(canDescribeMediaFrame)}, options...) } - - options = append([]FilterOption{contextOption}, options...) for _, option := range options { if err = option(filter); err != nil { @@ -87,14 +90,10 @@ func CreateFilter(ctx context.Context, decoder *Decoder, filterConfig *FilterCon return nil, ErrorSrcContextSetParameter } - fmt.Println("check1") - if err = filter.srcContext.Initialize(astiav.NewDictionary()); err != nil { return nil, ErrorSrcContextInitialise } - fmt.Println("check2") - filter.output.SetName("in") filter.output.SetFilterContext(filter.srcContext.FilterContext()) filter.output.SetPadIdx(0) @@ -124,11 +123,19 @@ func CreateFilter(ctx context.Context, decoder *Decoder, filterConfig *FilterCon return filter, nil } -func (filter *Filter) Start() { +func (filter *GeneralFilter) Ctx() context.Context { + return filter.ctx +} + +func (filter *GeneralFilter) Start() { go filter.loop() } -func (filter *Filter) loop() { +func (filter *GeneralFilter) Stop() { + filter.cancel() +} + +func (filter *GeneralFilter) loop() { var ( err error = nil srcFrame *astiav.Frame @@ -166,29 +173,22 @@ loop1: } } -func (filter *Filter) pushFrame(frame *astiav.Frame) error { +func (filter *GeneralFilter) pushFrame(frame *astiav.Frame) error { ctx, cancel := context.WithTimeout(filter.ctx, time.Second) defer cancel() return filter.buffer.Push(ctx, frame) } -func (filter *Filter) GetFrame() (*astiav.Frame, error) { - ctx, cancel := context.WithTimeout(filter.ctx, time.Second) - defer cancel() - - return filter.buffer.Pop(ctx) -} - -func (filter *Filter) PutBack(frame *astiav.Frame) { +func (filter *GeneralFilter) PutBack(frame *astiav.Frame) { filter.buffer.PutBack(frame) } -func (filter *Filter) WaitForFrame() chan *astiav.Frame { +func (filter *GeneralFilter) WaitForFrame() chan *astiav.Frame { return filter.buffer.GetChannel() } -func (filter *Filter) close() { +func (filter *GeneralFilter) close() { if filter.graph != nil { filter.graph.Free() } @@ -199,3 +199,103 @@ func (filter *Filter) close() { filter.output.Free() } } + +func (filter *GeneralFilter) SetBuffer(buffer buffer.BufferWithGenerator[astiav.Frame]) { + filter.buffer = buffer +} + +func (filter *GeneralFilter) AddToFilterContent(content string) { + filter.content += content +} + +func (filter *GeneralFilter) SetFrameRate(describe CanDescribeFrameRate) { + filter.srcContextParams.SetFramerate(describe.FrameRate()) +} + +func (filter *GeneralFilter) SetTimeBase(describe CanDescribeTimeBase) { + filter.srcContextParams.SetTimeBase(describe.TimeBase()) +} + +func (filter *GeneralFilter) SetHeight(describe CanDescribeMediaVideoFrame) { + filter.srcContextParams.SetHeight(describe.Height()) +} + +func (filter *GeneralFilter) SetWidth(describe CanDescribeMediaVideoFrame) { + filter.srcContextParams.SetWidth(describe.Width()) +} + +func (filter *GeneralFilter) SetPixelFormat(describe CanDescribeMediaVideoFrame) { + filter.srcContextParams.SetPixelFormat(describe.PixelFormat()) +} + +func (filter *GeneralFilter) SetSampleAspectRatio(describe CanDescribeMediaVideoFrame) { + filter.srcContextParams.SetSampleAspectRatio(describe.SampleAspectRatio()) +} + +func (filter *GeneralFilter) SetColorSpace(describe CanDescribeMediaVideoFrame) { + filter.srcContextParams.SetColorSpace(describe.ColorSpace()) +} + +func (filter *GeneralFilter) SetColorRange(describe CanDescribeMediaVideoFrame) { + filter.srcContextParams.SetColorRange(describe.ColorRange()) +} + +func (filter *GeneralFilter) SetSampleRate(describe CanDescribeMediaAudioFrame) { + filter.srcContextParams.SetSampleRate(describe.SampleRate()) +} + +func (filter *GeneralFilter) SetSampleFormat(describe CanDescribeMediaAudioFrame) { + filter.srcContextParams.SetSampleFormat(describe.SampleFormat()) +} + +func (filter *GeneralFilter) SetChannelLayout(describe CanDescribeMediaAudioFrame) { + filter.srcContextParams.SetChannelLayout(describe.ChannelLayout()) +} + +func (filter *GeneralFilter) MediaType() astiav.MediaType { + return filter.sinkContext.MediaType() +} + +func (filter *GeneralFilter) FrameRate() astiav.Rational { + return filter.sinkContext.FrameRate() +} + +func (filter *GeneralFilter) TimeBase() astiav.Rational { + return filter.sinkContext.TimeBase() +} + +func (filter *GeneralFilter) Height() int { + return filter.sinkContext.Height() +} + +func (filter *GeneralFilter) Width() int { + return filter.sinkContext.Width() +} + +func (filter *GeneralFilter) PixelFormat() astiav.PixelFormat { + return filter.sinkContext.PixelFormat() +} + +func (filter *GeneralFilter) SampleAspectRatio() astiav.Rational { + return filter.sinkContext.SampleAspectRatio() +} + +func (filter *GeneralFilter) ColorSpace() astiav.ColorSpace { + return filter.sinkContext.ColorSpace() +} + +func (filter *GeneralFilter) ColorRange() astiav.ColorRange { + return filter.sinkContext.ColorRange() +} + +func (filter *GeneralFilter) SampleRate() int { + return filter.sinkContext.SampleRate() +} + +func (filter *GeneralFilter) SampleFormat() astiav.SampleFormat { + return filter.sinkContext.SampleFormat() +} + +func (filter *GeneralFilter) ChannelLayout() astiav.ChannelLayout { + return filter.sinkContext.ChannelLayout() +} diff --git a/pkg/filter_options.go b/pkg/filter_options.go index b5ee65b..7452285 100644 --- a/pkg/filter_options.go +++ b/pkg/filter_options.go @@ -12,7 +12,7 @@ import ( ) type ( - FilterOption func(*Filter) error + FilterOption func(Filter) error Name string ) @@ -33,136 +33,197 @@ const ( ) var ( - VideoFilters = &FilterConfig{ + VideoFilters = FilterConfig{ Source: videoBufferFilterName, Sink: videoBufferSinkFilterName, } - AudioFilters = &FilterConfig{ + AudioFilters = FilterConfig{ Source: audioBufferFilterName, Sink: audioBufferSinkFilterName, } ) func WithFilterBufferSize(size int) FilterOption { - return func(filter *Filter) error { - filter.buffer = buffer.CreateChannelBuffer(filter.ctx, size, internal.CreateFramePool()) + return func(filter Filter) error { + s, ok := filter.(CanSetBuffer[astiav.Frame]) + if !ok { + return ErrorInterfaceMismatch + } + s.SetBuffer(buffer.CreateChannelBuffer(filter.Ctx(), size, internal.CreateFramePool())) return nil } } -func withVideoSetFilterContextParameters(decoder *Decoder) func(*Filter) error { - return func(filter *Filter) error { - filter.srcContextParams.SetHeight(decoder.decoderContext.Height()) - filter.srcContextParams.SetPixelFormat(decoder.decoderContext.PixelFormat()) - filter.srcContextParams.SetSampleAspectRatio(decoder.decoderContext.SampleAspectRatio()) - filter.srcContextParams.SetTimeBase(decoder.decoderContext.TimeBase()) - filter.srcContextParams.SetWidth(decoder.decoderContext.Width()) +func withVideoSetFilterContextParameters(decoder CanDescribeMediaVideoFrame) func(Filter) error { + return func(filter Filter) error { + canSetMediaVideoFrame, ok := filter.(CanSetMediaVideoFrame) + if !ok { + return ErrorInterfaceMismatch + } - filter.srcContextParams.SetColorSpace(decoder.decoderContext.ColorSpace()) - filter.srcContextParams.SetColorRange(decoder.decoderContext.ColorRange()) + canSetMediaVideoFrame.SetFrameRate(decoder) + canSetMediaVideoFrame.SetHeight(decoder) + canSetMediaVideoFrame.SetPixelFormat(decoder) + canSetMediaVideoFrame.SetSampleAspectRatio(decoder) + canSetMediaVideoFrame.SetTimeBase(decoder) + canSetMediaVideoFrame.SetWidth(decoder) + + canSetMediaVideoFrame.SetColorSpace(decoder) + canSetMediaVideoFrame.SetColorRange(decoder) return nil } } func WithVideoScaleFilterContent(width, height uint16) FilterOption { - return func(filter *Filter) error { - filter.content += fmt.Sprintf("scale=%d:%d,", width, height) + return func(filter Filter) error { + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + + a.AddToFilterContent(fmt.Sprintf("scale=%d:%d,", width, height)) return nil } } func WithVideoPixelFormatFilterContent(pixelFormat astiav.PixelFormat) FilterOption { - return func(filter *Filter) error { - filter.content += fmt.Sprintf("format=pix_fmts=%s,", pixelFormat) + return func(filter Filter) error { + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + fmt.Println("pixel filter added:", pixelFormat.String()) + a.AddToFilterContent(fmt.Sprintf("format=pix_fmts=%s,", pixelFormat)) return nil } } func WithVideoFPSFilterContent(fps uint8) FilterOption { - return func(filter *Filter) error { - filter.content += fmt.Sprintf("fps=%d,", fps) + return func(filter Filter) error { + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + + a.AddToFilterContent(fmt.Sprintf("fps=%d,", fps)) return nil } } // +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -func withAudioSetFilterContextParameters(decoder *Decoder) func(*Filter) error { - return func(filter *Filter) error { - // Print parameter values before setting them - fmt.Println("Setting filter parameters with values:") - fmt.Printf(" Channel Layout: %v\n", decoder.decoderContext.ChannelLayout()) - fmt.Printf(" Sample Format: %v\n", decoder.decoderContext.SampleFormat()) - fmt.Printf(" Sample Rate: %v\n", decoder.decoderContext.SampleRate()) - fmt.Printf(" Time Base: %v\n", decoder.decoderContext.TimeBase()) - - // Set the parameters - filter.srcContextParams.SetChannelLayout(decoder.decoderContext.ChannelLayout()) - filter.srcContextParams.SetSampleFormat(decoder.decoderContext.SampleFormat()) - filter.srcContextParams.SetSampleRate(decoder.decoderContext.SampleRate()) - filter.srcContextParams.SetTimeBase(decoder.decoderContext.TimeBase()) +func withAudioSetFilterContextParameters(decoder CanDescribeMediaAudioFrame) func(Filter) error { + return func(filter Filter) error { + canSetMediaAudioFrame, ok := filter.(CanSetMediaAudioFrame) + if !ok { + return ErrorInterfaceMismatch + } + canSetMediaAudioFrame.SetChannelLayout(decoder) + canSetMediaAudioFrame.SetSampleFormat(decoder) + canSetMediaAudioFrame.SetSampleRate(decoder) + canSetMediaAudioFrame.SetTimeBase(decoder) return nil } } func WithAudioSampleFormatChannelLayoutFilter(sampleFormat astiav.SampleFormat, channelLayout astiav.ChannelLayout) FilterOption { - return func(filter *Filter) error { - filter.content += fmt.Sprintf("aformat=sample_fmts=%s:channel_layouts=%s", sampleFormat.String(), channelLayout.String()) + "," + return func(filter Filter) error { + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + + a.AddToFilterContent(fmt.Sprintf("aformat=sample_fmts=%s:channel_layouts=%s", sampleFormat.String(), channelLayout.String()) + ",") return nil } } func WithAudioSampleRateFilter(samplerate uint32) FilterOption { - return func(filter *Filter) error { - filter.content += fmt.Sprintf("aresample=%d,", samplerate) + return func(filter Filter) error { + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + + a.AddToFilterContent(fmt.Sprintf("aresample=%d,", samplerate)) return nil } } func WithAudioSamplesPerFrameContent(nsamples uint16) FilterOption { - return func(filter *Filter) error { - filter.content += fmt.Sprintf("asetnsamples=%d,", nsamples) + return func(filter Filter) error { + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + + a.AddToFilterContent(fmt.Sprintf("asetnsamples=%d,", nsamples)) return nil } } func WithAudioCompressionContent(threshold int, ratio int, attack float64, release float64) FilterOption { - return func(filter *Filter) error { + return func(filter Filter) error { // NOTE: DYNAMIC RANGE COMPRESSION TO HANDLE SUDDEN VOLUME CHANGES // Possible values 'acompressor=threshold=-12dB:ratio=2:attack=0.05:release=0.2" // MOST POPULAR VALUES - filter.content += fmt.Sprintf("acompressor=threshold=%ddB:ratio=%d:attack=%.2f:release=%.2f,", - threshold, ratio, attack, release) + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + + a.AddToFilterContent(fmt.Sprintf("acompressor=threshold=%ddB:ratio=%d:attack=%.2f:release=%.2f,", + threshold, ratio, attack, release)) return nil } } func WithAudioHighPassFilterContent(id string, frequency float32, order uint8) FilterOption { - return func(filter *Filter) error { + return func(filter Filter) error { // NOTE: HIGH-PASS FILTER TO REMOVE WIND NOISE AND TURBULENCE // NOTE: 120HZ CUTOFF MIGHT PRESERVE VOICE WHILE REMOVING LOW RUMBLE; BUT MORE TESTING IS NEEDED - filter.content += fmt.Sprintf("highpass@%s=frequency=%.2f:poles=%d", id, frequency, order) + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + + a.AddToFilterContent(fmt.Sprintf("highpass@%s=frequency=%.2f:poles=%d", id, frequency, order)) return nil } } func WithAudioLowPassFilterContent(id string, frequency float32, order uint8) FilterOption { - return func(filter *Filter) error { - filter.content += fmt.Sprintf("lowpass@%s=frequency=%.2f:poles=%d", id, frequency, order) + return func(filter Filter) error { + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + + a.AddToFilterContent(fmt.Sprintf("lowpass@%s=frequency=%.2f:poles=%d", id, frequency, order)) return nil } } func WithAudioNotchFilterContent(id string, frequency float32, qFactor float32) FilterOption { - return func(filter *Filter) error { - filter.content += fmt.Sprintf("bandreject@%s=frequency=%.2f:width_type=q:width=%.2f", id, frequency, qFactor) + return func(filter Filter) error { + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + + a.AddToFilterContent(fmt.Sprintf("bandreject@%s=frequency=%.2f:width_type=q:width=%.2f", id, frequency, qFactor)) return nil } } func WithAudioNotchHarmonicsFilterContent(id string, fundamental float32, harmonics uint8, qFactor float32) FilterOption { - return func(filter *Filter) error { + return func(filter Filter) error { + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + var filters = make([]string, 0) for i := uint8(0); i < harmonics; i++ { @@ -170,50 +231,75 @@ func WithAudioNotchHarmonicsFilterContent(id string, fundamental float32, harmon filters = append(filters, fmt.Sprintf("bandreject@%s%d=frequency=%.2f:width_type=q:width=%.2f", id, i, harmonic, qFactor)) } - filter.content += strings.Join(filters, ",") + "," + a.AddToFilterContent(strings.Join(filters, ",") + ",") return nil } } func WithAudioEqualiserFilter(id string, frequency float32, width float32, gain float32) FilterOption { - return func(filter *Filter) error { + return func(filter Filter) error { // NOTE: EQUALISER CAN BE USED TO ENHANCE SPEECH BANDWIDTH (300 - 3kHz). MORE RESEARCH NEEDS TO DONE - filter.content += fmt.Sprintf("equalizer@%s=frequency=%.2f:width_type=h:width=%.2f:gain=%.2f,", id, frequency, width, gain) + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + + a.AddToFilterContent(fmt.Sprintf("equalizer@%s=frequency=%.2f:width_type=h:width=%.2f:gain=%.2f,", id, frequency, width, gain)) return nil } } func WithAudioSilenceGateContent(threshold int, range_ int, attack float64, release float64) FilterOption { - return func(filter *Filter) error { + return func(filter Filter) error { // NOTE: IF EVERYTHING WORKS, WE SHOULD HAVE LIGHT NOISE WHICH CAN BE CONSIDERED AS SILENCE. THIS GATE REMOVES SILENCE // NOTE: POSSIBLE VALUES 'agate=threshold=-30dB:range=-30dB:attack=0.01:release=0.1" // MOST POPULAR; MORE TESTING IS NEEDED - filter.content += fmt.Sprintf("agate=threshold=%ddB:range=%ddB:attack=%.2f:release=%.2f,", - threshold, range_, attack, release) + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + + a.AddToFilterContent(fmt.Sprintf("agate=threshold=%ddB:range=%ddB:attack=%.2f:release=%.2f,", + threshold, range_, attack, release)) return nil } } func WithAudioLoudnessNormaliseContent(intensity int, truePeak float64, range_ int) FilterOption { - return func(filter *Filter) error { + return func(filter Filter) error { // NOTE: NORMALISES THE FINAL AUDIO. MUST BE CALLED AT THE END // NOTE: POSSIBLE VALUES "loudnorm=I=-16:TP=-1.5:LRA=11" // MOST POPULAR - filter.content += fmt.Sprintf("loudnorm=I=%d:TP=%.1f:LRA=%d", - intensity, truePeak, range_) + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + + a.AddToFilterContent(fmt.Sprintf("loudnorm=I=%d:TP=%.1f:LRA=%d", + intensity, truePeak, range_)) return nil } } func WithFFTBroadBandNoiseFilter(id string, strength float32, rPatch float32, rSearch float32) FilterOption { - return func(filter *Filter) error { + return func(filter Filter) error { // TODO: NEEDS A UPDATOR TO CONTROL NOISE SAMPLING - filter.content += fmt.Sprintf("") + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + + a.AddToFilterContent(fmt.Sprintf("")) return nil } } func WithMeanBroadBandNoiseFilter(id string, strength float32, rPatch float32, rSearch float32) FilterOption { - return func(filter *Filter) error { - filter.content += fmt.Sprintf("anlmdn@%s=strength=%.2f:patch=%.2f:research=%.2f", id, strength, rPatch, rSearch) + return func(filter Filter) error { + a, ok := filter.(CanAddToFilterContent) + if !ok { + return ErrorInterfaceMismatch + } + + a.AddToFilterContent(fmt.Sprintf("anlmdn@%s=strength=%.2f:patch=%.2f:research=%.2f", id, strength, rPatch, rSearch)) return nil } } diff --git a/pkg/interfaces.go b/pkg/interfaces.go new file mode 100644 index 0000000..2688b4e --- /dev/null +++ b/pkg/interfaces.go @@ -0,0 +1,163 @@ +package transcode + +import ( + "context" + + "github.com/asticode/go-astiav" + + "github.com/harshabose/tools/buffer/pkg" +) + +type CanSetDemuxerInputOption interface { + SetInputOption(key, value string, flags astiav.DictionaryFlags) error +} + +type CanSetDemuxerInputFormat interface { + SetInputFormat(*astiav.InputFormat) +} + +type CanSetBuffer[T any] interface { + SetBuffer(buffer buffer.BufferWithGenerator[T]) +} + +type CanDescribeFrameRate interface { + FrameRate() astiav.Rational +} + +type CanDescribeTimeBase interface { + TimeBase() astiav.Rational +} + +type CanSetFrameRate interface { + SetFrameRate(CanDescribeFrameRate) +} + +type CanSetTimeBase interface { + SetTimeBase(CanDescribeTimeBase) +} + +type CanDescribeMediaPacket interface { + MediaType() astiav.MediaType + CodecID() astiav.CodecID + GetCodecParameters() *astiav.CodecParameters + CanDescribeFrameRate + CanDescribeTimeBase +} + +type CanProduceMediaPacket interface { + WaitForPacket() chan *astiav.Packet + PutBack(*astiav.Packet) +} + +type CanProduceMediaFrame interface { + WaitForFrame() chan *astiav.Frame + PutBack(*astiav.Frame) +} + +type CanDescribeMediaVideoFrame interface { + CanDescribeFrameRate + CanDescribeTimeBase + Height() int + Width() int + PixelFormat() astiav.PixelFormat + SampleAspectRatio() astiav.Rational + ColorSpace() astiav.ColorSpace + ColorRange() astiav.ColorRange +} + +type CanSetMediaVideoFrame interface { + CanSetFrameRate + CanSetTimeBase + SetHeight(CanDescribeMediaVideoFrame) + SetWidth(CanDescribeMediaVideoFrame) + SetPixelFormat(CanDescribeMediaVideoFrame) + SetSampleAspectRatio(CanDescribeMediaVideoFrame) + SetColorSpace(CanDescribeMediaVideoFrame) + SetColorRange(CanDescribeMediaVideoFrame) +} + +type CanDescribeMediaFrame interface { + MediaType() astiav.MediaType + CanDescribeMediaVideoFrame + CanDescribeMediaAudioFrame +} + +type CanSetMediaAudioFrame interface { + CanSetTimeBase + SetSampleRate(CanDescribeMediaAudioFrame) + SetSampleFormat(CanDescribeMediaAudioFrame) + SetChannelLayout(CanDescribeMediaAudioFrame) +} + +type CanDescribeMediaAudioFrame interface { + CanDescribeTimeBase + SampleRate() int + SampleFormat() astiav.SampleFormat + ChannelLayout() astiav.ChannelLayout +} + +type CanSetMediaPacket interface { + FillContextContent(CanDescribeMediaPacket) error + SetCodec(CanDescribeMediaPacket) error + CanSetFrameRate + CanSetTimeBase +} + +type Demuxer interface { + Ctx() context.Context + Start() + Stop() + CanProduceMediaPacket +} + +type Decoder interface { + Ctx() context.Context + Start() + Stop() + CanProduceMediaFrame +} + +type CanAddToFilterContent interface { + AddToFilterContent(string) +} + +type Filter interface { + Ctx() context.Context + Start() + Stop() + CanProduceMediaFrame +} + +type CanPauseUnPauseEncoder interface { + PauseEncoding() error + UnPauseEncoding() error +} + +type CanGetParameterSets interface { + GetParameterSets() (sps, pps []byte, err error) +} + +type Encoder interface { + Ctx() context.Context + Start() + Stop() + CanProduceMediaPacket +} + +type CanSetEncoderCodecSettings interface { + SetEncoderCodecSettings(codecSettings) error +} + +type CanUpdateBitrate interface { + UpdateBitrate(int64) error +} + +type CanGetCurrentBitrate interface { + GetCurrentBitrate() (int64, error) +} + +type UpdateBitrateCallBack func(bps int64) error + +type CanGetUpdateBitrateCallBack interface { + OnUpdateBitrate() UpdateBitrateCallBack +} diff --git a/pkg/new_encoder.go b/pkg/new_encoder.go deleted file mode 100644 index 9bca6ab..0000000 --- a/pkg/new_encoder.go +++ /dev/null @@ -1,515 +0,0 @@ -package transcode - -import ( - "context" - "encoding/base64" - "errors" - "fmt" - "math" - "strconv" - "strings" - "sync" - "time" - - "github.com/asticode/go-astiav" - - "github.com/harshabose/simple_webrtc_comm/transcode/internal" - "github.com/harshabose/tools/buffer/pkg" -) - -type Encoder struct { - buffer buffer.BufferWithGenerator[astiav.Packet] - filter *Filter - codec *astiav.Codec - codecFlags *astiav.Dictionary - copyCodecFlags *astiav.Dictionary - codecSettings codecSettings - bandwidthChan chan int64 - options []EncoderOption - sps, pps []byte - - encoderContext *astiav.CodecContext - fallbackEncoderContext *astiav.CodecContext - - ctx context.Context - mux sync.Mutex -} - -func NewEncoder(ctx context.Context, codecID astiav.CodecID, filter *Filter, options ...EncoderOption) (*Encoder, error) { - encoder := &Encoder{ - filter: filter, - codecFlags: astiav.NewDictionary(), - ctx: ctx, - } - if encoder.codec = astiav.FindEncoder(codecID); encoder.codec == nil { - return nil, ErrorNoCodecFound - } - - encoderContext, err := createNewEncoder(encoder.codec, filter) - if err != nil { - return nil, err - } - encoder.encoderContext = encoderContext - - for _, option := range options { - if err := option(encoder); err != nil { - return nil, err - } - } - - if encoder.codecSettings == nil { - fmt.Println("warn: no encoder settings were provided") - } - - copyDict, err := copyDictionary(encoder.codecFlags) - if err != nil { - return nil, err - } - encoder.copyCodecFlags = copyDict - - if err := openEncoder(encoder.encoderContext, encoder.codec, encoder.codecFlags); err != nil { - return nil, err - } - - if encoder.buffer == nil { - encoder.buffer = buffer.CreateChannelBuffer(ctx, 256, internal.CreatePacketPool()) - } - - encoder.findParameterSets(encoder.encoderContext.ExtraData()) - - return encoder, nil -} - -func (e *Encoder) Start() { - go e.loop() -} - -func (e *Encoder) GetPacket() (*astiav.Packet, error) { - ctx, cancel := context.WithTimeout(e.ctx, time.Second) // TODO: Needs to be based on something - defer cancel() - - return e.buffer.Pop(ctx) -} - -func (e *Encoder) WaitForPacket() chan *astiav.Packet { - return e.buffer.GetChannel() -} - -func (e *Encoder) pushPacket(packet *astiav.Packet) error { - ctx, cancel := context.WithTimeout(e.ctx, time.Second) - defer cancel() - - return e.buffer.Push(ctx, packet) -} - -func (e *Encoder) PutBack(packet *astiav.Packet) { - e.buffer.PutBack(packet) -} - -func (e *Encoder) GetParameterSets() (sps []byte, pps []byte) { - sps = e.sps - pps = e.pps - - return sps, pps -} - -func (e *Encoder) GetTimeBase() astiav.Rational { - e.mux.Lock() - defer e.mux.Unlock() - - if e.encoderContext != nil { - return e.encoderContext.TimeBase() - } - if e.fallbackEncoderContext != nil { - return e.fallbackEncoderContext.TimeBase() - } - - return astiav.Rational{} -} - -func (e *Encoder) GetDuration() time.Duration { - e.mux.Lock() - defer e.mux.Unlock() - - if e.encoderContext != nil { - if e.encoderContext.MediaType() == astiav.MediaTypeAudio { - return time.Duration(float64(time.Second) * float64(e.encoderContext.FrameSize()) / float64(e.encoderContext.SampleRate())) - } - return time.Duration(float64(time.Second) / e.encoderContext.Framerate().Float64()) - - } - - if e.fallbackEncoderContext != nil { - if e.fallbackEncoderContext.MediaType() == astiav.MediaTypeAudio { - return time.Duration(float64(time.Second) * float64(e.fallbackEncoderContext.FrameSize()) / float64(e.fallbackEncoderContext.SampleRate())) - } - return time.Duration(float64(time.Second) / e.fallbackEncoderContext.Framerate().Float64()) - } - - return time.Second / 30 -} - -func (e *Encoder) SetBitrateChannel(channel chan int64) { - e.mux.Lock() - defer e.mux.Unlock() - - e.bandwidthChan = channel -} - -func (e *Encoder) createNewEncoderContext() error { - e.mux.Lock() - - e.fallbackEncoderContext = e.encoderContext - e.encoderContext = nil - copyDict, err := copyDictionary(e.copyCodecFlags) - if err != nil { - e.mux.Unlock() - return err - } - - e.codecFlags.Free() - e.codecFlags = nil - e.codecFlags = copyDict - - e.mux.Unlock() - - encoderContext, err := createNewOpenEncoder(e.codec, e.filter, e.codecFlags) - if err != nil { - e.mux.Lock() - e.encoderContext = e.fallbackEncoderContext - e.fallbackEncoderContext = nil - e.mux.Unlock() - - fmt.Printf("New encoder creation failed, reverted: %v\n", err) - return err - } - - e.mux.Lock() - oldFallback := e.fallbackEncoderContext - e.encoderContext = encoderContext - e.fallbackEncoderContext = nil // Free later - e.mux.Unlock() - - if oldFallback != nil { - oldFallback.Free() - oldFallback = nil - fmt.Printf("๐Ÿงน Cleaned up fallback encoder context\n") - } - - return nil -} - -func (e *Encoder) getCurrentBitrate() (int64, error) { - // Get the x264opts string - entry := e.copyCodecFlags.Get("x264opts", nil, 0) - if entry == nil { - return 0, errors.New("error getting x264opts from the dictionary") // Default value - } - - x264opts := entry.Value() - - // Parse bitrate from "bitrate=2500:vbv-maxrate=2500:..." - parts := strings.Split(x264opts, ":") - for _, part := range parts { - if strings.HasPrefix(part, "bitrate=") { - bitrateStr := strings.TrimPrefix(part, "bitrate=") - bitrate, err := strconv.ParseInt(bitrateStr, 10, 64) - if err != nil { - return 0, err - } - return bitrate, nil - } - } - - return 2500, errors.New("cannot find bitrate in the dictionary") // Default if not found -} - -func (e *Encoder) updateX264OptsWithNewBitrate(newBitrate int64) error { - entry := e.copyCodecFlags.Get("x264opts", nil, 0) - if entry == nil { - return errors.New("x264opts not found") - } - - x264opts := entry.Value() - parts := strings.Split(x264opts, ":") - - // Enforce level 3.1 limits (14,000 kbps max) - maxAllowedBitrate := int64(12000) // Stay below 14,000 limit - - // Clamp the new bitrate - if newBitrate > maxAllowedBitrate { - fmt.Printf("โš ๏ธ Bitrate %d clamped to %d (level 3.1 limit)\n", newBitrate, maxAllowedBitrate) - newBitrate = maxAllowedBitrate - } - - // Conservative VBV calculations within level limits - vbvMaxRate := min(newBitrate+200, maxAllowedBitrate) // Small headroom - vbvBuffer := min(newBitrate/2, 5000) // Cap buffer at 5000kb - - // Ensure minimum values - if vbvBuffer < 200 { - vbvBuffer = 200 - } - - // Ensure maxrate > bitrate (but within limits) - if vbvMaxRate <= newBitrate { - vbvMaxRate = min(newBitrate+100, maxAllowedBitrate) - } - - paramsToUpdate := map[string]string{ - "bitrate": fmt.Sprintf("%d", newBitrate), - "vbv-maxrate": fmt.Sprintf("%d", vbvMaxRate), - "vbv-bufsize": fmt.Sprintf("%d", vbvBuffer), - } - - // Find and replace each parameter - for paramName, paramValue := range paramsToUpdate { - found := false - for i, part := range parts { - if strings.HasPrefix(part, paramName+"=") { - parts[i] = fmt.Sprintf("%s=%s", paramName, paramValue) - found = true - break - } - } - - if !found { - parts = append(parts, fmt.Sprintf("%s=%s", paramName, paramValue)) - } - } - - newX264opts := strings.Join(parts, ":") - fmt.Printf("๐Ÿ”ง Safe update: bitrate=%d, vbv-maxrate=%d, vbv-bufsize=%d\n", - newBitrate, vbvMaxRate, vbvBuffer) - - return e.copyCodecFlags.Set("x264opts", newX264opts, 0) -} - -// updateBitrate updates the bitrate on codecFlags. The bitrate units are kbps (kilobits per second) -func (e *Encoder) updateBitrate(bitrate int64) error { - const maxReasonableBitrate = 50_000 // 50 Mbps max - if bitrate > maxReasonableBitrate { - fmt.Printf("๐Ÿšซ Rejecting unreasonable bitrate: %d kbps (max: %d kbps)\n", - bitrate, maxReasonableBitrate) - return nil - } - - start := time.Now() - - e.mux.Lock() - - current, err := e.getCurrentBitrate() - if err != nil { - e.mux.Unlock() - fmt.Println("error getting current bitrate; err:", err.Error()) - return err - } - - change := math.Abs(float64(current)-float64(bitrate)) / math.Abs(float64(current)) - - if change < 0.1 || change > 2 { - e.mux.Unlock() - fmt.Printf("change not appropriate; current: %d; new: %d; change:%f\n", current, bitrate, change) - return nil - } - - fmt.Println("change approved!!!!!!!!!!!!!!!!!!!!!!!!; change:", change) - - // NOTE: ONLY UPDATE IF CHANGE IS MORE THAN 10% AND LESS THAN 200% - if err := e.updateX264OptsWithNewBitrate(bitrate); err != nil { - e.mux.Unlock() - fmt.Println("error while updating the bitrate; err:", err.Error()) - return err - } - - e.mux.Unlock() - if err := e.createNewEncoderContext(); err != nil { - return err - } - - updated, err := e.getCurrentBitrate() - if err != nil { - return err - } - - duration := time.Since(start) - fmt.Printf("๐Ÿ”„ Bitrate updated: %d โ†’ %d expected change: (%.1f%%) in %v\n", - current, updated, change*100, duration) - - return nil -} - -func (e *Encoder) pickContextAndProcess(frame *astiav.Frame) error { - e.mux.Lock() - defer e.mux.Unlock() - - if e.encoderContext != nil { - if err := e.sendFrameAndPutPackets(e.encoderContext, frame); err != nil { - return err - } - - return nil - } - - if e.fallbackEncoderContext != nil { - if err := e.sendFrameAndPutPackets(e.fallbackEncoderContext, frame); err != nil { - return err - } - - return nil - } - - return errors.New("invalid encoder context state") -} - -func (e *Encoder) sendFrameAndPutPackets(encoderContext *astiav.CodecContext, frame *astiav.Frame) error { - // NOTE: MUX NOT NEEDED AS BUFFER IS NON-MUX IMPLEMENTATION - // TODO: DO I NEED MUX? - // NOTE: IF THE CALLED OF THIS FUNCTION LOCKS, DOES THE LOCK STILL PERSIST HERE? - defer e.filter.PutBack(frame) - - if err := encoderContext.SendFrame(frame); err != nil { - return err - } - - for { - packet := e.buffer.Generate() - if err := encoderContext.ReceivePacket(packet); err != nil { - e.buffer.PutBack(packet) - break - } - - if err := e.pushPacket(packet); err != nil { - e.buffer.PutBack(packet) - continue - } - } - - return nil -} - -func (e *Encoder) loop() { - defer e.Close() - - for { - select { - case <-e.ctx.Done(): - return - case bitrate := <-e.bandwidthChan: - fmt.Println("updated bitrate:", bitrate) - if err := e.updateBitrate(bitrate); err != nil { - fmt.Printf("error while encoding; err: %s\n", err.Error()) - } - case frame := <-e.filter.WaitForFrame(): - if err := e.pickContextAndProcess(frame); err != nil { - if !errors.Is(err, astiav.ErrEagain) { - continue - } - } - } - } -} - -func (e *Encoder) Close() { - e.mux.Lock() - defer e.mux.Unlock() - - if e.encoderContext != nil { - e.encoderContext.Free() - e.encoderContext = nil - } - - if e.fallbackEncoderContext != nil { - e.fallbackEncoderContext.Free() - e.encoderContext = nil - } -} - -func (e *Encoder) findParameterSets(extraData []byte) { - if len(extraData) > 0 { - // Find the first start code (0x00000001) - for i := 0; i < len(extraData)-4; i++ { - if extraData[i] == 0 && extraData[i+1] == 0 && extraData[i+2] == 0 && extraData[i+3] == 1 { - // Skip start code to get the NAL type - nalType := extraData[i+4] & 0x1F - - // Find the next start code or end - nextStart := len(extraData) - for j := i + 4; j < len(extraData)-4; j++ { - if extraData[j] == 0 && extraData[j+1] == 0 && extraData[j+2] == 0 && extraData[j+3] == 1 { - nextStart = j - break - } - } - - if nalType == 7 { // SPS - e.sps = make([]byte, nextStart-i) - copy(e.sps, extraData[i:nextStart]) - } else if nalType == 8 { // PPS - e.pps = make([]byte, len(extraData)-i) - copy(e.pps, extraData[i:]) - } - - i = nextStart - 1 - } - } - fmt.Println("SPS for current encoder: ", e.sps) - fmt.Println("PPS for current encoder: ", e.pps) - - // Convert to base64 - spsBase64 := base64.StdEncoding.EncodeToString(e.sps) - ppsBase64 := base64.StdEncoding.EncodeToString(e.pps) - - fmt.Printf("DefaultSPSBase64 = \"%s\"\n", spsBase64) - fmt.Printf("DefaultPPSBase64 = \"%s\"\n", ppsBase64) - } -} - -func createNewEncoder(codec *astiav.Codec, filter *Filter) (*astiav.CodecContext, error) { - encoderContext := astiav.AllocCodecContext(codec) - if encoderContext == nil { - return nil, ErrorAllocateCodecContext - } - - if filter.sinkContext.MediaType() == astiav.MediaTypeAudio { - withAudioSetEncoderContextParameters(filter, encoderContext) - } - if filter.sinkContext.MediaType() == astiav.MediaTypeVideo { - withVideoSetEncoderContextParameter(filter, encoderContext) - } - - return encoderContext, nil -} - -func createNewOpenEncoder(codec *astiav.Codec, filter *Filter, settings *astiav.Dictionary) (*astiav.CodecContext, error) { - encoderContext, err := createNewEncoder(codec, filter) - if err != nil { - return nil, err - } - - if err := openEncoder(encoderContext, codec, settings); err != nil { - return nil, err - } - - return encoderContext, nil -} - -func openEncoder(encoderContext *astiav.CodecContext, codec *astiav.Codec, settings *astiav.Dictionary) error { - encoderContext.SetFlags(astiav.NewCodecContextFlags(astiav.CodecContextFlagGlobalHeader)) - if err := encoderContext.Open(codec, settings); err != nil { - return err - } - - return nil -} - -func copyDictionary(source *astiav.Dictionary) (*astiav.Dictionary, error) { - copyBytes := source.Pack() - newDict := astiav.NewDictionary() - - if err := newDict.Unpack(copyBytes); err != nil { - return nil, err - } - - return newDict, nil -} diff --git a/pkg/notch_updates.go b/pkg/notch_updates.go index 30f8030..f5e453c 100644 --- a/pkg/notch_updates.go +++ b/pkg/notch_updates.go @@ -32,11 +32,11 @@ package transcode // ) // // type Updator interface { -// Start(*Filter) +// Start(*GeneralFilter) // } // // func WithUpdateFilter(updator Updator) FilterOption { -// return func(filter *Filter) error { +// return func(filter *GeneralFilter) error { // filter.updators = append(filter.updators, updator) // return nil // } @@ -203,7 +203,7 @@ package transcode // } // } // -// func (update *PropNoiseFilterUpdator) loop3(filter *Filter) { +// func (update *PropNoiseFilterUpdator) loop3(filter *GeneralFilter) { // ticker := time.NewTicker(update.interval) // defer ticker.Stop() // @@ -219,13 +219,13 @@ package transcode // } // } // -// func (update *PropNoiseFilterUpdator) Start(filter *Filter) { +// func (update *PropNoiseFilterUpdator) Start(filter *GeneralFilter) { // go update.loop1() // go update.loop2() // go update.loop3(filter) // } // -// func (update *PropNoiseFilterUpdator) update(filter *Filter) error { +// func (update *PropNoiseFilterUpdator) update(filter *GeneralFilter) error { // if filter == nil { // return errors.New("filter is nil") // } diff --git a/pkg/notch_updates_test.go b/pkg/notch_updates_test.go deleted file mode 100644 index c92d579..0000000 --- a/pkg/notch_updates_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package transcode - -import "testing" - -func TestPropellerNoise(t *testing.T) { - -} diff --git a/pkg/transcode_test.go b/pkg/transcode_test.go new file mode 100644 index 0000000..a8248b3 --- /dev/null +++ b/pkg/transcode_test.go @@ -0,0 +1,421 @@ +package transcode + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/asticode/go-astiav" +) + +func TestTranscoderWithAVFoundation(t *testing.T) { + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Create demuxer with AVFoundation input format + // Using "0" as input for facetime camera + demuxer, err := CreateGeneralDemuxer(ctx, "0", WithAvFoundationInputFormatOption) + if err != nil { + t.Fatalf("Failed to create demuxer: %v", err) + } + + // Create decoder + decoder, err := CreateGeneralDecoder(ctx, demuxer) + if err != nil { + t.Fatalf("Failed to create decoder: %v", err) + } + + // Create filter with video configuration + filter, err := CreateGeneralFilter(ctx, decoder, VideoFilters, + WithVideoScaleFilterContent(640, 480), + WithVideoPixelFormatFilterContent(astiav.PixelFormatYuv420P), + WithVideoFPSFilterContent(30)) + if err != nil { + t.Fatalf("Failed to create filter: %v", err) + } + + // Create encoder with H.264 codec + encoder, err := CreateGeneralEncoder(ctx, astiav.CodecIDH264, filter, WithX264LowLatencyOptions) + if err != nil { + t.Fatalf("Failed to create encoder: %v", err) + } + + // Create transcoder + transcoder := NewTranscoder(demuxer, decoder, filter, encoder) + + // Start the transcoder + transcoder.Start() + defer time.Sleep(2 * time.Second) + defer transcoder.Stop() + + // Wait for and process some packets to verify it's working + fmt.Println("Transcoder started, waiting for packets...") + + packetCount := 0 + timeout := time.After(5 * time.Second) + + for { + select { + case <-timeout: + // Test passed if we received some packets + if packetCount > 0 { + fmt.Printf("Test passed: received %d packets\n", packetCount) + return + } + t.Fatalf("Timeout reached without receiving any packets") + case packet := <-transcoder.WaitForPacket(): + packetCount++ + fmt.Printf("Received packet %d, size: %d bytes\n", packetCount, packet.Size()) + transcoder.PutBack(packet) + + // Exit after receiving a few packets + if packetCount >= 10 { + fmt.Printf("Test passed: received %d packets\n", packetCount) + return + } + } + } +} + +func TestTranscoderWithEncoderBuilder(t *testing.T) { + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Create demuxer with AVFoundation input format + // Using "0" as input for facetime camera + demuxer, err := CreateGeneralDemuxer(ctx, "0", WithAvFoundationInputFormatOption) + if err != nil { + t.Fatalf("Failed to create demuxer: %v", err) + } + + // Create decoder + decoder, err := CreateGeneralDecoder(ctx, demuxer) + if err != nil { + t.Fatalf("Failed to create decoder: %v", err) + } + + // Create filter with video configuration + filter, err := CreateGeneralFilter(ctx, decoder, VideoFilters, + WithVideoScaleFilterContent(640, 480), + WithVideoPixelFormatFilterContent(astiav.PixelFormatYuv420P), + WithVideoFPSFilterContent(30)) + if err != nil { + t.Fatalf("Failed to create filter: %v", err) + } + + // Create encoder with H.264 codec using EncoderBuilder + encoderBuilder := NewEncoderBuilder(astiav.CodecIDH264, &LowLatencyX264Settings, 10, filter) + encoder, err := encoderBuilder.Build(ctx) + if err != nil { + t.Fatalf("Failed to create encoder: %v", err) + } + + // Create transcoder + transcoder := NewTranscoder(demuxer, decoder, filter, encoder) + + // Start the transcoder + transcoder.Start() + defer time.Sleep(2 * time.Second) + defer transcoder.Stop() + + // Wait for and process some packets to verify it's working + fmt.Println("Transcoder with EncoderBuilder started, waiting for packets...") + + packetCount := 0 + timeout := time.After(5 * time.Second) + + for { + select { + case <-timeout: + // Test passed if we received some packets + if packetCount > 0 { + fmt.Printf("Test passed: received %d packets\n", packetCount) + return + } + t.Fatalf("Timeout reached without receiving any packets") + case packet := <-transcoder.WaitForPacket(): + packetCount++ + fmt.Printf("Received packet %d, size: %d bytes\n", packetCount, packet.Size()) + transcoder.PutBack(packet) + + // Exit after receiving a few packets + if packetCount >= 10 { + fmt.Printf("Test passed: received %d packets\n", packetCount) + return + } + } + } +} + +func TestTranscoderWithUpdateEncoder(t *testing.T) { + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Create demuxer with AVFoundation input format + // Using "0" as input for facetime camera + demuxer, err := CreateGeneralDemuxer(ctx, "0", WithAvFoundationInputFormatOption) + if err != nil { + t.Fatalf("Failed to create demuxer: %v", err) + } + + // Create decoder + decoder, err := CreateGeneralDecoder(ctx, demuxer) + if err != nil { + t.Fatalf("Failed to create decoder: %v", err) + } + + // Create filter with video configuration + filter, err := CreateGeneralFilter(ctx, decoder, VideoFilters, + WithVideoScaleFilterContent(640, 480), + WithVideoPixelFormatFilterContent(astiav.PixelFormatYuv420P), + WithVideoFPSFilterContent(30)) + if err != nil { + t.Fatalf("Failed to create filter: %v", err) + } + + // Create encoder builder with WebRTCOptimisedX264Settings + encoderBuilder := NewEncoderBuilder(astiav.CodecIDH264, &WebRTCOptimisedX264Settings, 10, filter) + + // Define min and max bitrates for testing (in bits per second) + minBitrate := int64(500_000) // 500 kbps + maxBitrate := int64(1_500_000) // 1.5 Mbps + + // Create UpdateEncoder with configuration + updateConfig := UpdateConfig{ + MinBitrate: minBitrate, + MaxBitrate: maxBitrate, + CutVideoBelowMinBitrate: false, // Don't pause when below min bitrate + } + fmt.Println("Trying to create updateEncoder...") + + updateEncoder, err := NewUpdateEncoder(ctx, updateConfig, encoderBuilder) + if err != nil { + t.Fatalf("Failed to create update encoder: %v", err) + } + + fmt.Println("Created updateEncoder successfully...") + + // Create transcoder + transcoder := NewTranscoder(demuxer, decoder, filter, updateEncoder) + + // Start the transcoder + transcoder.Start() + defer time.Sleep(2 * time.Second) + defer transcoder.Stop() + + // Wait for and process some packets to verify it's working + fmt.Println("Transcoder with UpdateEncoder started, waiting for packets...") + + // Define bitrates to test (in bits per second) + bitrateTests := []struct { + name string + bitrate int64 + }{ + {"Initial", 800_000}, // Initial bitrate (within range) + {"Within range 1", 1_000_000}, // Within range + {"Within range 2", 1_200_000}, // Within range + {"Above max", 2_000_000}, // Above max (should be capped) + {"Below min", 300_000}, // Below min (should be capped, not paused) + } + + // Function to wait for packets after bitrate change + waitForPackets := func(name string, count int) error { + receivedCount := 0 + timeout := time.After(3 * time.Second) + + for receivedCount < count { + select { + case <-timeout: + return fmt.Errorf("timeout waiting for packets after %s bitrate change", name) + case packet := <-transcoder.WaitForPacket(): + receivedCount++ + fmt.Printf("[%s] Received packet %d, size: %d bytes\n", name, receivedCount, packet.Size()) + fmt.Println("Trying to putback packet") + transcoder.PutBack(packet) + } + } + return nil + } + + // Test each bitrate + for _, test := range bitrateTests { + fmt.Printf("Updating bitrate to %d bps (%s)...\n", test.bitrate, test.name) + + err := transcoder.UpdateBitrate(test.bitrate) + if err != nil { + t.Fatalf("Failed to update bitrate to %d bps (%s): %v", test.bitrate, test.name, err) + } + fmt.Printf("Update to %d successfull\n", test.bitrate) + + // Wait for packets after bitrate change + if err := waitForPackets(test.name, 5); err != nil { + t.Fatal(err) + } + + fmt.Printf("Successfully received packets after %s bitrate change\n", test.name) + } + + fmt.Println("Test passed: successfully updated bitrate multiple times") +} + +func TestTranscoderWithUpdateEncoderAndPausing(t *testing.T) { + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + // Create demuxer with AVFoundation input format + // Using "0" as input for facetime camera + demuxer, err := CreateGeneralDemuxer(ctx, "0", WithAvFoundationInputFormatOption) + if err != nil { + t.Fatalf("Failed to create demuxer: %v", err) + } + + // Create decoder + decoder, err := CreateGeneralDecoder(ctx, demuxer) + if err != nil { + t.Fatalf("Failed to create decoder: %v", err) + } + + // Create filter with video configuration + filter, err := CreateGeneralFilter(ctx, decoder, VideoFilters, + WithVideoScaleFilterContent(640, 480), + WithVideoPixelFormatFilterContent(astiav.PixelFormatYuv420P), + WithVideoFPSFilterContent(30)) + if err != nil { + t.Fatalf("Failed to create filter: %v", err) + } + + // Create encoder builder with WebRTCOptimisedX264Settings + encoderBuilder := NewEncoderBuilder(astiav.CodecIDH264, &WebRTCOptimisedX264Settings, 10, filter) + + // Define min and max bitrates for testing (in bits per second) + minBitrate := int64(500_000) // 500 kbps + maxBitrate := int64(1_500_000) // 1.5 Mbps + + // Create UpdateEncoder with configuration that enables pausing + updateConfig := UpdateConfig{ + MinBitrate: minBitrate, + MaxBitrate: maxBitrate, + CutVideoBelowMinBitrate: true, // Pause when below min bitrate + } + + updateEncoder, err := NewUpdateEncoder(ctx, updateConfig, encoderBuilder) + if err != nil { + t.Fatalf("Failed to create update encoder: %v", err) + } + + // Create transcoder + transcoder := NewTranscoder(demuxer, decoder, filter, updateEncoder) + + // Start the transcoder + transcoder.Start() + defer time.Sleep(2 * time.Second) + defer transcoder.Stop() + + // Wait for and process some packets to verify it's working + fmt.Println("Transcoder with UpdateEncoder and pausing started, waiting for packets...") + + // Function to wait for packets with timeout + waitForPackets := func(name string, count int) error { + receivedCount := 0 + timeout := time.After(3 * time.Second) + + for receivedCount < count { + select { + case <-timeout: + return fmt.Errorf("timeout waiting for packets after %s bitrate change", name) + case packet := <-transcoder.WaitForPacket(): + receivedCount++ + fmt.Printf("[%s] Received packet %d, size: %d bytes\n", name, receivedCount, packet.Size()) + transcoder.PutBack(packet) + } + } + + fmt.Printf("Successfully received %d packets after %s bitrate change\n", count, name) + return nil + } + + // Function to test that WaitForPacket blocks when paused + testPauseBlocking := func(name string) error { + fmt.Printf("[%s] Testing that WaitForPacket blocks when paused...\n", name) + + // Start a goroutine to call WaitForPacket + packetReceived := make(chan bool, 1) + go func() { + select { + case <-transcoder.WaitForPacket(): + packetReceived <- true + } + }() + + // Wait briefly to see if packet is received (it shouldn't be) + select { + case <-packetReceived: + return fmt.Errorf("received packet when encoding should be paused") + case <-time.After(1 * time.Second): + fmt.Printf("[%s] Confirmed: WaitForPacket is blocking (encoder paused)\n", name) + return nil + } + } + + // First test with normal bitrate (should receive packets) + fmt.Println("Testing with normal bitrate (800 kbps)...") + if err := waitForPackets("Normal bitrate", 5); err != nil { + t.Fatal(err) + } + + // Update to below min bitrate with pausing enabled + belowMinBitrate := int64(300_000) // 300 kbps + fmt.Printf("Updating bitrate to %d bps (below min with pausing)...\n", belowMinBitrate) + + // Update bitrate in a goroutine since it might block if the encoder is processing + updateComplete := make(chan error, 1) + go func() { + updateComplete <- transcoder.UpdateBitrate(belowMinBitrate) + }() + + // Wait for update to complete + select { + case err := <-updateComplete: + if err != nil { + t.Fatalf("Failed to update bitrate to %d bps: %v", belowMinBitrate, err) + } + case <-time.After(5 * time.Second): + t.Fatal("Timeout waiting for bitrate update to complete") + } + + // Test that WaitForPacket blocks when paused + if err := testPauseBlocking("Below min with pausing"); err != nil { + t.Fatal(err) + } + + // Update back to normal bitrate (this should unpause) + normalBitrate := int64(800_000) // 800 kbps + fmt.Printf("Updating bitrate back to %d bps (normal)...\n", normalBitrate) + + // Update bitrate in a goroutine + go func() { + updateComplete <- transcoder.UpdateBitrate(normalBitrate) + }() + + // Wait for update to complete + select { + case err := <-updateComplete: + if err != nil { + t.Fatalf("Failed to update bitrate to %d bps: %v", normalBitrate, err) + } + case <-time.After(5 * time.Second): + t.Fatal("Timeout waiting for bitrate update to complete") + } + + // Wait for packets after bitrate change (should receive packets again) + if err := waitForPackets("Back to normal", 5); err != nil { + t.Fatal(err) + } + + fmt.Println("Test passed: successfully tested pausing with low bitrate") +} diff --git a/pkg/transcoder.go b/pkg/transcoder.go new file mode 100644 index 0000000..a6951d5 --- /dev/null +++ b/pkg/transcoder.go @@ -0,0 +1,94 @@ +package transcode + +import ( + "github.com/asticode/go-astiav" +) + +type Transcoder struct { + demuxer Demuxer + decoder Decoder + filter Filter + encoder Encoder +} + +func CreateTranscoder(options ...TranscoderOption) (*Transcoder, error) { + t := &Transcoder{} + for _, option := range options { + if err := option(t); err != nil { + return nil, err + } + } + + return t, nil +} + +func NewTranscoder(demuxer Demuxer, decoder Decoder, filter Filter, encoder Encoder) *Transcoder { + return &Transcoder{ + demuxer: demuxer, + decoder: decoder, + filter: filter, + encoder: encoder, + } +} + +func (t *Transcoder) Start() { + t.demuxer.Start() + t.decoder.Start() + t.filter.Start() + t.encoder.Start() +} + +func (t *Transcoder) Stop() { + t.encoder.Stop() + t.filter.Stop() + t.decoder.Stop() + t.demuxer.Stop() +} + +func (t *Transcoder) WaitForPacket() chan *astiav.Packet { + return t.encoder.WaitForPacket() +} + +func (t *Transcoder) PutBack(packet *astiav.Packet) { + t.encoder.PutBack(packet) +} + +func (t *Transcoder) PauseEncoding() error { + p, ok := t.encoder.(CanPauseUnPauseEncoder) + if !ok { + return ErrorInterfaceMismatch + } + + return p.PauseEncoding() +} + +func (t *Transcoder) UnPauseEncoding() error { + p, ok := t.encoder.(CanPauseUnPauseEncoder) + if !ok { + return ErrorInterfaceMismatch + } + + return p.UnPauseEncoding() +} + +func (t *Transcoder) GetParameterSets() (sps, pps []byte, err error) { + p, ok := t.encoder.(CanGetParameterSets) + if !ok { + return nil, nil, ErrorInterfaceMismatch + } + + return p.GetParameterSets() +} + +func (t *Transcoder) UpdateBitrate(bps int64) error { + u, ok := t.encoder.(CanUpdateBitrate) + if !ok { + return ErrorInterfaceMismatch + } + + return u.UpdateBitrate(bps) +} + +func (t *Transcoder) OnUpdateBitrate() UpdateBitrateCallBack { + return t.UpdateBitrate +} diff --git a/pkg/transcoder_options.go b/pkg/transcoder_options.go new file mode 100644 index 0000000..3674713 --- /dev/null +++ b/pkg/transcoder_options.go @@ -0,0 +1,70 @@ +package transcode + +import ( + "context" + + "github.com/asticode/go-astiav" +) + +type TranscoderOption = func(*Transcoder) error + +func WithGeneralDemuxer(ctx context.Context, containerAddress string, options ...DemuxerOption) TranscoderOption { + return func(transcoder *Transcoder) error { + demuxer, err := CreateGeneralDemuxer(ctx, containerAddress, options...) + if err != nil { + return err + } + + transcoder.demuxer = demuxer + return nil + } +} + +func WithGeneralDecoder(ctx context.Context, options ...DecoderOption) TranscoderOption { + return func(transcoder *Transcoder) error { + decoder, err := CreateGeneralDecoder(ctx, transcoder.demuxer, options...) + if err != nil { + return err + } + + transcoder.decoder = decoder + return nil + } +} + +func WithGeneralFilter(ctx context.Context, filterConfig FilterConfig, options ...FilterOption) TranscoderOption { + return func(transcoder *Transcoder) error { + filter, err := CreateGeneralFilter(ctx, transcoder.decoder, filterConfig, options...) + if err != nil { + return err + } + + transcoder.filter = filter + return nil + } +} + +func WithGeneralEncoder(ctx context.Context, codecID astiav.CodecID, options ...EncoderOption) TranscoderOption { + return func(transcoder *Transcoder) error { + encoder, err := CreateGeneralEncoder(ctx, codecID, transcoder.filter, options...) + if err != nil { + return err + } + + transcoder.encoder = encoder + return nil + } +} + +func WithBitrateControlEncoder(ctx context.Context, codecID astiav.CodecID, bitrateControlConfig UpdateConfig, settings codecSettings, bufferSize int) TranscoderOption { + return func(transcoder *Transcoder) error { + builder := NewEncoderBuilder(codecID, settings, bufferSize, transcoder.filter) + updateEncoder, err := NewUpdateEncoder(ctx, bitrateControlConfig, builder) + if err != nil { + return err + } + + transcoder.encoder = updateEncoder + return nil + } +} diff --git a/pkg/update_encoder_wrapper.go b/pkg/update_encoder_wrapper.go new file mode 100644 index 0000000..fa599f2 --- /dev/null +++ b/pkg/update_encoder_wrapper.go @@ -0,0 +1,201 @@ +package transcode + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + + "github.com/asticode/go-astiav" +) + +type UpdateConfig struct { + MaxBitrate, MinBitrate int64 + CutVideoBelowMinBitrate bool +} + +type UpdateEncoder struct { + encoder Encoder + config UpdateConfig + builder *GeneralEncoderBuilder + mux sync.RWMutex + ctx context.Context + + paused atomic.Bool + resume chan struct{} + pauseMux sync.Mutex +} + +func NewUpdateEncoder(ctx context.Context, config UpdateConfig, builder *GeneralEncoderBuilder) (*UpdateEncoder, error) { + updater := &UpdateEncoder{ + config: config, + builder: builder, + resume: make(chan struct{}), + ctx: ctx, + } + + encoder, err := builder.Build(ctx) + if err != nil { + return nil, err + } + + updater.encoder = encoder + + return updater, nil +} + +func (u *UpdateEncoder) Ctx() context.Context { + u.mux.Lock() + defer u.mux.Unlock() + + return u.encoder.Ctx() +} + +func (u *UpdateEncoder) Start() { + u.mux.Lock() + defer u.mux.Unlock() + + u.encoder.Start() +} + +func (u *UpdateEncoder) WaitForPacket() chan *astiav.Packet { + if u.paused.Load() { + <-u.resume + } + + return u.encoder.WaitForPacket() +} + +func (u *UpdateEncoder) PutBack(packet *astiav.Packet) { + u.mux.RLock() + defer u.mux.RUnlock() + + u.encoder.PutBack(packet) +} + +func (u *UpdateEncoder) Stop() { + u.mux.Lock() + defer u.mux.Unlock() + + u.encoder.Stop() +} + +// UpdateBitrate modifies the encoder's target bitrate to the specified value in bits per second. +// Returns an error if the update fails. +func (u *UpdateEncoder) UpdateBitrate(bps int64) error { + if err := u.checkPause(bps); err != nil { + return err + } + + bps = u.cutoff(bps) + + g, ok := u.encoder.(CanGetCurrentBitrate) + if !ok { + return ErrorInterfaceMismatch + } + + current, err := g.GetCurrentBitrate() + if err != nil { + return err + } + fmt.Printf("got bitrate update request (%d -> %d)\n", current, bps) + + _, change := u.calculateBitrateChange(current, bps) + if change < 5 { + return nil + } + + if err := u.builder.UpdateBitrate(bps); err != nil { + return err + } + + newEncoder, err := u.builder.Build(u.ctx) + if err != nil { + return fmt.Errorf("build new encoder: %w", err) + } + + newEncoder.Start() + + u.mux.Lock() + oldEncoder := u.encoder + u.encoder = newEncoder + u.mux.Unlock() + + if oldEncoder != nil { + oldEncoder.Stop() + } + + // Print encoder update notification + fmt.Println() + fmt.Println("โ•”โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•—") + fmt.Println("โ•‘ ๐ŸŽฅ ENCODER UPDATED ๐ŸŽฅ โ•‘") + fmt.Printf("โ•‘ New Bitrate: %6d kbps โ•‘\n", bps/1000) + fmt.Println("โ•šโ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•") + fmt.Println() + + return nil +} + +func (u *UpdateEncoder) cutoff(bps int64) int64 { + if bps > u.config.MaxBitrate { + bps = u.config.MaxBitrate + } + + if bps < u.config.MinBitrate { + bps = u.config.MinBitrate + } + + return bps +} + +func (u *UpdateEncoder) shouldPause(bps int64) bool { + return bps <= u.config.MinBitrate && u.config.CutVideoBelowMinBitrate +} + +func (u *UpdateEncoder) checkPause(bps int64) error { + shouldPause := u.shouldPause(bps) + + if shouldPause { + fmt.Println("pausing video...") + return u.PauseEncoding() + } + return u.UnPauseEncoding() +} + +func (u *UpdateEncoder) PauseEncoding() error { + u.paused.Store(true) + return nil +} + +func (u *UpdateEncoder) UnPauseEncoding() error { + u.pauseMux.Lock() + defer u.pauseMux.Unlock() + + if u.paused.Swap(false) { + close(u.resume) + u.resume = make(chan struct{}) + } + return nil +} + +func (u *UpdateEncoder) GetParameterSets() (sps []byte, pps []byte, err error) { + p, ok := u.encoder.(CanGetParameterSets) + if !ok { + return nil, nil, ErrorInterfaceMismatch + } + + return p.GetParameterSets() +} + +func (u *UpdateEncoder) calculateBitrateChange(currentBps, newBps int64) (absoluteChange int64, percentageChange float64) { + absoluteChange = newBps - currentBps + if absoluteChange < 0 { + absoluteChange = -absoluteChange + } + + if currentBps > 0 { + percentageChange = (float64(absoluteChange) / float64(currentBps)) * 100 + } + + return absoluteChange, percentageChange +} diff --git a/pkg/x264options.go b/pkg/x264options.go new file mode 100644 index 0000000..af9dc15 --- /dev/null +++ b/pkg/x264options.go @@ -0,0 +1,141 @@ +package transcode + +import ( + "reflect" + "strconv" + "strings" +) + +type X264AdvancedOptions struct { + // PRIMARY OPTIONS + Bitrate string `x264-opts:"bitrate"` + VBVMaxBitrate string `x264-opts:"vbv-maxrate"` + VBVBuffer string `x264-opts:"vbv-bufsize"` + RateTolerance string `x264-opts:"ratetol"` + MaxGOP string `x264-opts:"keyint"` + MinGOP string `x264-opts:"min-keyint"` + MaxQP string `x264-opts:"qpmax"` + MinQP string `x264-opts:"qpmin"` + MaxQPStep string `x264-opts:"qpstep"` + IntraRefresh string `x264-opts:"intra-refresh"` + ConstrainedIntra string `x264-opts:"constrained-intra"` + + // SECONDARY OPTIONS; SOME OF THEM ARE ALREADY SET BY PRESET, PROFILE AND TUNE + SceneCut string `x264-opts:"scenecut"` + BFrames string `x264-opts:"bframes"` + BAdapt string `x264-opts:"b-adapt"` + Refs string `x264-opts:"ref"` + RCLookAhead string `x264-opts:"rc-lookahead"` + AQMode string `x264-opts:"aq-mode"` + NalHrd string `x264-opts:"nal-hrd"` +} + +func (o *X264AdvancedOptions) ForEach(f func(key, value string) error) error { + t := reflect.TypeOf(*o) + v := reflect.ValueOf(*o) + + // Build a single x264opts string + var optParts []string + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + tag := field.Tag.Get("x264-opts") + if tag != "" && v.Field(i).String() != "" { + optParts = append(optParts, tag+"="+v.Field(i).String()) + } + } + + // Join all options with colons + if len(optParts) > 0 { + x264optsValue := strings.Join(optParts, ":") + if err := f("x264opts", x264optsValue); err != nil { + return err + } + } + + return nil +} + +func (o *X264AdvancedOptions) UpdateBitrate(bps int64) error { + kbps := bps / 1000 + + // Core bitrate settings (strict CBR) + o.Bitrate = strconv.FormatInt(kbps, 10) + o.VBVMaxBitrate = strconv.FormatInt(kbps, 10) // Same as bitrate for CBR + + // VBV buffer: 0.5 seconds for low latency + // Formula: buffer_kb = (bitrate_kbps * buffer_duration_seconds) + // Minimum of 100 kb might be needed. TODO: do more research + bufferKb := max(kbps/2, 100) // 0.5 seconds = 1/2 second + o.VBVBuffer = strconv.FormatInt(bufferKb, 10) + + return nil +} + +func (o *X264AdvancedOptions) GetCurrentBitrate() (int64, error) { + kbps, err := strconv.ParseInt(o.Bitrate, 10, 64) + if err != nil { + return 0, err + } + return kbps * 1000, nil // Convert kbps to bps +} + +type X264Options struct { + *X264AdvancedOptions + // PRECOMPILED OPTIONS + Profile string `x264:"profile"` + Level string `x264:"level"` + Preset string `x264:"preset"` + Tune string `x264:"tune"` +} + +func (o *X264Options) ForEach(f func(key, value string) error) error { + t := reflect.TypeOf(*o) + v := reflect.ValueOf(*o) + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + tag := field.Tag.Get("x264") + if tag != "" { + if err := f(tag, v.Field(i).String()); err != nil { + return err + } + } + } + + return o.X264AdvancedOptions.ForEach(f) +} + +func (o *X264Options) UpdateBitrate(bps int64) error { + return o.X264AdvancedOptions.UpdateBitrate(bps) +} + +// TODO: WARN: MAKING THIS A POINTER VARIABLE WILL MAKE ALL TRACKS WHICH USE THIS SETTINGS TO SHARE BITRATE + +var LowLatencyBitrateControlled = &X264Options{ + Profile: "baseline", + Level: "3.1", + Preset: "ultrafast", + Tune: "zerolatency", + + X264AdvancedOptions: &X264AdvancedOptions{ + Bitrate: "800", // 800kbps + VBVMaxBitrate: "800", + VBVBuffer: "400", + RateTolerance: "1", // 1% rate tolerance + MaxGOP: "25", + MinGOP: "13", + // MaxQP: "80", + // MinQP: "24", + // MaxQPStep: "80", + IntraRefresh: "1", + ConstrainedIntra: "1", + SceneCut: "0", + BFrames: "0", + BAdapt: "0", + Refs: "1", + RCLookAhead: "0", + AQMode: "1", // Not sure; do more research + NalHrd: "cbr", + }, +}