diff --git a/transport/internet/tls/config_other.go b/transport/internet/tls/config_other.go index 76bae39a..40845f8d 100644 --- a/transport/internet/tls/config_other.go +++ b/transport/internet/tls/config_other.go @@ -2,20 +2,62 @@ package tls -import "crypto/x509" +import ( + "crypto/x509" + "sync" -func (c *Config) getCertPool() *x509.CertPool { - pool, err := x509.SystemCertPool() - if err != nil { - newError("failed to get system cert pool.").Base(err).WriteToLog() + "v2ray.com/core/common/compare" +) + +type certPoolCache struct { + sync.Mutex + once sync.Once + pool *x509.CertPool + extraCerts [][]byte +} + +func (c *certPoolCache) hasCert(cert []byte) bool { + for _, xCert := range c.extraCerts { + if compare.BytesEqual(xCert, cert) { + return true + } + } + return false +} + +func (c *certPoolCache) get(extraCerts []*Certificate) *x509.CertPool { + c.once.Do(func() { + pool, err := x509.SystemCertPool() + if err != nil { + newError("failed to get system cert pool.").Base(err).WriteToLog() + return + } + c.pool = pool + }) + + if c.pool == nil { return nil } - if pool != nil { - for _, cert := range c.Certificate { - if cert.Usage == Certificate_AUTHORITY_VERIFY { - pool.AppendCertsFromPEM(cert.Certificate) - } + + if len(extraCerts) == 0 { + return c.pool + } + + c.Lock() + defer c.Unlock() + + for _, cert := range extraCerts { + if !c.hasCert(cert.Certificate) { + c.pool.AppendCertsFromPEM(cert.Certificate) + c.extraCerts = append(c.extraCerts, cert.Certificate) } } - return pool + + return c.pool +} + +var combineCertPool certPoolCache + +func (c *Config) getCertPool() *x509.CertPool { + return combineCertPool.get(c.Certificate) }