diff --git a/cmd/kubeadm/app/cmd/phases/certs.go b/cmd/kubeadm/app/cmd/phases/certs.go index 36fe4ff67f..778cd6c368 100644 --- a/cmd/kubeadm/app/cmd/phases/certs.go +++ b/cmd/kubeadm/app/cmd/phases/certs.go @@ -91,7 +91,6 @@ func newCertSubPhases() []workflow.Phase { Short: "Generates all certificates", InheritFlags: getCertPhaseFlags("all"), RunAllSiblings: true, - LocalFlags: localFlags(), } subPhases = append(subPhases, allPhase) @@ -104,6 +103,7 @@ func newCertSubPhases() []workflow.Phase { for _, cert := range certList { certPhase := newCertSubPhase(cert, runCertPhase(cert, ca)) + certPhase.LocalFlags = localFlags() subPhases = append(subPhases, certPhase) } } @@ -133,7 +133,6 @@ func newCertSubPhase(certSpec *certsphase.KubeadmCert, run func(c workflow.RunDa ), Run: run, InheritFlags: getCertPhaseFlags(certSpec.Name), - LocalFlags: localFlags(), } return phase } diff --git a/cmd/kubeadm/test/cmd/BUILD b/cmd/kubeadm/test/cmd/BUILD index d7abaae45b..fa944de0ef 100644 --- a/cmd/kubeadm/test/cmd/BUILD +++ b/cmd/kubeadm/test/cmd/BUILD @@ -36,6 +36,7 @@ go_test( "//cmd/kubeadm/app/phases/certs:go_default_library", "//cmd/kubeadm/app/util/pkiutil:go_default_library", "//cmd/kubeadm/test:go_default_library", + "//vendor/github.com/pkg/errors:go_default_library", "//vendor/github.com/renstrom/dedent:go_default_library", "//vendor/sigs.k8s.io/yaml:go_default_library", ], diff --git a/cmd/kubeadm/test/cmd/init_test.go b/cmd/kubeadm/test/cmd/init_test.go index f005be75d2..e986047ec6 100644 --- a/cmd/kubeadm/test/cmd/init_test.go +++ b/cmd/kubeadm/test/cmd/init_test.go @@ -17,8 +17,11 @@ limitations under the License. package kubeadm import ( + "os/exec" + "strings" "testing" + "github.com/pkg/errors" "github.com/renstrom/dedent" "k8s.io/kubernetes/cmd/kubeadm/app/phases/certs" "k8s.io/kubernetes/cmd/kubeadm/app/util/pkiutil" @@ -200,24 +203,62 @@ func TestCmdInitCertPhaseCSR(t *testing.T) { t.Skip() } - csrDir := testutil.SetupTempDir(t) - - cert := &certs.KubeadmCertKubeletClient - kubeadmPath := getKubeadmPath() - _, _, err := RunCmd(kubeadmPath, - "init", - "phase", - "certs", - cert.BaseName, - "--csr-only", - "--csr-dir="+csrDir, - ) - if err != nil { - t.Fatalf("couldn't run kubeadm: %v", err) + tests := []struct { + name string + baseName string + expectedError string + }{ + { + name: "generate CSR", + baseName: certs.KubeadmCertKubeletClient.BaseName, + }, + { + name: "fails on CSR", + baseName: certs.KubeadmCertRootCA.BaseName, + expectedError: "unknown flag: --csr-only", + }, + { + name: "fails on all", + baseName: "all", + expectedError: "unknown flag: --csr-only", + }, } - if _, _, err := pkiutil.TryLoadCSRAndKeyFromDisk(csrDir, cert.BaseName); err != nil { - t.Fatalf("couldn't load certificate %q: %v", cert.BaseName, err) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + csrDir := testutil.SetupTempDir(t) + cert := &certs.KubeadmCertKubeletClient + kubeadmPath := getKubeadmPath() + _, stderr, err := RunCmd(kubeadmPath, + "init", + "phase", + "certs", + test.baseName, + "--csr-only", + "--csr-dir="+csrDir, + ) + + if test.expectedError != "" { + cause := errors.Cause(err) + _, ok := cause.(*exec.ExitError) + if !ok { + t.Fatalf("expected exitErr: got %T (%v)", cause, err) + } + + if !strings.Contains(stderr, test.expectedError) { + t.Errorf("expected %q to contain %q", stderr, test.expectedError) + } + return + } + + if err != nil { + t.Fatalf("couldn't run kubeadm: %v", err) + } + + if _, _, err := pkiutil.TryLoadCSRAndKeyFromDisk(csrDir, cert.BaseName); err != nil { + t.Fatalf("couldn't load certificate %q: %v", cert.BaseName, err) + } + }) } } diff --git a/cmd/kubeadm/test/cmd/util.go b/cmd/kubeadm/test/cmd/util.go index 0e9656c44d..ecc1951b5d 100644 --- a/cmd/kubeadm/test/cmd/util.go +++ b/cmd/kubeadm/test/cmd/util.go @@ -43,7 +43,7 @@ func runCmdNoWrap(command string, args ...string) (string, string, error) { func RunCmd(command string, args ...string) (string, string, error) { stdout, stderr, err := runCmdNoWrap(command, args...) if err != nil { - return "", "", errors.Wrapf(err, "error running %s %v; \nstdout %q, \nstderr %q, \ngot error", + return stdout, stderr, errors.Wrapf(err, "error running %s %v; \nstdout %q, \nstderr %q, \ngot error", command, args, stdout, stderr) } return stdout, stderr, nil