diff --git a/agent/agent_test.go b/agent/agent_test.go index 7b5d77e455..887ba9b5f1 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1229,8 +1229,11 @@ func TestAgent_RestoreServiceWithAliasCheck(t *testing.T) { testCtx, testCancel := context.WithCancel(context.Background()) defer testCancel() - testHTTPServer := launchHTTPCheckServer(t, testCtx) - defer testHTTPServer.Close() + testHTTPServer, returnPort := launchHTTPCheckServer(t, testCtx) + defer func() { + testHTTPServer.Close() + returnPort() + }() registerServicesAndChecks := func(t *testing.T, a *TestAgent) { // add one persistent service with a simple check @@ -1338,8 +1341,8 @@ node_name = "` + a.Config.NodeName + `" } } -func launchHTTPCheckServer(t *testing.T, ctx context.Context) *httptest.Server { - ports := freeport.GetT(t, 1) +func launchHTTPCheckServer(t *testing.T, ctx context.Context) (srv *httptest.Server, returnPortsFn func()) { + ports := freeport.MustTake(1) port := ports[0] addr := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) @@ -1353,12 +1356,12 @@ func launchHTTPCheckServer(t *testing.T, ctx context.Context) *httptest.Server { _, _ = w.Write([]byte("OK\n")) }) - srv := &httptest.Server{ + srv = &httptest.Server{ Listener: listener, Config: &http.Server{Handler: handler}, } srv.Start() - return srv + return srv, func() { freeport.Return(ports) } } func TestAgent_AddCheck_Alias(t *testing.T) { diff --git a/agent/consul/client_test.go b/agent/consul/client_test.go index 1b3728c282..acc5cc6a2e 100644 --- a/agent/consul/client_test.go +++ b/agent/consul/client_test.go @@ -13,7 +13,7 @@ import ( "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/testrpc" - "github.com/hashicorp/net-rpc-msgpackrpc" + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/serf/serf" "github.com/stretchr/testify/require" "golang.org/x/time/rate" @@ -22,15 +22,27 @@ import ( func testClientConfig(t *testing.T) (string, *Config) { dir := testutil.TempDir(t, "consul") config := DefaultConfig() + + ports := freeport.MustTake(2) + + returnPortsFn := func() { + // The method of plumbing this into the client shutdown hook doesn't + // cover all exit points, so we insulate this against multiple + // invocations and then it's safe to call it a bunch of times. + freeport.Return(ports) + config.NotifyShutdown = nil // self-erasing + } + config.NotifyShutdown = returnPortsFn + config.Datacenter = "dc1" config.DataDir = dir config.NodeName = uniqueNodeName(t.Name()) config.RPCAddr = &net.TCPAddr{ IP: []byte{127, 0, 0, 1}, - Port: freeport.Get(1)[0], + Port: ports[0], } config.SerfLANConfig.MemberlistConfig.BindAddr = "127.0.0.1" - config.SerfLANConfig.MemberlistConfig.BindPort = freeport.Get(1)[0] + config.SerfLANConfig.MemberlistConfig.BindPort = ports[1] config.SerfLANConfig.MemberlistConfig.ProbeTimeout = 200 * time.Millisecond config.SerfLANConfig.MemberlistConfig.ProbeInterval = time.Second config.SerfLANConfig.MemberlistConfig.GossipInterval = 100 * time.Millisecond @@ -59,6 +71,7 @@ func testClientWithConfig(t *testing.T, cb func(c *Config)) (string, *Client) { } client, err := NewClient(config) if err != nil { + config.NotifyShutdown() t.Fatalf("err: %v", err) } return dir, client @@ -416,6 +429,7 @@ func TestClient_RPC_TLS(t *testing.T) { defer s1.Shutdown() dir2, conf2 := testClientConfig(t) + defer conf2.NotifyShutdown() conf2.VerifyOutgoing = true configureTLS(conf2) c1, err := NewClient(conf2) @@ -460,6 +474,7 @@ func TestClient_RPC_RateLimit(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") dir2, conf2 := testClientConfig(t) + defer conf2.NotifyShutdown() conf2.RPCRate = 2 conf2.RPCMaxBurst = 2 c1, err := NewClient(conf2) @@ -527,6 +542,7 @@ func TestClient_SnapshotRPC_RateLimit(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") dir2, conf1 := testClientConfig(t) + defer conf1.NotifyShutdown() conf1.RPCRate = 2 conf1.RPCMaxBurst = 2 c1, err := NewClient(conf1) @@ -569,6 +585,7 @@ func TestClient_SnapshotRPC_TLS(t *testing.T) { defer s1.Shutdown() dir2, conf2 := testClientConfig(t) + defer conf2.NotifyShutdown() conf2.VerifyOutgoing = true configureTLS(conf2) c1, err := NewClient(conf2) diff --git a/agent/consul/config.go b/agent/consul/config.go index ad5b75a4ab..96183c53e4 100644 --- a/agent/consul/config.go +++ b/agent/consul/config.go @@ -110,6 +110,9 @@ type Config struct { // configured at this point. NotifyListen func() + // NotifyShutdown is called after Server is completely Shutdown. + NotifyShutdown func() + // RPCAddr is the RPC address used by Consul. This should be reachable // by the WAN and LAN RPCAddr *net.TCPAddr diff --git a/agent/consul/operator_raft_endpoint_test.go b/agent/consul/operator_raft_endpoint_test.go index 23b1da7924..4f89f63a3d 100644 --- a/agent/consul/operator_raft_endpoint_test.go +++ b/agent/consul/operator_raft_endpoint_test.go @@ -11,7 +11,7 @@ import ( "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/sdk/freeport" "github.com/hashicorp/consul/testrpc" - "github.com/hashicorp/net-rpc-msgpackrpc" + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/raft" "github.com/pascaldekloe/goe/verify" ) @@ -145,10 +145,13 @@ func TestOperator_RaftRemovePeerByAddress(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") + ports := freeport.MustTake(1) + defer freeport.Return(ports) + // Try to remove a peer that's not there. arg := structs.RaftRemovePeerRequest{ Datacenter: "dc1", - Address: raft.ServerAddress(fmt.Sprintf("127.0.0.1:%d", freeport.Get(1)[0])), + Address: raft.ServerAddress(fmt.Sprintf("127.0.0.1:%d", ports[0])), } var reply struct{} err := msgpackrpc.CallWithCodec(codec, "Operator.RaftRemovePeerByAddress", &arg, &reply) @@ -277,7 +280,10 @@ func TestOperator_RaftRemovePeerByID(t *testing.T) { // Add it manually to Raft. { - future := s1.raft.AddVoter(arg.ID, raft.ServerAddress(fmt.Sprintf("127.0.0.1:%d", freeport.Get(1)[0])), 0, 0) + ports := freeport.MustTake(1) + defer freeport.Return(ports) + + future := s1.raft.AddVoter(arg.ID, raft.ServerAddress(fmt.Sprintf("127.0.0.1:%d", ports[0])), 0, 0) if err := future.Error(); err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/consul/server.go b/agent/consul/server.go index 895f9cc04c..82ccb3be0c 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -843,6 +843,10 @@ func (s *Server) Shutdown() error { // Close the connection pool s.connPool.Shutdown() + if s.config.NotifyShutdown != nil { + s.config.NotifyShutdown() + } + return nil } diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index 018113deb3..6d2edeba3d 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -42,7 +42,17 @@ func testServerConfig(t *testing.T) (string, *Config) { dir := testutil.TempDir(t, "consul") config := DefaultConfig() - ports := freeport.Get(3) + ports := freeport.MustTake(3) + + returnPortsFn := func() { + // The method of plumbing this into the server shutdown hook doesn't + // cover all exit points, so we insulate this against multiple + // invocations and then it's safe to call it a bunch of times. + freeport.Return(ports) + config.NotifyShutdown = nil // self-erasing + } + config.NotifyShutdown = returnPortsFn + config.NodeName = uniqueNodeName(t.Name()) config.Bootstrap = true config.Datacenter = "dc1" @@ -56,6 +66,7 @@ func testServerConfig(t *testing.T) (string, *Config) { nodeID, err := uuid.GenerateUUID() if err != nil { + returnPortsFn() t.Fatal(err) } config.NodeID = types.NodeID(nodeID) @@ -112,6 +123,8 @@ func testServerConfig(t *testing.T) (string, *Config) { }, } + config.NotifyShutdown = returnPortsFn + return dir, config } @@ -168,6 +181,7 @@ func testServerWithConfig(t *testing.T, cb func(*Config)) (string, *Server) { srv, err = newServer(config) if err != nil { + config.NotifyShutdown() os.RemoveAll(dir) r.Fatalf("err: %v", err) } diff --git a/agent/http_test.go b/agent/http_test.go index 08f7fa60cd..37ec2db201 100644 --- a/agent/http_test.go +++ b/agent/http_test.go @@ -738,6 +738,7 @@ func TestParseWait(t *testing.T) { t.Fatalf("Bad: %v", b) } } + func TestPProfHandlers_EnableDebug(t *testing.T) { t.Parallel() require := require.New(t) @@ -751,6 +752,7 @@ func TestPProfHandlers_EnableDebug(t *testing.T) { require.Equal(http.StatusOK, resp.Code) } + func TestPProfHandlers_DisableDebugNoACLs(t *testing.T) { t.Parallel() require := require.New(t) diff --git a/agent/testagent.go b/agent/testagent.go index b01d4e5d5e..e0e34e7ce0 100644 --- a/agent/testagent.go +++ b/agent/testagent.go @@ -52,6 +52,10 @@ type TestAgent struct { // when Shutdown() is called. Config *config.RuntimeConfig + // returnPortsFn will put the ports claimed for the test back into the + // general freeport pool + returnPortsFn func() + // LogOutput is the sink for the logs. If nil, logs are written // to os.Stderr. LogOutput io.Writer @@ -150,12 +154,21 @@ func (a *TestAgent) Start() (err error) { hclDataDir = `data_dir = "` + d + `"` } + portsConfig, returnPortsFn := randomPortsSource(a.UseTLS) + a.returnPortsFn = returnPortsFn a.Config = TestConfig( - randomPortsSource(a.UseTLS), + portsConfig, config.Source{Name: a.Name, Format: "hcl", Data: a.HCL}, config.Source{Name: a.Name + ".data_dir", Format: "hcl", Data: hclDataDir}, ) + defer func() { + if err != nil && a.returnPortsFn != nil { + a.returnPortsFn() + a.returnPortsFn = nil + } + }() + // write the keyring if a.Key != "" { writeKey := func(key, filename string) error { @@ -286,6 +299,14 @@ func (a *TestAgent) Shutdown() error { return nil } + // Return ports last of all + defer func() { + if a.returnPortsFn != nil { + a.returnPortsFn() + a.returnPortsFn = nil + } + }() + // shutdown agent before endpoints defer a.Agent.ShutdownEndpoints() if err := a.Agent.ShutdownAgent(); err != nil { @@ -350,27 +371,32 @@ func (a *TestAgent) consulConfig() *consul.Config { // chance of port conflicts for concurrently executed test binaries. // Instead of relying on one set of ports to be sufficient we retry // starting the agent with different ports on port conflict. -func randomPortsSource(tls bool) config.Source { - ports := freeport.Get(6) +func randomPortsSource(tls bool) (src config.Source, returnPortsFn func()) { + ports := freeport.MustTake(6) + + var http, https int if tls { - ports[1] = -1 + http = -1 + https = ports[2] } else { - ports[2] = -1 + http = ports[1] + https = -1 } + return config.Source{ Name: "ports", Format: "hcl", Data: ` ports = { dns = ` + strconv.Itoa(ports[0]) + ` - http = ` + strconv.Itoa(ports[1]) + ` - https = ` + strconv.Itoa(ports[2]) + ` + http = ` + strconv.Itoa(http) + ` + https = ` + strconv.Itoa(https) + ` serf_lan = ` + strconv.Itoa(ports[3]) + ` serf_wan = ` + strconv.Itoa(ports[4]) + ` server = ` + strconv.Itoa(ports[5]) + ` } `, - } + }, func() { freeport.Return(ports) } } func NodeID() string { diff --git a/connect/proxy/listener_test.go b/connect/proxy/listener_test.go index 358bb191eb..aee123444b 100644 --- a/connect/proxy/listener_test.go +++ b/connect/proxy/listener_test.go @@ -4,13 +4,14 @@ import ( "bytes" "context" "fmt" - "github.com/hashicorp/consul/connect" "log" "net" "os" "testing" "time" + "github.com/hashicorp/consul/connect" + metrics "github.com/armon/go-metrics" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -110,7 +111,8 @@ func TestPublicListener(t *testing.T) { // Can't enable t.Parallel since we rely on the global metrics instance. ca := agConnect.TestCA(t, nil) - ports := freeport.GetT(t, 1) + ports := freeport.MustTake(1) + defer freeport.Return(ports) testApp := NewTestTCPServer(t) defer testApp.Close() @@ -162,7 +164,8 @@ func TestUpstreamListener(t *testing.T) { // Can't enable t.Parallel since we rely on the global metrics instance. ca := agConnect.TestCA(t, nil) - ports := freeport.GetT(t, 1) + ports := freeport.MustTake(1) + defer freeport.Return(ports) // Run a test server that we can dial. testSvr := connect.NewTestServer(t, "db", ca) diff --git a/connect/proxy/proxy_test.go b/connect/proxy/proxy_test.go index 339f935d54..10480e0a03 100644 --- a/connect/proxy/proxy_test.go +++ b/connect/proxy/proxy_test.go @@ -22,7 +22,9 @@ func TestProxy_public(t *testing.T) { t.Parallel() require := require.New(t) - ports := freeport.GetT(t, 1) + + ports := freeport.MustTake(1) + defer freeport.Return(ports) a := agent.NewTestAgent(t, t.Name(), "") defer a.Shutdown() diff --git a/connect/proxy/testing.go b/connect/proxy/testing.go index 6dc6838edb..cad243478b 100644 --- a/connect/proxy/testing.go +++ b/connect/proxy/testing.go @@ -24,21 +24,23 @@ type TestTCPServer struct { l net.Listener stopped int32 accepted, closed, active int32 + returnPortsFn func() } // NewTestTCPServer opens as a listening socket on the given address and returns // a TestTCPServer serving requests to it. The server is already started and can // be stopped by calling Close(). func NewTestTCPServer(t testing.T) *TestTCPServer { - port := freeport.GetT(t, 1) - addr := TestLocalAddr(port[0]) + ports := freeport.MustTake(1) + addr := TestLocalAddr(ports[0]) l, err := net.Listen("tcp", addr) require.NoError(t, err) log.Printf("test tcp server listening on %s", addr) s := &TestTCPServer{ - l: l, + l: l, + returnPortsFn: func() { freeport.Return(ports) }, } go s.accept() @@ -51,6 +53,10 @@ func (s *TestTCPServer) Close() { if s.l != nil { s.l.Close() } + if s.returnPortsFn != nil { + s.returnPortsFn() + s.returnPortsFn = nil + } } // Addr returns the address that this server is listening on. diff --git a/connect/testing.go b/connect/testing.go index 78fc6d2c8b..6bf2fca061 100644 --- a/connect/testing.go +++ b/connect/testing.go @@ -94,22 +94,24 @@ type TestServer struct { // Listening is closed when the listener is run. Listening chan struct{} - l net.Listener - stopFlag int32 - stopChan chan struct{} + l net.Listener + returnPortsFn func() + stopFlag int32 + stopChan chan struct{} } // NewTestServer returns a TestServer. It should be closed when test is // complete. func NewTestServer(t testing.T, service string, ca *structs.CARoot) *TestServer { - ports := freeport.GetT(t, 1) + ports := freeport.MustTake(1) return &TestServer{ - Service: service, - CA: ca, - stopChan: make(chan struct{}), - TLSCfg: TestTLSConfig(t, service, ca), - Addr: fmt.Sprintf("127.0.0.1:%d", ports[0]), - Listening: make(chan struct{}), + Service: service, + CA: ca, + stopChan: make(chan struct{}), + TLSCfg: TestTLSConfig(t, service, ca), + Addr: fmt.Sprintf("127.0.0.1:%d", ports[0]), + Listening: make(chan struct{}), + returnPortsFn: func() { freeport.Return(ports) }, } } @@ -186,6 +188,10 @@ func (s *TestServer) Close() error { if s.l != nil { s.l.Close() } + if s.returnPortsFn != nil { + s.returnPortsFn() + s.returnPortsFn = nil + } close(s.stopChan) } return nil diff --git a/sdk/freeport/ephemeral_fallback.go b/sdk/freeport/ephemeral_fallback.go new file mode 100644 index 0000000000..740791cc87 --- /dev/null +++ b/sdk/freeport/ephemeral_fallback.go @@ -0,0 +1,7 @@ +//+build !linux + +package freeport + +func getEphemeralPortRange() (int, int, error) { + return 0, 0, nil +} diff --git a/sdk/freeport/ephemeral_linux.go b/sdk/freeport/ephemeral_linux.go new file mode 100644 index 0000000000..00b8815e1b --- /dev/null +++ b/sdk/freeport/ephemeral_linux.go @@ -0,0 +1,36 @@ +//+build linux + +package freeport + +import ( + "fmt" + "os/exec" + "regexp" + "strconv" +) + +const ephemeralPortRangeSysctlKey = "net.ipv4.ip_local_port_range" + +var ephemeralPortRangePatt = regexp.MustCompile(`^\s*(\d+)\s+(\d+)\s*$`) + +func getEphemeralPortRange() (int, int, error) { + cmd := exec.Command("sysctl", "-n", ephemeralPortRangeSysctlKey) + out, err := cmd.Output() + if err != nil { + return 0, 0, err + } + + val := string(out) + + m := ephemeralPortRangePatt.FindStringSubmatch(val) + if m != nil { + min, err1 := strconv.Atoi(m[1]) + max, err2 := strconv.Atoi(m[2]) + + if err1 == nil && err2 == nil { + return min, max, nil + } + } + + return 0, 0, fmt.Errorf("unexpected sysctl value %q for key %q", val, ephemeralPortRangeSysctlKey) +} diff --git a/sdk/freeport/ephemeral_linux_test.go b/sdk/freeport/ephemeral_linux_test.go new file mode 100644 index 0000000000..2d9385df4b --- /dev/null +++ b/sdk/freeport/ephemeral_linux_test.go @@ -0,0 +1,18 @@ +//+build linux + +package freeport + +import ( + "testing" +) + +func TestGetEphemeralPortRange(t *testing.T) { + min, max, err := getEphemeralPortRange() + if err != nil { + t.Fatalf("err: %v", err) + } + if min <= 0 || max <= 0 || min > max { + t.Fatalf("unexpected values: min=%d, max=%d", min, max) + } + t.Logf("min=%d, max=%d", min, max) +} diff --git a/sdk/freeport/freeport.go b/sdk/freeport/freeport.go index 806449ba4a..576f39f2fa 100644 --- a/sdk/freeport/freeport.go +++ b/sdk/freeport/freeport.go @@ -3,9 +3,12 @@ package freeport import ( + "container/list" "fmt" "math/rand" "net" + "os" + "runtime" "sync" "time" @@ -14,12 +17,10 @@ import ( const ( // blockSize is the size of the allocated port block. ports are given out - // consecutively from that block with roll-over for the lifetime of the - // application/test run. + // consecutively from that block and after that point in a LRU fashion. blockSize = 1500 - // maxBlocks is the number of available port blocks. - // lowPort + maxBlocks * blockSize must be less than 65535. + // maxBlocks is the number of available port blocks before exclusions. maxBlocks = 30 // lowPort is the lowest port number that should be used. @@ -31,31 +32,158 @@ const ( ) var ( + // effectiveMaxBlocks is the number of available port blocks. + // lowPort + effectiveMaxBlocks * blockSize must be less than 65535. + effectiveMaxBlocks int + // firstPort is the first port of the allocated block. firstPort int // lockLn is the system-wide mutex for the port block. lockLn net.Listener - // mu guards nextPort + // mu guards: + // - pendingPorts + // - freePorts + // - total mu sync.Mutex // once is used to do the initialization on the first call to retrieve free // ports once sync.Once - // port is the last allocated port. - port int + // condNotEmpty is a condition variable to wait for freePorts to be not + // empty. Linked to 'mu' + condNotEmpty *sync.Cond + + // freePorts is a FIFO of all currently free ports. Take from the front, + // and return to the back. + freePorts *list.List + + // pendingPorts is a FIFO of recently freed ports that have not yet passed + // the not-in-use check. + pendingPorts *list.List + + // total is the total number of available ports in the block for use. + total int ) // initialize is used to initialize freeport. func initialize() { - if lowPort+maxBlocks*blockSize > 65535 { + var err error + effectiveMaxBlocks, err = adjustMaxBlocks() + if err != nil { + panic("freeport: ephemeral port range detection failed: " + err.Error()) + } + if effectiveMaxBlocks < 0 { + panic("freeport: no blocks of ports available outside of ephemeral range") + } + if lowPort+effectiveMaxBlocks*blockSize > 65535 { panic("freeport: block size too big or too many blocks requested") } rand.Seed(time.Now().UnixNano()) firstPort, lockLn = alloc() + + condNotEmpty = sync.NewCond(&mu) + freePorts = list.New() + pendingPorts = list.New() + + // fill with all available free ports + for port := firstPort + 1; port < firstPort+blockSize; port++ { + if used := isPortInUse(port); !used { + freePorts.PushBack(port) + } + } + total = freePorts.Len() + + go checkFreedPorts() +} + +// reset will reverse the setup from initialize() and then redo it (for tests) +func reset() { + mu.Lock() + defer mu.Unlock() + + logf("INFO", "resetting the freeport package state") + + effectiveMaxBlocks = 0 + firstPort = 0 + if lockLn != nil { + lockLn.Close() + lockLn = nil + } + + once = sync.Once{} + + freePorts = nil + pendingPorts = nil + total = 0 +} + +func checkFreedPorts() { + ticker := time.NewTicker(250 * time.Millisecond) + for { + <-ticker.C + checkFreedPortsOnce() + } +} + +func checkFreedPortsOnce() { + mu.Lock() + defer mu.Unlock() + + pending := pendingPorts.Len() + remove := make([]*list.Element, 0, pending) + for elem := pendingPorts.Front(); elem != nil; elem = elem.Next() { + port := elem.Value.(int) + if used := isPortInUse(port); !used { + freePorts.PushBack(port) + remove = append(remove, elem) + } + } + + retained := pending - len(remove) + + if retained > 0 { + logf("WARN", "%d out of %d pending ports are still in use; something probably didn't wait around for the port to be closed!", retained, pending) + } + + if len(remove) == 0 { + return + } + + for _, elem := range remove { + pendingPorts.Remove(elem) + } + + condNotEmpty.Broadcast() +} + +// adjustMaxBlocks avoids having the allocation ranges overlap the ephemeral +// port range. +func adjustMaxBlocks() (int, error) { + ephemeralPortMin, ephemeralPortMax, err := getEphemeralPortRange() + if err != nil { + return 0, err + } + + if ephemeralPortMin <= 0 || ephemeralPortMax <= 0 { + logf("INFO", "ephemeral port range detection not configured for GOOS=%q", runtime.GOOS) + return maxBlocks, nil + } + + logf("INFO", "detected ephemeral port range of [%d, %d]", ephemeralPortMin, ephemeralPortMax) + for block := 0; block < maxBlocks; block++ { + min := lowPort + block*blockSize + max := min + blockSize + overlap := intervalOverlap(min, max-1, ephemeralPortMin, ephemeralPortMax) + if overlap { + logf("INFO", "reducing max blocks from %d to %d to avoid the ephemeral port range", maxBlocks, block) + return block, nil + } + } + return maxBlocks, nil } // alloc reserves a port block for exclusive use for the lifetime of the @@ -64,76 +192,154 @@ func initialize() { // be automatically released when the application terminates. func alloc() (int, net.Listener) { for i := 0; i < attempts; i++ { - block := int(rand.Int31n(int32(maxBlocks))) + block := int(rand.Int31n(int32(effectiveMaxBlocks))) firstPort := lowPort + block*blockSize ln, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", firstPort)) if err != nil { continue } - // log.Printf("[DEBUG] freeport: allocated port block %d (%d-%d)", block, firstPort, firstPort+blockSize-1) + // logf("DEBUG", "allocated port block %d (%d-%d)", block, firstPort, firstPort+blockSize-1) return firstPort, ln } panic("freeport: cannot allocate port block") } +// MustTake is the same as Take except it panics on error. +func MustTake(n int) (ports []int) { + ports, err := Take(n) + if err != nil { + panic(err) + } + return ports +} + +// Take returns a list of free ports from the allocated port block. It is safe +// to call this method concurrently. Ports have been tested to be available on +// 127.0.0.1 TCP but there is no guarantee that they will remain free in the +// future. +func Take(n int) (ports []int, err error) { + if n <= 0 { + return nil, fmt.Errorf("freeport: cannot take %d ports", n) + } + + mu.Lock() + defer mu.Unlock() + + // Reserve a port block + once.Do(initialize) + + if n > total { + return nil, fmt.Errorf("freeport: block size too small") + } + + for len(ports) < n { + for freePorts.Len() == 0 { + if total == 0 { + return nil, fmt.Errorf("freeport: impossible to satisfy request; there are no actual free ports in the block anymore") + } + condNotEmpty.Wait() + } + + elem := freePorts.Front() + freePorts.Remove(elem) + port := elem.Value.(int) + + if used := isPortInUse(port); used { + // Something outside of the test suite has stolen this port, possibly + // due to assignment to an ephemeral port, remove it completely. + logf("WARN", "leaked port %d due to theft; removing from circulation", port) + total-- + continue + } + + ports = append(ports, port) + } + + // logf("DEBUG", "free ports: %v", ports) + return ports, nil +} + +// peekFree returns the next port that will be returned by Take to aid in testing. +func peekFree() int { + mu.Lock() + defer mu.Unlock() + return freePorts.Front().Value.(int) +} + +// peekAllFree returns all free ports that could be returned by Take to aid in testing. +func peekAllFree() []int { + mu.Lock() + defer mu.Unlock() + + var out []int + for elem := freePorts.Front(); elem != nil; elem = elem.Next() { + port := elem.Value.(int) + out = append(out, port) + } + + return out +} + +// stats returns diagnostic data to aid in testing +func stats() (numTotal, numPending, numFree int) { + mu.Lock() + defer mu.Unlock() + return total, pendingPorts.Len(), freePorts.Len() +} + +// Return returns a block of ports back to the general pool. These ports should +// have been returned from a call to Take(). +func Return(ports []int) { + if len(ports) == 0 { + return // convenience short circuit for test ergonomics + } + + mu.Lock() + defer mu.Unlock() + + for _, port := range ports { + if port > firstPort && port < firstPort+blockSize { + pendingPorts.PushBack(port) + } + } +} + +func isPortInUse(port int) bool { + ln, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", port)) + if err != nil { + return true + } + ln.Close() + return false +} + func tcpAddr(ip string, port int) *net.TCPAddr { return &net.TCPAddr{IP: net.ParseIP(ip), Port: port} } -// Get wraps the Free function and panics on any failure retrieving ports. -func Get(n int) (ports []int) { - ports, err := Free(n) - if err != nil { - panic(err) +// intervalOverlap returns true if the doubly-inclusive integer intervals +// represented by [min1, max1] and [min2, max2] overlap. +func intervalOverlap(min1, max1, min2, max2 int) bool { + if min1 > max1 { + logf("WARN", "interval1 is not ordered [%d, %d]", min1, max1) + return false } - - return ports + if min2 > max2 { + logf("WARN", "interval2 is not ordered [%d, %d]", min2, max2) + return false + } + return min1 <= max2 && min2 <= max1 } -// GetT is suitable for use when retrieving unused ports in tests. If there is -// an error retrieving free ports, the test will be failed. -func GetT(t testing.T, n int) (ports []int) { - ports, err := Free(n) - if err != nil { - t.Fatalf("Failed retrieving free port: %v", err) - } - - return ports +func logf(severity string, format string, a ...interface{}) { + fmt.Fprintf(os.Stderr, "["+severity+"] freeport: "+format+"\n", a...) } -// Free returns a list of free ports from the allocated port block. It is safe -// to call this method concurrently. Ports have been tested to be available on -// 127.0.0.1 TCP but there is no guarantee that they will remain free in the -// future. -func Free(n int) (ports []int, err error) { - mu.Lock() - defer mu.Unlock() +// Deprecated: Please use Take/Return calls instead. +func Get(n int) (ports []int) { return MustTake(n) } - if n > blockSize-1 { - return nil, fmt.Errorf("freeport: block size too small") - } +// Deprecated: Please use Take/Return calls instead. +func GetT(t testing.T, n int) (ports []int) { return MustTake(n) } - // Reserve a port block - once.Do(initialize) - - for len(ports) < n { - port++ - - // roll-over the port - if port < firstPort+1 || port >= firstPort+blockSize { - port = firstPort + 1 - } - - // if the port is in use then skip it - ln, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", port)) - if err != nil { - // log.Println("[DEBUG] freeport: port already in use: ", port) - continue - } - ln.Close() - - ports = append(ports, port) - } - // log.Println("[DEBUG] freeport: free ports:", ports) - return ports, nil -} +// Deprecated: Please use Take/Return calls instead. +func Free(n int) (ports []int, err error) { return MustTake(n), nil } diff --git a/sdk/freeport/freeport_test.go b/sdk/freeport/freeport_test.go new file mode 100644 index 0000000000..598ac789c1 --- /dev/null +++ b/sdk/freeport/freeport_test.go @@ -0,0 +1,231 @@ +package freeport + +import ( + "fmt" + "io" + "net" + "testing" + + "github.com/hashicorp/consul/sdk/testutil/retry" +) + +func TestTakeReturn(t *testing.T) { + // NOTE: for global var reasons this cannot execute in parallel + // t.Parallel() + + // Since this test is destructive (i.e. it leaks all ports) it means that + // any other test cases in this package will not function after it runs. To + // help out we reset the global state after we run this test. + defer reset() + + // OK: do a simple take/return cycle to trigger the package initialization + func() { + ports, err := Take(1) + if err != nil { + t.Fatalf("err: %v", err) + } + defer Return(ports) + + if len(ports) != 1 { + t.Fatalf("expected %d but got %d ports", 1, len(ports)) + } + }() + + waitForStatsReset := func() (numTotal int) { + t.Helper() + numTotal, numPending, numFree := stats() + if numTotal != numFree+numPending { + t.Fatalf("expected total (%d) and free+pending (%d) ports to match", numTotal, numFree+numPending) + } + retry.Run(t, func(r *retry.R) { + numTotal, numPending, numFree = stats() + if numPending != 0 { + r.Fatalf("pending is still non zero: %d", numPending) + } + if numTotal != numFree { + r.Fatalf("total (%d) does not equal free (%d)", numTotal, numFree) + } + }) + return numTotal + } + + // Reset + numTotal := waitForStatsReset() + + // -------------------- + // OK: take the max + func() { + ports, err := Take(numTotal) + if err != nil { + t.Fatalf("err: %v", err) + } + defer Return(ports) + + if len(ports) != numTotal { + t.Fatalf("expected %d but got %d ports", numTotal, len(ports)) + } + }() + + // Reset + numTotal = waitForStatsReset() + + expectError := func(expected string, got error) { + t.Helper() + if got == nil { + t.Fatalf("expected error but was nil") + } + if got.Error() != expected { + t.Fatalf("expected error %q but got %q", expected, got.Error()) + } + } + + // -------------------- + // ERROR: take too many ports + func() { + ports, err := Take(numTotal + 1) + defer Return(ports) + expectError("freeport: block size too small", err) + }() + + // -------------------- + // ERROR: invalid ports request (negative) + func() { + _, err := Take(-1) + expectError("freeport: cannot take -1 ports", err) + }() + + // -------------------- + // ERROR: invalid ports request (zero) + func() { + _, err := Take(0) + expectError("freeport: cannot take 0 ports", err) + }() + + // -------------------- + // OK: Steal a port under the covers and let freeport detect the theft and compensate + leakedPort := peekFree() + func() { + leakyListener, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", leakedPort)) + if err != nil { + t.Fatalf("err: %v", err) + } + defer leakyListener.Close() + + func() { + ports, err := Take(3) + if err != nil { + t.Fatalf("err: %v", err) + } + defer Return(ports) + + if len(ports) != 3 { + t.Fatalf("expected %d but got %d ports", 3, len(ports)) + } + + for _, port := range ports { + if port == leakedPort { + t.Fatalf("did not expect for Take to return the leaked port") + } + } + }() + + newNumTotal := waitForStatsReset() + if newNumTotal != numTotal-1 { + t.Fatalf("expected total to drop to %d but got %d", numTotal-1, newNumTotal) + } + numTotal = newNumTotal // update outer variable for later tests + }() + + // -------------------- + // OK: sequence it so that one Take must wait on another Take to Return. + func() { + mostPorts, err := Take(numTotal - 5) + if err != nil { + t.Fatalf("err: %v", err) + } + + type reply struct { + ports []int + err error + } + ch := make(chan reply, 1) + go func() { + ports, err := Take(10) + ch <- reply{ports: ports, err: err} + }() + + Return(mostPorts) + + r := <-ch + if r.err != nil { + t.Fatalf("err: %v", r.err) + } + defer Return(r.ports) + + if len(r.ports) != 10 { + t.Fatalf("expected %d ports but got %d", 10, len(r.ports)) + } + }() + + // Reset + numTotal = waitForStatsReset() + + // -------------------- + // ERROR: Now we end on the crazy "Ocean's 11" level port theft where we + // orchestrate a situation where all ports are stolen and we don't find out + // until Take. + func() { + // 1. Grab all of the ports. + allPorts := peekAllFree() + + // 2. Leak all of the ports + leaked := make([]io.Closer, 0, len(allPorts)) + defer func() { + for _, c := range leaked { + c.Close() + } + }() + for _, port := range allPorts { + ln, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", port)) + if err != nil { + t.Fatalf("err: %v", err) + } + leaked = append(leaked, ln) + } + + // 3. Request 1 port which will detect the leaked ports and fail. + _, err := Take(1) + expectError("freeport: impossible to satisfy request; there are no actual free ports in the block anymore", err) + + // 4. Wait for the block to zero out. + newNumTotal := waitForStatsReset() + if newNumTotal != 0 { + t.Fatalf("expected total to drop to %d but got %d", 0, newNumTotal) + } + }() +} + +func TestIntervalOverlap(t *testing.T) { + cases := []struct { + min1, max1, min2, max2 int + overlap bool + }{ + {0, 0, 0, 0, true}, + {1, 1, 1, 1, true}, + {1, 3, 1, 3, true}, // same + {1, 3, 4, 6, false}, // serial + {1, 4, 3, 6, true}, // inner overlap + {1, 6, 3, 4, true}, // nest + } + + for _, tc := range cases { + t.Run(fmt.Sprintf("%d:%d vs %d:%d", tc.min1, tc.max1, tc.min2, tc.max2), func(t *testing.T) { + if tc.overlap != intervalOverlap(tc.min1, tc.max1, tc.min2, tc.max2) { // 1 vs 2 + t.Fatalf("expected %v but got %v", tc.overlap, !tc.overlap) + } + if tc.overlap != intervalOverlap(tc.min2, tc.max2, tc.min1, tc.max1) { // 2 vs 1 + t.Fatalf("expected %v but got %v", tc.overlap, !tc.overlap) + } + }) + } +} diff --git a/sdk/testutil/server.go b/sdk/testutil/server.go index d10a7bc5ae..80600730bc 100644 --- a/sdk/testutil/server.go +++ b/sdk/testutil/server.go @@ -104,6 +104,7 @@ type TestServerConfig struct { ReadyTimeout time.Duration `json:"-"` Stdout, Stderr io.Writer `json:"-"` Args []string `json:"-"` + ReturnPorts func() `json:"-"` } type TestACLs struct { @@ -138,7 +139,8 @@ func defaultServerConfig() *TestServerConfig { panic(err) } - ports := freeport.Get(6) + ports := freeport.MustTake(6) + return &TestServerConfig{ NodeName: "node-" + nodeID, NodeID: nodeID, @@ -167,6 +169,9 @@ func defaultServerConfig() *TestServerConfig { "cluster_id": "11111111-2222-3333-4444-555555555555", }, }, + ReturnPorts: func() { + freeport.Return(ports) + }, } } @@ -244,6 +249,7 @@ func newTestServerConfigT(t *testing.T, cb ServerConfigCallback) (*TestServer, e } cfg := defaultServerConfig() + cfg.DataDir = filepath.Join(tmpdir, "data") if cb != nil { cb(cfg) @@ -251,6 +257,7 @@ func newTestServerConfigT(t *testing.T, cb ServerConfigCallback) (*TestServer, e b, err := json.Marshal(cfg) if err != nil { + cfg.ReturnPorts() os.RemoveAll(tmpdir) return nil, errors.Wrap(err, "failed marshaling json") } @@ -258,6 +265,7 @@ func newTestServerConfigT(t *testing.T, cb ServerConfigCallback) (*TestServer, e log.Printf("CONFIG JSON: %s", string(b)) configFile := filepath.Join(tmpdir, "config.json") if err := ioutil.WriteFile(configFile, b, 0644); err != nil { + cfg.ReturnPorts() os.RemoveAll(tmpdir) return nil, errors.Wrap(err, "failed writing config content") } @@ -278,6 +286,7 @@ func newTestServerConfigT(t *testing.T, cb ServerConfigCallback) (*TestServer, e cmd.Stdout = stdout cmd.Stderr = stderr if err := cmd.Start(); err != nil { + cfg.ReturnPorts() os.RemoveAll(tmpdir) return nil, errors.Wrap(err, "failed starting command") } @@ -319,6 +328,7 @@ func newTestServerConfigT(t *testing.T, cb ServerConfigCallback) (*TestServer, e // Stop stops the test Consul server, and removes the Consul data // directory once we are done. func (s *TestServer) Stop() error { + defer s.Config.ReturnPorts() defer os.RemoveAll(s.tmpdir) // There was no process