package loadbalancer import ( "bufio" "context" "fmt" "net" "strconv" "strings" "testing" "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/sirupsen/logrus" ) func Test_UnitLoadBalancer(t *testing.T) { _, reporterConfig := GinkgoConfiguration() reporterConfig.Verbose = testing.Verbose() RegisterFailHandler(Fail) RunSpecs(t, "LoadBalancer Suite", reporterConfig) } func init() { logrus.SetLevel(logrus.DebugLevel) } type testServer struct { address string listener net.Listener conns []net.Conn prefix string } func createServer(ctx context.Context, prefix string) (*testServer, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, err } s := &testServer{ prefix: prefix, listener: listener, address: listener.Addr().String(), } go s.serve() go func() { <-ctx.Done() s.close() }() return s, nil } func (s *testServer) serve() { for { conn, err := s.listener.Accept() if err != nil { return } s.conns = append(s.conns, conn) go s.echo(conn) } } func (s *testServer) close() { logrus.Printf("testServer %s closing", s.prefix) s.address = "" s.listener.Close() for _, conn := range s.conns { conn.Close() } } func (s *testServer) echo(conn net.Conn) { for { result, err := bufio.NewReader(conn).ReadString('\n') if err != nil { return } conn.Write([]byte(s.prefix + ":" + result)) } } func ping(conn net.Conn) (string, error) { fmt.Fprintf(conn, "ping\n") result, err := bufio.NewReader(conn).ReadString('\n') if err != nil { return "", err } return strings.TrimSpace(result), nil } var _ = Describe("LoadBalancer", func() { // creates a LB using a default server (ie fixed registration endpoint) // and then adds a new server (a node). The node server is then closed, and it is confirmed // that new connections use the default server. When("loadbalancer is running", Ordered, func() { ctx, cancel := context.WithCancel(context.Background()) var defaultServer, node1Server, node2Server *testServer var conn1, conn2, conn3, conn4 net.Conn var lb *LoadBalancer var err error BeforeAll(func() { tmpDir := GinkgoT().TempDir() defaultServer, err = createServer(ctx, "default") Expect(err).NotTo(HaveOccurred(), "createServer(default) failed") node1Server, err = createServer(ctx, "node1") Expect(err).NotTo(HaveOccurred(), "createServer(node1) failed") node2Server, err = createServer(ctx, "node2") Expect(err).NotTo(HaveOccurred(), "createServer(node2) failed") // start the loadbalancer with the default server as the only server lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address, RandomPort, false) Expect(err).NotTo(HaveOccurred(), "New() failed") }) AfterAll(func() { cancel() }) It("adds node1 as a server", func() { // add the node as a new server address. lb.Update([]string{node1Server.address}) lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK }) By(fmt.Sprintf("Added node1 server: %v", lb.servers.getServers())) // wait for state to change Eventually(func() state { if s := lb.servers.getServer(node1Server.address); s != nil { return s.state } return stateInvalid }, 5, 1).Should(Equal(statePreferred)) }) It("connects to node1", func() { // make sure connections go to the node conn1, err = net.Dial("tcp", lb.localAddress) Expect(err).NotTo(HaveOccurred(), "net.Dial failed") Expect(ping(conn1)).To(Equal("node1:ping"), "Unexpected ping(conn1) result") By("conn1 tested OK") }) It("changes node1 state to failed", func() { // set failing health check for node 1 lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultFailed }) // wait for state to change Eventually(func() state { if s := lb.servers.getServer(node1Server.address); s != nil { return s.state } return stateInvalid }, 5, 1).Should(Equal(stateFailed)) }) It("disconnects from node1", func() { // Server connections are checked every second, now that node 1 is failed // the connections to it should be closed. Expect(ping(conn1)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1") By("conn1 closed on failure OK") // connections shoould go to the default now that node 1 is failed conn2, err = net.Dial("tcp", lb.localAddress) Expect(err).NotTo(HaveOccurred(), "net.Dial failed") Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result") By("conn2 tested OK") }) It("does not close connections unexpectedly", func() { // make sure the health checks don't close the connection we just made - // connections should only be closed when it transitions from health to unhealthy. time.Sleep(2 * time.Second) Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result") By("conn2 tested OK again") }) It("closes connections when dial fails", func() { // shut down the first node server to force failover to the default node1Server.close() // make sure new connections go to the default, and existing connections are closed conn3, err = net.Dial("tcp", lb.localAddress) Expect(err).NotTo(HaveOccurred(), "net.Dial failed") Expect(ping(conn3)).To(Equal("default:ping"), "Unexpected ping(conn3) result") By("conn3 tested OK") }) It("replaces node2 as a server", func() { // add the second node as a new server address. lb.Update([]string{node2Server.address}) lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK }) By(fmt.Sprintf("Added node2 server: %v", lb.servers.getServers())) // wait for state to change Eventually(func() state { if s := lb.servers.getServer(node2Server.address); s != nil { return s.state } return stateInvalid }, 5, 1).Should(Equal(statePreferred)) }) It("connects to node2", func() { // make sure connection now goes to the second node, // and connections to the default are closed. conn4, err = net.Dial("tcp", lb.localAddress) Expect(err).NotTo(HaveOccurred(), "net.Dial failed") Expect(ping(conn4)).To(Equal("node2:ping"), "Unexpected ping(conn3) result") By("conn4 tested OK") }) It("does not close connections unexpectedly", func() { // Server connections are checked every second, now that we have a healthy // server, connections to the default server should be closed time.Sleep(2 * time.Second) Expect(ping(conn2)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1") By("conn2 closed on failure OK") Expect(ping(conn3)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1") By("conn3 closed on failure OK") }) It("adds default as a server", func() { // add the default as a full server lb.Update([]string{node2Server.address, defaultServer.address}) lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK }) // wait for state to change Eventually(func() state { if s := lb.servers.getServer(defaultServer.address); s != nil { return s.state } return stateInvalid }, 5, 1).Should(Equal(statePreferred)) By(fmt.Sprintf("Default server added: %v", lb.servers.getServers())) }) It("returns the default server in the address list", func() { // confirm that both servers are listed in the address list Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address, defaultServer.address)) // confirm that the default is still listed as default Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default") }) It("does not return the default server in the address list after removing it", func() { // remove the default as a server lb.Update([]string{node2Server.address}) By(fmt.Sprintf("Default removed: %v", lb.servers.getServers())) // confirm that it is not listed as a server Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address)) // but is still listed as the default Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default") }) It("removes default server when no longer default", func() { // set node2 as the default lb.SetDefault(node2Server.address) By(fmt.Sprintf("Default set: %v", lb.servers.getServers())) // confirm that it is still listed as a server Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address)) // and is listed as the default Expect(lb.servers.getDefaultAddress()).To(Equal(node2Server.address), "node2 server is not default") }) It("sets all three servers", func() { // set node2 as the default lb.SetDefault(defaultServer.address) By(fmt.Sprintf("Default set: %v", lb.servers.getServers())) lb.Update([]string{node1Server.address, node2Server.address, defaultServer.address}) lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK }) lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK }) lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK }) // wait for state to change Eventually(func() state { if s := lb.servers.getServer(defaultServer.address); s != nil { return s.state } return stateInvalid }, 5, 1).Should(Equal(statePreferred)) By(fmt.Sprintf("All servers set: %v", lb.servers.getServers())) // confirm that it is still listed as a server Expect(lb.ServerAddresses()).To(ConsistOf(node1Server.address, node2Server.address, defaultServer.address)) // and is listed as the default Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default") }) }) // confirms that the loadbalancer will not dial itself When("the default server is the loadbalancer", Ordered, func() { ctx, cancel := context.WithCancel(context.Background()) var defaultServer *testServer var lb *LoadBalancer var err error BeforeAll(func() { tmpDir := GinkgoT().TempDir() defaultServer, err = createServer(ctx, "default") Expect(err).NotTo(HaveOccurred(), "createServer(default) failed") address := defaultServer.address defaultServer.close() _, port, _ := net.SplitHostPort(address) intPort, _ := strconv.Atoi(port) lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+address, intPort, false) Expect(err).NotTo(HaveOccurred(), "New() failed") }) AfterAll(func() { cancel() }) It("fails immediately", func() { conn, err := net.Dial("tcp", lb.localAddress) Expect(err).NotTo(HaveOccurred(), "net.Dial failed") _, err = ping(conn) Expect(err).To(HaveOccurred(), "Unexpected successful ping on failed connection") }) }) // confirms that connnections to invalid addresses fail quickly When("there are no valid addresses", Ordered, func() { ctx, cancel := context.WithCancel(context.Background()) var lb *LoadBalancer var err error BeforeAll(func() { tmpDir := GinkgoT().TempDir() lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://127.0.0.1:0/", RandomPort, false) Expect(err).NotTo(HaveOccurred(), "New() failed") }) AfterAll(func() { cancel() }) It("fails fast", func() { conn, err := net.Dial("tcp", lb.localAddress) Expect(err).NotTo(HaveOccurred(), "net.Dial failed") done := make(chan error) go func() { _, err = ping(conn) done <- err }() timeout := time.After(10 * time.Millisecond) select { case err := <-done: if err == nil { Fail("Unexpected successful ping from invalid address") } case <-timeout: Fail("Test timed out") } }) }) // confirms that connnections to unreachable addresses do fail within the // expected duration When("the server is unreachable", Ordered, func() { ctx, cancel := context.WithCancel(context.Background()) var lb *LoadBalancer var err error BeforeAll(func() { tmpDir := GinkgoT().TempDir() lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://192.0.2.1:6443", RandomPort, false) Expect(err).NotTo(HaveOccurred(), "New() failed") }) AfterAll(func() { cancel() }) It("fails with the correct timeout", func() { conn, err := net.Dial("tcp", lb.localAddress) Expect(err).NotTo(HaveOccurred(), "net.Dial failed") done := make(chan error) go func() { _, err = ping(conn) done <- err }() timeout := time.After(11 * time.Second) select { case err := <-done: if err == nil { Fail("Unexpected successful ping from unreachable address") } case <-timeout: Fail("Test timed out") } }) }) })