diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index 1c63e5cc..197586f1 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -135,10 +135,10 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil { switch securitySettings := securitySettings.(type) { case *v2tls.Config: - config := securitySettings.GetTLSConfig() if dest.Address.Family().IsDomain() { - config.ServerName = dest.Address.Domain() + securitySettings.OverrideServerNameIfEmpty(dest.Address.Domain()) } + config := securitySettings.GetTLSConfig() tlsConn := tls.Client(iConn, config) iConn = tlsConn } diff --git a/transport/internet/tcp/dialer.go b/transport/internet/tcp/dialer.go index 17beeb90..37921d9d 100644 --- a/transport/internet/tcp/dialer.go +++ b/transport/internet/tcp/dialer.go @@ -29,10 +29,10 @@ func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil { tlsConfig, ok := securitySettings.(*tls.Config) if ok { - config := tlsConfig.GetTLSConfig() if dest.Address.Family().IsDomain() { - config.ServerName = dest.Address.Domain() + tlsConfig.OverrideServerNameIfEmpty(dest.Address.Domain()) } + config := tlsConfig.GetTLSConfig() conn = tls.Client(conn, config) } } diff --git a/transport/internet/tls/config.go b/transport/internet/tls/config.go index 468e6273..919a91a3 100644 --- a/transport/internet/tls/config.go +++ b/transport/internet/tls/config.go @@ -10,9 +10,9 @@ var ( globalSessionCache = tls.NewLRUClientSessionCache(128) ) -func (v *Config) BuildCertificates() []tls.Certificate { - certs := make([]tls.Certificate, 0, len(v.Certificate)) - for _, entry := range v.Certificate { +func (c *Config) BuildCertificates() []tls.Certificate { + certs := make([]tls.Certificate, 0, len(c.Certificate)) + for _, entry := range c.Certificate { keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key) if err != nil { log.Trace(newError("ignoring invalid X509 key pair").Base(err).AtWarning()) @@ -23,21 +23,27 @@ func (v *Config) BuildCertificates() []tls.Certificate { return certs } -func (v *Config) GetTLSConfig() *tls.Config { +func (c *Config) GetTLSConfig() *tls.Config { config := &tls.Config{ ClientSessionCache: globalSessionCache, NextProtos: []string{"http/1.1"}, } - if v == nil { + if c == nil { return config } - config.InsecureSkipVerify = v.AllowInsecure - config.Certificates = v.BuildCertificates() + config.InsecureSkipVerify = c.AllowInsecure + config.Certificates = c.BuildCertificates() config.BuildNameToCertificate() - if len(v.ServerName) > 0 { - config.ServerName = v.ServerName + if len(c.ServerName) > 0 { + config.ServerName = c.ServerName } return config } + +func (c *Config) OverrideServerNameIfEmpty(serverName string) { + if len(c.ServerName) == 0 { + c.ServerName = serverName + } +} diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index c42c4c9a..3386af82 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -46,10 +46,10 @@ func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error) tlsConfig, ok := securitySettings.(*tls.Config) if ok { protocol = "wss" - dialer.TLSClientConfig = tlsConfig.GetTLSConfig() if dest.Address.Family().IsDomain() { - dialer.TLSClientConfig.ServerName = dest.Address.Domain() + tlsConfig.OverrideServerNameIfEmpty(dest.Address.Domain()) } + dialer.TLSClientConfig = tlsConfig.GetTLSConfig() } }