Browse Source

fix concurrent access to tls config

pull/1524/head^2
Darien Raymond 6 years ago
parent
commit
9a9b6f9077
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
  1. 16
      transport/internet/tls/config.go

16
transport/internet/tls/config.go

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"sync"
"time" "time"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
@ -77,10 +78,17 @@ func (c *Config) getCustomCA() []*Certificate {
} }
func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
var access sync.RWMutex
return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
domain := hello.ServerName domain := hello.ServerName
certExpired := false certExpired := false
if certificate, found := c.NameToCertificate[domain]; found {
access.RLock()
certificate, found := c.NameToCertificate[domain]
access.RUnlock()
if found {
if !isCertificateExpired(certificate) { if !isCertificateExpired(certificate) {
return certificate, nil return certificate, nil
} }
@ -90,6 +98,7 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
if certExpired { if certExpired {
newCerts := make([]tls.Certificate, 0, len(c.Certificates)) newCerts := make([]tls.Certificate, 0, len(c.Certificates))
access.Lock()
for _, certificate := range c.Certificates { for _, certificate := range c.Certificates {
if !isCertificateExpired(&certificate) { if !isCertificateExpired(&certificate) {
newCerts = append(newCerts, certificate) newCerts = append(newCerts, certificate)
@ -97,6 +106,7 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
} }
c.Certificates = newCerts c.Certificates = newCerts
access.Unlock()
} }
var issuedCertificate *tls.Certificate var issuedCertificate *tls.Certificate
@ -110,8 +120,10 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
continue continue
} }
access.Lock()
c.Certificates = append(c.Certificates, *newCert) c.Certificates = append(c.Certificates, *newCert)
issuedCertificate = &c.Certificates[len(c.Certificates)-1] issuedCertificate = &c.Certificates[len(c.Certificates)-1]
access.Unlock()
break break
} }
} }
@ -120,7 +132,9 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
return nil, newError("failed to create a new certificate for ", domain) return nil, newError("failed to create a new certificate for ", domain)
} }
access.Lock()
c.BuildNameToCertificate() c.BuildNameToCertificate()
access.Unlock()
return issuedCertificate, nil return issuedCertificate, nil
} }

Loading…
Cancel
Save