diff --git a/concurrency.go b/concurrency.go index 2d6fd87..ac4ff1c 100644 --- a/concurrency.go +++ b/concurrency.go @@ -1,6 +1,10 @@ package lo -import "sync" +import ( + "context" + "sync" + "time" +) type synchronize struct { locker sync.Locker @@ -93,3 +97,59 @@ func Async6[A any, B any, C any, D any, E any, F any](f func() (A, B, C, D, E, F }() return ch } + +// Timeout returns an error if `callback` runs longer than `duration`. +// A return in the callback is equivalent to calling `done()`. +// Warning: you whould use `context.WithTimeout` when available. +func Timeout(duration time.Duration, callback func(done func())) error { + release := make(chan struct{}) // channel will be closed by garbage collector + done := func() { release <- struct{}{} } + + go func() { + callback(done) + done() + }() + + select { + case <-release: + return nil + case <-time.After(duration): + return context.DeadlineExceeded + } +} + +// Deadline returns an error if `callback` runs longer than `duration`. +// A return in the callback is equivalent to calling `done()`. +// Warning: you whould use `context.WithDeadline` when available. +func Deadline(t time.Time, callback func(done func())) error { + release := make(chan struct{}) // channel will be closed by garbage collector + done := func() { release <- struct{}{} } + + go func() { + callback(done) + done() + }() + + select { + case <-release: + return nil + case <-time.After(time.Until(t)): + return context.DeadlineExceeded + } +} + +// Race blocks until a callback is processed. +// `done()` can be used to unlock race before leaving callback function. +func Race(callbacks ...func(done func())) { + release := make(chan struct{}) // channel will be closed by garbage collector + done := func() { release <- struct{}{} } + + for i := range callbacks { + go func(index int) { + callbacks[index](done) + done() + }(i) + } + + <-release +} diff --git a/concurrency_test.go b/concurrency_test.go index ae65efd..5efc1a6 100644 --- a/concurrency_test.go +++ b/concurrency_test.go @@ -1,7 +1,9 @@ package lo import ( + "context" "sync" + "sync/atomic" "testing" "time" @@ -212,3 +214,68 @@ func TestAsyncX(t *testing.T) { } } } + +func TestTimeout(t *testing.T) { + t.Parallel() + testWithTimeout(t, 100*time.Millisecond) + is := assert.New(t) + + err := Timeout(10*time.Millisecond, func(done func()) { + done() + }) + is.Nil(err) + + err = Timeout(10*time.Millisecond, func(done func()) { + time.Sleep(20 * time.Millisecond) + done() + }) + is.Error(err) + is.Equal(err, context.DeadlineExceeded) +} + +func TestDeadline(t *testing.T) { + t.Parallel() + testWithTimeout(t, 100*time.Millisecond) + is := assert.New(t) + + err := Deadline(time.Now().Add(10*time.Millisecond), func(done func()) { + done() + }) + is.Nil(err) + + err = Deadline(time.Now().Add(10*time.Millisecond), func(done func()) { + time.Sleep(20 * time.Millisecond) + done() + }) + is.Error(err) + is.Equal(err, context.DeadlineExceeded) +} + +func TestRace(t *testing.T) { + t.Parallel() + testWithTimeout(t, 100*time.Millisecond) + is := assert.New(t) + + var wonRace int32 + + func1 := func(done func()) { + time.Sleep(5 * time.Millisecond) + atomic.CompareAndSwapInt32(&wonRace, 0, 1) + done() + } + + func2 := func(done func()) { + time.Sleep(30 * time.Millisecond) + atomic.CompareAndSwapInt32(&wonRace, 0, 2) + done() + } + + func3 := func(done func()) { + time.Sleep(50 * time.Millisecond) + atomic.CompareAndSwapInt32(&wonRace, 0, 3) + done() + } + + Race(func1, func2, func3) + is.EqualValues(1, atomic.LoadInt32(&wonRace)) +}