diff --git a/pkg/agent/loadbalancer/loadbalancer_test.go b/pkg/agent/loadbalancer/loadbalancer_test.go index e91f52760e..1cb26736e0 100644 --- a/pkg/agent/loadbalancer/loadbalancer_test.go +++ b/pkg/agent/loadbalancer/loadbalancer_test.go @@ -3,18 +3,21 @@ package loadbalancer import ( "bufio" "context" - "errors" "fmt" "net" "net/url" - "os" "strings" "testing" "time" "github.com/k3s-io/k3s/pkg/cli/cmds" + "github.com/sirupsen/logrus" ) +func init() { + logrus.SetLevel(logrus.DebugLevel) +} + type testServer struct { listener net.Listener conns []net.Conn @@ -71,33 +74,17 @@ func ping(conn net.Conn) (string, error) { return strings.TrimSpace(result), nil } -func assertEqual(t *testing.T, a interface{}, b interface{}) { - if a != b { - t.Fatalf("[ %v != %v ]", a, b) - } -} - -func assertNotEqual(t *testing.T, a interface{}, b interface{}) { - if a == b { - t.Fatalf("[ %v == %v ]", a, b) - } -} - func Test_UnitFailOver(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "lb-test") - if err != nil { - assertEqual(t, err, nil) - } - defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() ogServe, err := createServer("og") if err != nil { - assertEqual(t, err, nil) + t.Fatalf("createServer(og) failed: %v", err) } lbServe, err := createServer("lb") if err != nil { - assertEqual(t, err, nil) + t.Fatalf("createServer(lb) failed: %v", err) } cfg := cmds.Agent{ @@ -107,12 +94,12 @@ func Test_UnitFailOver(t *testing.T) { lb, err := New(context.TODO(), cfg.DataDir, SupervisorServiceName, cfg.ServerURL, RandomPort, false) if err != nil { - assertEqual(t, err, nil) + t.Fatalf("New() failed: %v", err) } parsedURL, err := url.Parse(lb.LoadBalancerServerURL()) if err != nil { - assertEqual(t, err, nil) + t.Fatalf("url.Parse failed: %v", err) } localAddress := parsedURL.Host @@ -120,36 +107,39 @@ func Test_UnitFailOver(t *testing.T) { conn1, err := net.Dial("tcp", localAddress) if err != nil { - assertEqual(t, err, nil) + t.Fatalf("net.Dial failed: %v", err) } result1, err := ping(conn1) if err != nil { - assertEqual(t, err, nil) + t.Fatalf("ping(conn1) failed: %v", err) + } + if result1 != "lb:ping" { + t.Fatalf("Unexpected ping result: %v", result1) } - assertEqual(t, result1, "lb:ping") lbServe.close() _, err = ping(conn1) - assertNotEqual(t, err, nil) + if err == nil { + t.Fatal("Unexpected successful ping on closed connection conn1") + } conn2, err := net.Dial("tcp", localAddress) if err != nil { - assertEqual(t, err, nil) + t.Fatalf("net.Dial failed: %v", err) + } result2, err := ping(conn2) if err != nil { - assertEqual(t, err, nil) + t.Fatalf("ping(conn2) failed: %v", err) + } + if result2 != "og:ping" { + t.Fatalf("Unexpected ping result: %v", result2) } - assertEqual(t, result2, "og:ping") } func Test_UnitFailFast(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "lb-test") - if err != nil { - assertEqual(t, err, nil) - } - defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() cfg := cmds.Agent{ ServerURL: "http://127.0.0.1:0/", @@ -158,12 +148,12 @@ func Test_UnitFailFast(t *testing.T) { lb, err := New(context.TODO(), cfg.DataDir, SupervisorServiceName, cfg.ServerURL, RandomPort, false) if err != nil { - assertEqual(t, err, nil) + t.Fatalf("New() failed: %v", err) } conn, err := net.Dial("tcp", lb.localAddress) if err != nil { - assertEqual(t, err, nil) + t.Fatalf("net.Dial failed: %v", err) } done := make(chan error) @@ -175,8 +165,10 @@ func Test_UnitFailFast(t *testing.T) { select { case err := <-done: - assertNotEqual(t, err, nil) + if err == nil { + t.Fatal("Unexpected successful ping from invalid address") + } case <-timeout: - t.Fatal(errors.New("time out")) + t.Fatal("Test timed out") } }