diff --git a/pkg/bootstrap/bootstrap.go b/pkg/bootstrap/bootstrap.go index 832d1a7061..d3c9f7ee24 100644 --- a/pkg/bootstrap/bootstrap.go +++ b/pkg/bootstrap/bootstrap.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "path/filepath" + "time" "github.com/pkg/errors" "github.com/rancher/k3s/pkg/daemons/config" @@ -15,17 +16,19 @@ import ( func Handler(bootstrap *config.ControlRuntimeBootstrap) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("Content-Type", "application/json") - Write(rw, bootstrap) + ReadFromDisk(rw, bootstrap) }) } -func Write(w io.Writer, bootstrap *config.ControlRuntimeBootstrap) error { - paths, err := objToMap(bootstrap) +// ReadFromDisk reads the bootstrap data from the files on disk and +// writes their content in JSON form to the given io.Writer. +func ReadFromDisk(w io.Writer, bootstrap *config.ControlRuntimeBootstrap) error { + paths, err := ObjToMap(bootstrap) if err != nil { return nil } - dataMap := map[string][]byte{} + dataMap := make(map[string]File) for pathKey, path := range paths { if path == "" { continue @@ -35,24 +38,45 @@ func Write(w io.Writer, bootstrap *config.ControlRuntimeBootstrap) error { return errors.Wrapf(err, "failed to read %s", path) } - dataMap[pathKey] = data + info, err := os.Stat(path) + if err != nil { + return err + } + + dataMap[pathKey] = File{ + Timestamp: info.ModTime(), + Content: data, + } } return json.NewEncoder(w).Encode(dataMap) } -func Read(r io.Reader, bootstrap *config.ControlRuntimeBootstrap) error { - paths, err := objToMap(bootstrap) +// File is a representation of a certificate +// or key file within the bootstrap context that contains +// the contents of the file as well as a timestamp from +// when the file was last modified. +type File struct { + Timestamp time.Time + Content []byte +} + +type PathsDataformat map[string]File + +// WriteToDiskFromStorage writes the contents of the given reader to the paths +// derived from within the ControlRuntimeBootstrap. +func WriteToDiskFromStorage(r io.Reader, bootstrap *config.ControlRuntimeBootstrap) error { + paths, err := ObjToMap(bootstrap) if err != nil { return err } - files := map[string][]byte{} + files := make(PathsDataformat) if err := json.NewDecoder(r).Decode(&files); err != nil { return err } - for pathKey, data := range files { + for pathKey, bsf := range files { path, ok := paths[pathKey] if !ok { continue @@ -61,8 +85,7 @@ func Read(r io.Reader, bootstrap *config.ControlRuntimeBootstrap) error { if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { return errors.Wrapf(err, "failed to mkdir %s", filepath.Dir(path)) } - - if err := ioutil.WriteFile(path, data, 0600); err != nil { + if err := ioutil.WriteFile(path, bsf.Content, 0600); err != nil { return errors.Wrapf(err, "failed to write to %s", path) } } @@ -70,7 +93,7 @@ func Read(r io.Reader, bootstrap *config.ControlRuntimeBootstrap) error { return nil } -func objToMap(obj interface{}) (map[string]string, error) { +func ObjToMap(obj interface{}) (map[string]string, error) { bytes, err := json.Marshal(obj) if err != nil { return nil, err diff --git a/pkg/bootstrap/bootstrap_test.go b/pkg/bootstrap/bootstrap_test.go new file mode 100644 index 0000000000..d6fe3fb74f --- /dev/null +++ b/pkg/bootstrap/bootstrap_test.go @@ -0,0 +1,46 @@ +package bootstrap + +import ( + "testing" + + "github.com/rancher/k3s/pkg/daemons/config" +) + +func TestObjToMap(t *testing.T) { + type args struct { + obj interface{} + } + tests := []struct { + name string + args args + want map[string]string + wantErr bool + }{ + { + name: "Minimal Valid", + args: args{ + obj: &config.ControlRuntimeBootstrap{ + ServerCA: "/var/lib/rancher/k3s/server/tls/server-ca.crt", + ServerCAKey: "/var/lib/rancher/k3s/server/tls/server-ca.key", + }, + }, + wantErr: false, + }, + { + name: "Minimal Invalid", + args: args{ + obj: 1, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ObjToMap(tt.args.obj) + if (err != nil) != tt.wantErr { + t.Errorf("ObjToMap() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/pkg/cli/server/server.go b/pkg/cli/server/server.go index 77be5c2fb1..fc418aae57 100644 --- a/pkg/cli/server/server.go +++ b/pkg/cli/server/server.go @@ -176,6 +176,17 @@ func run(app *cli.Context, cfg *cmds.Server, leaderControllers server.CustomCont // delete local loadbalancers state for apiserver and supervisor servers loadbalancer.ResetLoadBalancer(filepath.Join(cfg.DataDir, "agent"), loadbalancer.SupervisorServiceName) loadbalancer.ResetLoadBalancer(filepath.Join(cfg.DataDir, "agent"), loadbalancer.APIServerServiceName) + + // at this point we're doing a restore. Check to see if we've + // passed in a token and if not, check if the token file exists. + // If it doesn't, return an error indicating the token is necessary. + if cfg.Token == "" { + if _, err := os.Stat(filepath.Join(cfg.DataDir, "server/token")); err != nil { + if os.IsNotExist(err) { + return errors.New("") + } + } + } } serverConfig.ControlConfig.ClusterReset = cfg.ClusterReset diff --git a/pkg/clientaccess/kubeconfig.go b/pkg/clientaccess/kubeconfig.go index 4aa5db4303..bcdb66f3fc 100644 --- a/pkg/clientaccess/kubeconfig.go +++ b/pkg/clientaccess/kubeconfig.go @@ -9,7 +9,7 @@ import ( ) // WriteClientKubeConfig generates a kubeconfig at destFile that can be used to connect to a server at url with the given certs and keys -func WriteClientKubeConfig(destFile string, url string, serverCAFile string, clientCertFile string, clientKeyFile string) error { +func WriteClientKubeConfig(destFile, url, serverCAFile, clientCertFile, clientKeyFile string) error { serverCA, err := ioutil.ReadFile(serverCAFile) if err != nil { return errors.Wrapf(err, "failed to read %s", serverCAFile) diff --git a/pkg/clientaccess/token.go b/pkg/clientaccess/token.go index 062593c926..5476f71874 100644 --- a/pkg/clientaccess/token.go +++ b/pkg/clientaccess/token.go @@ -16,9 +16,15 @@ import ( "github.com/sirupsen/logrus" ) -var ( - defaultClientTimeout = 10 * time.Second +const ( + tokenPrefix = "K10" + tokenFormat = "%s%s::%s:%s" + caHashLength = sha256.Size * 2 + defaultClientTimeout = 10 * time.Second +) + +var ( defaultClient = &http.Client{ Timeout: defaultClientTimeout, } @@ -32,12 +38,6 @@ var ( } ) -const ( - tokenPrefix = "K10" - tokenFormat = "%s%s::%s:%s" - caHashLength = sha256.Size * 2 -) - type OverrideURLCallback func(config []byte) (*url.URL, error) type Info struct { @@ -49,8 +49,8 @@ type Info struct { } // String returns the token data, templated according to the token format -func (info *Info) String() string { - return fmt.Sprintf(tokenFormat, tokenPrefix, hashCA(info.CACerts), info.Username, info.Password) +func (i *Info) String() string { + return fmt.Sprintf(tokenFormat, tokenPrefix, hashCA(i.CACerts), i.Username, i.Password) } // ParseAndValidateToken parses a token, downloads and validates the server's CA bundle, @@ -70,7 +70,7 @@ func ParseAndValidateToken(server string, token string) (*Info, error) { // ParseAndValidateToken parses a token with user override, downloads and // validates the server's CA bundle, and validates it according to the caHash from the token if set. -func ParseAndValidateTokenForUser(server string, token string, username string) (*Info, error) { +func ParseAndValidateTokenForUser(server, token, username string) (*Info, error) { info, err := parseToken(token) if err != nil { return nil, err @@ -86,11 +86,11 @@ func ParseAndValidateTokenForUser(server string, token string, username string) } // setAndValidateServer updates the remote server's cert info, and validates it against the provided hash -func (info *Info) setAndValidateServer(server string) error { - if err := info.setServer(server); err != nil { +func (i *Info) setAndValidateServer(server string) error { + if err := i.setServer(server); err != nil { return err } - return info.validateCAHash() + return i.validateCAHash() } // validateCACerts returns a boolean indicating whether or not a CA bundle matches the provided hash, @@ -118,7 +118,7 @@ func ParseUsernamePassword(token string) (string, string, bool) { // parseToken parses a token into an Info struct func parseToken(token string) (*Info, error) { - var info = &Info{} + var info Info if len(token) == 0 { return nil, errors.New("token must not be empty") @@ -150,7 +150,7 @@ func parseToken(token string) (*Info, error) { info.Username = parts[0] info.Password = parts[1] - return info, nil + return &info, nil } // GetHTTPClient returns a http client that validates TLS server certificates using the provided CA bundle. @@ -177,25 +177,25 @@ func GetHTTPClient(cacerts []byte) *http.Client { } // Get makes a request to a subpath of info's BaseURL -func (info *Info) Get(path string) ([]byte, error) { - u, err := url.Parse(info.BaseURL) +func (i *Info) Get(path string) ([]byte, error) { + u, err := url.Parse(i.BaseURL) if err != nil { return nil, err } u.Path = path - return get(u.String(), GetHTTPClient(info.CACerts), info.Username, info.Password) + return get(u.String(), GetHTTPClient(i.CACerts), i.Username, i.Password) } // setServer sets the BaseURL and CACerts fields of the Info by connecting to the server // and storing the CA bundle. -func (info *Info) setServer(server string) error { +func (i *Info) setServer(server string) error { url, err := url.Parse(server) if err != nil { return errors.Wrapf(err, "Invalid server url, failed to parse: %s", server) } if url.Scheme != "https" { - return fmt.Errorf("only https:// URLs are supported, invalid scheme: %s", server) + return errors.New("only https:// URLs are supported, invalid scheme: " + server) } for strings.HasSuffix(url.Path, "/") { @@ -207,25 +207,25 @@ func (info *Info) setServer(server string) error { return err } - info.BaseURL = url.String() - info.CACerts = cacerts + i.BaseURL = url.String() + i.CACerts = cacerts return nil } // ValidateCAHash validates that info's caHash matches the CACerts hash. -func (info *Info) validateCAHash() error { - if len(info.caHash) > 0 && len(info.CACerts) == 0 { +func (i *Info) validateCAHash() error { + if len(i.caHash) > 0 && len(i.CACerts) == 0 { // Warn if the user provided a CA hash but we're not going to validate because it's already trusted logrus.Warn("Cluster CA certificate is trusted by the host CA bundle. " + "Token CA hash will not be validated.") - } else if len(info.caHash) == 0 && len(info.CACerts) > 0 { + } else if len(i.caHash) == 0 && len(i.CACerts) > 0 { // Warn if the CA is self-signed but the user didn't provide a hash to validate it against logrus.Warn("Cluster CA certificate is not trusted by the host CA bundle, but the token does not include a CA hash. " + "Use the full token from the server's node-token file to enable Cluster CA validation.") - } else if len(info.CACerts) > 0 && len(info.caHash) > 0 { + } else if len(i.CACerts) > 0 && len(i.caHash) > 0 { // only verify CA hash if the server cert is not trusted by the OS CA bundle - if ok, serverHash := validateCACerts(info.CACerts, info.caHash); !ok { - return fmt.Errorf("token CA hash does not match the Cluster CA certificate hash: %s != %s", info.caHash, serverHash) + if ok, serverHash := validateCACerts(i.CACerts, i.caHash); !ok { + return fmt.Errorf("token CA hash does not match the Cluster CA certificate hash: %s != %s", i.caHash, serverHash) } } return nil @@ -288,18 +288,18 @@ func get(u string, client *http.Client, username, password string) ([]byte, erro return ioutil.ReadAll(resp.Body) } -func FormatToken(token string, certFile string) (string, error) { +func FormatToken(token, certFile string) (string, error) { if len(token) == 0 { return token, nil } certHash := "" if len(certFile) > 0 { - bytes, err := ioutil.ReadFile(certFile) + b, err := ioutil.ReadFile(certFile) if err != nil { return "", nil } - digest := sha256.Sum256(bytes) + digest := sha256.Sum256(b) certHash = tokenPrefix + hex.EncodeToString(digest[:]) + "::" } return certHash + token, nil diff --git a/pkg/clientaccess/token_test.go b/pkg/clientaccess/token_test.go index 8e01664b57..488f821e72 100644 --- a/pkg/clientaccess/token_test.go +++ b/pkg/clientaccess/token_test.go @@ -320,7 +320,7 @@ func newTLSServer(t *testing.T, username, password string, sendWrongCA bool) *ht } bootstrapData := &config.ControlRuntimeBootstrap{} w.Header().Set("Content-Type", "application/json") - if err := bootstrap.Write(w, bootstrapData); err != nil { + if err := bootstrap.ReadFromDisk(w, bootstrapData); err != nil { t.Errorf("failed to write bootstrap: %v", err) } return diff --git a/pkg/cluster/bootstrap.go b/pkg/cluster/bootstrap.go index a799dbf2f6..7ca3a2c570 100644 --- a/pkg/cluster/bootstrap.go +++ b/pkg/cluster/bootstrap.go @@ -3,10 +3,17 @@ package cluster import ( "bytes" "context" + "encoding/json" "errors" + "fmt" + "io" + "io/ioutil" "os" "path/filepath" + "strings" + "time" + "github.com/k3s-io/kine/pkg/client" "github.com/rancher/k3s/pkg/bootstrap" "github.com/rancher/k3s/pkg/clientaccess" "github.com/rancher/k3s/pkg/daemons/config" @@ -28,10 +35,8 @@ func (c *Cluster) Bootstrap(ctx context.Context) error { } c.shouldBootstrap = shouldBootstrap - if shouldBootstrap { - if err := c.bootstrap(ctx); err != nil { - return err - } + if c.shouldBootstrap { + return c.bootstrap(ctx) } return nil @@ -85,47 +90,312 @@ func (c *Cluster) shouldBootstrapLoad(ctx context.Context) (bool, error) { // Check the stamp file to see if we have successfully bootstrapped using this token. // NOTE: The fact that we use a hash of the token to generate the stamp // means that it is unsafe to use the same token for multiple clusters. - stamp := c.bootstrapStamp() - if _, err := os.Stat(stamp); err == nil { - logrus.Info("Cluster bootstrap already complete") - return false, nil - } + // stamp := c.bootstrapStamp() + // if _, err := os.Stat(stamp); err == nil { + // logrus.Info("Cluster bootstrap already complete") + // return false, nil + // } // No errors and no bootstrap stamp, need to bootstrap. return true, nil } -// bootstrapped touches a file to indicate that bootstrap has been completed. -func (c *Cluster) bootstrapped() error { - stamp := c.bootstrapStamp() - if err := os.MkdirAll(filepath.Dir(stamp), 0700); err != nil { +// isDirEmpty checks to see if the given directory +// is empty. +func isDirEmpty(name string) (bool, error) { + f, err := os.Open(name) + if err != nil { + return false, err + } + defer f.Close() + + _, err = f.Readdir(1) + if err == io.EOF { + return true, nil + } + + return false, err +} + +// certDirsExist checks to see if the directories +// that contain the needed certificates exist. +func (c *Cluster) certDirsExist() error { + bootstrapDirs := []string{ + "tls", + "tls/etcd", + } + + const ( + missingDir = "missing %s directory from ${data-dir}" + emptyDir = "%s directory is empty" + ) + + for _, dir := range bootstrapDirs { + d := filepath.Join(c.config.DataDir, dir) + if _, err := os.Stat(d); os.IsNotExist(err) { + errMsg := fmt.Sprintf(missingDir, d) + logrus.Debug(errMsg) + return errors.New(errMsg) + } + + ok, err := isDirEmpty(d) + if err != nil { + return err + } + + if ok { + errMsg := fmt.Sprintf(emptyDir, d) + logrus.Debug(errMsg) + return errors.New(errMsg) + } + } + + return nil +} + +// migrateBootstrapData migrates bootstrap data from the old format to the new format. +func migrateBootstrapData(ctx context.Context, data *bytes.Buffer, files bootstrap.PathsDataformat) error { + logrus.Info("Migrating bootstrap data to new format") + + var oldBootstrapData map[string][]byte + if err := json.NewDecoder(data).Decode(&oldBootstrapData); err != nil { + // if this errors here, we can assume that the error being thrown + // is not related to needing to perform a migration. return err } - // return if file already exists - if _, err := os.Stat(stamp); err == nil { - return nil + // iterate through the old bootstrap data structure + // and copy into the new bootstrap data structure + for k, v := range oldBootstrapData { + files[k] = bootstrap.File{ + Content: v, + } } - // otherwise try to create it - f, err := os.Create(stamp) + return nil +} + +const systemTimeSkew = int64(3) + +// ReconcileBootstrapData is called before any data is saved to the +// datastore or locally. It checks to see if the contents of the +// bootstrap data in the datastore is newer than on disk or different +// and dependingon where the difference is, the newer data is written +// to the older. +func (c *Cluster) ReconcileBootstrapData(ctx context.Context, buf *bytes.Buffer, crb *config.ControlRuntimeBootstrap) error { + logrus.Info("Reconciling bootstrap data between datastore and disk") + + if err := c.certDirsExist(); err != nil { + logrus.Warn(err.Error()) + return bootstrap.WriteToDiskFromStorage(buf, crb) + } + + token := c.config.Token + if token == "" { + tokenFromFile, err := readTokenFromFile(c.runtime.ServerToken, c.runtime.ServerCA, c.config.DataDir) + if err != nil { + return err + } + if tokenFromFile == "" { + // at this point this is a fresh start in a non-managed environment + c.saveBootstrap = true + return nil + } + token = tokenFromFile + } + normalizedToken, err := normalizeToken(token) if err != nil { return err } - return f.Close() + var value *client.Value + + storageClient, err := client.New(c.etcdConfig) + if err != nil { + return err + } + + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + +RETRY: + for { + value, err = c.getBootstrapKeyFromStorage(ctx, storageClient, normalizedToken, token) + if err != nil { + if strings.Contains(err.Error(), "not supported for learner") { + for range ticker.C { + continue RETRY + } + + } + return err + } + if value == nil { + return nil + } + + break + } + + paths, err := bootstrap.ObjToMap(crb) + if err != nil { + return err + } + + files := make(bootstrap.PathsDataformat) + if err := json.NewDecoder(buf).Decode(&files); err != nil { + // This will fail if data is being pulled from old an cluster since + // older clusters used a map[string][]byte for the data structure. + // Therefore, we need to perform a migration to the newer bootstrap + // format; bootstrap.BootstrapFile. + if err := migrateBootstrapData(ctx, buf, files); err != nil { + return err + } + } + + type update struct { + db, disk, conflict bool + } + + var updateDatastore, updateDisk bool + + results := make(map[string]update) + + for pathKey, fileData := range files { + path, ok := paths[pathKey] + if !ok { + continue + } + + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + logrus.Warn(path + " doesn't exist. continuing...") + updateDisk = true + continue + } + return err + } + defer f.Close() + + fData, err := ioutil.ReadAll(f) + if err != nil { + return err + } + + if !bytes.Equal(fileData.Content, fData) { + logrus.Warnf("%s is out of sync with datastore", path) + + info, err := os.Stat(path) + if err != nil { + return err + } + + switch { + case info.ModTime().Unix()-files[pathKey].Timestamp.Unix() >= systemTimeSkew: + if _, ok := results[path]; !ok { + results[path] = update{ + db: true, + } + } + + for pk := range files { + p, ok := paths[pk] + if !ok { + continue + } + + if filepath.Base(p) == info.Name() { + continue + } + + i, err := os.Stat(p) + if err != nil { + return err + } + + if i.ModTime().Unix()-files[pk].Timestamp.Unix() >= systemTimeSkew { + if _, ok := results[path]; !ok { + results[path] = update{ + conflict: true, + } + } + } + } + case info.ModTime().Unix()-files[pathKey].Timestamp.Unix() <= systemTimeSkew: + if _, ok := results[info.Name()]; !ok { + results[path] = update{ + disk: true, + } + } + + for pk := range files { + p, ok := paths[pk] + if !ok { + continue + } + + if filepath.Base(p) == info.Name() { + continue + } + + i, err := os.Stat(p) + if err != nil { + return err + } + + if i.ModTime().Unix()-files[pk].Timestamp.Unix() <= systemTimeSkew { + if _, ok := results[path]; !ok { + results[path] = update{ + conflict: true, + } + } + } + } + default: + if _, ok := results[path]; ok { + results[path] = update{} + } + } + } + } + + for path, res := range results { + if res.db { + updateDatastore = true + logrus.Warn(path + " newer than datastore") + } else if res.disk { + updateDisk = true + logrus.Warn("datastore newer than " + path) + } else if res.conflict { + logrus.Warnf("datastore / disk conflict: %s newer than in the datastore", path) + } + } + + switch { + case updateDatastore: + logrus.Warn("updating bootstrap data in datastore from disk") + return c.save(ctx, true) + case updateDisk: + logrus.Warn("updating bootstrap data on disk from datastore") + return bootstrap.WriteToDiskFromStorage(buf, crb) + default: + // on disk certificates match timestamps in storage. noop. + } + + return nil } // httpBootstrap retrieves bootstrap data (certs and keys, etc) from the remote server via HTTP // and loads it into the ControlRuntimeBootstrap struct. Unlike the storage bootstrap path, // this data does not need to be decrypted since it is generated on-demand by an existing server. -func (c *Cluster) httpBootstrap() error { +func (c *Cluster) httpBootstrap(ctx context.Context) error { content, err := c.clientAccessInfo.Get("/v1-" + version.Program + "/server-bootstrap") if err != nil { return err } - return bootstrap.Read(bytes.NewBuffer(content), &c.runtime.ControlRuntimeBootstrap) + return c.ReconcileBootstrapData(ctx, bytes.NewBuffer(content), &c.config.Runtime.ControlRuntimeBootstrap) } // bootstrap performs cluster bootstrapping, either via HTTP (for managed databases) or direct load from datastore. @@ -134,20 +404,13 @@ func (c *Cluster) bootstrap(ctx context.Context) error { // bootstrap managed database via HTTPS if c.runtime.HTTPBootstrap { - return c.httpBootstrap() + return c.httpBootstrap(ctx) } // Bootstrap directly from datastore return c.storageBootstrap(ctx) } -// bootstrapStamp returns the path to a file in datadir/db that is used to record -// that a cluster has been joined. The filename is based on a portion of the sha256 hash of the token. -// We hash the token value exactly as it is provided by the user, NOT the normalized version. -func (c *Cluster) bootstrapStamp() string { - return filepath.Join(c.config.DataDir, "db/joined-"+keyHash(c.config.Token)) -} - // Snapshot is a proxy method to call the snapshot method on the managedb // interface for etcd clusters. func (c *Cluster) Snapshot(ctx context.Context, config *config.Control) error { diff --git a/pkg/cluster/bootstrap_test.go b/pkg/cluster/bootstrap_test.go new file mode 100644 index 0000000000..4250c8c8ae --- /dev/null +++ b/pkg/cluster/bootstrap_test.go @@ -0,0 +1,254 @@ +package cluster + +import ( + "bytes" + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/k3s-io/kine/pkg/endpoint" + "github.com/rancher/k3s/pkg/bootstrap" + "github.com/rancher/k3s/pkg/clientaccess" + "github.com/rancher/k3s/pkg/cluster/managed" + "github.com/rancher/k3s/pkg/daemons/config" +) + +func Test_isDirEmpty(t *testing.T) { + const tmpDir = "test_dir" + + type args struct { + name string + } + tests := []struct { + name string + args args + setup func() error + teardown func() error + want bool + wantErr bool + }{ + { + name: "is empty", + args: args{ + name: tmpDir, + }, + setup: func() error { + return os.Mkdir(tmpDir, 0700) + }, + teardown: func() error { + return os.RemoveAll(tmpDir) + }, + want: true, + wantErr: false, + }, + { + name: "is not empty", + args: args{ + name: tmpDir, + }, + setup: func() error { + os.Mkdir(tmpDir, 0700) + _, _ = os.Create(filepath.Join(filepath.Dir(tmpDir), tmpDir, "test_file")) + return nil + }, + teardown: func() error { + return os.RemoveAll(tmpDir) + }, + want: false, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer tt.teardown() + if err := tt.setup(); err != nil { + t.Errorf("Setup for isDirEmpty() failed = %v", err) + return + } + got, err := isDirEmpty(tt.args.name) + if (err != nil) != tt.wantErr { + t.Errorf("isDirEmpty() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("isDirEmpty() = %+v\nWant = %+v", got, tt.want) + } + }) + } +} + +func TestCluster_certDirsExist(t *testing.T) { + const testDataDir = "/tmp/k3s/" + + testTLSDir := filepath.Join(testDataDir, "server", "tls") + testTLSEtcdDir := filepath.Join(testDataDir, "server", "tls", "etcd") + + type fields struct { + clientAccessInfo *clientaccess.Info + config *config.Control + runtime *config.ControlRuntime + managedDB managed.Driver + etcdConfig endpoint.ETCDConfig + shouldBootstrap bool + storageStarted bool + saveBootstrap bool + } + tests := []struct { + name string + fields fields + setup func() error + teardown func() error + wantErr bool + }{ + { + name: "exists", + fields: fields{ + config: &config.Control{ + DataDir: filepath.Join(testDataDir, "server"), + }, + }, + setup: func() error { + os.MkdirAll(testTLSEtcdDir, 0700) + + _, _ = os.Create(filepath.Join(testTLSDir, "test_file")) + _, _ = os.Create(filepath.Join(testTLSEtcdDir, "test_file")) + + return nil + }, + teardown: func() error { + return os.RemoveAll(testDataDir) + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Cluster{ + clientAccessInfo: tt.fields.clientAccessInfo, + config: tt.fields.config, + runtime: tt.fields.runtime, + managedDB: tt.fields.managedDB, + etcdConfig: tt.fields.etcdConfig, + storageStarted: tt.fields.storageStarted, + saveBootstrap: tt.fields.saveBootstrap, + } + defer tt.teardown() + if err := tt.setup(); err != nil { + t.Errorf("Setup for Cluster.certDirsExist() failed = %v", err) + return + } + if err := c.certDirsExist(); (err != nil) != tt.wantErr { + t.Errorf("Cluster.certDirsExist() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestCluster_migrateBootstrapData(t *testing.T) { + type fields struct { + clientAccessInfo *clientaccess.Info + config *config.Control + runtime *config.ControlRuntime + managedDB managed.Driver + etcdConfig endpoint.ETCDConfig + joining bool + storageStarted bool + saveBootstrap bool + shouldBootstrap bool + } + type args struct { + ctx context.Context + data *bytes.Buffer + files bootstrap.PathsDataformat + } + tests := []struct { + name string + args args + setup func() error // Optional, delete if unused + teardown func() error // Optional, delete if unused + wantErr bool + }{ + { + name: "Success", + args: args{ + ctx: context.Background(), + data: bytes.NewBuffer([]byte(`{"ServerCA": "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURSBSRVFVRVNULS0tLS0KTUlJQ3ZEQ0NBYVFDQVFBd2R6RUxNQWtHQTFVRUJoTUNWVk14RFRBTEJnTlZCQWdNQkZWMFlXZ3hEekFOQmdOVgpCQWNNQmt4cGJtUnZiakVXTUJRR0ExVUVDZ3dOUkdsbmFVTmxjblFnU1c1akxqRVJNQThHQTFVRUN3d0lSR2xuCmFVTmxjblF4SFRBYkJnTlZCQU1NRkdWNFlXMXdiR1V1WkdsbmFXTmxjblF1WTI5dE1JSUJJakFOQmdrcWhraUcKOXcwQkFRRUZBQU9DQVE4QU1JSUJDZ0tDQVFFQTgrVG83ZCsya1BXZUJ2L29yVTNMVmJKd0RyU1FiZUthbUNtbwp3cDVicUR4SXdWMjB6cVJiN0FQVU9LWW9WRUZGT0VRczZUNmdJbW5Jb2xoYmlINm00emdaL0NQdldCT2taYytjCjFQbzJFbXZCeitBRDVzQmRUNWt6R1FBNk5iV3laR2xkeFJ0aE5MT3MxZWZPaGRuV0Z1aEkxNjJxbWNmbGdwaUkKV0R1d3E0QzlmK1lrZUpoTm45ZEY1K293bThjT1FtRHJWOE5OZGlUcWluOHEzcVlBSEhKUlcyOGdsSlVDWmtUWgp3SWFTUjZjckJROFRiWU5FMGRjK0NhYTNET0lrejFFT3NIV3pUeCtuMHpLZnFjYmdYaTRESngrQzFianB0WVBSCkJQWkw4REFlV3VBOGVidWRWVDQ0eUVwODJHOTYvR2djZjdGMzN4TXhlMHljK1hhNm93SURBUUFCb0FBd0RRWUoKS29aSWh2Y05BUUVGQlFBRGdnRUJBQjBrY3JGY2NTbUZEbXhveDBOZTAxVUlxU3NEcUhnTCtYbUhUWEp3cmU2RApoSlNad2J2RXRPSzBHMytkcjRGczExV3VVTnQ1cWNMc3g1YTh1azRHNkFLSE16dWhMc0o3WFpqZ21RWEdFQ3BZClE0bUMzeVQzWm9DR3BJWGJ3K2lQM2xtRUVYZ2FRTDBUeDVMRmwvb2tLYktZd0lxTml5S1dPTWo3WlIvd3hXZy8KWkRHUnM1NXh1b2VMREovWlJGZjliSStJYUNVZDFZcmZZY0hJbDNHODdBdityNDlZVndxUkRUMFZEVjd1TGdxbgoyOVhJMVBwVlVOQ1BRR245cC9lWDZRbzd2cERhUHliUnRBMlI3WExLalFhRjlvWFdlQ1VxeTFodkphYzlRRk8yCjk3T2IxYWxwSFBvWjdtV2lFdUp3akJQaWk2YTlNOUczMG5VbzM5bEJpMXc9Ci0tLS0tRU5EIENFUlRJRklDQVRFIFJFUVVFU1QtLS0tLQ=="}`)), + files: make(bootstrap.PathsDataformat), + }, + wantErr: false, + }, + { + name: "Invalid Old Format", + args: args{ + ctx: context.Background(), + data: &bytes.Buffer{}, + files: bootstrap.PathsDataformat{ + "ServerCA": bootstrap.File{ + Timestamp: time.Now(), + }, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := migrateBootstrapData(tt.args.ctx, tt.args.data, tt.args.files); (err != nil) != tt.wantErr { + t.Errorf("Cluster.migrateBootstrapData() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestCluster_Snapshot(t *testing.T) { + type fields struct { + clientAccessInfo *clientaccess.Info + config *config.Control + runtime *config.ControlRuntime + managedDB managed.Driver + etcdConfig endpoint.ETCDConfig + joining bool + storageStarted bool + saveBootstrap bool + shouldBootstrap bool + } + type args struct { + ctx context.Context + config *config.Control + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "Fail on non etcd cluster", + fields: fields{}, + args: args{ + ctx: context.Background(), + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Cluster{ + clientAccessInfo: tt.fields.clientAccessInfo, + config: tt.fields.config, + runtime: tt.fields.runtime, + managedDB: tt.fields.managedDB, + etcdConfig: tt.fields.etcdConfig, + joining: tt.fields.joining, + storageStarted: tt.fields.storageStarted, + saveBootstrap: tt.fields.saveBootstrap, + shouldBootstrap: tt.fields.shouldBootstrap, + } + if err := c.Snapshot(tt.args.ctx, tt.args.config); (err != nil) != tt.wantErr { + t.Errorf("Cluster.Snapshot() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/pkg/cluster/cluster.go b/pkg/cluster/cluster.go index 8708668c2d..c4ac801500 100644 --- a/pkg/cluster/cluster.go +++ b/pkg/cluster/cluster.go @@ -20,11 +20,11 @@ type Cluster struct { config *config.Control runtime *config.ControlRuntime managedDB managed.Driver - shouldBootstrap bool - storageStarted bool etcdConfig endpoint.ETCDConfig joining bool + storageStarted bool saveBootstrap bool + shouldBootstrap bool } // Start creates the dynamic tls listener, http request handler, @@ -81,14 +81,7 @@ func (c *Cluster) Start(ctx context.Context) (<-chan struct{}, error) { // if necessary, store bootstrap data to datastore if c.saveBootstrap { - if err := c.save(ctx); err != nil { - return nil, err - } - } - - // if necessary, record successful bootstrap - if c.shouldBootstrap { - if err := c.bootstrapped(); err != nil { + if err := c.save(ctx, false); err != nil { return nil, err } } @@ -106,7 +99,7 @@ func (c *Cluster) Start(ctx context.Context) (<-chan struct{}, error) { for { select { case <-ready: - if err := c.save(ctx); err != nil { + if err := c.save(ctx, false); err != nil { panic(err) } @@ -153,7 +146,7 @@ func (c *Cluster) startStorage(ctx context.Context) error { return nil } -// New creates an initial cluster using the provided configuration +// New creates an initial cluster using the provided configuration. func New(config *config.Control) *Cluster { return &Cluster{ config: config, diff --git a/pkg/cluster/storage.go b/pkg/cluster/storage.go index a84a7f0886..7827be71c2 100644 --- a/pkg/cluster/storage.go +++ b/pkg/cluster/storage.go @@ -19,9 +19,9 @@ import ( // snapshot of the cluster's CA certs and keys, encryption passphrases, etc - encrypted with the join token. // This is used when bootstrapping a cluster from a managed database or external etcd cluster. // This is NOT used with embedded etcd, which bootstraps over HTTP. -func (c *Cluster) save(ctx context.Context) error { +func (c *Cluster) save(ctx context.Context, override bool) error { buf := &bytes.Buffer{} - if err := bootstrap.Write(buf, &c.runtime.ControlRuntimeBootstrap); err != nil { + if err := bootstrap.ReadFromDisk(buf, &c.runtime.ControlRuntimeBootstrap); err != nil { return err } token := c.config.Token @@ -47,14 +47,20 @@ func (c *Cluster) save(ctx context.Context) error { return err } - _, err = c.getBootstrapKeyFromStorage(ctx, storageClient, normalizedToken, token) - if err != nil { + if _, err := c.getBootstrapKeyFromStorage(ctx, storageClient, normalizedToken, token); err != nil { return err } if err := storageClient.Create(ctx, storageKey(normalizedToken), data); err != nil { if err.Error() == "key exists" { - logrus.Warnln("bootstrap key already exists") + logrus.Warn("bootstrap key already exists") + if override { + bsd, err := c.bootstrapKeyData(ctx, storageClient) + if err != nil { + return err + } + return storageClient.Update(ctx, storageKey(normalizedToken), bsd.Modified, data) + } return nil } else if strings.Contains(err.Error(), "not supported for learner") { logrus.Debug("skipping bootstrap data save on learner") @@ -66,9 +72,25 @@ func (c *Cluster) save(ctx context.Context) error { return nil } +// bootstrapKeyData lists keys stored in the datastore with the prefix "/bootstrap", and +// will return the first such key. It will return an error if not exactly one key is found. +func (c *Cluster) bootstrapKeyData(ctx context.Context, storageClient client.Client) (*client.Value, error) { + bootstrapList, err := storageClient.List(ctx, "/bootstrap", 0) + if err != nil { + return nil, err + } + if len(bootstrapList) == 0 { + return nil, errors.New("no bootstrap data found") + } + if len(bootstrapList) > 1 { + return nil, errors.New("found multiple bootstrap keys in storage") + } + return &bootstrapList[0], nil +} + // storageBootstrap loads data from the datastore into the ControlRuntimeBootstrap struct. // The storage key and encryption passphrase are both derived from the join token. -// token is either passed +// token is either passed. func (c *Cluster) storageBootstrap(ctx context.Context) error { if err := c.startStorage(ctx); err != nil { return err @@ -110,7 +132,8 @@ func (c *Cluster) storageBootstrap(ctx context.Context) error { return err } - return bootstrap.Read(bytes.NewBuffer(data), &c.runtime.ControlRuntimeBootstrap) + return c.ReconcileBootstrapData(ctx, bytes.NewBuffer(data), &c.config.Runtime.ControlRuntimeBootstrap) + //return bootstrap.WriteToDiskFromStorage(bytes.NewBuffer(data), &c.runtime.ControlRuntimeBootstrap) } // getBootstrapKeyFromStorage will list all keys that has prefix /bootstrap and will check for key that is @@ -157,6 +180,7 @@ func (c *Cluster) getBootstrapKeyFromStorage(ctx context.Context, storageClient // found then it will still strip the token from any additional info func readTokenFromFile(serverToken, certs, dataDir string) (string, error) { tokenFile := filepath.Join(dataDir, "token") + b, err := ioutil.ReadFile(tokenFile) if err != nil { if os.IsNotExist(err) { @@ -168,6 +192,7 @@ func readTokenFromFile(serverToken, certs, dataDir string) (string, error) { } return "", err } + // strip the token from any new line if its read from file return string(bytes.TrimRight(b, "\n")), nil } @@ -178,6 +203,7 @@ func normalizeToken(token string) (string, error) { if !ok { return password, errors.New("failed to normalize token") } + return password, nil } @@ -186,6 +212,7 @@ func normalizeToken(token string) (string, error) { // then migrate those and resave only with the normalized token func (c *Cluster) migrateOldTokens(ctx context.Context, bootstrapList []client.Value, storageClient client.Client, emptyStringKey, tokenKey, token, oldToken string) error { oldTokenKey := storageKey(oldToken) + for _, bootstrapKV := range bootstrapList { // checking for empty string bootstrap key if string(bootstrapKV.Key) == emptyStringKey { @@ -200,6 +227,7 @@ func (c *Cluster) migrateOldTokens(ctx context.Context, bootstrapList []client.V } } } + return nil } @@ -209,10 +237,12 @@ func doMigrateToken(ctx context.Context, storageClient client.Client, keyValue c if err != nil { return err } + encryptedData, err := encrypt(newToken, data) if err != nil { return err } + // saving the new encrypted data with the right token key if err := storageClient.Create(ctx, newTokenKey, encryptedData); err != nil { if err.Error() == "key exists" { @@ -224,10 +254,12 @@ func doMigrateToken(ctx context.Context, storageClient client.Client, keyValue c return err } } + logrus.Infof("created bootstrap key %s", newTokenKey) // deleting the old key if err := storageClient.Delete(ctx, oldTokenKey, keyValue.Modified); err != nil { logrus.Warnf("failed to delete old bootstrap key %s", oldTokenKey) } + return nil } diff --git a/pkg/daemons/control/deps/deps.go b/pkg/daemons/control/deps/deps.go index 22bbeb93f9..347a13385c 100644 --- a/pkg/daemons/control/deps/deps.go +++ b/pkg/daemons/control/deps/deps.go @@ -89,9 +89,9 @@ func KubeConfig(dest, url, caCert, clientCert, clientKey string) error { return kubeconfigTemplate.Execute(output, &data) } -// FillRuntimeCerts is responsible for filling out all the +// CreateRuntimeCertFiles is responsible for filling out all the // .crt and .key filenames for a ControlRuntime. -func FillRuntimeCerts(config *config.Control, runtime *config.ControlRuntime) { +func CreateRuntimeCertFiles(config *config.Control, runtime *config.ControlRuntime) { runtime.ClientCA = filepath.Join(config.DataDir, "tls", "client-ca.crt") runtime.ClientCAKey = filepath.Join(config.DataDir, "tls", "client-ca.key") runtime.ServerCA = filepath.Join(config.DataDir, "tls", "server-ca.crt") diff --git a/pkg/daemons/control/server.go b/pkg/daemons/control/server.go index d0a98be54b..25f5dfd368 100644 --- a/pkg/daemons/control/server.go +++ b/pkg/daemons/control/server.go @@ -242,7 +242,7 @@ func prepare(ctx context.Context, config *config.Control, runtime *config.Contro os.MkdirAll(filepath.Join(config.DataDir, "tls"), 0700) os.MkdirAll(filepath.Join(config.DataDir, "cred"), 0700) - deps.FillRuntimeCerts(config, runtime) + deps.CreateRuntimeCertFiles(config, runtime) cluster := cluster.New(config) diff --git a/pkg/etcd/etcd.go b/pkg/etcd/etcd.go index b0f5ab8e95..09aeb8bdc9 100644 --- a/pkg/etcd/etcd.go +++ b/pkg/etcd/etcd.go @@ -184,7 +184,7 @@ func (e *ETCD) IsInitialized(ctx context.Context, config *config.Control) (bool, } else if os.IsNotExist(err) { return false, nil } else { - return false, errors.Wrapf(err, "invalid state for wal directory %s", dir) + return false, errors.Wrap(err, "invalid state for wal directory "+dir) } } diff --git a/tests/util/runtime.go b/tests/util/runtime.go index e9c12215ed..102fc169e4 100644 --- a/tests/util/runtime.go +++ b/tests/util/runtime.go @@ -51,7 +51,7 @@ func GenerateRuntime(cnf *config.Control) error { os.MkdirAll(filepath.Join(cnf.DataDir, "tls"), 0700) os.MkdirAll(filepath.Join(cnf.DataDir, "cred"), 0700) - deps.FillRuntimeCerts(cnf, runtime) + deps.CreateRuntimeCertFiles(cnf, runtime) if err := deps.GenServerDeps(cnf, runtime); err != nil { return err