feat: adding GroupByErr+CountByErr helper

This commit is contained in:
Samuel Berthe
2026-02-28 17:34:46 +01:00
parent 153f867680
commit 75474ff444
7 changed files with 363 additions and 1 deletions
+22
View File
@@ -697,6 +697,17 @@ groups := lo.GroupBy([]int{0, 1, 2, 3, 4, 5}, func(i int) int {
// map[int][]int{0: []int{0, 3}, 1: []int{1, 4}, 2: []int{2, 5}}
```
```go
// Use GroupByErr when the iteratee function can return an error
result, err := lo.GroupByErr([]int{0, 1, 2, 3, 4, 5}, func(i int) (int, error) {
if i == 3 {
return 0, fmt.Errorf("number 3 is not allowed")
}
return i % 3, nil
})
// map[int][]int(nil), error("number 3 is not allowed")
```
[[play](https://go.dev/play/p/XnQBd_v6brd)]
Parallel processing: like `lo.GroupBy()`, but callback is called in goroutine.
@@ -1184,6 +1195,17 @@ count := lo.CountBy([]int{1, 5, 1}, func(i int) bool {
// 2
```
```go
// Use CountByErr when the predicate can return an error
count, err := lo.CountByErr([]int{1, 5, 1}, func(i int) (bool, error) {
if i == 5 {
return false, fmt.Errorf("5 not allowed")
}
return i < 4, nil
})
// 0, error("5 not allowed")
```
[[play](https://go.dev/play/p/ByQbNYQQi4X)]
### CountValues
+2 -1
View File
@@ -1,7 +1,7 @@
---
name: CountBy
slug: countby
sourceRef: slice.go#L596
sourceRef: slice.go#L849
category: core
subCategory: slice
playUrl: https://go.dev/play/p/ByQbNYQQi4X
@@ -13,6 +13,7 @@ similarHelpers:
- core#slice#some
- core#slice#filter
- core#slice#find
- core#slice#countbyerr
position: 0
signatures:
- "func CountBy[T any](collection []T, predicate func(item T) bool) int"
+36
View File
@@ -0,0 +1,36 @@
---
name: CountByErr
slug: countbyerr
sourceRef: slice.go#L863
category: core
subCategory: slice
signatures:
- "func CountByErr[T any](collection []T, predicate func(item T) (bool, error)) (int, error)"
variantHelpers:
- core#slice#countbyerr
similarHelpers:
- core#slice#countby
- core#slice#count
- core#slice#everybyerr
- core#slice#somebyerr
position: 5
---
Counts the number of elements for which the predicate is true. Returns an error if the predicate function fails, stopping iteration immediately.
```go
count, err := lo.CountByErr([]int{1, 5, 1}, func(i int) (bool, error) {
if i == 5 {
return false, fmt.Errorf("5 not allowed")
}
return i < 4, nil
})
// 0, error("5 not allowed")
```
```go
count, err := lo.CountByErr([]int{1, 5, 1}, func(i int) (bool, error) {
return i < 4, nil
})
// 2, nil
```
+1
View File
@@ -8,6 +8,7 @@ playUrl: https://go.dev/play/p/XnQBd_v6brd
variantHelpers:
- core#slice#groupby
similarHelpers:
- core#slice#groupbyerr
- core#slice#groupbymap
- core#slice#partitionby
- core#slice#keyby
+39
View File
@@ -0,0 +1,39 @@
---
name: GroupByErr
slug: groupbyerr
sourceRef: slice.go#L279
category: core
subCategory: slice
signatures:
- "func GroupByErr[T any, U comparable, Slice ~[]T](collection Slice, iteratee func(item T) (U, error)) (map[U]Slice, error)"
variantHelpers:
- core#slice#groupbyerr
similarHelpers:
- core#slice#groupby
- core#slice#groupbymap
- core#slice#partitionby
- core#slice#keyby
- parallel#slice#groupby
position: 121
---
Groups elements by a key computed from each element using an iteratee that can return an error. Stops iteration immediately when an error is encountered. The result is a map keyed by the group key with slices of original elements.
```go
// Error case - stops on first error
result, err := lo.GroupByErr([]int{0, 1, 2, 3, 4, 5}, func(i int) (int, error) {
if i == 3 {
return 0, fmt.Errorf("number 3 is not allowed")
}
return i % 3, nil
})
// map[int][]int(nil), error("number 3 is not allowed")
```
```go
// Success case
result, err := lo.GroupByErr([]int{0, 1, 2, 3, 4, 5}, func(i int) (int, error) {
return i % 3, nil
})
// map[int][]int{0: {0, 3}, 1: {1, 4}, 2: {2, 5}}, nil
```
+35
View File
@@ -276,6 +276,23 @@ func GroupBy[T any, U comparable, Slice ~[]T](collection Slice, iteratee func(it
return result
}
// GroupByErr returns an object composed of keys generated from the results of running each element of collection through iteratee.
// It returns the first error returned by the iteratee function.
func GroupByErr[T any, U comparable, Slice ~[]T](collection Slice, iteratee func(item T) (U, error)) (map[U]Slice, error) {
result := map[U]Slice{}
for i := range collection {
key, err := iteratee(collection[i])
if err != nil {
return nil, err
}
result[key] = append(result[key], collection[i])
}
return result, nil
}
// GroupByMap returns an object composed of keys generated from the results of running each element of collection through transform.
// Play: https://go.dev/play/p/iMeruQ3_W80
func GroupByMap[T any, K comparable, V any](collection []T, transform func(item T) (K, V)) map[K][]V {
@@ -841,6 +858,24 @@ func CountBy[T any](collection []T, predicate func(item T) bool) int {
return count
}
// CountByErr counts the number of elements in the collection for which predicate is true.
// It returns the first error returned by the predicate.
func CountByErr[T any](collection []T, predicate func(item T) (bool, error)) (int, error) {
var count int
for i := range collection {
ok, err := predicate(collection[i])
if err != nil {
return 0, err
}
if ok {
count++
}
}
return count, nil
}
// CountValues counts the number of each element in the collection.
// Play: https://go.dev/play/p/-p-PyLT4dfy
func CountValues[T comparable](collection []T) map[T]int {
+228
View File
@@ -847,6 +847,128 @@ func TestGroupBy(t *testing.T) {
is.IsType(nonempty[42], allStrings, "type preserved")
}
func TestGroupByErr(t *testing.T) {
t.Parallel()
is := assert.New(t)
tests := []struct {
name string
input []int
iteratee func(item int) (int, error)
wantResult map[int][]int
wantErr bool
errMsg string
expectedCallbackCount int
}{
{
name: "successful grouping",
input: []int{0, 1, 2, 3, 4, 5},
iteratee: func(i int) (int, error) {
return i % 3, nil
},
wantResult: map[int][]int{
0: {0, 3},
1: {1, 4},
2: {2, 5},
},
wantErr: false,
expectedCallbackCount: 6,
},
{
name: "error at fourth element stops iteration",
input: []int{0, 1, 2, 3, 4, 5},
iteratee: func(i int) (int, error) {
if i == 3 {
return 0, fmt.Errorf("number 3 is not allowed")
}
return i % 3, nil
},
wantResult: nil,
wantErr: true,
errMsg: "number 3 is not allowed",
expectedCallbackCount: 4,
},
{
name: "error at first element stops iteration immediately",
input: []int{0, 1, 2, 3, 4, 5},
iteratee: func(i int) (int, error) {
if i == 0 {
return 0, fmt.Errorf("number 0 is not allowed")
}
return i % 3, nil
},
wantResult: nil,
wantErr: true,
errMsg: "number 0 is not allowed",
expectedCallbackCount: 1,
},
{
name: "error at last element",
input: []int{0, 1, 2, 3, 4, 5},
iteratee: func(i int) (int, error) {
if i == 5 {
return 0, fmt.Errorf("number 5 is not allowed")
}
return i % 3, nil
},
wantResult: nil,
wantErr: true,
errMsg: "number 5 is not allowed",
expectedCallbackCount: 6,
},
{
name: "empty input slice",
input: []int{},
iteratee: func(i int) (int, error) {
return i % 3, nil
},
wantResult: map[int][]int{},
wantErr: false,
expectedCallbackCount: 0,
},
{
name: "all elements in same group",
input: []int{3, 6, 9, 12},
iteratee: func(i int) (int, error) {
return 0, nil
},
wantResult: map[int][]int{
0: {3, 6, 9, 12},
},
wantErr: false,
expectedCallbackCount: 4,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Track callback count to test early return
callbackCount := 0
wrappedIteratee := func(item int) (int, error) {
callbackCount++
return tt.iteratee(item)
}
result, err := GroupByErr(tt.input, wrappedIteratee)
if tt.wantErr {
is.Error(err)
is.Equal(tt.errMsg, err.Error())
is.Nil(result)
} else {
is.NoError(err)
is.Equal(tt.wantResult, result)
}
// Verify callback count matches expected
is.Equal(tt.expectedCallbackCount, callbackCount, "callback count should match expected")
})
}
}
func TestGroupByMap(t *testing.T) {
t.Parallel()
is := assert.New(t)
@@ -1743,6 +1865,112 @@ func TestCountBy(t *testing.T) {
is.Zero(count3)
}
func TestCountByErr(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input []int
predicate func(int) (bool, error)
want int
wantErr string
wantCallCount int
}{
{
name: "count elements less than 2",
input: []int{1, 2, 1},
predicate: func(i int) (bool, error) {
return i < 2, nil
},
want: 2,
wantErr: "",
wantCallCount: 3,
},
{
name: "count elements greater than 2",
input: []int{1, 2, 1},
predicate: func(i int) (bool, error) {
return i > 2, nil
},
want: 0,
wantErr: "",
wantCallCount: 3,
},
{
name: "empty slice",
input: []int{},
predicate: func(i int) (bool, error) {
return i <= 2, nil
},
want: 0,
wantErr: "",
wantCallCount: 0,
},
{
name: "error on third element",
input: []int{1, 2, 3, 4, 5},
predicate: func(i int) (bool, error) {
if i == 3 {
return false, fmt.Errorf("error at %d", i)
}
return i < 3, nil
},
want: 0,
wantErr: "error at 3",
wantCallCount: 3, // stops early at error
},
{
name: "error on first element",
input: []int{1, 2, 3},
predicate: func(i int) (bool, error) {
return false, fmt.Errorf("first element error")
},
want: 0,
wantErr: "first element error",
wantCallCount: 1,
},
{
name: "all match",
input: []int{1, 2, 3},
predicate: func(i int) (bool, error) {
return i > 0, nil
},
want: 3,
wantErr: "",
wantCallCount: 3,
},
}
for _, tt := range tests {
tt := tt // capture range variable
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
is := assert.New(t)
callCount := 0
wrappedPredicate := func(i int) (bool, error) {
callCount++
return tt.predicate(i)
}
got, err := CountByErr(tt.input, wrappedPredicate)
if tt.wantErr != "" {
is.Error(err)
is.Equal(tt.wantErr, err.Error())
is.Equal(tt.want, got)
if tt.wantCallCount > 0 {
is.Equal(tt.wantCallCount, callCount, "should stop early on error")
}
} else {
is.NoError(err)
is.Equal(tt.want, got)
is.Equal(tt.wantCallCount, callCount)
}
})
}
}
func TestCountValues(t *testing.T) {
t.Parallel()
is := assert.New(t)