From da0568d8d0202334320c9e0443b6f565f6111dd5 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sat, 14 Apr 2018 13:28:57 +0200 Subject: [PATCH] refine cert generation --- transport/internet/tls/config.go | 107 ++++++++++++++++--------------- 1 file changed, 57 insertions(+), 50 deletions(-) diff --git a/transport/internet/tls/config.go b/transport/internet/tls/config.go index 355f096e..81ad40a4 100644 --- a/transport/internet/tls/config.go +++ b/transport/internet/tls/config.go @@ -58,13 +58,64 @@ func issueCertificate(rawCA *Certificate, domain string) (*tls.Certificate, erro return &cert, err } -func (c *Config) hasCustomCA() bool { +func (c *Config) getCustomCA() []*Certificate { + certs := make([]*Certificate, 0, len(c.Certificate)) for _, certificate := range c.Certificate { if certificate.Usage == Certificate_AUTHORITY_ISSUE { - return true + certs = append(certs, certificate) } } - return false + return certs +} + +func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + domain := hello.ServerName + certExpired := false + if certificate, found := c.NameToCertificate[domain]; found { + if !isCertificateExpired(certificate) { + return certificate, nil + } + certExpired = true + } + + if certExpired { + newCerts := make([]tls.Certificate, 0, len(c.Certificates)) + + for _, certificate := range c.Certificates { + if !isCertificateExpired(&certificate) { + newCerts = append(newCerts, certificate) + } + } + + c.Certificates = newCerts + } + + var issuedCertificate *tls.Certificate + + // Create a new certificate from existing CA if possible + for _, rawCert := range ca { + if rawCert.Usage == Certificate_AUTHORITY_ISSUE { + newCert, err := issueCertificate(rawCert, domain) + if err != nil { + newError("failed to issue new certificate for ", domain).Base(err).WriteToLog() + continue + } + + c.Certificates = append(c.Certificates, *newCert) + issuedCertificate = &c.Certificates[len(c.Certificates)-1] + break + } + } + + if issuedCertificate == nil { + return nil, newError("failed to create a new certificate for ", domain) + } + + c.BuildNameToCertificate() + + return issuedCertificate, nil + } } func (c *Config) GetTLSConfig(opts ...Option) *tls.Config { @@ -83,54 +134,10 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config { config.InsecureSkipVerify = c.AllowInsecure config.Certificates = c.BuildCertificates() config.BuildNameToCertificate() - if c.hasCustomCA() { - config.GetCertificate = func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { - domain := hello.ServerName - certExpired := false - if certificate, found := config.NameToCertificate[domain]; found { - if !isCertificateExpired(certificate) { - return certificate, nil - } - certExpired = true - } - if certExpired { - newCerts := make([]tls.Certificate, 0, len(config.Certificates)) - - for _, certificate := range config.Certificates { - if !isCertificateExpired(&certificate) { - newCerts = append(newCerts, certificate) - } - } - - config.Certificates = newCerts - } - - var issuedCertificate *tls.Certificate - - // Create a new certificate from existing CA if possible - for _, rawCert := range c.Certificate { - if rawCert.Usage == Certificate_AUTHORITY_ISSUE { - newCert, err := issueCertificate(rawCert, domain) - if err != nil { - newError("failed to issue new certificate for ", domain).Base(err).WriteToLog() - continue - } - - config.Certificates = append(config.Certificates, *newCert) - issuedCertificate = &config.Certificates[len(config.Certificates)-1] - break - } - } - - if issuedCertificate == nil { - return nil, newError("failed to create a new certificate for ", domain) - } - - config.BuildNameToCertificate() - - return issuedCertificate, nil - } + caCerts := c.getCustomCA() + if len(caCerts) > 0 { + config.GetCertificate = getGetCertificateFunc(config, caCerts) } if len(c.ServerName) > 0 {