diff --git a/pkg/daemons/control/ha.go b/pkg/daemons/control/ha.go index 6f1784c616..759461e735 100644 --- a/pkg/daemons/control/ha.go +++ b/pkg/daemons/control/ha.go @@ -2,6 +2,8 @@ package control import ( "context" + "crypto/tls" + "crypto/x509" "encoding/json" "io/ioutil" "os" @@ -34,11 +36,16 @@ func setHAData(cfg *config.Control) error { if cfg.StorageBackend != "etcd3" || cfg.CertStorageBackend != "etcd3" { return nil } + tlsConfig, err := genTLSConfig(cfg) + if err != nil { + return err + } endpoints := strings.Split(cfg.StorageEndpoint, ",") cli, err := clientv3.New(clientv3.Config{ Endpoints: endpoints, DialTimeout: etcdDialTimeout, + TLS: tlsConfig, }) if err != nil { return err @@ -71,10 +78,16 @@ func getHAData(cfg *config.Control) error { if cfg.StorageBackend != "etcd3" || cfg.CertStorageBackend != "etcd3" { return nil } + tlsConfig, err := genTLSConfig(cfg) + if err != nil { + return err + } + endpoints := strings.Split(cfg.StorageEndpoint, ",") cli, err := clientv3.New(clientv3.Config{ Endpoints: endpoints, DialTimeout: etcdDialTimeout, + TLS: tlsConfig, }) if err != nil { return err @@ -99,6 +112,35 @@ func getHAData(cfg *config.Control) error { return writeRuntimeCertData(cfg.Runtime, serverRuntime) } +func genTLSConfig(cfg *config.Control) (*tls.Config, error) { + tlsConfig := &tls.Config{} + if cfg.StorageCertFile != "" && cfg.StorageKeyFile != "" { + certPem, err := ioutil.ReadFile(cfg.StorageCertFile) + if err != nil { + return nil, err + } + keyPem, err := ioutil.ReadFile(cfg.StorageKeyFile) + if err != nil { + return nil, err + } + tlsCert, err := tls.X509KeyPair(certPem, keyPem) + if err != nil { + return nil, err + } + tlsConfig.Certificates = []tls.Certificate{tlsCert} + } + if cfg.StorageCAFile != "" { + caData, err := ioutil.ReadFile(cfg.StorageCAFile) + if err != nil { + return nil, err + } + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(caData) + tlsConfig.RootCAs = certPool + } + return tlsConfig, nil +} + func readRuntimeCertData(runtime *config.ControlRuntime) ([]byte, error) { serverHACerts := map[string]string{ runtime.ServerCA: "",