mirror of https://github.com/v2ray/v2ray-core
fix concurrent access to tls config
parent
5e25741742
commit
9a9b6f9077
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"v2ray.com/core/common/net"
|
"v2ray.com/core/common/net"
|
||||||
|
@ -77,10 +78,17 @@ func (c *Config) getCustomCA() []*Certificate {
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
var access sync.RWMutex
|
||||||
|
|
||||||
return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
domain := hello.ServerName
|
domain := hello.ServerName
|
||||||
certExpired := false
|
certExpired := false
|
||||||
if certificate, found := c.NameToCertificate[domain]; found {
|
|
||||||
|
access.RLock()
|
||||||
|
certificate, found := c.NameToCertificate[domain]
|
||||||
|
access.RUnlock()
|
||||||
|
|
||||||
|
if found {
|
||||||
if !isCertificateExpired(certificate) {
|
if !isCertificateExpired(certificate) {
|
||||||
return certificate, nil
|
return certificate, nil
|
||||||
}
|
}
|
||||||
|
@ -90,6 +98,7 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
|
||||||
if certExpired {
|
if certExpired {
|
||||||
newCerts := make([]tls.Certificate, 0, len(c.Certificates))
|
newCerts := make([]tls.Certificate, 0, len(c.Certificates))
|
||||||
|
|
||||||
|
access.Lock()
|
||||||
for _, certificate := range c.Certificates {
|
for _, certificate := range c.Certificates {
|
||||||
if !isCertificateExpired(&certificate) {
|
if !isCertificateExpired(&certificate) {
|
||||||
newCerts = append(newCerts, certificate)
|
newCerts = append(newCerts, certificate)
|
||||||
|
@ -97,6 +106,7 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Certificates = newCerts
|
c.Certificates = newCerts
|
||||||
|
access.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
var issuedCertificate *tls.Certificate
|
var issuedCertificate *tls.Certificate
|
||||||
|
@ -110,8 +120,10 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
access.Lock()
|
||||||
c.Certificates = append(c.Certificates, *newCert)
|
c.Certificates = append(c.Certificates, *newCert)
|
||||||
issuedCertificate = &c.Certificates[len(c.Certificates)-1]
|
issuedCertificate = &c.Certificates[len(c.Certificates)-1]
|
||||||
|
access.Unlock()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -120,7 +132,9 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
|
||||||
return nil, newError("failed to create a new certificate for ", domain)
|
return nil, newError("failed to create a new certificate for ", domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
access.Lock()
|
||||||
c.BuildNameToCertificate()
|
c.BuildNameToCertificate()
|
||||||
|
access.Unlock()
|
||||||
|
|
||||||
return issuedCertificate, nil
|
return issuedCertificate, nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue