diff --git a/transport/internet/tls/config.go b/transport/internet/tls/config.go index 405f4bf8..86f6d96c 100644 --- a/transport/internet/tls/config.go +++ b/transport/internet/tls/config.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "sync" "time" "v2ray.com/core/common/net" @@ -77,10 +78,17 @@ func (c *Config) getCustomCA() []*Certificate { } func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + var access sync.RWMutex + return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { domain := hello.ServerName certExpired := false - if certificate, found := c.NameToCertificate[domain]; found { + + access.RLock() + certificate, found := c.NameToCertificate[domain] + access.RUnlock() + + if found { if !isCertificateExpired(certificate) { return certificate, nil } @@ -90,6 +98,7 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli if certExpired { newCerts := make([]tls.Certificate, 0, len(c.Certificates)) + access.Lock() for _, certificate := range c.Certificates { if !isCertificateExpired(&certificate) { newCerts = append(newCerts, certificate) @@ -97,6 +106,7 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli } c.Certificates = newCerts + access.Unlock() } var issuedCertificate *tls.Certificate @@ -110,8 +120,10 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli continue } + access.Lock() c.Certificates = append(c.Certificates, *newCert) issuedCertificate = &c.Certificates[len(c.Certificates)-1] + access.Unlock() break } } @@ -120,7 +132,9 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli return nil, newError("failed to create a new certificate for ", domain) } + access.Lock() c.BuildNameToCertificate() + access.Unlock() return issuedCertificate, nil }