From e3046120b359d48e1b32f2e9d977571cd5f3068f Mon Sep 17 00:00:00 2001 From: "Chris S. Kim" Date: Tue, 9 Aug 2022 12:22:39 -0400 Subject: [PATCH] Close active listeners on error If startListeners successfully created listeners for some of its input addresses but eventually failed, the function would return an error and existing listeners would not be cleaned up. --- .changelog/14081.txt | 3 ++ agent/agent.go | 19 +++++++++++-- agent/agent_test.go | 67 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 3 deletions(-) create mode 100644 .changelog/14081.txt diff --git a/.changelog/14081.txt b/.changelog/14081.txt new file mode 100644 index 0000000000..ccb03ffb03 --- /dev/null +++ b/.changelog/14081.txt @@ -0,0 +1,3 @@ +```release-note:bug +agent: Fixes an issue where an agent that fails to start due to bad addresses won't clean up any existing listeners +``` diff --git a/agent/agent.go b/agent/agent.go index e087af5187..8a263647e9 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -863,8 +863,18 @@ func (a *Agent) listenAndServeDNS() error { return merr.ErrorOrNil() } +// startListeners will return a net.Listener for every address unless an +// error is encountered, in which case it will close all previously opened +// listeners and return the error. func (a *Agent) startListeners(addrs []net.Addr) ([]net.Listener, error) { - var ln []net.Listener + var lns []net.Listener + + closeAll := func() { + for _, l := range lns { + l.Close() + } + } + for _, addr := range addrs { var l net.Listener var err error @@ -873,22 +883,25 @@ func (a *Agent) startListeners(addrs []net.Addr) ([]net.Listener, error) { case *net.UnixAddr: l, err = a.listenSocket(x.Name) if err != nil { + closeAll() return nil, err } case *net.TCPAddr: l, err = net.Listen("tcp", x.String()) if err != nil { + closeAll() return nil, err } l = &tcpKeepAliveListener{l.(*net.TCPListener)} default: + closeAll() return nil, fmt.Errorf("unsupported address type %T", addr) } - ln = append(ln, l) + lns = append(lns, l) } - return ln, nil + return lns, nil } // listenHTTP binds listeners to the provided addresses and also returns diff --git a/agent/agent_test.go b/agent/agent_test.go index d7b118fcba..8bae81ce45 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -5857,6 +5857,73 @@ func Test_coalesceTimerTwoPeriods(t *testing.T) { } +func TestAgent_startListeners(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + t.Parallel() + + ports := freeport.GetN(t, 3) + bd := BaseDeps{ + Deps: consul.Deps{ + Logger: hclog.NewInterceptLogger(nil), + Tokens: new(token.Store), + GRPCConnPool: &fakeGRPCConnPool{}, + }, + RuntimeConfig: &config.RuntimeConfig{ + HTTPAddrs: []net.Addr{}, + }, + Cache: cache.New(cache.Options{}), + } + + bd, err := initEnterpriseBaseDeps(bd, nil) + require.NoError(t, err) + + agent, err := New(bd) + require.NoError(t, err) + + // use up an address + used := net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[2]} + l, err := net.Listen("tcp", used.String()) + require.NoError(t, err) + t.Cleanup(func() { l.Close() }) + + var lns []net.Listener + t.Cleanup(func() { + for _, ln := range lns { + ln.Close() + } + }) + + // first two addresses open listeners but third address should fail + lns, err = agent.startListeners([]net.Addr{ + &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[0]}, + &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[1]}, + &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[2]}, + }) + require.Contains(t, err.Error(), "address already in use") + + // first two ports should be freed up + retry.Run(t, func(r *retry.R) { + lns, err = agent.startListeners([]net.Addr{ + &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[0]}, + &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[1]}, + }) + require.NoError(r, err) + require.Len(r, lns, 2) + }) + + // first two ports should be in use + retry.Run(t, func(r *retry.R) { + _, err = agent.startListeners([]net.Addr{ + &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[0]}, + &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[1]}, + }) + require.Contains(r, err.Error(), "address already in use") + }) + +} + func getExpectedCaPoolByFile(t *testing.T) *x509.CertPool { pool := x509.NewCertPool() data, err := ioutil.ReadFile("../test/ca/root.cer")