diff --git a/command/agent/agent.go b/command/agent/agent.go index 95fcd8bec5..d0f62fc3e6 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -20,8 +20,7 @@ import ( ) const ( - SerfLANKeyring = "serf/local.keyring" - SerfWANKeyring = "serf/remote.keyring" + SerfKeyring = "serf/keyring" ) /* @@ -176,10 +175,8 @@ func (a *Agent) consulConfig() *consul.Config { base.SerfWANConfig.MemberlistConfig.SecretKey = key } if a.config.Server && a.config.keyringFilesExist() { - pathWAN := filepath.Join(base.DataDir, SerfWANKeyring) - pathLAN := filepath.Join(base.DataDir, SerfLANKeyring) - base.SerfWANConfig.KeyringFile = pathWAN - base.SerfLANConfig.KeyringFile = pathLAN + path := filepath.Join(base.DataDir, SerfKeyring) + base.SerfLANConfig.KeyringFile = path } if a.config.NodeName != "" { base.NodeName = a.config.NodeName diff --git a/command/agent/config.go b/command/agent/config.go index 9334553d00..60bd730321 100644 --- a/command/agent/config.go +++ b/command/agent/config.go @@ -413,14 +413,9 @@ func (c *Config) ClientListenerAddr(override string, port int) (string, error) { // keyringFilesExist checks for existence of the keyring files for Serf func (c *Config) keyringFilesExist() bool { - if _, err := os.Stat(filepath.Join(c.DataDir, SerfLANKeyring)); err != nil { + if _, err := os.Stat(filepath.Join(c.DataDir, SerfKeyring)); err != nil { return false } - if c.Server { - if _, err := os.Stat(filepath.Join(c.DataDir, SerfWANKeyring)); err != nil { - return false - } - } return true } diff --git a/command/keys.go b/command/keys.go index 8ca98bdf2f..c603a3d53f 100644 --- a/command/keys.go +++ b/command/keys.go @@ -22,7 +22,7 @@ type KeysCommand struct { func (c *KeysCommand) Run(args []string) int { var installKey, useKey, removeKey, init, dataDir string - var listKeys, wan bool + var listKeys bool cmdFlags := flag.NewFlagSet("keys", flag.ContinueOnError) cmdFlags.Usage = func() { c.Ui.Output(c.Help()) } @@ -31,7 +31,6 @@ func (c *KeysCommand) Run(args []string) int { cmdFlags.StringVar(&useKey, "use", "", "use key") cmdFlags.StringVar(&removeKey, "remove", "", "remove key") cmdFlags.BoolVar(&listKeys, "list", false, "list keys") - cmdFlags.BoolVar(&wan, "wan", false, "operate on wan keys") cmdFlags.StringVar(&init, "init", "", "initialize keyring") cmdFlags.StringVar(&dataDir, "data-dir", "", "data directory") @@ -63,39 +62,17 @@ func (c *KeysCommand) Run(args []string) int { return 1 } - var out, paths []string - var failures map[string]string - var err error - if init != "" { if dataDir == "" { c.Ui.Error("Must provide -data-dir") return 1 } - if _, err := base64.StdEncoding.DecodeString(init); err != nil { - c.Ui.Error(fmt.Sprintf("Invalid key: %s", err)) - return 1 - } - - paths = append(paths, filepath.Join(dataDir, agent.SerfLANKeyring)) - if wan { - paths = append(paths, filepath.Join(dataDir, agent.SerfWANKeyring)) - } - - keys := []string{init} - keyringBytes, err := json.MarshalIndent(keys, "", " ") - if err != nil { + path := filepath.Join(dataDir, agent.SerfKeyring) + if err := initializeKeyring(path, init); err != nil { c.Ui.Error(fmt.Sprintf("Error: %s", err)) return 1 } - for _, path := range paths { - if err := initializeKeyring(path, keyringBytes); err != nil { - c.Ui.Error("Error: %s", err) - return 1 - } - } - return 0 } @@ -107,112 +84,64 @@ func (c *KeysCommand) Run(args []string) int { } defer client.Close() - if listKeys { - var keys map[string]int - var numNodes int - - if wan { - c.Ui.Info("Asking all WAN members for installed keys...") - keys, numNodes, failures, err = client.ListKeysWAN() - } else { - c.Ui.Info("Asking all LAN members for installed keys...") - keys, numNodes, failures, err = client.ListKeysLAN() - } + // For all key-related operations, we must be querying a server node. + s, err := client.Stats() + if err != nil { + c.Ui.Error(fmt.Sprintf("Error: %s", err)) + return 1 + } + if s["consul"]["server"] != "true" { + c.Ui.Error("Error: Key modification can only be handled by a server") + return 1 + } - if err != nil { - if len(failures) > 0 { - for node, msg := range failures { - out = append(out, fmt.Sprintf("failed: %s | %s", node, msg)) - } - c.Ui.Error(columnize.SimpleFormat(out)) - } - c.Ui.Error("") - c.Ui.Error(fmt.Sprintf("Failed gathering member keys: %s", err)) - return 1 + if listKeys { + c.Ui.Info("Asking all WAN members for installed keys...") + if rval := c.listKeysOperation(client.ListKeysWAN); rval != 0 { + return rval } - - c.Ui.Info("Keys gathered, listing cluster keys...") - c.Ui.Output("") - - for key, num := range keys { - out = append(out, fmt.Sprintf("%s | [%d/%d]", key, num, numNodes)) + c.Ui.Info("Asking all LAN members for installed keys...") + if rval := c.listKeysOperation(client.ListKeysLAN); rval != 0 { + return rval } - c.Ui.Output(columnize.SimpleFormat(out)) - return 0 } if installKey != "" { - if wan { - c.Ui.Info("Installing new WAN gossip encryption key...") - failures, err = client.InstallKeyWAN(installKey) - } else { - c.Ui.Info("Installing new LAN gossip encryption key...") - failures, err = client.InstallKeyLAN(installKey) + c.Ui.Info("Installing new WAN gossip encryption key...") + if rval := c.keyOperation(installKey, client.InstallKeyWAN); rval != 0 { + return rval } - - if err != nil { - if len(failures) > 0 { - for node, msg := range failures { - out = append(out, fmt.Sprintf("failed: %s | %s", node, msg)) - } - c.Ui.Error(columnize.SimpleFormat(out)) - } - c.Ui.Error("") - c.Ui.Error(fmt.Sprintf("Error installing key: %s", err)) - return 1 + c.Ui.Info("Installing new LAN gossip encryption key...") + if rval := c.keyOperation(installKey, client.InstallKeyLAN); rval != 0 { + return rval } - c.Ui.Info("Successfully installed key!") return 0 } if useKey != "" { - if wan { - c.Ui.Info("Changing primary encryption key on WAN members...") - failures, err = client.UseKeyWAN(useKey) - } else { - c.Ui.Info("Changing primary encryption key on LAN members...") - failures, err = client.UseKeyLAN(useKey) + c.Ui.Info("Changing primary WAN gossip encryption key...") + if rval := c.keyOperation(useKey, client.UseKeyWAN); rval != 0 { + return rval } - - if err != nil { - if len(failures) > 0 { - for node, msg := range failures { - out = append(out, fmt.Sprintf("failed: %s | %s", node, msg)) - } - c.Ui.Error(columnize.SimpleFormat(out)) - } - c.Ui.Error("") - c.Ui.Error(fmt.Sprintf("Error changing primary key: %s", err)) - return 1 + c.Ui.Info("Changing primary LAN gossip encryption key...") + if rval := c.keyOperation(useKey, client.UseKeyLAN); rval != 0 { + return rval } - c.Ui.Info("Successfully changed primary key!") return 0 } if removeKey != "" { - if wan { - c.Ui.Info("Removing key from WAN members...") - failures, err = client.RemoveKeyWAN(removeKey) - } else { - c.Ui.Info("Removing key from LAN members...") - failures, err = client.RemoveKeyLAN(removeKey) + c.Ui.Info("Removing WAN gossip encryption key...") + if rval := c.keyOperation(removeKey, client.RemoveKeyWAN); rval != 0 { + return rval } - - if err != nil { - if len(failures) > 0 { - for node, msg := range failures { - out = append(out, fmt.Sprintf("failed: %s | %s", node, msg)) - } - c.Ui.Error(columnize.SimpleFormat(out)) - } - c.Ui.Error("") - c.Ui.Error(fmt.Sprintf("Error removing key: %s", err)) - return 1 + c.Ui.Info("Removing LAN gossip encryption key...") + if rval := c.keyOperation(removeKey, client.RemoveKeyLAN); rval != 0 { + return rval } - c.Ui.Info("Successfully removed key!") return 0 } @@ -221,8 +150,67 @@ func (c *KeysCommand) Run(args []string) int { return 0 } +type keyFunc func(string) (map[string]string, error) + +func (c *KeysCommand) keyOperation(key string, fn keyFunc) int { + var out []string + + failures, err := fn(key) + + if err != nil { + if len(failures) > 0 { + for node, msg := range failures { + out = append(out, fmt.Sprintf("failed: %s | %s", node, msg)) + } + c.Ui.Error(columnize.SimpleFormat(out)) + } + c.Ui.Error("") + c.Ui.Error(fmt.Sprintf("Error: %s", err)) + return 1 + } + + return 0 +} + +type listKeysFunc func() (map[string]int, int, map[string]string, error) + +func (c *KeysCommand) listKeysOperation(fn listKeysFunc) int { + var out []string + + keys, numNodes, failures, err := fn() + + if err != nil { + if len(failures) > 0 { + for node, msg := range failures { + out = append(out, fmt.Sprintf("failed: %s | %s", node, msg)) + } + c.Ui.Error(columnize.SimpleFormat(out)) + } + c.Ui.Error("") + c.Ui.Error(fmt.Sprintf("Failed gathering member keys: %s", err)) + return 1 + } + for key, num := range keys { + out = append(out, fmt.Sprintf("%s | [%d/%d]", key, num, numNodes)) + } + c.Ui.Output(columnize.SimpleFormat(out)) + + c.Ui.Output("") + return 0 +} + // initializeKeyring will create a keyring file at a given path. -func initializeKeyring(path string, key []byte) error { +func initializeKeyring(path, key string) error { + if _, err := base64.StdEncoding.DecodeString(key); err != nil { + return fmt.Errorf("Invalid key: %s", err) + } + + keys := []string{key} + keyringBytes, err := json.MarshalIndent(keys, "", " ") + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { return err }