update tls config generation

pull/931/head
Darien Raymond 2018-02-28 15:15:22 +01:00
parent b7d48fe7c5
commit bdab1af29a
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
5 changed files with 23 additions and 21 deletions

View File

@ -86,8 +86,8 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er
var iConn internet.Connection = session var iConn internet.Connection = session
if config := v2tls.ConfigFromContext(ctx, v2tls.WithDestination(dest)); config != nil { if config := v2tls.ConfigFromContext(ctx); config != nil {
tlsConn := tls.Client(iConn, config.GetTLSConfig()) tlsConn := tls.Client(iConn, config.GetTLSConfig(v2tls.WithDestination(dest)))
iConn = tlsConn iConn = tlsConn
} }

View File

@ -27,8 +27,8 @@ func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error
return nil, err return nil, err
} }
if config := tls.ConfigFromContext(ctx, tls.WithDestination(dest), tls.WithNextProto("h2")); config != nil { if config := tls.ConfigFromContext(ctx); config != nil {
conn = tls.Client(conn, config.GetTLSConfig()) conn = tls.Client(conn, config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("h2")))
} }
tcpSettings := getTCPSettingsFromContext(ctx) tcpSettings := getTCPSettingsFromContext(ctx)

View File

@ -39,8 +39,8 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, handler
addConn: handler, addConn: handler,
} }
if config := tls.ConfigFromContext(ctx, tls.WithNextProto("h2")); config != nil { if config := tls.ConfigFromContext(ctx); config != nil {
l.tlsConfig = config.GetTLSConfig() l.tlsConfig = config.GetTLSConfig(tls.WithNextProto("h2"))
} }
if tcpSettings.HeaderSettings != nil { if tcpSettings.HeaderSettings != nil {

View File

@ -25,7 +25,7 @@ func (c *Config) BuildCertificates() []tls.Certificate {
return certs return certs
} }
func (c *Config) GetTLSConfig() *tls.Config { func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
config := &tls.Config{ config := &tls.Config{
ClientSessionCache: globalSessionCache, ClientSessionCache: globalSessionCache,
NextProtos: []string{"http/1.1"}, NextProtos: []string{"http/1.1"},
@ -34,6 +34,10 @@ func (c *Config) GetTLSConfig() *tls.Config {
return config return config
} }
for _, opt := range opts {
opt(config)
}
config.InsecureSkipVerify = c.AllowInsecure config.InsecureSkipVerify = c.AllowInsecure
config.Certificates = c.BuildCertificates() config.Certificates = c.BuildCertificates()
config.BuildNameToCertificate() config.BuildNameToCertificate()
@ -47,10 +51,10 @@ func (c *Config) GetTLSConfig() *tls.Config {
return config return config
} }
type Option func(*Config) type Option func(*tls.Config)
func WithDestination(dest net.Destination) Option { func WithDestination(dest net.Destination) Option {
return func(config *Config) { return func(config *tls.Config) {
if dest.Address.Family().IsDomain() && len(config.ServerName) == 0 { if dest.Address.Family().IsDomain() && len(config.ServerName) == 0 {
config.ServerName = dest.Address.Domain() config.ServerName = dest.Address.Domain()
} }
@ -58,23 +62,21 @@ func WithDestination(dest net.Destination) Option {
} }
func WithNextProto(protocol ...string) Option { func WithNextProto(protocol ...string) Option {
return func(config *Config) { return func(config *tls.Config) {
if len(config.NextProtocol) == 0 { if len(config.NextProtos) == 0 {
config.NextProtocol = protocol config.NextProtos = protocol
} }
} }
} }
func ConfigFromContext(ctx context.Context, opts ...Option) *Config { func ConfigFromContext(ctx context.Context) *Config {
securitySettings := internet.SecuritySettingsFromContext(ctx) securitySettings := internet.SecuritySettingsFromContext(ctx)
if securitySettings == nil { if securitySettings == nil {
return nil return nil
} }
if config, ok := securitySettings.(*Config); ok { config, ok := securitySettings.(*Config)
for _, opt := range opts { if !ok {
opt(config) return nil
}
return config
} }
return nil return config
} }

View File

@ -41,9 +41,9 @@ func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error)
protocol := "ws" protocol := "ws"
if config := tls.ConfigFromContext(ctx, tls.WithDestination(dest)); config != nil { if config := tls.ConfigFromContext(ctx); config != nil {
protocol = "wss" protocol = "wss"
dialer.TLSClientConfig = config.GetTLSConfig() dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest))
} }
host := dest.NetAddr() host := dest.NetAddr()