k3s/pkg/agent/loadbalancer/loadbalancer_test.go

430 lines
13 KiB
Go
Raw Normal View History

2019-07-24 07:22:31 +00:00
package loadbalancer
import (
"bufio"
"context"
2019-07-24 07:22:31 +00:00
"fmt"
"net"
"strconv"
2019-07-24 07:22:31 +00:00
"strings"
"testing"
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/sirupsen/logrus"
2019-07-24 07:22:31 +00:00
)
func Test_UnitLoadBalancer(t *testing.T) {
_, reporterConfig := GinkgoConfiguration()
reporterConfig.Verbose = testing.Verbose()
RegisterFailHandler(Fail)
RunSpecs(t, "LoadBalancer Suite", reporterConfig)
}
func init() {
logrus.SetLevel(logrus.DebugLevel)
}
type testServer struct {
address string
2019-07-24 07:22:31 +00:00
listener net.Listener
conns []net.Conn
prefix string
}
func createServer(ctx context.Context, prefix string) (*testServer, error) {
2019-07-24 07:22:31 +00:00
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, err
}
s := &testServer{
2019-07-24 07:22:31 +00:00
prefix: prefix,
listener: listener,
address: listener.Addr().String(),
2019-07-24 07:22:31 +00:00
}
go s.serve()
go func() {
<-ctx.Done()
s.close()
}()
2019-07-24 07:22:31 +00:00
return s, nil
}
func (s *testServer) serve() {
2019-07-24 07:22:31 +00:00
for {
conn, err := s.listener.Accept()
if err != nil {
return
}
s.conns = append(s.conns, conn)
go s.echo(conn)
}
}
func (s *testServer) close() {
logrus.Printf("testServer %s closing", s.prefix)
s.address = ""
2019-07-24 07:22:31 +00:00
s.listener.Close()
for _, conn := range s.conns {
conn.Close()
}
}
func (s *testServer) echo(conn net.Conn) {
2019-07-24 07:22:31 +00:00
for {
result, err := bufio.NewReader(conn).ReadString('\n')
if err != nil {
return
}
conn.Write([]byte(s.prefix + ":" + result))
}
}
func ping(conn net.Conn) (string, error) {
fmt.Fprintf(conn, "ping\n")
result, err := bufio.NewReader(conn).ReadString('\n')
if err != nil {
return "", err
}
return strings.TrimSpace(result), nil
}
var _ = Describe("LoadBalancer", func() {
// creates a LB using a default server (ie fixed registration endpoint)
// and then adds a new server (a node). The node server is then closed, and it is confirmed
// that new connections use the default server.
When("loadbalancer is running", Ordered, func() {
ctx, cancel := context.WithCancel(context.Background())
var defaultServer, node1Server, node2Server *testServer
var conn1, conn2, conn3, conn4 net.Conn
var lb *LoadBalancer
var err error
BeforeAll(func() {
tmpDir := GinkgoT().TempDir()
defaultServer, err = createServer(ctx, "default")
Expect(err).NotTo(HaveOccurred(), "createServer(default) failed")
node1Server, err = createServer(ctx, "node1")
Expect(err).NotTo(HaveOccurred(), "createServer(node1) failed")
node2Server, err = createServer(ctx, "node2")
Expect(err).NotTo(HaveOccurred(), "createServer(node2) failed")
// start the loadbalancer with the default server as the only server
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address, RandomPort, false)
Expect(err).NotTo(HaveOccurred(), "New() failed")
})
AfterAll(func() {
cancel()
})
It("adds node1 as a server", func() {
// add the node as a new server address.
lb.Update([]string{node1Server.address})
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK })
By(fmt.Sprintf("Added node1 server: %v", lb.servers.getServers()))
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(node1Server.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(statePreferred))
})
It("connects to node1", func() {
// make sure connections go to the node
conn1, err = net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
Expect(ping(conn1)).To(Equal("node1:ping"), "Unexpected ping(conn1) result")
2019-07-24 07:22:31 +00:00
By("conn1 tested OK")
})
It("changes node1 state to failed", func() {
// set failing health check for node 1
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultFailed })
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(node1Server.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(stateFailed))
})
It("disconnects from node1", func() {
// Server connections are checked every second, now that node 1 is failed
// the connections to it should be closed.
Expect(ping(conn1)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
By("conn1 closed on failure OK")
// connections shoould go to the default now that node 1 is failed
conn2, err = net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result")
By("conn2 tested OK")
})
It("does not close connections unexpectedly", func() {
// make sure the health checks don't close the connection we just made -
// connections should only be closed when it transitions from health to unhealthy.
time.Sleep(2 * time.Second)
Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result")
By("conn2 tested OK again")
})
It("closes connections when dial fails", func() {
// shut down the first node server to force failover to the default
node1Server.close()
// make sure new connections go to the default, and existing connections are closed
conn3, err = net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
Expect(ping(conn3)).To(Equal("default:ping"), "Unexpected ping(conn3) result")
By("conn3 tested OK")
})
It("replaces node2 as a server", func() {
// add the second node as a new server address.
lb.Update([]string{node2Server.address})
lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK })
By(fmt.Sprintf("Added node2 server: %v", lb.servers.getServers()))
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(node2Server.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(statePreferred))
})
It("connects to node2", func() {
// make sure connection now goes to the second node,
// and connections to the default are closed.
conn4, err = net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
Expect(ping(conn4)).To(Equal("node2:ping"), "Unexpected ping(conn3) result")
By("conn4 tested OK")
})
It("does not close connections unexpectedly", func() {
// Server connections are checked every second, now that we have a healthy
// server, connections to the default server should be closed
time.Sleep(2 * time.Second)
Expect(ping(conn2)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
By("conn2 closed on failure OK")
Expect(ping(conn3)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
By("conn3 closed on failure OK")
})
It("adds default as a server", func() {
// add the default as a full server
lb.Update([]string{node2Server.address, defaultServer.address})
lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK })
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(defaultServer.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(statePreferred))
By(fmt.Sprintf("Default server added: %v", lb.servers.getServers()))
})
It("returns the default server in the address list", func() {
// confirm that both servers are listed in the address list
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address, defaultServer.address))
// confirm that the default is still listed as default
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
})
It("does not return the default server in the address list after removing it", func() {
// remove the default as a server
lb.Update([]string{node2Server.address})
By(fmt.Sprintf("Default removed: %v", lb.servers.getServers()))
// confirm that it is not listed as a server
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address))
// but is still listed as the default
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
})
It("removes default server when no longer default", func() {
// set node2 as the default
lb.SetDefault(node2Server.address)
By(fmt.Sprintf("Default set: %v", lb.servers.getServers()))
// confirm that it is still listed as a server
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address))
// and is listed as the default
Expect(lb.servers.getDefaultAddress()).To(Equal(node2Server.address), "node2 server is not default")
})
It("sets all three servers", func() {
// set node2 as the default
lb.SetDefault(defaultServer.address)
By(fmt.Sprintf("Default set: %v", lb.servers.getServers()))
lb.Update([]string{node1Server.address, node2Server.address, defaultServer.address})
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK })
lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK })
lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK })
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(defaultServer.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(statePreferred))
By(fmt.Sprintf("All servers set: %v", lb.servers.getServers()))
// confirm that it is still listed as a server
Expect(lb.ServerAddresses()).To(ConsistOf(node1Server.address, node2Server.address, defaultServer.address))
// and is listed as the default
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
})
})
// confirms that the loadbalancer will not dial itself
When("the default server is the loadbalancer", Ordered, func() {
ctx, cancel := context.WithCancel(context.Background())
var defaultServer *testServer
var lb *LoadBalancer
var err error
BeforeAll(func() {
tmpDir := GinkgoT().TempDir()
defaultServer, err = createServer(ctx, "default")
Expect(err).NotTo(HaveOccurred(), "createServer(default) failed")
address := defaultServer.address
defaultServer.close()
_, port, _ := net.SplitHostPort(address)
intPort, _ := strconv.Atoi(port)
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+address, intPort, false)
Expect(err).NotTo(HaveOccurred(), "New() failed")
})
AfterAll(func() {
cancel()
})
It("fails immediately", func() {
conn, err := net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
_, err = ping(conn)
Expect(err).To(HaveOccurred(), "Unexpected successful ping on failed connection")
})
})
// confirms that connnections to invalid addresses fail quickly
When("there are no valid addresses", Ordered, func() {
ctx, cancel := context.WithCancel(context.Background())
var lb *LoadBalancer
var err error
BeforeAll(func() {
tmpDir := GinkgoT().TempDir()
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://127.0.0.1:0/", RandomPort, false)
Expect(err).NotTo(HaveOccurred(), "New() failed")
})
AfterAll(func() {
cancel()
})
It("fails fast", func() {
conn, err := net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
done := make(chan error)
go func() {
_, err = ping(conn)
done <- err
}()
timeout := time.After(10 * time.Millisecond)
select {
case err := <-done:
if err == nil {
Fail("Unexpected successful ping from invalid address")
}
case <-timeout:
Fail("Test timed out")
}
})
})
// confirms that connnections to unreachable addresses do fail within the
// expected duration
When("the server is unreachable", Ordered, func() {
ctx, cancel := context.WithCancel(context.Background())
var lb *LoadBalancer
var err error
BeforeAll(func() {
tmpDir := GinkgoT().TempDir()
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://192.0.2.1:6443", RandomPort, false)
Expect(err).NotTo(HaveOccurred(), "New() failed")
})
AfterAll(func() {
cancel()
})
It("fails with the correct timeout", func() {
conn, err := net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
done := make(chan error)
go func() {
_, err = ping(conn)
done <- err
}()
timeout := time.After(11 * time.Second)
select {
case err := <-done:
if err == nil {
Fail("Unexpected successful ping from unreachable address")
}
case <-timeout:
Fail("Test timed out")
}
})
})
})