diff --git a/pkg/registry/core/service/portallocator/allocator.go b/pkg/registry/core/service/portallocator/allocator.go index 8e3d42be08..e088bc91a5 100644 --- a/pkg/registry/core/service/portallocator/allocator.go +++ b/pkg/registry/core/service/portallocator/allocator.go @@ -33,6 +33,7 @@ type Interface interface { Allocate(int) error AllocateNext() (int, error) Release(int) error + ForEach(func(int)) } var ( @@ -117,6 +118,13 @@ func (r *PortAllocator) AllocateNext() (int, error) { return r.portRange.Base + offset, nil } +// ForEach calls the provided function for each allocated port. +func (r *PortAllocator) ForEach(fn func(int)) { + r.alloc.ForEach(func(offset int) { + fn(r.portRange.Base + offset) + }) +} + // Release releases the port back to the pool. Releasing an // unallocated port or a port out of the range is a no-op and // returns no error. diff --git a/pkg/registry/core/service/portallocator/allocator_test.go b/pkg/registry/core/service/portallocator/allocator_test.go index 7386172b5e..8250112042 100644 --- a/pkg/registry/core/service/portallocator/allocator_test.go +++ b/pkg/registry/core/service/portallocator/allocator_test.go @@ -103,6 +103,44 @@ func TestAllocate(t *testing.T) { } } +func TestForEach(t *testing.T) { + pr, err := net.ParsePortRange("10000-10200") + if err != nil { + t.Fatal(err) + } + + testCases := []sets.Int{ + sets.NewInt(), + sets.NewInt(10000), + sets.NewInt(10000, 10200), + sets.NewInt(10000, 10099, 10200), + } + + for i, tc := range testCases { + r := NewPortAllocator(*pr) + + for port := range tc { + if err := r.Allocate(port); err != nil { + t.Errorf("[%d] error allocating port %v: %v", i, port, err) + } + if !r.Has(port) { + t.Errorf("[%d] expected port %v allocated", i, port) + } + } + + calls := sets.NewInt() + r.ForEach(func(port int) { + calls.Insert(port) + }) + if len(calls) != len(tc) { + t.Errorf("[%d] expected %d calls, got %d", i, len(tc), len(calls)) + } + if !calls.Equal(tc) { + t.Errorf("[%d] expected calls to equal testcase: %v vs %v", i, calls.List(), tc.List()) + } + } +} + func TestSnapshot(t *testing.T) { pr, err := net.ParsePortRange("10000-10200") if err != nil {