|
|
@ -4,6 +4,7 @@ import ( |
|
|
|
"context" |
|
|
|
"context" |
|
|
|
"crypto/tls" |
|
|
|
"crypto/tls" |
|
|
|
"crypto/x509" |
|
|
|
"crypto/x509" |
|
|
|
|
|
|
|
"sync" |
|
|
|
"time" |
|
|
|
"time" |
|
|
|
|
|
|
|
|
|
|
|
"v2ray.com/core/common/net" |
|
|
|
"v2ray.com/core/common/net" |
|
|
@ -77,10 +78,17 @@ func (c *Config) getCustomCA() []*Certificate { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { |
|
|
|
func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { |
|
|
|
|
|
|
|
var access sync.RWMutex |
|
|
|
|
|
|
|
|
|
|
|
return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { |
|
|
|
return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { |
|
|
|
domain := hello.ServerName |
|
|
|
domain := hello.ServerName |
|
|
|
certExpired := false |
|
|
|
certExpired := false |
|
|
|
if certificate, found := c.NameToCertificate[domain]; found { |
|
|
|
|
|
|
|
|
|
|
|
access.RLock() |
|
|
|
|
|
|
|
certificate, found := c.NameToCertificate[domain] |
|
|
|
|
|
|
|
access.RUnlock() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if found { |
|
|
|
if !isCertificateExpired(certificate) { |
|
|
|
if !isCertificateExpired(certificate) { |
|
|
|
return certificate, nil |
|
|
|
return certificate, nil |
|
|
|
} |
|
|
|
} |
|
|
@ -90,6 +98,7 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli |
|
|
|
if certExpired { |
|
|
|
if certExpired { |
|
|
|
newCerts := make([]tls.Certificate, 0, len(c.Certificates)) |
|
|
|
newCerts := make([]tls.Certificate, 0, len(c.Certificates)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
access.Lock() |
|
|
|
for _, certificate := range c.Certificates { |
|
|
|
for _, certificate := range c.Certificates { |
|
|
|
if !isCertificateExpired(&certificate) { |
|
|
|
if !isCertificateExpired(&certificate) { |
|
|
|
newCerts = append(newCerts, certificate) |
|
|
|
newCerts = append(newCerts, certificate) |
|
|
@ -97,6 +106,7 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
c.Certificates = newCerts |
|
|
|
c.Certificates = newCerts |
|
|
|
|
|
|
|
access.Unlock() |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
var issuedCertificate *tls.Certificate |
|
|
|
var issuedCertificate *tls.Certificate |
|
|
@ -110,8 +120,10 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli |
|
|
|
continue |
|
|
|
continue |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
access.Lock() |
|
|
|
c.Certificates = append(c.Certificates, *newCert) |
|
|
|
c.Certificates = append(c.Certificates, *newCert) |
|
|
|
issuedCertificate = &c.Certificates[len(c.Certificates)-1] |
|
|
|
issuedCertificate = &c.Certificates[len(c.Certificates)-1] |
|
|
|
|
|
|
|
access.Unlock() |
|
|
|
break |
|
|
|
break |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
@ -120,7 +132,9 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli |
|
|
|
return nil, newError("failed to create a new certificate for ", domain) |
|
|
|
return nil, newError("failed to create a new certificate for ", domain) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
access.Lock() |
|
|
|
c.BuildNameToCertificate() |
|
|
|
c.BuildNameToCertificate() |
|
|
|
|
|
|
|
access.Unlock() |
|
|
|
|
|
|
|
|
|
|
|
return issuedCertificate, nil |
|
|
|
return issuedCertificate, nil |
|
|
|
} |
|
|
|
} |
|
|
|