diff --git a/mediadevices.go b/mediadevices.go index 2969bb8..b4a9622 100644 --- a/mediadevices.go +++ b/mediadevices.go @@ -207,13 +207,11 @@ func selectBestDriver(filter driver.FilterFn, constraints MediaTrackConstraints) } func (m *mediaDevices) selectAudio(constraints MediaTrackConstraints) (Tracker, error) { - filter := driver.FilterAudioRecorder() + typeFilter := driver.FilterAudioRecorder() + filter := typeFilter if constraints.DeviceID != "" { - typeFilter := driver.FilterAudioRecorder() idFilter := driver.FilterID(constraints.DeviceID) - filter = func(d driver.Driver) bool { - return typeFilter(d) && idFilter(d) - } + filter = driver.FilterAnd(typeFilter, idFilter) } d, c, err := selectBestDriver(filter, constraints) @@ -225,15 +223,11 @@ func (m *mediaDevices) selectAudio(constraints MediaTrackConstraints) (Tracker, } func (m *mediaDevices) selectVideo(constraints MediaTrackConstraints) (Tracker, error) { typeFilter := driver.FilterVideoRecorder() - screenFilter := driver.FilterDeviceType(driver.Screen) - filter := func(d driver.Driver) bool { - return typeFilter(d) && !screenFilter(d) - } + notScreenFilter := driver.FilterNot(driver.FilterDeviceType(driver.Screen)) + filter := driver.FilterAnd(typeFilter, notScreenFilter) if constraints.DeviceID != "" { idFilter := driver.FilterID(constraints.DeviceID) - filter = func(d driver.Driver) bool { - return typeFilter(d) && !screenFilter(d) && idFilter(d) - } + filter = driver.FilterAnd(typeFilter, notScreenFilter, idFilter) } d, c, err := selectBestDriver(filter, constraints) @@ -247,14 +241,10 @@ func (m *mediaDevices) selectVideo(constraints MediaTrackConstraints) (Tracker, func (m *mediaDevices) selectScreen(constraints MediaTrackConstraints) (Tracker, error) { typeFilter := driver.FilterVideoRecorder() screenFilter := driver.FilterDeviceType(driver.Screen) - filter := func(d driver.Driver) bool { - return typeFilter(d) && screenFilter(d) - } + filter := driver.FilterAnd(typeFilter, screenFilter) if constraints.DeviceID != "" { idFilter := driver.FilterID(constraints.DeviceID) - filter = func(d driver.Driver) bool { - return typeFilter(d) && screenFilter(d) && idFilter(d) - } + filter = driver.FilterAnd(typeFilter, screenFilter, idFilter) } d, c, err := selectBestDriver(filter, constraints) diff --git a/pkg/driver/manager.go b/pkg/driver/manager.go index dd3f5cb..a47746e 100644 --- a/pkg/driver/manager.go +++ b/pkg/driver/manager.go @@ -34,6 +34,25 @@ func FilterDeviceType(t DeviceType) FilterFn { } } +// FilterAnd returns a filter function to take logical conjunction of given filters. +func FilterAnd(filters ...FilterFn) FilterFn { + return func(d Driver) bool { + for _, f := range filters { + if !f(d) { + return false + } + } + return true + } +} + +// FilterNot returns a filter function to take logical inverse of the given filter. +func FilterNot(filter FilterFn) FilterFn { + return func(d Driver) bool { + return !filter(d) + } +} + // Manager is a singleton to manage multiple drivers and their states type Manager struct { drivers map[string]Driver diff --git a/pkg/driver/manager_test.go b/pkg/driver/manager_test.go new file mode 100644 index 0000000..d528d51 --- /dev/null +++ b/pkg/driver/manager_test.go @@ -0,0 +1,42 @@ +package driver + +import ( + "testing" +) + +func filterTrue(d Driver) bool { + return true +} +func filterFalse(d Driver) bool { + return false +} + +func TestFilterNot(t *testing.T) { + if FilterNot(filterTrue)(nil) != false { + t.Error("FilterNot(filterTrue)() must be false") + } + if FilterNot(filterFalse)(nil) != true { + t.Error("FilterNot(filterFalse)() must be true") + } +} + +func TestFilterAnd(t *testing.T) { + if FilterAnd(filterTrue, filterTrue)(nil) != true { + t.Error("FilterAnd(filterTrue, filterTrue)() must be true") + } + if FilterAnd(filterTrue, filterFalse)(nil) != false { + t.Error("FilterAnd(filterTrue, filterFalse)() must be false") + } + if FilterAnd(filterFalse, filterTrue)(nil) != false { + t.Error("FilterAnd(filterFalse, filterTrue)() must be false") + } + if FilterAnd(filterFalse, filterFalse)(nil) != false { + t.Error("FilterAnd(filterFalse, filterFalse)() must be false") + } + if FilterAnd(filterFalse, filterTrue, filterTrue)(nil) != false { + t.Error("FilterAnd(filterFalse, filterTrue, filterTrue)() must be false") + } + if FilterAnd(filterTrue, filterTrue, filterTrue)(nil) != true { + t.Error("FilterAnd(filterTrue, filterTrue, filterTrue)() must be true") + } +}