From 786b3b10956e71603d9ac68498702e54ba35fdd5 Mon Sep 17 00:00:00 2001 From: Hans Hasselberg Date: Tue, 26 Feb 2019 16:52:07 +0100 Subject: [PATCH] Centralise tls configuration part 1 (#5366) In order to be able to reload the TLS configuration, we need one way to generate the different configurations. This PR introduces a `tlsutil.Configurator` which holds a `tlsutil.Config`. Afterwards it is responsible for rendering every `tls.Config`. In this particular PR I moved `IncomingHTTPSConfig`, `IncomingTLSConfig`, and `OutgoingTLSWrapper` into `tlsutil.Configurator`. This PR is a pure refactoring - not a single feature added. And not a single test added. I only slightly modified existing tests as necessary. --- agent/agent.go | 11 +- agent/config/runtime.go | 39 ++-- agent/config/runtime_test.go | 34 +++ agent/consul/client.go | 7 +- agent/consul/config.go | 36 ++- agent/consul/server.go | 9 +- agent/consul/server_test.go | 3 +- agent/testagent.go | 2 + tlsutil/config.go | 256 +++++++++++---------- tlsutil/config_test.go | 434 ++++++++++++++--------------------- 10 files changed, 397 insertions(+), 434 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index dd7c40399c..2fe73969dc 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -40,6 +40,7 @@ import ( "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/lib/file" "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/types" "github.com/hashicorp/consul/watch" multierror "github.com/hashicorp/go-multierror" @@ -249,6 +250,8 @@ type Agent struct { // grpcServer is the server instance used currently to serve xDS API for // Envoy. grpcServer *grpc.Server + + tlsConfigurator *tlsutil.Configurator } func New(c *config.RuntimeConfig) (*Agent, error) { @@ -383,15 +386,17 @@ func (a *Agent) Start() error { // waiting to discover a consul server consulCfg.ServerUp = a.sync.SyncFull.Trigger + a.tlsConfigurator = tlsutil.NewConfigurator(c.ToTLSUtilConfig()) + // Setup either the client or the server. if c.ServerMode { - server, err := consul.NewServerLogger(consulCfg, a.logger, a.tokens) + server, err := consul.NewServerLogger(consulCfg, a.logger, a.tokens, a.tlsConfigurator) if err != nil { return fmt.Errorf("Failed to start Consul server: %v", err) } a.delegate = server } else { - client, err := consul.NewClientLogger(consulCfg, a.logger) + client, err := consul.NewClientLogger(consulCfg, a.logger, a.tlsConfigurator) if err != nil { return fmt.Errorf("Failed to start Consul client: %v", err) } @@ -649,7 +654,7 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) { var tlscfg *tls.Config _, isTCP := l.(*tcpKeepAliveListener) if isTCP && proto == "https" { - tlscfg, err = a.config.IncomingHTTPSConfig() + tlscfg, err = a.tlsConfigurator.IncomingHTTPSConfig() if err != nil { return err } diff --git a/agent/config/runtime.go b/agent/config/runtime.go index 0978064f4f..884b19c8dd 100644 --- a/agent/config/runtime.go +++ b/agent/config/runtime.go @@ -1,7 +1,6 @@ package config import ( - "crypto/tls" "fmt" "net" "reflect" @@ -1440,25 +1439,6 @@ type RuntimeConfig struct { Watches []map[string]interface{} } -// IncomingHTTPSConfig returns the TLS configuration for HTTPS -// connections to consul. -func (c *RuntimeConfig) IncomingHTTPSConfig() (*tls.Config, error) { - tc := &tlsutil.Config{ - VerifyIncoming: c.VerifyIncoming || c.VerifyIncomingHTTPS, - VerifyOutgoing: c.VerifyOutgoing, - CAFile: c.CAFile, - CAPath: c.CAPath, - CertFile: c.CertFile, - KeyFile: c.KeyFile, - NodeName: c.NodeName, - ServerName: c.ServerName, - TLSMinVersion: c.TLSMinVersion, - CipherSuites: c.TLSCipherSuites, - PreferServerCipherSuites: c.TLSPreferServerCipherSuites, - } - return tc.IncomingTLSConfig() -} - func (c *RuntimeConfig) apiAddresses(maxPerType int) (unixAddrs, httpAddrs, httpsAddrs []string) { if len(c.HTTPSAddrs) > 0 { for i, addr := range c.HTTPSAddrs { @@ -1597,6 +1577,25 @@ func (c *RuntimeConfig) Sanitized() map[string]interface{} { return sanitize("rt", reflect.ValueOf(c)).Interface().(map[string]interface{}) } +func (c *RuntimeConfig) ToTLSUtilConfig() *tlsutil.Config { + return &tlsutil.Config{ + VerifyIncoming: c.VerifyIncoming, + VerifyIncomingRPC: c.VerifyIncomingRPC, + VerifyIncomingHTTPS: c.VerifyIncomingHTTPS, + VerifyOutgoing: c.VerifyOutgoing, + CAFile: c.CAFile, + CAPath: c.CAPath, + CertFile: c.CertFile, + KeyFile: c.KeyFile, + NodeName: c.NodeName, + ServerName: c.ServerName, + TLSMinVersion: c.TLSMinVersion, + CipherSuites: c.TLSCipherSuites, + PreferServerCipherSuites: c.TLSPreferServerCipherSuites, + EnableAgentTLSForChecks: c.EnableAgentTLSForChecks, + } +} + // isSecret determines whether a field name represents a field which // may contain a secret. func isSecret(name string) bool { diff --git a/agent/config/runtime_test.go b/agent/config/runtime_test.go index 8518281459..a6723588dc 100644 --- a/agent/config/runtime_test.go +++ b/agent/config/runtime_test.go @@ -5426,6 +5426,40 @@ func TestRuntime_ClientAddressAnyV6(t *testing.T) { require.Equal(t, "[::1]:5688", https) } +func TestRuntime_ToTLSUtilConfig(t *testing.T) { + c := &RuntimeConfig{ + VerifyIncoming: true, + VerifyIncomingRPC: true, + VerifyIncomingHTTPS: true, + VerifyOutgoing: true, + CAFile: "a", + CAPath: "b", + CertFile: "c", + KeyFile: "d", + NodeName: "e", + ServerName: "f", + TLSMinVersion: "tls12", + TLSCipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}, + TLSPreferServerCipherSuites: true, + EnableAgentTLSForChecks: true, + } + r := c.ToTLSUtilConfig() + require.Equal(t, c.VerifyIncoming, r.VerifyIncoming) + require.Equal(t, c.VerifyIncomingRPC, r.VerifyIncomingRPC) + require.Equal(t, c.VerifyIncomingHTTPS, r.VerifyIncomingHTTPS) + require.Equal(t, c.VerifyOutgoing, r.VerifyOutgoing) + require.Equal(t, c.CAFile, r.CAFile) + require.Equal(t, c.CAPath, r.CAPath) + require.Equal(t, c.CertFile, r.CertFile) + require.Equal(t, c.KeyFile, r.KeyFile) + require.Equal(t, c.NodeName, r.NodeName) + require.Equal(t, c.ServerName, r.ServerName) + require.Equal(t, c.TLSMinVersion, r.TLSMinVersion) + require.Equal(t, c.TLSCipherSuites, r.CipherSuites) + require.Equal(t, c.TLSPreferServerCipherSuites, r.PreferServerCipherSuites) + require.Equal(t, c.EnableAgentTLSForChecks, r.EnableAgentTLSForChecks) +} + func splitIPPort(hostport string) (net.IP, int) { h, p, err := net.SplitHostPort(hostport) if err != nil { diff --git a/agent/consul/client.go b/agent/consul/client.go index d2a84f3cdc..48e279a12b 100644 --- a/agent/consul/client.go +++ b/agent/consul/client.go @@ -15,6 +15,7 @@ import ( "github.com/hashicorp/consul/agent/router" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/lib" + "github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/serf/serf" "golang.org/x/time/rate" ) @@ -88,10 +89,10 @@ type Client struct { // NewClient is used to construct a new Consul client from the // configuration, potentially returning an error func NewClient(config *Config) (*Client, error) { - return NewClientLogger(config, nil) + return NewClientLogger(config, nil, tlsutil.NewConfigurator(config.ToTLSUtilConfig())) } -func NewClientLogger(config *Config, logger *log.Logger) (*Client, error) { +func NewClientLogger(config *Config, logger *log.Logger, tlsConfigurator *tlsutil.Configurator) (*Client, error) { // Check the protocol version if err := config.CheckProtocolVersion(); err != nil { return nil, err @@ -113,7 +114,7 @@ func NewClientLogger(config *Config, logger *log.Logger) (*Client, error) { } // Create the tls Wrapper - tlsWrap, err := config.tlsConfig().OutgoingTLSWrapper() + tlsWrap, err := tlsConfigurator.OutgoingRPCWrapper() if err != nil { return nil, err } diff --git a/agent/consul/config.go b/agent/consul/config.go index e3ebd88627..82a05627d2 100644 --- a/agent/consul/config.go +++ b/agent/consul/config.go @@ -382,6 +382,22 @@ type Config struct { ConnectReplicationToken string } +func (c *Config) ToTLSUtilConfig() *tlsutil.Config { + return &tlsutil.Config{ + VerifyIncoming: c.VerifyIncoming, + VerifyOutgoing: c.VerifyOutgoing, + CAFile: c.CAFile, + CAPath: c.CAPath, + CertFile: c.CertFile, + KeyFile: c.KeyFile, + NodeName: c.NodeName, + ServerName: c.ServerName, + TLSMinVersion: c.TLSMinVersion, + CipherSuites: c.TLSCipherSuites, + PreferServerCipherSuites: c.TLSPreferServerCipherSuites, + } +} + // CheckProtocolVersion validates the protocol version. func (c *Config) CheckProtocolVersion() error { if c.ProtocolVersion < ProtocolVersionMin { @@ -500,23 +516,3 @@ func DefaultConfig() *Config { return conf } - -// tlsConfig maps this config into a tlsutil config. -func (c *Config) tlsConfig() *tlsutil.Config { - tlsConf := &tlsutil.Config{ - VerifyIncoming: c.VerifyIncoming, - VerifyOutgoing: c.VerifyOutgoing, - VerifyServerHostname: c.VerifyServerHostname, - UseTLS: c.UseTLS, - CAFile: c.CAFile, - CAPath: c.CAPath, - CertFile: c.CertFile, - KeyFile: c.KeyFile, - NodeName: c.NodeName, - ServerName: c.ServerName, - Domain: c.Domain, - TLSMinVersion: c.TLSMinVersion, - PreferServerCipherSuites: c.TLSPreferServerCipherSuites, - } - return tlsConf -} diff --git a/agent/consul/server.go b/agent/consul/server.go index f82331f4b8..a81c8310be 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -253,12 +253,12 @@ type Server struct { } func NewServer(config *Config) (*Server, error) { - return NewServerLogger(config, nil, new(token.Store)) + return NewServerLogger(config, nil, new(token.Store), tlsutil.NewConfigurator(config.ToTLSUtilConfig())) } // NewServer is used to construct a new Consul server from the // configuration, potentially returning an error -func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store) (*Server, error) { +func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store, tlsConfigurator *tlsutil.Configurator) (*Server, error) { // Check the protocol version. if err := config.CheckProtocolVersion(); err != nil { return nil, err @@ -297,14 +297,13 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store) (* } // Create the TLS wrapper for outgoing connections. - tlsConf := config.tlsConfig() - tlsWrap, err := tlsConf.OutgoingTLSWrapper() + tlsWrap, err := tlsConfigurator.OutgoingRPCWrapper() if err != nil { return nil, err } // Get the incoming TLS config. - incomingTLS, err := tlsConf.IncomingTLSConfig() + incomingTLS, err := tlsConfigurator.IncomingRPCConfig() if err != nil { return nil, err } diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index d89db26193..d486c33477 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -18,6 +18,7 @@ import ( "github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/testutil" "github.com/hashicorp/consul/testutil/retry" + "github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/types" "github.com/hashicorp/go-uuid" ) @@ -176,7 +177,7 @@ func newServer(c *Config) (*Server, error) { w = os.Stderr } logger := log.New(w, c.NodeName+" - ", log.LstdFlags|log.Lmicroseconds) - srv, err := NewServerLogger(c, logger, new(token.Store)) + srv, err := NewServerLogger(c, logger, new(token.Store), tlsutil.NewConfigurator(c.ToTLSUtilConfig())) if err != nil { return nil, err } diff --git a/agent/testagent.go b/agent/testagent.go index 587ea215ab..e9749611db 100644 --- a/agent/testagent.go +++ b/agent/testagent.go @@ -26,6 +26,7 @@ import ( "github.com/hashicorp/consul/lib/freeport" "github.com/hashicorp/consul/logger" "github.com/hashicorp/consul/testutil/retry" + "github.com/hashicorp/consul/tlsutil" "github.com/stretchr/testify/require" ) @@ -148,6 +149,7 @@ func (a *TestAgent) Start(t *testing.T) *TestAgent { agent.LogWriter = a.LogWriter agent.logger = log.New(logOutput, a.Name+" - ", log.LstdFlags|log.Lmicroseconds) agent.MemSink = metrics.NewInmemSink(1*time.Second, time.Minute) + agent.tlsConfigurator = tlsutil.NewConfigurator(a.Config.ToTLSUtilConfig()) // we need the err var in the next exit condition if err := agent.Start(); err == nil { diff --git a/tlsutil/config.go b/tlsutil/config.go index 6e7a3bca8f..da78b5e80e 100644 --- a/tlsutil/config.go +++ b/tlsutil/config.go @@ -33,7 +33,9 @@ type Config struct { // VerifyIncoming is used to verify the authenticity of incoming connections. // This means that TCP requests are forbidden, only allowing for TLS. TLS connections // must match a provided certificate authority. This can be used to force client auth. - VerifyIncoming bool + VerifyIncoming bool + VerifyIncomingRPC bool + VerifyIncomingHTTPS bool // VerifyOutgoing is used to verify the authenticity of outgoing connections. // This means that TLS requests are used, and TCP requests are not made. TLS connections @@ -87,6 +89,8 @@ type Config struct { // PreferServerCipherSuites specifies whether to prefer the server's ciphersuite // over the client ciphersuites. PreferServerCipherSuites bool + + EnableAgentTLSForChecks bool } // AppendCA opens and parses the CA file and adds the certificates to @@ -125,89 +129,6 @@ func (c *Config) skipBuiltinVerify() bool { return c.VerifyServerHostname == false && c.ServerName == "" } -// OutgoingTLSConfig generates a TLS configuration for outgoing -// requests. It will return a nil config if this configuration should -// not use TLS for outgoing connections. -func (c *Config) OutgoingTLSConfig() (*tls.Config, error) { - if !c.UseTLS && !c.VerifyOutgoing { - return nil, nil - } - // Create the tlsConfig - tlsConfig := &tls.Config{ - RootCAs: x509.NewCertPool(), - InsecureSkipVerify: c.skipBuiltinVerify(), - ServerName: c.ServerName, - } - if len(c.CipherSuites) != 0 { - tlsConfig.CipherSuites = c.CipherSuites - } - if c.PreferServerCipherSuites { - tlsConfig.PreferServerCipherSuites = true - } - - // Ensure we have a CA if VerifyOutgoing is set - if c.VerifyOutgoing && c.CAFile == "" && c.CAPath == "" { - return nil, fmt.Errorf("VerifyOutgoing set, and no CA certificate provided!") - } - - // Parse the CA certs if any - rootConfig := &rootcerts.Config{ - CAFile: c.CAFile, - CAPath: c.CAPath, - } - if err := rootcerts.ConfigureTLS(tlsConfig, rootConfig); err != nil { - return nil, err - } - - // Add cert/key - cert, err := c.KeyPair() - if err != nil { - return nil, err - } else if cert != nil { - tlsConfig.Certificates = []tls.Certificate{*cert} - } - - // Check if a minimum TLS version was set - if c.TLSMinVersion != "" { - tlsvers, ok := TLSLookup[c.TLSMinVersion] - if !ok { - return nil, fmt.Errorf("TLSMinVersion: value %s not supported, please specify one of [tls10,tls11,tls12]", c.TLSMinVersion) - } - tlsConfig.MinVersion = tlsvers - } - - return tlsConfig, nil -} - -// OutgoingTLSWrapper returns a a DCWrapper based on the OutgoingTLS -// configuration. If hostname verification is on, the wrapper -// will properly generate the dynamic server name for verification. -func (c *Config) OutgoingTLSWrapper() (DCWrapper, error) { - // Get the TLS config - tlsConfig, err := c.OutgoingTLSConfig() - if err != nil { - return nil, err - } - - // Check if TLS is not enabled - if tlsConfig == nil { - return nil, nil - } - - // Generate the wrapper based on hostname verification - wrapper := func(dc string, conn net.Conn) (net.Conn, error) { - if c.VerifyServerHostname { - // Strip the trailing '.' from the domain if any - domain := strings.TrimSuffix(c.Domain, ".") - tlsConfig = tlsConfig.Clone() - tlsConfig.ServerName = "server." + dc + "." + domain - } - return c.wrapTLSClient(conn, tlsConfig) - } - - return wrapper, nil -} - // SpecificDC is used to invoke a static datacenter // and turns a DCWrapper into a Wrapper type. func SpecificDC(dc string, tlsWrap DCWrapper) Wrapper { @@ -277,71 +198,158 @@ func (c *Config) wrapTLSClient(conn net.Conn, tlsConfig *tls.Config) (net.Conn, return tlsConn, err } -// IncomingTLSConfig generates a TLS configuration for incoming requests -func (c *Config) IncomingTLSConfig() (*tls.Config, error) { - // Create the tlsConfig +type Configurator struct { + base *Config +} + +func NewConfigurator(config *Config) *Configurator { + return &Configurator{base: config} +} + +func (c *Configurator) commonTLSConfig() (*tls.Config, error) { + if c.base == nil { + return nil, fmt.Errorf("No config") + } + tlsConfig := &tls.Config{ - ServerName: c.ServerName, - ClientCAs: x509.NewCertPool(), - ClientAuth: tls.NoClientCert, + ServerName: c.base.ServerName, } if tlsConfig.ServerName == "" { - tlsConfig.ServerName = c.NodeName + tlsConfig.ServerName = c.base.NodeName } // Set the cipher suites - if len(c.CipherSuites) != 0 { - tlsConfig.CipherSuites = c.CipherSuites + if len(c.base.CipherSuites) != 0 { + tlsConfig.CipherSuites = c.base.CipherSuites } - if c.PreferServerCipherSuites { + if c.base.PreferServerCipherSuites { tlsConfig.PreferServerCipherSuites = true } - // Parse the CA certs if any - if c.CAFile != "" { - pool, err := rootcerts.LoadCAFile(c.CAFile) - if err != nil { - return nil, err - } - tlsConfig.ClientCAs = pool - } else if c.CAPath != "" { - pool, err := rootcerts.LoadCAPath(c.CAPath) - if err != nil { - return nil, err - } - tlsConfig.ClientCAs = pool - } - // Add cert/key - cert, err := c.KeyPair() + cert, err := c.base.KeyPair() if err != nil { return nil, err } else if cert != nil { tlsConfig.Certificates = []tls.Certificate{*cert} } - // Check if we require verification - if c.VerifyIncoming { - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - if c.CAFile == "" && c.CAPath == "" { - return nil, fmt.Errorf("VerifyIncoming set, and no CA certificate provided!") - } - if cert == nil { - return nil, fmt.Errorf("VerifyIncoming set, and no Cert/Key pair provided!") - } - } - // Check if a minimum TLS version was set - if c.TLSMinVersion != "" { - tlsvers, ok := TLSLookup[c.TLSMinVersion] + if c.base.TLSMinVersion != "" { + tlsvers, ok := TLSLookup[c.base.TLSMinVersion] if !ok { - return nil, fmt.Errorf("TLSMinVersion: value %s not supported, please specify one of [tls10,tls11,tls12]", c.TLSMinVersion) + return nil, fmt.Errorf("TLSMinVersion: value %s not supported, please specify one of [tls10,tls11,tls12]", c.base.TLSMinVersion) } tlsConfig.MinVersion = tlsvers } return tlsConfig, nil } +func (c *Configurator) outgoingTLSConfig() (*tls.Config, error) { + tlsConfig, err := c.commonTLSConfig() + if err != nil { + return nil, err + } + + tlsConfig.RootCAs = x509.NewCertPool() + tlsConfig.InsecureSkipVerify = c.base.skipBuiltinVerify() + + // Ensure we have a CA if VerifyOutgoing is set + if c.base.VerifyOutgoing && c.base.CAFile == "" && c.base.CAPath == "" { + return nil, fmt.Errorf("VerifyOutgoing set, and no CA certificate provided!") + } + + // Parse the CA certs if any + rootConfig := &rootcerts.Config{ + CAFile: c.base.CAFile, + CAPath: c.base.CAPath, + } + if err := rootcerts.ConfigureTLS(tlsConfig, rootConfig); err != nil { + return nil, err + } + + return tlsConfig, nil +} + +func (c *Configurator) incomingTLSConfig(verify bool) (*tls.Config, error) { + tlsConfig, err := c.commonTLSConfig() + if err != nil { + return nil, err + } + + tlsConfig.ClientCAs = x509.NewCertPool() + tlsConfig.ClientAuth = tls.NoClientCert + + // Parse the CA certs if any + if c.base.CAFile != "" { + pool, err := rootcerts.LoadCAFile(c.base.CAFile) + if err != nil { + return nil, err + } + tlsConfig.ClientCAs = pool + } else if c.base.CAPath != "" { + pool, err := rootcerts.LoadCAPath(c.base.CAPath) + if err != nil { + return nil, err + } + tlsConfig.ClientCAs = pool + } + + if verify { + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + + if c.base.CAFile == "" && c.base.CAPath == "" { + return nil, fmt.Errorf("VerifyIncoming set, and no CA certificate provided!") + } + if len(tlsConfig.Certificates) == 0 { + return nil, fmt.Errorf("VerifyIncoming set, and no Cert/Key pair provided!") + } + } + return tlsConfig, nil +} + +func (c *Configurator) IncomingRPCConfig() (*tls.Config, error) { + return c.incomingTLSConfig(c.base.VerifyIncoming || c.base.VerifyIncomingRPC) +} + +func (c *Configurator) IncomingHTTPSConfig() (*tls.Config, error) { + return c.incomingTLSConfig(c.base.VerifyIncoming || c.base.VerifyIncomingHTTPS) +} + +func (c *Configurator) OutgoingRPCConfig() (*tls.Config, error) { + useTLS := c.base.CAFile != "" || c.base.CAPath != "" || c.base.VerifyOutgoing + if !useTLS { + return nil, nil + } + return c.outgoingTLSConfig() +} + +func (c *Configurator) OutgoingRPCWrapper() (DCWrapper, error) { + // Get the TLS config + tlsConfig, err := c.OutgoingRPCConfig() + if err != nil { + return nil, err + } + + // Check if TLS is not enabled + if tlsConfig == nil { + return nil, nil + } + + // Generate the wrapper based on hostname verification + wrapper := func(dc string, conn net.Conn) (net.Conn, error) { + if c.base.VerifyServerHostname { + // Strip the trailing '.' from the domain if any + domain := strings.TrimSuffix(c.base.Domain, ".") + tlsConfig = tlsConfig.Clone() + tlsConfig.ServerName = "server." + dc + "." + domain + } + return c.base.wrapTLSClient(conn, tlsConfig) + } + + return wrapper, nil +} + // ParseCiphers parse ciphersuites from the comma-separated string into recognized slice func ParseCiphers(cipherStr string) ([]uint16, error) { suites := []uint16{} diff --git a/tlsutil/config_test.go b/tlsutil/config_test.go index d69cf7c987..534bbe906e 100644 --- a/tlsutil/config_test.go +++ b/tlsutil/config_test.go @@ -40,20 +40,6 @@ func TestConfig_CACertificate_Valid(t *testing.T) { } } -func TestConfig_CAPath_Valid(t *testing.T) { - conf := &Config{ - CAPath: "../test/ca_path", - } - - tlsConf, err := conf.IncomingTLSConfig() - if err != nil { - t.Fatalf("err: %v", err) - } - if len(tlsConf.ClientCAs.Subjects()) != 2 { - t.Fatalf("expected certs") - } -} - func TestConfig_KeyPair_None(t *testing.T) { conf := &Config{} cert, err := conf.KeyPair() @@ -79,53 +65,38 @@ func TestConfig_KeyPair_Valid(t *testing.T) { } } -func TestConfig_OutgoingTLS_MissingCA(t *testing.T) { +func TestConfigurator_OutgoingTLS_MissingCA(t *testing.T) { conf := &Config{ VerifyOutgoing: true, } - tls, err := conf.OutgoingTLSConfig() - if err == nil { - t.Fatalf("expected err") - } - if tls != nil { - t.Fatalf("bad: %v", tls) - } + c := NewConfigurator(conf) + tlsConf, err := c.OutgoingRPCConfig() + require.Error(t, err) + require.Nil(t, tlsConf) } -func TestConfig_OutgoingTLS_OnlyCA(t *testing.T) { +func TestConfigurator_OutgoingTLS_OnlyCA(t *testing.T) { conf := &Config{ CAFile: "../test/ca/root.cer", } - tls, err := conf.OutgoingTLSConfig() - if err != nil { - t.Fatalf("err: %v", err) - } - if tls != nil { - t.Fatalf("expected no config") - } + c := NewConfigurator(conf) + tlsConf, err := c.OutgoingRPCConfig() + require.NoError(t, err) + require.NotNil(t, tlsConf) } -func TestConfig_OutgoingTLS_VerifyOutgoing(t *testing.T) { +func TestConfigurator_OutgoingTLS_VerifyOutgoing(t *testing.T) { conf := &Config{ VerifyOutgoing: true, CAFile: "../test/ca/root.cer", } - tls, err := conf.OutgoingTLSConfig() - if err != nil { - t.Fatalf("err: %v", err) - } - if tls == nil { - t.Fatalf("expected config") - } - if len(tls.RootCAs.Subjects()) != 1 { - t.Fatalf("expect root cert") - } - if tls.ServerName != "" { - t.Fatalf("expect no server name verification") - } - if !tls.InsecureSkipVerify { - t.Fatalf("should skip built-in verification") - } + c := NewConfigurator(conf) + tlsConf, err := c.OutgoingRPCConfig() + require.NoError(t, err) + require.NotNil(t, tlsConf) + require.Equal(t, len(tlsConf.RootCAs.Subjects()), 1) + require.Empty(t, tlsConf.ServerName) + require.True(t, tlsConf.InsecureSkipVerify) } func TestConfig_SkipBuiltinVerify(t *testing.T) { @@ -145,77 +116,51 @@ func TestConfig_SkipBuiltinVerify(t *testing.T) { } } -func TestConfig_OutgoingTLS_ServerName(t *testing.T) { +func TestConfigurator_OutgoingTLS_ServerName(t *testing.T) { conf := &Config{ VerifyOutgoing: true, CAFile: "../test/ca/root.cer", ServerName: "consul.example.com", } - tls, err := conf.OutgoingTLSConfig() - if err != nil { - t.Fatalf("err: %v", err) - } - if tls == nil { - t.Fatalf("expected config") - } - if len(tls.RootCAs.Subjects()) != 1 { - t.Fatalf("expect root cert") - } - if tls.ServerName != "consul.example.com" { - t.Fatalf("expect server name") - } - if tls.InsecureSkipVerify { - t.Fatalf("should not skip built-in verification") - } + c := NewConfigurator(conf) + tlsConf, err := c.OutgoingRPCConfig() + require.NoError(t, err) + require.NotNil(t, tlsConf) + require.Equal(t, len(tlsConf.RootCAs.Subjects()), 1) + require.Equal(t, tlsConf.ServerName, "consul.example.com") + require.False(t, tlsConf.InsecureSkipVerify) } -func TestConfig_OutgoingTLS_VerifyHostname(t *testing.T) { +func TestConfigurator_OutgoingTLS_VerifyHostname(t *testing.T) { conf := &Config{ VerifyOutgoing: true, VerifyServerHostname: true, CAFile: "../test/ca/root.cer", } - tls, err := conf.OutgoingTLSConfig() - if err != nil { - t.Fatalf("err: %v", err) - } - if tls == nil { - t.Fatalf("expected config") - } - if len(tls.RootCAs.Subjects()) != 1 { - t.Fatalf("expect root cert") - } - if tls.InsecureSkipVerify { - t.Fatalf("should not skip built-in verification") - } + c := NewConfigurator(conf) + tlsConf, err := c.OutgoingRPCConfig() + require.NoError(t, err) + require.NotNil(t, tlsConf) + require.Equal(t, len(tlsConf.RootCAs.Subjects()), 1) + require.False(t, tlsConf.InsecureSkipVerify) } -func TestConfig_OutgoingTLS_WithKeyPair(t *testing.T) { +func TestConfigurator_OutgoingTLS_WithKeyPair(t *testing.T) { conf := &Config{ VerifyOutgoing: true, CAFile: "../test/ca/root.cer", CertFile: "../test/key/ourdomain.cer", KeyFile: "../test/key/ourdomain.key", } - tls, err := conf.OutgoingTLSConfig() - if err != nil { - t.Fatalf("err: %v", err) - } - if tls == nil { - t.Fatalf("expected config") - } - if len(tls.RootCAs.Subjects()) != 1 { - t.Fatalf("expect root cert") - } - if !tls.InsecureSkipVerify { - t.Fatalf("should skip verification") - } - if len(tls.Certificates) != 1 { - t.Fatalf("expected client cert") - } + c := NewConfigurator(conf) + tlsConf, err := c.OutgoingRPCConfig() + require.NoError(t, err) + require.NotNil(t, tlsConf) + require.True(t, tlsConf.InsecureSkipVerify) + require.Equal(t, len(tlsConf.Certificates), 1) } -func TestConfig_OutgoingTLS_TLSMinVersion(t *testing.T) { +func TestConfigurator_OutgoingTLS_TLSMinVersion(t *testing.T) { tlsVersions := []string{"tls10", "tls11", "tls12"} for _, version := range tlsVersions { conf := &Config{ @@ -223,23 +168,19 @@ func TestConfig_OutgoingTLS_TLSMinVersion(t *testing.T) { CAFile: "../test/ca/root.cer", TLSMinVersion: version, } - tls, err := conf.OutgoingTLSConfig() - if err != nil { - t.Fatalf("err: %v", err) - } - if tls == nil { - t.Fatalf("expected config") - } - if tls.MinVersion != TLSLookup[version] { - t.Fatalf("expected tls min version: %v, %v", tls.MinVersion, TLSLookup[version]) - } + c := NewConfigurator(conf) + tlsConf, err := c.OutgoingRPCConfig() + require.NoError(t, err) + require.NotNil(t, tlsConf) + require.Equal(t, tlsConf.MinVersion, TLSLookup[version]) } } func startTLSServer(config *Config) (net.Conn, chan error) { errc := make(chan error, 1) - tlsConfigServer, err := config.IncomingTLSConfig() + c := NewConfigurator(config) + tlsConfigServer, err := c.IncomingRPCConfig() if err != nil { errc <- err return nil, errc @@ -273,7 +214,7 @@ func startTLSServer(config *Config) (net.Conn, chan error) { return clientConn, errc } -func TestConfig_outgoingWrapper_OK(t *testing.T) { +func TestConfigurator_outgoingWrapper_OK(t *testing.T) { config := &Config{ CAFile: "../test/hostname/CertAuth.crt", CertFile: "../test/hostname/Alice.crt", @@ -288,27 +229,22 @@ func TestConfig_outgoingWrapper_OK(t *testing.T) { t.Fatalf("startTLSServer err: %v", <-errc) } - wrap, err := config.OutgoingTLSWrapper() - if err != nil { - t.Fatalf("OutgoingTLSWrapper err: %v", err) - } + c := NewConfigurator(config) + wrap, err := c.OutgoingRPCWrapper() + require.NoError(t, err) tlsClient, err := wrap("dc1", client) - if err != nil { - t.Fatalf("wrapTLS err: %v", err) - } + require.NoError(t, err) + defer tlsClient.Close() - if err := tlsClient.(*tls.Conn).Handshake(); err != nil { - t.Fatalf("write err: %v", err) - } + err = tlsClient.(*tls.Conn).Handshake() + require.NoError(t, err) err = <-errc - if err != nil { - t.Fatalf("server: %v", err) - } + require.NoError(t, err) } -func TestConfig_outgoingWrapper_BadDC(t *testing.T) { +func TestConfigurator_outgoingWrapper_BadDC(t *testing.T) { config := &Config{ CAFile: "../test/hostname/CertAuth.crt", CertFile: "../test/hostname/Alice.crt", @@ -323,25 +259,22 @@ func TestConfig_outgoingWrapper_BadDC(t *testing.T) { t.Fatalf("startTLSServer err: %v", <-errc) } - wrap, err := config.OutgoingTLSWrapper() - if err != nil { - t.Fatalf("OutgoingTLSWrapper err: %v", err) - } + c := NewConfigurator(config) + wrap, err := c.OutgoingRPCWrapper() + require.NoError(t, err) tlsClient, err := wrap("dc2", client) - if err != nil { - t.Fatalf("wrapTLS err: %v", err) - } + require.NoError(t, err) + err = tlsClient.(*tls.Conn).Handshake() - if _, ok := err.(x509.HostnameError); !ok { - t.Fatalf("should get hostname err: %v", err) - } + _, ok := err.(x509.HostnameError) + require.True(t, ok) tlsClient.Close() <-errc } -func TestConfig_outgoingWrapper_BadCert(t *testing.T) { +func TestConfigurator_outgoingWrapper_BadCert(t *testing.T) { config := &Config{ CAFile: "../test/ca/root.cer", CertFile: "../test/key/ourdomain.cer", @@ -356,15 +289,13 @@ func TestConfig_outgoingWrapper_BadCert(t *testing.T) { t.Fatalf("startTLSServer err: %v", <-errc) } - wrap, err := config.OutgoingTLSWrapper() - if err != nil { - t.Fatalf("OutgoingTLSWrapper err: %v", err) - } + c := NewConfigurator(config) + wrap, err := c.OutgoingRPCWrapper() + require.NoError(t, err) tlsClient, err := wrap("dc1", client) - if err != nil { - t.Fatalf("wrapTLS err: %v", err) - } + require.NoError(t, err) + err = tlsClient.(*tls.Conn).Handshake() if _, ok := err.(x509.HostnameError); !ok { t.Fatalf("should get hostname err: %v", err) @@ -374,7 +305,7 @@ func TestConfig_outgoingWrapper_BadCert(t *testing.T) { <-errc } -func TestConfig_wrapTLS_OK(t *testing.T) { +func TestConfigurator_wrapTLS_OK(t *testing.T) { config := &Config{ CAFile: "../test/ca/root.cer", CertFile: "../test/key/ourdomain.cer", @@ -387,24 +318,19 @@ func TestConfig_wrapTLS_OK(t *testing.T) { t.Fatalf("startTLSServer err: %v", <-errc) } - clientConfig, err := config.OutgoingTLSConfig() - if err != nil { - t.Fatalf("OutgoingTLSConfig err: %v", err) - } + c := NewConfigurator(config) + clientConfig, err := c.OutgoingRPCConfig() + require.NoError(t, err) tlsClient, err := config.wrapTLSClient(client, clientConfig) - if err != nil { - t.Fatalf("wrapTLS err: %v", err) - } else { - tlsClient.Close() - } + require.NoError(t, err) + + tlsClient.Close() err = <-errc - if err != nil { - t.Fatalf("server: %v", err) - } + require.NoError(t, err) } -func TestConfig_wrapTLS_BadCert(t *testing.T) { +func TestConfigurator_wrapTLS_BadCert(t *testing.T) { serverConfig := &Config{ CertFile: "../test/key/ssl-cert-snakeoil.pem", KeyFile: "../test/key/ssl-cert-snakeoil.key", @@ -420,114 +346,16 @@ func TestConfig_wrapTLS_BadCert(t *testing.T) { VerifyOutgoing: true, } - clientTLSConfig, err := clientConfig.OutgoingTLSConfig() - if err != nil { - t.Fatalf("OutgoingTLSConfig err: %v", err) - } + c := NewConfigurator(clientConfig) + clientTLSConfig, err := c.OutgoingRPCConfig() + require.NoError(t, err) tlsClient, err := clientConfig.wrapTLSClient(client, clientTLSConfig) - if err == nil { - t.Fatalf("wrapTLS no err") - } - if tlsClient != nil { - t.Fatalf("returned a client") - } + require.Error(t, err) + require.Nil(t, tlsClient) err = <-errc - if err != nil { - t.Fatalf("server: %v", err) - } -} - -func TestConfig_IncomingTLS(t *testing.T) { - conf := &Config{ - VerifyIncoming: true, - CAFile: "../test/ca/root.cer", - CertFile: "../test/key/ourdomain.cer", - KeyFile: "../test/key/ourdomain.key", - } - tlsC, err := conf.IncomingTLSConfig() - if err != nil { - t.Fatalf("err: %v", err) - } - if tlsC == nil { - t.Fatalf("expected config") - } - if len(tlsC.ClientCAs.Subjects()) != 1 { - t.Fatalf("expect client cert") - } - if tlsC.ClientAuth != tls.RequireAndVerifyClientCert { - t.Fatalf("should not skip verification") - } - if len(tlsC.Certificates) != 1 { - t.Fatalf("expected client cert") - } -} - -func TestConfig_IncomingTLS_MissingCA(t *testing.T) { - conf := &Config{ - VerifyIncoming: true, - CertFile: "../test/key/ourdomain.cer", - KeyFile: "../test/key/ourdomain.key", - } - _, err := conf.IncomingTLSConfig() - if err == nil { - t.Fatalf("expected err") - } -} - -func TestConfig_IncomingTLS_MissingKey(t *testing.T) { - conf := &Config{ - VerifyIncoming: true, - CAFile: "../test/ca/root.cer", - } - _, err := conf.IncomingTLSConfig() - if err == nil { - t.Fatalf("expected err") - } -} - -func TestConfig_IncomingTLS_NoVerify(t *testing.T) { - conf := &Config{} - tlsC, err := conf.IncomingTLSConfig() - if err != nil { - t.Fatalf("err: %v", err) - } - if tlsC == nil { - t.Fatalf("expected config") - } - if len(tlsC.ClientCAs.Subjects()) != 0 { - t.Fatalf("do not expect client cert") - } - if tlsC.ClientAuth != tls.NoClientCert { - t.Fatalf("should skip verification") - } - if len(tlsC.Certificates) != 0 { - t.Fatalf("unexpected client cert") - } -} - -func TestConfig_IncomingTLS_TLSMinVersion(t *testing.T) { - tlsVersions := []string{"tls10", "tls11", "tls12"} - for _, version := range tlsVersions { - conf := &Config{ - VerifyIncoming: true, - CAFile: "../test/ca/root.cer", - CertFile: "../test/key/ourdomain.cer", - KeyFile: "../test/key/ourdomain.key", - TLSMinVersion: version, - } - tls, err := conf.IncomingTLSConfig() - if err != nil { - t.Fatalf("err: %v", err) - } - if tls == nil { - t.Fatalf("expected config") - } - if tls.MinVersion != TLSLookup[version] { - t.Fatalf("expected tls min version: %v, %v", tls.MinVersion, TLSLookup[version]) - } - } + require.NoError(t, err) } func TestConfig_ParseCiphers(t *testing.T) { @@ -592,3 +420,93 @@ func TestConfig_ParseCiphers(t *testing.T) { t.Fatal("should fail on unsupported cipherX") } } + +func TestConfigurator_IncomingHTTPSConfig_CA_PATH(t *testing.T) { + conf := &Config{CAPath: "../test/ca_path"} + + c := NewConfigurator(conf) + tlsConf, err := c.IncomingHTTPSConfig() + require.NoError(t, err) + require.Equal(t, len(tlsConf.ClientCAs.Subjects()), 2) +} + +func TestConfigurator_IncomingHTTPS(t *testing.T) { + conf := &Config{ + VerifyIncoming: true, + CAFile: "../test/ca/root.cer", + CertFile: "../test/key/ourdomain.cer", + KeyFile: "../test/key/ourdomain.key", + } + c := NewConfigurator(conf) + tlsConf, err := c.IncomingHTTPSConfig() + require.NoError(t, err) + require.NotNil(t, tlsConf) + require.Equal(t, len(tlsConf.ClientCAs.Subjects()), 1) + require.Equal(t, tlsConf.ClientAuth, tls.RequireAndVerifyClientCert) + require.Equal(t, len(tlsConf.Certificates), 1) +} + +func TestConfigurator_IncomingHTTPS_MissingCA(t *testing.T) { + conf := &Config{ + VerifyIncoming: true, + CertFile: "../test/key/ourdomain.cer", + KeyFile: "../test/key/ourdomain.key", + } + c := NewConfigurator(conf) + _, err := c.IncomingHTTPSConfig() + require.Error(t, err) +} + +func TestConfigurator_IncomingHTTPS_MissingKey(t *testing.T) { + conf := &Config{ + VerifyIncoming: true, + CAFile: "../test/ca/root.cer", + } + c := NewConfigurator(conf) + _, err := c.IncomingHTTPSConfig() + require.Error(t, err) +} + +func TestConfigurator_IncomingHTTPS_NoVerify(t *testing.T) { + conf := &Config{} + c := NewConfigurator(conf) + tlsConf, err := c.IncomingHTTPSConfig() + require.NoError(t, err) + require.NotNil(t, tlsConf) + require.Equal(t, len(tlsConf.ClientCAs.Subjects()), 0) + require.Equal(t, tlsConf.ClientAuth, tls.NoClientCert) + require.Equal(t, len(tlsConf.Certificates), 0) +} + +func TestConfigurator_IncomingHTTPS_TLSMinVersion(t *testing.T) { + tlsVersions := []string{"tls10", "tls11", "tls12"} + for _, version := range tlsVersions { + conf := &Config{ + VerifyIncoming: true, + CAFile: "../test/ca/root.cer", + CertFile: "../test/key/ourdomain.cer", + KeyFile: "../test/key/ourdomain.key", + TLSMinVersion: version, + } + c := NewConfigurator(conf) + tlsConf, err := c.IncomingHTTPSConfig() + require.NoError(t, err) + require.NotNil(t, tlsConf) + require.Equal(t, tlsConf.MinVersion, TLSLookup[version]) + } +} + +func TestConfigurator_IncomingHTTPSCAPath_Valid(t *testing.T) { + conf := &Config{ + CAPath: "../test/ca_path", + } + + c := NewConfigurator(conf) + tlsConf, err := c.IncomingHTTPSConfig() + if err != nil { + t.Fatalf("err: %v", err) + } + if len(tlsConf.ClientCAs.Subjects()) != 2 { + t.Fatalf("expected certs") + } +}