diff --git a/agent/agent_test.go b/agent/agent_test.go index 6b3ea8cfbd..60c80b8833 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -295,10 +295,6 @@ func TestAgent_HTTPMaxHeaderBytes(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ports, err := freeport.Take(1) - require.NoError(t, err) - t.Cleanup(func() { freeport.Return(ports) }) - caConfig := tlsutil.Config{} tlsConf, err := tlsutil.NewConfigurator(caConfig, hclog.New(nil)) require.NoError(t, err) @@ -312,7 +308,7 @@ func TestAgent_HTTPMaxHeaderBytes(t *testing.T) { }, RuntimeConfig: &config.RuntimeConfig{ HTTPAddrs: []net.Addr{ - &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[0]}, + &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: freeport.Port(t)}, }, HTTPMaxHeaderBytes: tt.maxHeaderBytes, }, @@ -5281,10 +5277,7 @@ func TestAgent_ListenHTTP_MultipleAddresses(t *testing.T) { t.Skip("too slow for testing.Short") } - ports, err := freeport.Take(2) - require.NoError(t, err) - t.Cleanup(func() { freeport.Return(ports) }) - + ports := freeport.GetN(t, 2) caConfig := tlsutil.Config{} tlsConf, err := tlsutil.NewConfigurator(caConfig, hclog.New(nil)) require.NoError(t, err) diff --git a/sdk/freeport/freeport.go b/sdk/freeport/freeport.go index eab12a87f7..ffdfdb3948 100644 --- a/sdk/freeport/freeport.go +++ b/sdk/freeport/freeport.go @@ -11,8 +11,6 @@ import ( "runtime" "sync" "time" - - "github.com/mitchellh/go-testing-interface" ) const ( @@ -251,6 +249,8 @@ func alloc() (int, net.Listener) { } // MustTake is the same as Take except it panics on error. +// +// Deprecated: Use GetN or Port instead. func MustTake(n int) (ports []int) { ports, err := Take(n) if err != nil { @@ -263,6 +263,8 @@ func MustTake(n int) (ports []int) { // to call this method concurrently. Ports have been tested to be available on // 127.0.0.1 TCP but there is no guarantee that they will remain free in the // future. +// +// Most callers should prefer GetN or Port. func Take(n int) (ports []int, err error) { if n <= 0 { return nil, fmt.Errorf("freeport: cannot take %d ports", n) @@ -381,11 +383,39 @@ func logf(severity string, format string, a ...interface{}) { fmt.Fprintf(os.Stderr, "["+severity+"] freeport: "+format+"\n", a...) } +type TestingT interface { + Helper() + Fatalf(format string, args ...interface{}) + Cleanup(func()) +} + +// GetN returns n free ports from the allocated port block, and returns the +// ports to the pool when the test ends. See Take for more details. +func GetN(t TestingT, n int) []int { + t.Helper() + ports, err := Take(n) + if err != nil { + t.Fatalf("failed to take %v ports: %w", n, err) + } + t.Cleanup(func() { + Return(ports) + }) + return ports +} + +// Port returns a single free port from the allocated port block, and returns the +// port to the pool when the test ends. See Take for more details. +// Use GetN if more than a single port is required. +func Port(t TestingT) int { + t.Helper() + return GetN(t, 1)[0] +} + // Deprecated: Please use Take/Return calls instead. func Get(n int) (ports []int) { return MustTake(n) } // Deprecated: Please use Take/Return calls instead. -func GetT(t testing.T, n int) (ports []int) { return MustTake(n) } +func GetT(t TestingT, n int) (ports []int) { return MustTake(n) } // Deprecated: Please use Take/Return calls instead. func Free(n int) (ports []int, err error) { return MustTake(n), nil }