fixing bug for unsubscribing HLS streams - make sure go routine gets canceled

This commit is contained in:
Eric Tang
2017-04-27 10:56:09 -04:00
parent c553fcf3ae
commit f89a49f5b6
4 changed files with 77 additions and 36 deletions
+5 -2
View File
@@ -2,6 +2,7 @@
package stream
import (
"context"
"errors"
"runtime"
"sync"
@@ -170,7 +171,7 @@ func (q *Queue) Put(items ...interface{}) error {
// parameter. If no items are in the queue, this method will pause
// until items are added to the queue.
func (q *Queue) Get(number int64) ([]interface{}, error) {
return q.Poll(number, 0)
return q.Poll(context.Background(), number, 0)
}
// Poll retrieves items from the queue. If there are some items in the queue,
@@ -178,7 +179,7 @@ func (q *Queue) Get(number int64) ([]interface{}, error) {
// items are in the queue, this method will pause until items are added to the
// queue or the provided timeout is reached. A non-positive timeout will block
// until items are added. If a timeout occurs, ErrTimeout is returned.
func (q *Queue) Poll(number int64, timeout time.Duration) ([]interface{}, error) {
func (q *Queue) Poll(ctx context.Context, number int64, timeout time.Duration) ([]interface{}, error) {
if number < 1 {
// thanks again go
return []interface{}{}, nil
@@ -225,6 +226,8 @@ func (q *Queue) Poll(number int64, timeout time.Duration) ([]interface{}, error)
sema.response.Done()
}
return nil, ErrTimeout
case <-ctx.Done():
return nil, ctx.Err()
}
}
+27 -22
View File
@@ -50,8 +50,8 @@ func (b *streamBuffer) push(in interface{}) error {
return nil
}
func (b *streamBuffer) poll(wait time.Duration) (interface{}, error) {
results, err := b.q.Poll(1, wait)
func (b *streamBuffer) poll(ctx context.Context, wait time.Duration) (interface{}, error) {
results, err := b.q.Poll(ctx, 1, wait)
if err != nil {
return nil, err
}
@@ -114,7 +114,7 @@ func (s *VideoStream) ReadRTMPFromStream(ctx context.Context, dst av.MuxCloser)
//TODO: Make sure to listen to ctx.Done()
for {
item, err := s.buffer.poll(s.RTMPTimeout)
item, err := s.buffer.poll(ctx, s.RTMPTimeout)
if err != nil {
return err
}
@@ -220,26 +220,31 @@ func (s *VideoStream) WriteHLSSegmentToStream(seg HLSSegment) error {
//ReadHLSFromStream reads an HLS stream into an HLSBuffer
func (s *VideoStream) ReadHLSFromStream(ctx context.Context, mux HLSMuxer) error {
for {
// glog.Info("HLS Stream Buffer Len: %v\n", s.buffer.len())
item, err := s.buffer.poll(s.HLSTimeout)
if err != nil {
return err
}
ec := make(chan error, 1)
go func() {
ec <- func() error {
for {
// glog.Info("HLS Stream Buffer Len: %v\n", s.buffer.len())
item, err := s.buffer.poll(ctx, s.HLSTimeout)
if err != nil {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}
switch item.(type) {
case m3u8.MediaPlaylist:
mux.WritePlaylist(item.(m3u8.MediaPlaylist))
case HLSSegment:
mux.WriteSegment(item.(HLSSegment).Name, item.(HLSSegment).Data)
default:
return ErrBufferItemType
}
}
}()
}()
switch item.(type) {
case m3u8.MediaPlaylist:
mux.WritePlaylist(item.(m3u8.MediaPlaylist))
case HLSSegment:
mux.WriteSegment(item.(HLSSegment).Name, item.(HLSSegment).Data)
default:
return ErrBufferItemType
}
select {
case err := <-ec:
glog.Errorf("Got error reading HLS: %v", err)
return err
}
}
+15 -3
View File
@@ -52,6 +52,13 @@ func (s *StreamSubscriber) UnsubscribeRTMP(muxID string) error {
return nil
}
func (s *StreamSubscriber) HasSubscribers() bool {
rs := len(s.rtmpSubscribers)
hs := len(s.hlsSubscribers)
return rs+hs > 0
}
func (s *StreamSubscriber) StartRTMPWorker(ctx context.Context) error {
// glog.Infof("Starting RTMP worker")
q := pubsub.NewQueue()
@@ -118,13 +125,15 @@ func (s *StreamSubscriber) UnsubscribeHLS(muxID string) error {
func (s *StreamSubscriber) StartHLSWorker(ctx context.Context) error {
// fmt.Println("Kicking off HLS worker thread")
b := NewHLSBuffer()
go s.stream.ReadHLSFromStream(ctx, b)
readCtx, readCancel := context.WithCancel(context.Background())
go s.stream.ReadHLSFromStream(readCtx, b)
segments := map[string]bool{}
for {
// glog.Infof("Waiting for pl")
pl, err := b.WaitAndPopPlaylist(ctx)
popPlCtx, _ := context.WithCancel(context.Background())
pl, err := b.WaitAndPopPlaylist(popPlCtx)
if err != nil {
glog.Errorf("Error loading playlist: %v", err)
return err
@@ -149,7 +158,8 @@ func (s *StreamSubscriber) StartHLSWorker(ctx context.Context) error {
if segments[segName] {
continue
}
seg, err := b.WaitAndPopSegment(ctx, segName)
popSegCtx, _ := context.WithCancel(context.Background())
seg, err := b.WaitAndPopSegment(popSegCtx, segName)
if err != nil {
glog.Errorf("Error loading seg: %v", err)
}
@@ -163,6 +173,8 @@ func (s *StreamSubscriber) StartHLSWorker(ctx context.Context) error {
select {
case <-ctx.Done():
readCancel()
glog.Errorf("Canceling HLS Worker.")
return ctx.Err()
default:
}
+30 -9
View File
@@ -189,15 +189,6 @@ func TestWriteHLS(t *testing.T) {
}
}
// struct TestHLSBuffer struct{}
// func (b *TestHLSBuffer) WritePlaylist(m3u8.MediaPlaylist) error {
// }
// func (b *TestHLSBuffer) WriteSegment(name string, s []byte) error {
// }
func TestReadHLS(t *testing.T) {
stream := NewVideoStream("test", HLS)
stream.HLSTimeout = time.Millisecond * 100
@@ -227,6 +218,36 @@ func TestReadHLS(t *testing.T) {
}
}
func TestReadHLSCancel(t *testing.T) {
stream := NewVideoStream("test", HLS)
stream.HLSTimeout = time.Millisecond * 100
buffer := NewHLSBuffer()
grBefore := runtime.NumGoroutine()
stream.WriteHLSPlaylistToStream(m3u8.MediaPlaylist{SeqNo: 100})
for i := 0; i < 9; i++ {
stream.WriteHLSSegmentToStream(HLSSegment{Name: "test" + string(i), Data: []byte{0}})
}
ec := make(chan error, 1)
ctx, cancel := context.WithCancel(context.Background())
go func() { ec <- stream.ReadHLSFromStream(ctx, buffer) }()
// time.Sleep(time.Millisecond * 100)
cancel()
err := <-ec
if err != context.Canceled {
t.Errorf("Expecting canceled, got %v", err)
}
time.Sleep(time.Millisecond * 100)
grAfter := runtime.NumGoroutine()
if grBefore != grAfter {
t.Errorf("Should have %v Go routines, but have %v", grBefore, grAfter)
}
}
// type GoodHLSDemux struct{}
// func (d GoodHLSDemux) WaitAndPopPlaylist(ctx context.Context) (m3u8.MediaPlaylist, error) {