simplify tls config

pull/786/head
Darien Raymond 2017-12-17 00:53:17 +01:00
parent 9561301fea
commit 048ffbc7dc
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
7 changed files with 44 additions and 48 deletions

View File

@ -77,16 +77,9 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er
var iConn internet.Connection = session
if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
switch securitySettings := securitySettings.(type) {
case *v2tls.Config:
if dest.Address.Family().IsDomain() {
securitySettings.OverrideServerNameIfEmpty(dest.Address.Domain())
}
config := securitySettings.GetTLSConfig()
tlsConn := tls.Client(iConn, config)
iConn = tlsConn
}
if config := v2tls.ConfigFromContext(ctx, v2tls.WithDestination(dest)); config != nil {
tlsConn := tls.Client(iConn, config.GetTLSConfig())
iConn = tlsConn
}
return iConn, nil

View File

@ -59,13 +59,11 @@ func NewListener(ctx context.Context, address net.Address, port net.Port, addCon
config: kcpSettings,
addConn: addConn,
}
securitySettings := internet.SecuritySettingsFromContext(ctx)
if securitySettings != nil {
switch securitySettings := securitySettings.(type) {
case *v2tls.Config:
l.tlsConfig = securitySettings.GetTLSConfig()
}
if config := v2tls.ConfigFromContext(ctx); config != nil {
l.tlsConfig = config.GetTLSConfig()
}
hub, err := udp.ListenUDP(address, port, udp.ListenOption{Callback: l.OnReceive, Concurrency: 2})
if err != nil {
return nil, err

View File

@ -19,22 +19,16 @@ func getTCPSettingsFromContext(ctx context.Context) *Config {
}
func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) {
log.Trace(newError("dailing TCP to ", dest))
log.Trace(newError("dialing TCP to ", dest))
src := internet.DialerSourceFromContext(ctx)
conn, err := internet.DialSystem(ctx, src, dest)
if err != nil {
return nil, err
}
if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
tlsConfig, ok := securitySettings.(*tls.Config)
if ok {
if dest.Address.Family().IsDomain() {
tlsConfig.OverrideServerNameIfEmpty(dest.Address.Domain())
}
config := tlsConfig.GetTLSConfig()
conn = tls.Client(conn, config)
}
if config := tls.ConfigFromContext(ctx, tls.WithDestination(dest)); config != nil {
conn = tls.Client(conn, config.GetTLSConfig())
}
tcpSettings := getTCPSettingsFromContext(ctx)

View File

@ -37,12 +37,11 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, addConn
config: tcpSettings,
addConn: addConn,
}
if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
tlsConfig, ok := securitySettings.(*tls.Config)
if ok {
l.tlsConfig = tlsConfig.GetTLSConfig()
}
if config := tls.ConfigFromContext(ctx); config != nil {
l.tlsConfig = config.GetTLSConfig()
}
if tcpSettings.HeaderSettings != nil {
headerConfig, err := tcpSettings.HeaderSettings.GetInstance()
if err != nil {

View File

@ -1,9 +1,12 @@
package tls
import (
"context"
"crypto/tls"
"v2ray.com/core/app/log"
"v2ray.com/core/common/net"
"v2ray.com/core/transport/internet"
)
var (
@ -42,8 +45,26 @@ func (c *Config) GetTLSConfig() *tls.Config {
return config
}
func (c *Config) OverrideServerNameIfEmpty(serverName string) {
if len(c.ServerName) == 0 {
c.ServerName = serverName
type Option func(*Config)
func WithDestination(dest net.Destination) Option {
return func(config *Config) {
if dest.Address.Family().IsDomain() && len(config.ServerName) == 0 {
config.ServerName = dest.Address.Domain()
}
}
}
func ConfigFromContext(ctx context.Context, opts ...Option) *Config {
securitySettings := internet.SecuritySettingsFromContext(ctx)
if securitySettings == nil {
return nil
}
if config, ok := securitySettings.(*Config); ok {
for _, opt := range opts {
opt(config)
}
return config
}
return nil
}

View File

@ -42,15 +42,9 @@ func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error)
protocol := "ws"
if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
tlsConfig, ok := securitySettings.(*tls.Config)
if ok {
protocol = "wss"
if dest.Address.Family().IsDomain() {
tlsConfig.OverrideServerNameIfEmpty(dest.Address.Domain())
}
dialer.TLSClientConfig = tlsConfig.GetTLSConfig()
}
if config := tls.ConfigFromContext(ctx, tls.WithDestination(dest)); config != nil {
protocol = "wss"
dialer.TLSClientConfig = config.GetTLSConfig()
}
host := dest.NetAddr()

View File

@ -59,11 +59,8 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, addConn i
config: wsSettings,
addConn: addConn,
}
if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
tlsConfig, ok := securitySettings.(*v2tls.Config)
if ok {
l.tlsConfig = tlsConfig.GetTLSConfig()
}
if config := v2tls.ConfigFromContext(ctx); config != nil {
l.tlsConfig = config.GetTLSConfig()
}
err := l.listenws(address, port)