diff --git a/stream/queue.go b/stream/queue.go index 7f189c5dd1..79e046a9cb 100644 --- a/stream/queue.go +++ b/stream/queue.go @@ -2,6 +2,7 @@ package stream import ( + "context" "errors" "runtime" "sync" @@ -170,7 +171,7 @@ func (q *Queue) Put(items ...interface{}) error { // parameter. If no items are in the queue, this method will pause // until items are added to the queue. func (q *Queue) Get(number int64) ([]interface{}, error) { - return q.Poll(number, 0) + return q.Poll(context.Background(), number, 0) } // Poll retrieves items from the queue. If there are some items in the queue, @@ -178,7 +179,7 @@ func (q *Queue) Get(number int64) ([]interface{}, error) { // items are in the queue, this method will pause until items are added to the // queue or the provided timeout is reached. A non-positive timeout will block // until items are added. If a timeout occurs, ErrTimeout is returned. -func (q *Queue) Poll(number int64, timeout time.Duration) ([]interface{}, error) { +func (q *Queue) Poll(ctx context.Context, number int64, timeout time.Duration) ([]interface{}, error) { if number < 1 { // thanks again go return []interface{}{}, nil @@ -225,6 +226,8 @@ func (q *Queue) Poll(number int64, timeout time.Duration) ([]interface{}, error) sema.response.Done() } return nil, ErrTimeout + case <-ctx.Done(): + return nil, ctx.Err() } } diff --git a/stream/stream.go b/stream/stream.go index b576dfeb9d..89fefee770 100644 --- a/stream/stream.go +++ b/stream/stream.go @@ -50,8 +50,8 @@ func (b *streamBuffer) push(in interface{}) error { return nil } -func (b *streamBuffer) poll(wait time.Duration) (interface{}, error) { - results, err := b.q.Poll(1, wait) +func (b *streamBuffer) poll(ctx context.Context, wait time.Duration) (interface{}, error) { + results, err := b.q.Poll(ctx, 1, wait) if err != nil { return nil, err } @@ -114,7 +114,7 @@ func (s *VideoStream) ReadRTMPFromStream(ctx context.Context, dst av.MuxCloser) //TODO: Make sure to listen to ctx.Done() for { - item, err := s.buffer.poll(s.RTMPTimeout) + item, err := s.buffer.poll(ctx, s.RTMPTimeout) if err != nil { return err } @@ -220,26 +220,31 @@ func (s *VideoStream) WriteHLSSegmentToStream(seg HLSSegment) error { //ReadHLSFromStream reads an HLS stream into an HLSBuffer func (s *VideoStream) ReadHLSFromStream(ctx context.Context, mux HLSMuxer) error { - for { - // glog.Info("HLS Stream Buffer Len: %v\n", s.buffer.len()) - item, err := s.buffer.poll(s.HLSTimeout) - if err != nil { - return err - } + ec := make(chan error, 1) + go func() { + ec <- func() error { + for { + // glog.Info("HLS Stream Buffer Len: %v\n", s.buffer.len()) + item, err := s.buffer.poll(ctx, s.HLSTimeout) + if err != nil { + return err + } - select { - case <-ctx.Done(): - return ctx.Err() - default: - } + switch item.(type) { + case m3u8.MediaPlaylist: + mux.WritePlaylist(item.(m3u8.MediaPlaylist)) + case HLSSegment: + mux.WriteSegment(item.(HLSSegment).Name, item.(HLSSegment).Data) + default: + return ErrBufferItemType + } + } + }() + }() - switch item.(type) { - case m3u8.MediaPlaylist: - mux.WritePlaylist(item.(m3u8.MediaPlaylist)) - case HLSSegment: - mux.WriteSegment(item.(HLSSegment).Name, item.(HLSSegment).Data) - default: - return ErrBufferItemType - } + select { + case err := <-ec: + glog.Errorf("Got error reading HLS: %v", err) + return err } } diff --git a/stream/stream_subscriber.go b/stream/stream_subscriber.go index 14a61ab89f..7c069e44f4 100644 --- a/stream/stream_subscriber.go +++ b/stream/stream_subscriber.go @@ -52,6 +52,13 @@ func (s *StreamSubscriber) UnsubscribeRTMP(muxID string) error { return nil } +func (s *StreamSubscriber) HasSubscribers() bool { + rs := len(s.rtmpSubscribers) + hs := len(s.hlsSubscribers) + + return rs+hs > 0 +} + func (s *StreamSubscriber) StartRTMPWorker(ctx context.Context) error { // glog.Infof("Starting RTMP worker") q := pubsub.NewQueue() @@ -118,13 +125,15 @@ func (s *StreamSubscriber) UnsubscribeHLS(muxID string) error { func (s *StreamSubscriber) StartHLSWorker(ctx context.Context) error { // fmt.Println("Kicking off HLS worker thread") b := NewHLSBuffer() - go s.stream.ReadHLSFromStream(ctx, b) + readCtx, readCancel := context.WithCancel(context.Background()) + go s.stream.ReadHLSFromStream(readCtx, b) segments := map[string]bool{} for { // glog.Infof("Waiting for pl") - pl, err := b.WaitAndPopPlaylist(ctx) + popPlCtx, _ := context.WithCancel(context.Background()) + pl, err := b.WaitAndPopPlaylist(popPlCtx) if err != nil { glog.Errorf("Error loading playlist: %v", err) return err @@ -149,7 +158,8 @@ func (s *StreamSubscriber) StartHLSWorker(ctx context.Context) error { if segments[segName] { continue } - seg, err := b.WaitAndPopSegment(ctx, segName) + popSegCtx, _ := context.WithCancel(context.Background()) + seg, err := b.WaitAndPopSegment(popSegCtx, segName) if err != nil { glog.Errorf("Error loading seg: %v", err) } @@ -163,6 +173,8 @@ func (s *StreamSubscriber) StartHLSWorker(ctx context.Context) error { select { case <-ctx.Done(): + readCancel() + glog.Errorf("Canceling HLS Worker.") return ctx.Err() default: } diff --git a/stream/stream_test.go b/stream/stream_test.go index d213854325..c878da609a 100644 --- a/stream/stream_test.go +++ b/stream/stream_test.go @@ -189,15 +189,6 @@ func TestWriteHLS(t *testing.T) { } } -// struct TestHLSBuffer struct{} -// func (b *TestHLSBuffer) WritePlaylist(m3u8.MediaPlaylist) error { - -// } - -// func (b *TestHLSBuffer) WriteSegment(name string, s []byte) error { - -// } - func TestReadHLS(t *testing.T) { stream := NewVideoStream("test", HLS) stream.HLSTimeout = time.Millisecond * 100 @@ -227,6 +218,36 @@ func TestReadHLS(t *testing.T) { } } +func TestReadHLSCancel(t *testing.T) { + stream := NewVideoStream("test", HLS) + stream.HLSTimeout = time.Millisecond * 100 + buffer := NewHLSBuffer() + grBefore := runtime.NumGoroutine() + stream.WriteHLSPlaylistToStream(m3u8.MediaPlaylist{SeqNo: 100}) + for i := 0; i < 9; i++ { + stream.WriteHLSSegmentToStream(HLSSegment{Name: "test" + string(i), Data: []byte{0}}) + } + + ec := make(chan error, 1) + ctx, cancel := context.WithCancel(context.Background()) + go func() { ec <- stream.ReadHLSFromStream(ctx, buffer) }() + + // time.Sleep(time.Millisecond * 100) + cancel() + + err := <-ec + + if err != context.Canceled { + t.Errorf("Expecting canceled, got %v", err) + } + + time.Sleep(time.Millisecond * 100) + grAfter := runtime.NumGoroutine() + if grBefore != grAfter { + t.Errorf("Should have %v Go routines, but have %v", grBefore, grAfter) + } +} + // type GoodHLSDemux struct{} // func (d GoodHLSDemux) WaitAndPopPlaylist(ctx context.Context) (m3u8.MediaPlaylist, error) {