pool: remove useTLS and ForceTLS

In the past TLS usage was enforced with these variables, but these days
this decision is made by TLSConfigurator and there is no reason to keep
using the variables.
pull/7966/head
Hans Hasselberg 2020-05-28 10:18:30 +02:00
parent c45432014b
commit ad03f863ff
6 changed files with 45 additions and 81 deletions

View File

@ -137,7 +137,6 @@ func NewClientLogger(config *Config, logger hclog.InterceptLogger, tlsConfigurat
MaxTime: clientRPCConnMaxIdle,
MaxStreams: clientMaxStreams,
TLSConfigurator: tlsConfigurator,
ForceTLS: config.VerifyOutgoing,
Datacenter: config.Datacenter,
}
@ -356,7 +355,7 @@ func (c *Client) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io
// Request the operation.
var reply structs.SnapshotResponse
snap, err := SnapshotRPC(c.connPool, c.config.Datacenter, server.ShortName, server.Addr, server.UseTLS, args, in, &reply)
snap, err := SnapshotRPC(c.connPool, c.config.Datacenter, server.ShortName, server.Addr, args, in, &reply)
if err != nil {
return err
}

View File

@ -374,7 +374,6 @@ func NewServerLogger(config *Config, logger hclog.InterceptLogger, tokens *token
MaxTime: serverRPCCache,
MaxStreams: serverMaxStreams,
TLSConfigurator: tlsConfigurator,
ForceTLS: config.VerifyOutgoing,
Datacenter: config.Datacenter,
}

View File

@ -37,7 +37,7 @@ func (s *Server) dispatchSnapshotRequest(args *structs.SnapshotRequest, in io.Re
return nil, structs.ErrNoDCPath
}
snap, err := SnapshotRPC(s.connPool, dc, server.ShortName, server.Addr, server.UseTLS, args, in, reply)
snap, err := SnapshotRPC(s.connPool, dc, server.ShortName, server.Addr, args, in, reply)
if err != nil {
manager.NotifyFailedServer(server)
return nil, err
@ -52,7 +52,7 @@ func (s *Server) dispatchSnapshotRequest(args *structs.SnapshotRequest, in io.Re
if server == nil {
return nil, structs.ErrNoLeader
}
return SnapshotRPC(s.connPool, args.Datacenter, server.ShortName, server.Addr, server.UseTLS, args, in, reply)
return SnapshotRPC(s.connPool, args.Datacenter, server.ShortName, server.Addr, args, in, reply)
}
}
@ -194,14 +194,13 @@ func SnapshotRPC(
dc string,
nodeName string,
addr net.Addr,
useTLS bool,
args *structs.SnapshotRequest,
in io.Reader,
reply *structs.SnapshotResponse,
) (io.ReadCloser, error) {
// Write the snapshot RPC byte to set the mode, then perform the
// request.
conn, hc, err := connPool.DialTimeout(dc, nodeName, addr, 10*time.Second, useTLS, pool.RPCSnapshot)
conn, hc, err := connPool.DialTimeout(dc, nodeName, addr, 10*time.Second, pool.RPCSnapshot)
if err != nil {
return nil, err
}

View File

@ -46,7 +46,7 @@ func verifySnapshot(t *testing.T, s *Server, dc, token string) {
Op: structs.SnapshotSave,
}
var reply structs.SnapshotResponse
snap, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false,
snap, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr,
&args, bytes.NewReader([]byte("")), &reply)
if err != nil {
t.Fatalf("err: %v", err)
@ -121,7 +121,7 @@ func verifySnapshot(t *testing.T, s *Server, dc, token string) {
// Restore the snapshot.
args.Op = structs.SnapshotRestore
restore, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false,
restore, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr,
&args, snap, &reply)
if err != nil {
t.Fatalf("err: %v", err)
@ -196,7 +196,7 @@ func TestSnapshot_LeaderState(t *testing.T) {
Op: structs.SnapshotSave,
}
var reply structs.SnapshotResponse
snap, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false,
snap, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr,
&args, bytes.NewReader([]byte("")), &reply)
if err != nil {
t.Fatalf("err: %v", err)
@ -229,7 +229,7 @@ func TestSnapshot_LeaderState(t *testing.T) {
// Restore the snapshot.
args.Op = structs.SnapshotRestore
restore, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false,
restore, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr,
&args, snap, &reply)
if err != nil {
t.Fatalf("err: %v", err)
@ -268,7 +268,7 @@ func TestSnapshot_ACLDeny(t *testing.T) {
Op: structs.SnapshotSave,
}
var reply structs.SnapshotResponse
_, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false,
_, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr,
&args, bytes.NewReader([]byte("")), &reply)
if !acl.IsErrPermissionDenied(err) {
t.Fatalf("err: %v", err)
@ -282,7 +282,7 @@ func TestSnapshot_ACLDeny(t *testing.T) {
Op: structs.SnapshotRestore,
}
var reply structs.SnapshotResponse
_, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false,
_, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr,
&args, bytes.NewReader([]byte("")), &reply)
if !acl.IsErrPermissionDenied(err) {
t.Fatalf("err: %v", err)
@ -391,7 +391,7 @@ func TestSnapshot_AllowStale(t *testing.T) {
Op: structs.SnapshotSave,
}
var reply structs.SnapshotResponse
_, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false,
_, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr,
&args, bytes.NewReader([]byte("")), &reply)
if err == nil || !strings.Contains(err.Error(), structs.ErrNoLeader.Error()) {
t.Fatalf("err: %v", err)
@ -408,7 +408,7 @@ func TestSnapshot_AllowStale(t *testing.T) {
Op: structs.SnapshotSave,
}
var reply structs.SnapshotResponse
_, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false,
_, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr,
&args, bytes.NewReader([]byte("")), &reply)
if err == nil || !strings.Contains(err.Error(), "Raft error when taking snapshot") {
t.Fatalf("err: %v", err)

View File

@ -37,20 +37,25 @@ func insecureRPCClient(s *Server, c tlsutil.Config) (rpc.ClientCodec, error) {
if wrapper == nil {
return nil, err
}
conn, _, err := pool.DialTimeoutWithRPCTypeDirectly(
s.config.Datacenter,
s.config.NodeName,
addr,
nil,
time.Second,
true,
wrapper,
pool.RPCTLSInsecure,
pool.RPCTLSInsecure,
)
d := &net.Dialer{Timeout: time.Second}
conn, err := d.Dial("tcp", addr.String())
if err != nil {
return nil, err
}
// Switch the connection into TLS mode
if _, err = conn.Write([]byte{byte(pool.RPCTLSInsecure)}); err != nil {
conn.Close()
return nil, err
}
// Wrap the connection in a TLS client
tlsConn, err := wrapper(s.config.Datacenter, conn)
if err != nil {
conn.Close()
return nil, err
}
conn = tlsConn
return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle), nil
}

View File

@ -146,9 +146,6 @@ type ConnPool struct {
// Datacenter is the datacenter of the current agent.
Datacenter string
// ForceTLS is used to enforce outgoing TLS verification
ForceTLS bool
// Server should be set to true if this connection pool is configured in a
// server instead of a client.
Server bool
@ -208,7 +205,7 @@ func (p *ConnPool) Shutdown() error {
// wait for an existing connection attempt to finish, if one if in progress,
// and will return that one if it succeeds. If all else fails, it will return a
// newly-created connection and add it to the pool.
func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, error) {
func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr) (*Conn, error) {
if nodeName == "" {
return nil, fmt.Errorf("pool: ConnPool.acquire requires a node name")
}
@ -243,7 +240,7 @@ func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr, useTLS boo
// If we are the lead thread, make the new connection and then wake
// everybody else up to see if we got it.
if isLeadThread {
c, err := p.getNewConn(dc, nodeName, addr, useTLS)
c, err := p.getNewConn(dc, nodeName, addr)
p.Lock()
delete(p.limiter, addrStr)
close(wait)
@ -290,7 +287,6 @@ func (p *ConnPool) DialTimeout(
nodeName string,
addr net.Addr,
timeout time.Duration,
useTLS bool,
actualRPCType RPCType,
) (net.Conn, HalfCloser, error) {
p.once.Do(p.init)
@ -314,64 +310,26 @@ func (p *ConnPool) DialTimeout(
)
}
return DialTimeoutWithRPCTypeDirectly(
return p.dial(
dc,
nodeName,
addr,
p.SrcAddr,
timeout,
useTLS || p.ForceTLS,
p.TLSConfigurator.OutgoingRPCWrapper(),
actualRPCType,
RPCTLS,
)
}
// DialTimeoutInsecure is used to establish a raw connection to the given
// server, with given connection timeout. It also writes RPCTLSInsecure as the
// first byte to indicate that the client cannot provide a certificate. This is
// so far only used for AutoEncrypt.Sign.
func (p *ConnPool) DialTimeoutInsecure(
func (p *ConnPool) dial(
dc string,
nodeName string,
addr net.Addr,
timeout time.Duration,
wrapper tlsutil.DCWrapper,
) (net.Conn, HalfCloser, error) {
p.once.Do(p.init)
if wrapper == nil {
return nil, nil, fmt.Errorf("wrapper cannot be nil")
} else if dc != p.Datacenter {
return nil, nil, fmt.Errorf("insecure dialing prohibited between datacenters")
}
return DialTimeoutWithRPCTypeDirectly(
dc,
nodeName,
addr,
p.SrcAddr,
timeout,
true,
wrapper,
RPCTLSInsecure,
RPCTLSInsecure,
)
}
func DialTimeoutWithRPCTypeDirectly(
dc string,
nodeName string,
addr net.Addr,
src *net.TCPAddr,
timeout time.Duration,
useTLS bool,
wrapper tlsutil.DCWrapper,
actualRPCType RPCType,
tlsRPCType RPCType,
) (net.Conn, HalfCloser, error) {
// Try to dial the conn
d := &net.Dialer{LocalAddr: src, Timeout: timeout}
d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: timeout}
conn, err := d.Dial("tcp", addr.String())
if err != nil {
return nil, nil, err
@ -388,7 +346,8 @@ func DialTimeoutWithRPCTypeDirectly(
}
// Check if TLS is enabled
if useTLS && wrapper != nil {
if p.TLSConfigurator.UseTLS(dc) {
wrapper := p.TLSConfigurator.OutgoingRPCWrapper()
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(tlsRPCType)}); err != nil {
conn.Close()
@ -496,13 +455,13 @@ func DialTimeoutWithRPCTypeViaMeshGateway(
}
// getNewConn is used to return a new connection
func (p *ConnPool) getNewConn(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, error) {
func (p *ConnPool) getNewConn(dc string, nodeName string, addr net.Addr) (*Conn, error) {
if nodeName == "" {
return nil, fmt.Errorf("pool: ConnPool.getNewConn requires a node name")
}
// Get a new, raw connection and write the Consul multiplex byte to set the mode
conn, _, err := p.DialTimeout(dc, nodeName, addr, defaultDialTimeout, useTLS, RPCMultiplexV2)
conn, _, err := p.DialTimeout(dc, nodeName, addr, defaultDialTimeout, RPCMultiplexV2)
if err != nil {
return nil, err
}
@ -560,11 +519,11 @@ func (p *ConnPool) releaseConn(conn *Conn) {
}
// getClient is used to get a usable client for an address
func (p *ConnPool) getClient(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, *StreamClient, error) {
func (p *ConnPool) getClient(dc string, nodeName string, addr net.Addr) (*Conn, *StreamClient, error) {
retries := 0
START:
// Try to get a conn first
conn, err := p.acquire(dc, nodeName, addr, useTLS)
conn, err := p.acquire(dc, nodeName, addr)
if err != nil {
return nil, nil, fmt.Errorf("failed to get conn: %v", err)
}
@ -611,8 +570,12 @@ func (p *ConnPool) RPC(
// AutoEncrypt.Sign is a one-off call and it doesn't make sense to pool that
// connection if it is not being reused.
func (p *ConnPool) rpcInsecure(dc string, nodeName string, addr net.Addr, method string, args interface{}, reply interface{}) error {
if dc != p.Datacenter {
return fmt.Errorf("insecure dialing prohibited between datacenters")
}
var codec rpc.ClientCodec
conn, _, err := p.DialTimeoutInsecure(dc, nodeName, addr, 1*time.Second, p.TLSConfigurator.OutgoingRPCWrapper())
conn, _, err := p.dial(dc, nodeName, addr, 1*time.Second, 0, RPCTLSInsecure)
if err != nil {
return fmt.Errorf("rpcinsecure error establishing connection: %v", err)
}
@ -631,8 +594,7 @@ func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, method string,
p.once.Do(p.init)
// Get a usable client
useTLS := p.TLSConfigurator.UseTLS(dc)
conn, sc, err := p.getClient(dc, nodeName, addr, useTLS)
conn, sc, err := p.getClient(dc, nodeName, addr)
if err != nil {
return fmt.Errorf("rpc error getting client: %v", err)
}