rpc: oss changes for network area connection pooling (#7735)

pull/7753/head
Hans Hasselberg 2020-04-30 22:12:17 +02:00 committed by GitHub
parent 27eb12ec51
commit 51549bd232
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 209 additions and 86 deletions

View File

@ -109,7 +109,7 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token strin
for _, ip := range ips { for _, ip := range ips {
addr := net.TCPAddr{IP: ip, Port: port} addr := net.TCPAddr{IP: ip, Port: port}
if err = c.connPool.RPC(c.config.Datacenter, c.config.NodeName, &addr, 0, "AutoEncrypt.Sign", true, &args, &reply); err == nil { if err = c.connPool.RPC(c.config.Datacenter, c.config.NodeName, &addr, 0, "AutoEncrypt.Sign", &args, &reply); err == nil {
return &reply, pkPEM, nil return &reply, pkPEM, nil
} else { } else {
c.logger.Warn("AutoEncrypt failed", "error", err) c.logger.Warn("AutoEncrypt failed", "error", err)

View File

@ -186,7 +186,7 @@ func NewClientLogger(config *Config, logger hclog.InterceptLogger, tlsConfigurat
} }
// Start maintenance task for servers // Start maintenance task for servers
c.routers = router.New(c.logger, c.shutdownCh, c.serf, c.connPool) c.routers = router.New(c.logger, c.shutdownCh, c.serf, c.connPool, "")
go c.routers.Start() go c.routers.Start()
// Start LAN event handlers after the router is complete since the event // Start LAN event handlers after the router is complete since the event
@ -308,7 +308,7 @@ TRY:
} }
// Make the request. // Make the request.
rpcErr := c.connPool.RPC(c.config.Datacenter, server.ShortName, server.Addr, server.Version, method, server.UseTLS, args, reply) rpcErr := c.connPool.RPC(c.config.Datacenter, server.ShortName, server.Addr, server.Version, method, args, reply)
if rpcErr == nil { if rpcErr == nil {
return nil return nil
} }

View File

@ -418,7 +418,7 @@ func TestClient_RPC_ConsulServerPing(t *testing.T) {
for range servers { for range servers {
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
s := c.routers.FindServer() s := c.routers.FindServer()
ok, err := c.connPool.Ping(s.Datacenter, s.ShortName, s.Addr, s.Version, s.UseTLS) ok, err := c.connPool.Ping(s.Datacenter, s.ShortName, s.Addr, s.Version)
if !ok { if !ok {
t.Errorf("Unable to ping server %v: %s", s.String(), err) t.Errorf("Unable to ping server %v: %s", s.String(), err)
} }

View File

@ -307,7 +307,42 @@ func (s *Server) handleMultiplexV2(conn net.Conn) {
} }
return return
} }
go s.handleConsulConn(sub)
// In the beginning only RPC was supposed to be multiplexed
// with yamux. In order to add the ability to multiplex network
// area connections, this workaround was added.
// This code peeks the first byte and checks if it is
// RPCGossip, in which case this is handled by enterprise code.
// Otherwise this connection is handled like before by the RPC
// handler.
// This wouldn't work if a normal RPC could start with
// RPCGossip(6). In messagepack a 6 encodes a positive fixint:
// https://github.com/msgpack/msgpack/blob/master/spec.md.
// None of the RPCs we are doing starts with that, usually it is
// a string for datacenter.
peeked, first, err := pool.PeekFirstByte(sub)
if err != nil {
s.rpcLogger().Error("Problem peeking connection", "conn", logConn(sub), "err", err)
sub.Close()
return
}
sub = peeked
switch first {
case pool.RPCGossip:
buf := make([]byte, 1)
sub.Read(buf)
go func() {
if !s.handleEnterpriseRPCConn(pool.RPCGossip, sub, false) {
s.rpcLogger().Error("unrecognized RPC byte",
"byte", pool.RPCGossip,
"conn", logConn(conn),
)
sub.Close()
}
}()
default:
go s.handleConsulConn(sub)
}
} }
} }
@ -517,7 +552,7 @@ CHECK_LEADER:
rpcErr := structs.ErrNoLeader rpcErr := structs.ErrNoLeader
if leader != nil { if leader != nil {
rpcErr = s.connPool.RPC(s.config.Datacenter, leader.ShortName, leader.Addr, rpcErr = s.connPool.RPC(s.config.Datacenter, leader.ShortName, leader.Addr,
leader.Version, method, leader.UseTLS, args, reply) leader.Version, method, args, reply)
if rpcErr != nil && canRetry(info, rpcErr) { if rpcErr != nil && canRetry(info, rpcErr) {
goto RETRY goto RETRY
} }
@ -582,7 +617,7 @@ func (s *Server) forwardDC(method, dc string, args interface{}, reply interface{
metrics.IncrCounterWithLabels([]string{"rpc", "cross-dc"}, 1, metrics.IncrCounterWithLabels([]string{"rpc", "cross-dc"}, 1,
[]metrics.Label{{Name: "datacenter", Value: dc}}) []metrics.Label{{Name: "datacenter", Value: dc}})
if err := s.connPool.RPC(dc, server.ShortName, server.Addr, server.Version, method, server.UseTLS, args, reply); err != nil { if err := s.connPool.RPC(dc, server.ShortName, server.Addr, server.Version, method, args, reply); err != nil {
manager.NotifyFailedServer(server) manager.NotifyFailedServer(server)
s.rpcLogger().Error("RPC failed to server in DC", s.rpcLogger().Error("RPC failed to server in DC",
"server", server.Addr, "server", server.Addr,

View File

@ -391,7 +391,7 @@ func NewServerLogger(config *Config, logger hclog.InterceptLogger, tokens *token
loggers: loggers, loggers: loggers,
leaveCh: make(chan struct{}), leaveCh: make(chan struct{}),
reconcileCh: make(chan serf.Member, reconcileChSize), reconcileCh: make(chan serf.Member, reconcileChSize),
router: router.NewRouter(serverLogger, config.Datacenter), router: router.NewRouter(serverLogger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter)),
rpcServer: rpc.NewServer(), rpcServer: rpc.NewServer(),
insecureRPCServer: rpc.NewServer(), insecureRPCServer: rpc.NewServer(),
tlsConfigurator: tlsConfigurator, tlsConfigurator: tlsConfigurator,
@ -551,7 +551,7 @@ func NewServerLogger(config *Config, logger hclog.InterceptLogger, tokens *token
// Add a "static route" to the WAN Serf and hook it up to Serf events. // Add a "static route" to the WAN Serf and hook it up to Serf events.
if s.serfWAN != nil { if s.serfWAN != nil {
if err := s.router.AddArea(types.AreaWAN, s.serfWAN, s.connPool, s.config.VerifyOutgoing); err != nil { if err := s.router.AddArea(types.AreaWAN, s.serfWAN, s.connPool); err != nil {
s.Shutdown() s.Shutdown()
return nil, fmt.Errorf("Failed to add WAN serf route: %v", err) return nil, fmt.Errorf("Failed to add WAN serf route: %v", err)
} }
@ -839,23 +839,16 @@ func (s *Server) setupRPC() error {
return fmt.Errorf("RPC advertise address is not advertisable: %v", s.config.RPCAdvertise) return fmt.Errorf("RPC advertise address is not advertisable: %v", s.config.RPCAdvertise)
} }
// TODO (hans) switch NewRaftLayer to tlsConfigurator
// Provide a DC specific wrapper. Raft replication is only // Provide a DC specific wrapper. Raft replication is only
// ever done in the same datacenter, so we can provide it as a constant. // ever done in the same datacenter, so we can provide it as a constant.
wrapper := tlsutil.SpecificDC(s.config.Datacenter, s.tlsConfigurator.OutgoingRPCWrapper()) wrapper := tlsutil.SpecificDC(s.config.Datacenter, s.tlsConfigurator.OutgoingRPCWrapper())
// Define a callback for determining whether to wrap a connection with TLS // Define a callback for determining whether to wrap a connection with TLS
tlsFunc := func(address raft.ServerAddress) bool { tlsFunc := func(address raft.ServerAddress) bool {
if s.config.VerifyOutgoing { // raft only talks to its own datacenter
return true return s.tlsConfigurator.UseTLS(s.config.Datacenter)
}
server := s.serverLookup.Server(address)
if server == nil {
return false
}
return server.UseTLS
} }
s.raftLayer = NewRaftLayer(s.config.RPCSrcAddr, s.config.RPCAdvertise, wrapper, tlsFunc) s.raftLayer = NewRaftLayer(s.config.RPCSrcAddr, s.config.RPCAdvertise, wrapper, tlsFunc)
return nil return nil
@ -1361,6 +1354,7 @@ func (s *Server) ReloadConfig(config *Config) error {
// this will error if we lose leadership while bootstrapping here. // this will error if we lose leadership while bootstrapping here.
return s.bootstrapConfigEntries(config.ConfigEntryBootstrap) return s.bootstrapConfigEntries(config.ConfigEntryBootstrap)
} }
return nil return nil
} }

View File

@ -364,7 +364,7 @@ func (s *Server) maybeBootstrap() {
// Retry with exponential backoff to get peer status from this server // Retry with exponential backoff to get peer status from this server
for attempt := uint(0); attempt < maxPeerRetries; attempt++ { for attempt := uint(0); attempt < maxPeerRetries; attempt++ {
if err := s.connPool.RPC(s.config.Datacenter, server.ShortName, server.Addr, server.Version, if err := s.connPool.RPC(s.config.Datacenter, server.ShortName, server.Addr, server.Version,
"Status.Peers", server.UseTLS, &structs.DCSpecificRequest{Datacenter: s.config.Datacenter}, &peers); err != nil { "Status.Peers", &structs.DCSpecificRequest{Datacenter: s.config.Datacenter}, &peers); err != nil {
nextRetry := time.Duration((1 << attempt) * peerRetryBase) nextRetry := time.Duration((1 << attempt) * peerRetryBase)
s.logger.Error("Failed to confirm peer status for server (will retry).", s.logger.Error("Failed to confirm peer status for server (will retry).",
"server", server.Name, "server", server.Name,

View File

@ -1213,7 +1213,7 @@ func testVerifyRPC(s1, s2 *Server, t *testing.T) (bool, error) {
if leader == nil { if leader == nil {
t.Fatal("no leader") t.Fatal("no leader")
} }
return s2.connPool.Ping(leader.Datacenter, leader.ShortName, leader.Addr, leader.Version, leader.UseTLS) return s2.connPool.Ping(leader.Datacenter, leader.ShortName, leader.Addr, leader.Version)
} }
func TestServer_TLSToNoTLS(t *testing.T) { func TestServer_TLSToNoTLS(t *testing.T) {
@ -1277,7 +1277,6 @@ func TestServer_TLSToFullVerify(t *testing.T) {
c.CAFile = "../../test/client_certs/rootca.crt" c.CAFile = "../../test/client_certs/rootca.crt"
c.CertFile = "../../test/client_certs/server.crt" c.CertFile = "../../test/client_certs/server.crt"
c.KeyFile = "../../test/client_certs/server.key" c.KeyFile = "../../test/client_certs/server.key"
c.VerifyIncoming = true
c.VerifyOutgoing = true c.VerifyOutgoing = true
}) })
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)

View File

@ -43,7 +43,7 @@ func NewStatsFetcher(logger hclog.Logger, pool *pool.ConnPool, datacenter string
func (f *StatsFetcher) fetch(server *metadata.Server, replyCh chan *autopilot.ServerStats) { func (f *StatsFetcher) fetch(server *metadata.Server, replyCh chan *autopilot.ServerStats) {
var args struct{} var args struct{}
var reply autopilot.ServerStats var reply autopilot.ServerStats
err := f.pool.RPC(f.datacenter, server.ShortName, server.Addr, server.Version, "Status.RaftStats", server.UseTLS, &args, &reply) err := f.pool.RPC(f.datacenter, server.ShortName, server.Addr, server.Version, "Status.RaftStats", &args, &reply)
if err != nil { if err != nil {
f.logger.Warn("error getting server health from server", f.logger.Warn("error getting server health from server",
"server", server.Name, "server", server.Name,

View File

@ -2,6 +2,7 @@ package pool
import ( import (
"bufio" "bufio"
"fmt"
"net" "net"
) )
@ -47,3 +48,32 @@ func PeekForTLS(conn net.Conn) (net.Conn, bool, error) {
Conn: conn, Conn: conn,
}, isTLS, nil }, isTLS, nil
} }
// PeekFirstByte will read the first byte on the conn.
//
// This function does not close the conn on an error.
//
// The returned conn has the initial read buffered internally for the purposes
// of not consuming the first byte. After that buffer is drained the conn is a
// pass through to the original conn.
func PeekFirstByte(conn net.Conn) (net.Conn, byte, error) {
br := bufio.NewReader(conn)
// Grab enough to read the first byte. Then drain the buffer so future
// reads can be direct.
peeked, err := br.Peek(1)
if err != nil {
return nil, 0, err
} else if len(peeked) == 0 {
return conn, 0, fmt.Errorf("nothing to read")
}
peeked, err = br.Peek(br.Buffered())
if err != nil {
return nil, 0, err
}
return &peekedConn{
Peeked: peeked,
Conn: conn,
}, peeked[0], nil
}

View File

@ -389,7 +389,7 @@ func DialTimeoutWithRPCTypeDirectly(
} }
// Check if TLS is enabled // Check if TLS is enabled
if (useTLS) && wrapper != nil { if useTLS && wrapper != nil {
// Switch the connection into TLS mode // Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(tlsRPCType)}); err != nil { if _, err := conn.Write([]byte{byte(tlsRPCType)}); err != nil {
conn.Close() conn.Close()
@ -600,7 +600,6 @@ func (p *ConnPool) RPC(
addr net.Addr, addr net.Addr,
version int, version int,
method string, method string,
useTLS bool,
args interface{}, args interface{},
reply interface{}, reply interface{},
) error { ) error {
@ -611,7 +610,7 @@ func (p *ConnPool) RPC(
if method == "AutoEncrypt.Sign" { if method == "AutoEncrypt.Sign" {
return p.rpcInsecure(dc, nodeName, addr, method, args, reply) return p.rpcInsecure(dc, nodeName, addr, method, args, reply)
} else { } else {
return p.rpc(dc, nodeName, addr, version, method, useTLS, args, reply) return p.rpc(dc, nodeName, addr, version, method, args, reply)
} }
} }
@ -637,10 +636,11 @@ func (p *ConnPool) rpcInsecure(dc string, nodeName string, addr net.Addr, method
return nil return nil
} }
func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, version int, method string, useTLS bool, args interface{}, reply interface{}) error { func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error {
p.once.Do(p.init) p.once.Do(p.init)
// Get a usable client // Get a usable client
useTLS := p.TLSConfigurator.UseTLS(dc)
conn, sc, err := p.getClient(dc, nodeName, addr, version, useTLS) conn, sc, err := p.getClient(dc, nodeName, addr, version, useTLS)
if err != nil { if err != nil {
return fmt.Errorf("rpc error getting client: %v", err) return fmt.Errorf("rpc error getting client: %v", err)
@ -671,9 +671,9 @@ func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, version int, m
// Ping sends a Status.Ping message to the specified server and // Ping sends a Status.Ping message to the specified server and
// returns true if healthy, false if an error occurred // returns true if healthy, false if an error occurred
func (p *ConnPool) Ping(dc string, nodeName string, addr net.Addr, version int, useTLS bool) (bool, error) { func (p *ConnPool) Ping(dc string, nodeName string, addr net.Addr, version int) (bool, error) {
var out struct{} var out struct{}
err := p.RPC(dc, nodeName, addr, version, "Status.Ping", useTLS, struct{}{}, &out) err := p.RPC(dc, nodeName, addr, version, "Status.Ping", struct{}{}, &out)
return err == nil, err return err == nil, err
} }

View File

@ -61,7 +61,7 @@ type ManagerSerfCluster interface {
// Pinger is an interface wrapping client.ConnPool to prevent a cyclic import // Pinger is an interface wrapping client.ConnPool to prevent a cyclic import
// dependency. // dependency.
type Pinger interface { type Pinger interface {
Ping(dc, nodeName string, addr net.Addr, version int, useTLS bool) (bool, error) Ping(dc, nodeName string, addr net.Addr, version int) (bool, error)
} }
// serverList is a local copy of the struct used to maintain the list of // serverList is a local copy of the struct used to maintain the list of
@ -98,6 +98,10 @@ type Manager struct {
// client.ConnPool. // client.ConnPool.
connPoolPinger Pinger connPoolPinger Pinger
// serverName has the name of the managers's server. This is used to
// short-circuit pinging to itself.
serverName string
// notifyFailedBarrier is acts as a barrier to prevent queuing behind // notifyFailedBarrier is acts as a barrier to prevent queuing behind
// serverListLog and acts as a TryLock(). // serverListLog and acts as a TryLock().
notifyFailedBarrier int32 notifyFailedBarrier int32
@ -256,7 +260,7 @@ func (m *Manager) saveServerList(l serverList) {
} }
// New is the only way to safely create a new Manager struct. // New is the only way to safely create a new Manager struct.
func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger) (m *Manager) { func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger, serverName string) (m *Manager) {
if logger == nil { if logger == nil {
logger = hclog.New(&hclog.LoggerOptions{}) logger = hclog.New(&hclog.LoggerOptions{})
} }
@ -267,6 +271,7 @@ func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfC
m.connPoolPinger = connPoolPinger // can't pass *consul.ConnPool: import cycle m.connPoolPinger = connPoolPinger // can't pass *consul.ConnPool: import cycle
m.rebalanceTimer = time.NewTimer(clientRPCMinReuseDuration) m.rebalanceTimer = time.NewTimer(clientRPCMinReuseDuration)
m.shutdownCh = shutdownCh m.shutdownCh = shutdownCh
m.serverName = serverName
atomic.StoreInt32(&m.offline, 1) atomic.StoreInt32(&m.offline, 1)
l := serverList{} l := serverList{}
@ -340,7 +345,12 @@ func (m *Manager) RebalanceServers() {
// while Serf detects the node has failed. // while Serf detects the node has failed.
srv := l.servers[0] srv := l.servers[0]
ok, err := m.connPoolPinger.Ping(srv.Datacenter, srv.ShortName, srv.Addr, srv.Version, srv.UseTLS) // check to see if the manager is trying to ping itself,
// continue if that is the case.
if m.serverName != "" && srv.Name == m.serverName {
continue
}
ok, err := m.connPoolPinger.Ping(srv.Datacenter, srv.ShortName, srv.Addr, srv.Version)
if ok { if ok {
foundHealthyServer = true foundHealthyServer = true
break break

View File

@ -33,7 +33,7 @@ type fauxConnPool struct {
failPct float64 failPct float64
} }
func (cp *fauxConnPool) Ping(string, string, net.Addr, int, bool) (bool, error) { func (cp *fauxConnPool) Ping(string, string, net.Addr, int) (bool, error) {
var success bool var success bool
successProb := rand.Float64() successProb := rand.Float64()
if successProb > cp.failPct { if successProb > cp.failPct {
@ -53,14 +53,14 @@ func (s *fauxSerf) NumNodes() int {
func testManager() (m *Manager) { func testManager() (m *Manager) {
logger := GetBufferedLogger() logger := GetBufferedLogger()
shutdownCh := make(chan struct{}) shutdownCh := make(chan struct{})
m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}) m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}, "")
return m return m
} }
func testManagerFailProb(failPct float64) (m *Manager) { func testManagerFailProb(failPct float64) (m *Manager) {
logger := GetBufferedLogger() logger := GetBufferedLogger()
shutdownCh := make(chan struct{}) shutdownCh := make(chan struct{})
m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}) m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "")
return m return m
} }
@ -179,7 +179,7 @@ func test_reconcileServerList(maxServers int) (bool, error) {
// failPct of the servers for the reconcile. This // failPct of the servers for the reconcile. This
// allows for the selected server to no longer be // allows for the selected server to no longer be
// healthy for the reconcile below. // healthy for the reconcile below.
if ok, _ := m.connPoolPinger.Ping(node.Datacenter, node.ShortName, node.Addr, node.Version, node.UseTLS); ok { if ok, _ := m.connPoolPinger.Ping(node.Datacenter, node.ShortName, node.Addr, node.Version); ok {
// Will still be present // Will still be present
healthyServers = append(healthyServers, node) healthyServers = append(healthyServers, node)
} else { } else {
@ -299,7 +299,7 @@ func TestManagerInternal_refreshServerRebalanceTimer(t *testing.T) {
shutdownCh := make(chan struct{}) shutdownCh := make(chan struct{})
for _, s := range clusters { for _, s := range clusters {
m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}) m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}, "")
for i := 0; i < s.numServers; i++ { for i := 0; i < s.numServers; i++ {
nodeName := fmt.Sprintf("s%02d", i) nodeName := fmt.Sprintf("s%02d", i)
m.AddServer(&metadata.Server{Name: nodeName}) m.AddServer(&metadata.Server{Name: nodeName})

View File

@ -32,7 +32,7 @@ type fauxConnPool struct {
failAddr net.Addr failAddr net.Addr
} }
func (cp *fauxConnPool) Ping(dc string, nodeName string, addr net.Addr, version int, useTLS bool) (bool, error) { func (cp *fauxConnPool) Ping(dc string, nodeName string, addr net.Addr, version int) (bool, error) {
var success bool var success bool
successProb := rand.Float64() successProb := rand.Float64()
@ -57,21 +57,21 @@ func (s *fauxSerf) NumNodes() int {
func testManager(t testing.TB) (m *router.Manager) { func testManager(t testing.TB) (m *router.Manager) {
logger := testutil.Logger(t) logger := testutil.Logger(t)
shutdownCh := make(chan struct{}) shutdownCh := make(chan struct{})
m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}) m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "")
return m return m
} }
func testManagerFailProb(t testing.TB, failPct float64) (m *router.Manager) { func testManagerFailProb(t testing.TB, failPct float64) (m *router.Manager) {
logger := testutil.Logger(t) logger := testutil.Logger(t)
shutdownCh := make(chan struct{}) shutdownCh := make(chan struct{})
m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}) m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "")
return m return m
} }
func testManagerFailAddr(t testing.TB, failAddr net.Addr) (m *router.Manager) { func testManagerFailAddr(t testing.TB, failAddr net.Addr) (m *router.Manager) {
logger := testutil.Logger(t) logger := testutil.Logger(t)
shutdownCh := make(chan struct{}) shutdownCh := make(chan struct{})
m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}) m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}, "")
return m return m
} }
@ -195,7 +195,7 @@ func TestServers_FindServer(t *testing.T) {
func TestServers_New(t *testing.T) { func TestServers_New(t *testing.T) {
logger := testutil.Logger(t) logger := testutil.Logger(t)
shutdownCh := make(chan struct{}) shutdownCh := make(chan struct{})
m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}) m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "")
if m == nil { if m == nil {
t.Fatalf("Manager nil") t.Fatalf("Manager nil")
} }

View File

@ -26,6 +26,10 @@ type Router struct {
// used to short-circuit RTT calculations for local servers. // used to short-circuit RTT calculations for local servers.
localDatacenter string localDatacenter string
// serverName has the name of the router's server. This is used to
// short-circuit pinging to itself.
serverName string
// areas maps area IDs to structures holding information about that // areas maps area IDs to structures holding information about that
// area. // area.
areas map[types.AreaID]*areaInfo areas map[types.AreaID]*areaInfo
@ -83,7 +87,7 @@ type areaInfo struct {
} }
// NewRouter returns a new Router with the given configuration. // NewRouter returns a new Router with the given configuration.
func NewRouter(logger hclog.Logger, localDatacenter string) *Router { func NewRouter(logger hclog.Logger, localDatacenter, serverName string) *Router {
if logger == nil { if logger == nil {
logger = hclog.New(&hclog.LoggerOptions{}) logger = hclog.New(&hclog.LoggerOptions{})
} }
@ -91,6 +95,7 @@ func NewRouter(logger hclog.Logger, localDatacenter string) *Router {
router := &Router{ router := &Router{
logger: logger.Named(logging.Router), logger: logger.Named(logging.Router),
localDatacenter: localDatacenter, localDatacenter: localDatacenter,
serverName: serverName,
areas: make(map[types.AreaID]*areaInfo), areas: make(map[types.AreaID]*areaInfo),
managers: make(map[string][]*Manager), managers: make(map[string][]*Manager),
} }
@ -120,7 +125,7 @@ func (r *Router) Shutdown() {
} }
// AddArea registers a new network area with the router. // AddArea registers a new network area with the router.
func (r *Router) AddArea(areaID types.AreaID, cluster RouterSerfCluster, pinger Pinger, useTLS bool) error { func (r *Router) AddArea(areaID types.AreaID, cluster RouterSerfCluster, pinger Pinger) error {
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
@ -136,7 +141,6 @@ func (r *Router) AddArea(areaID types.AreaID, cluster RouterSerfCluster, pinger
cluster: cluster, cluster: cluster,
pinger: pinger, pinger: pinger,
managers: make(map[string]*managerInfo), managers: make(map[string]*managerInfo),
useTLS: useTLS,
} }
r.areas[areaID] = area r.areas[areaID] = area
@ -162,6 +166,23 @@ func (r *Router) AddArea(areaID types.AreaID, cluster RouterSerfCluster, pinger
return nil return nil
} }
// GetServerMetadataByAddr returns server metadata by dc and address. If it
// didn't find anything, nil is returned.
func (r *Router) GetServerMetadataByAddr(dc, addr string) *metadata.Server {
r.RLock()
defer r.RUnlock()
if ms, ok := r.managers[dc]; ok {
for _, m := range ms {
for _, s := range m.getServerList().servers {
if s.Addr.String() == addr {
return s
}
}
}
}
return nil
}
// removeManagerFromIndex does cleanup to take a manager out of the index of // removeManagerFromIndex does cleanup to take a manager out of the index of
// datacenters. This assumes the lock is already held for writing, and will // datacenters. This assumes the lock is already held for writing, and will
// panic if the given manager isn't found. // panic if the given manager isn't found.
@ -219,7 +240,7 @@ func (r *Router) addServer(area *areaInfo, s *metadata.Server) error {
info, ok := area.managers[s.Datacenter] info, ok := area.managers[s.Datacenter]
if !ok { if !ok {
shutdownCh := make(chan struct{}) shutdownCh := make(chan struct{})
manager := New(r.logger, shutdownCh, area.cluster, area.pinger) manager := New(r.logger, shutdownCh, area.cluster, area.pinger, r.serverName)
info = &managerInfo{ info = &managerInfo{
manager: manager, manager: manager,
shutdownCh: shutdownCh, shutdownCh: shutdownCh,

View File

@ -95,7 +95,7 @@ func testCluster(self string) *mockCluster {
func testRouter(t testing.TB, dc string) *Router { func testRouter(t testing.TB, dc string) *Router {
logger := testutil.Logger(t) logger := testutil.Logger(t)
return NewRouter(logger, dc) return NewRouter(logger, dc, "")
} }
func TestRouter_Shutdown(t *testing.T) { func TestRouter_Shutdown(t *testing.T) {
@ -104,7 +104,7 @@ func TestRouter_Shutdown(t *testing.T) {
// Create a WAN-looking area. // Create a WAN-looking area.
self := "node0.dc0" self := "node0.dc0"
wan := testCluster(self) wan := testCluster(self)
if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}, false); err != nil { if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -112,7 +112,7 @@ func TestRouter_Shutdown(t *testing.T) {
otherID := types.AreaID("other") otherID := types.AreaID("other")
other := newMockCluster(self) other := newMockCluster(self)
other.AddMember("dcY", "node1", nil) other.AddMember("dcY", "node1", nil)
if err := r.AddArea(otherID, other, &fauxConnPool{}, false); err != nil { if err := r.AddArea(otherID, other, &fauxConnPool{}); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
_, _, ok := r.FindRoute("dcY") _, _, ok := r.FindRoute("dcY")
@ -128,7 +128,7 @@ func TestRouter_Shutdown(t *testing.T) {
} }
// You can't add areas once the router is shut down. // You can't add areas once the router is shut down.
err := r.AddArea(otherID, other, &fauxConnPool{}, false) err := r.AddArea(otherID, other, &fauxConnPool{})
if err == nil || !strings.Contains(err.Error(), "router is shut down") { if err == nil || !strings.Contains(err.Error(), "router is shut down") {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -140,7 +140,7 @@ func TestRouter_Routing(t *testing.T) {
// Create a WAN-looking area. // Create a WAN-looking area.
self := "node0.dc0" self := "node0.dc0"
wan := testCluster(self) wan := testCluster(self)
if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}, false); err != nil { if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -169,7 +169,7 @@ func TestRouter_Routing(t *testing.T) {
other.AddMember("dc0", "node0", nil) other.AddMember("dc0", "node0", nil)
other.AddMember("dc1", "node1", nil) other.AddMember("dc1", "node1", nil)
other.AddMember("dcY", "node1", nil) other.AddMember("dcY", "node1", nil)
if err := r.AddArea(otherID, other, &fauxConnPool{}, false); err != nil { if err := r.AddArea(otherID, other, &fauxConnPool{}); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -274,7 +274,7 @@ func TestRouter_Routing_Offline(t *testing.T) {
// Create a WAN-looking area. // Create a WAN-looking area.
self := "node0.dc0" self := "node0.dc0"
wan := testCluster(self) wan := testCluster(self)
if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{1.0}, false); err != nil { if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{1.0}); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -328,7 +328,7 @@ func TestRouter_Routing_Offline(t *testing.T) {
other := newMockCluster(self) other := newMockCluster(self)
other.AddMember("dc0", "node0", nil) other.AddMember("dc0", "node0", nil)
other.AddMember("dc1", "node1", nil) other.AddMember("dc1", "node1", nil)
if err := r.AddArea(otherID, other, &fauxConnPool{}, false); err != nil { if err := r.AddArea(otherID, other, &fauxConnPool{}); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -353,7 +353,7 @@ func TestRouter_GetDatacenters(t *testing.T) {
self := "node0.dc0" self := "node0.dc0"
wan := testCluster(self) wan := testCluster(self)
if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}, false); err != nil { if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -385,7 +385,7 @@ func TestRouter_GetDatacentersByDistance(t *testing.T) {
// Start with just the WAN area described in the diagram above. // Start with just the WAN area described in the diagram above.
self := "node0.dc0" self := "node0.dc0"
wan := testCluster(self) wan := testCluster(self)
if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}, false); err != nil { if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -403,7 +403,7 @@ func TestRouter_GetDatacentersByDistance(t *testing.T) {
other := newMockCluster(self) other := newMockCluster(self)
other.AddMember("dc0", "node0", lib.GenerateCoordinate(20*time.Millisecond)) other.AddMember("dc0", "node0", lib.GenerateCoordinate(20*time.Millisecond))
other.AddMember("dc1", "node1", lib.GenerateCoordinate(21*time.Millisecond)) other.AddMember("dc1", "node1", lib.GenerateCoordinate(21*time.Millisecond))
if err := r.AddArea(otherID, other, &fauxConnPool{}, false); err != nil { if err := r.AddArea(otherID, other, &fauxConnPool{}); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -422,7 +422,7 @@ func TestRouter_GetDatacenterMaps(t *testing.T) {
self := "node0.dc0" self := "node0.dc0"
wan := testCluster(self) wan := testCluster(self)
if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}, false); err != nil { if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View File

@ -33,6 +33,7 @@ const (
Memberlist string = "memberlist" Memberlist string = "memberlist"
MeshGateway string = "mesh_gateway" MeshGateway string = "mesh_gateway"
Namespace string = "namespace" Namespace string = "namespace"
NetworkAreas string = "network_areas"
Operator string = "operator" Operator string = "operator"
PreparedQuery string = "prepared_query" PreparedQuery string = "prepared_query"
Proxy string = "proxy" Proxy string = "proxy"

View File

@ -179,9 +179,10 @@ type manual struct {
// *tls.Config necessary for Consul. Except the one in the api package. // *tls.Config necessary for Consul. Except the one in the api package.
type Configurator struct { type Configurator struct {
sync.RWMutex sync.RWMutex
base *Config base *Config
autoEncrypt *autoEncrypt autoEncrypt *autoEncrypt
manual *manual manual *manual
peerDatacenterUseTLS map[string]bool
caPool *x509.CertPool caPool *x509.CertPool
logger hclog.Logger logger hclog.Logger
@ -198,9 +199,10 @@ func NewConfigurator(config Config, logger hclog.Logger) (*Configurator, error)
} }
c := &Configurator{ c := &Configurator{
logger: logger.Named(logging.TLSUtil), logger: logger.Named(logging.TLSUtil),
manual: &manual{}, manual: &manual{},
autoEncrypt: &autoEncrypt{}, autoEncrypt: &autoEncrypt{},
peerDatacenterUseTLS: map[string]bool{},
} }
err := c.Update(config) err := c.Update(config)
if err != nil { if err != nil {
@ -323,6 +325,22 @@ func (c *Configurator) UpdateAutoEncrypt(manualCAPems, connectCAPems []string, p
return nil return nil
} }
func (c *Configurator) UpdateAreaPeerDatacenterUseTLS(peerDatacenter string, useTLS bool) {
c.Lock()
defer c.Unlock()
c.version++
c.peerDatacenterUseTLS[peerDatacenter] = useTLS
}
func (c *Configurator) getAreaForPeerDatacenterUseTLS(peerDatacenter string) bool {
c.RLock()
defer c.RUnlock()
if v, ok := c.peerDatacenterUseTLS[peerDatacenter]; ok {
return v
}
return true
}
func (c *Configurator) Base() Config { func (c *Configurator) Base() Config {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
@ -535,7 +553,7 @@ func (c *Configurator) outgoingRPCTLSDisabled() bool {
} }
// if CAs are provided or VerifyOutgoing is set, use TLS // if CAs are provided or VerifyOutgoing is set, use TLS
if c.caPool != nil || c.base.VerifyOutgoing { if c.base.VerifyOutgoing {
return false return false
} }
@ -742,16 +760,20 @@ func (c *Configurator) OutgoingALPNRPCConfig() *tls.Config {
// decides if verify server hostname should be used. // decides if verify server hostname should be used.
func (c *Configurator) OutgoingRPCWrapper() DCWrapper { func (c *Configurator) OutgoingRPCWrapper() DCWrapper {
c.log("OutgoingRPCWrapper") c.log("OutgoingRPCWrapper")
if c.outgoingRPCTLSDisabled() {
return nil
}
// Generate the wrapper based on dc // Generate the wrapper based on dc
return func(dc string, conn net.Conn) (net.Conn, error) { return func(dc string, conn net.Conn) (net.Conn, error) {
return c.wrapTLSClient(dc, conn) if c.UseTLS(dc) {
return c.wrapTLSClient(dc, conn)
}
return conn, nil
} }
} }
func (c *Configurator) UseTLS(dc string) bool {
return !c.outgoingRPCTLSDisabled() && c.getAreaForPeerDatacenterUseTLS(dc)
}
// OutgoingALPNRPCWrapper wraps the result of OutgoingALPNRPCConfig in an // OutgoingALPNRPCWrapper wraps the result of OutgoingALPNRPCConfig in an
// ALPNWrapper. It configures all of the negotiation plumbing. // ALPNWrapper. It configures all of the negotiation plumbing.
func (c *Configurator) OutgoingALPNRPCWrapper() ALPNWrapper { func (c *Configurator) OutgoingALPNRPCWrapper() ALPNWrapper {

View File

@ -99,10 +99,11 @@ func TestConfigurator_outgoingWrapper_OK(t *testing.T) {
func TestConfigurator_outgoingWrapper_noverify_OK(t *testing.T) { func TestConfigurator_outgoingWrapper_noverify_OK(t *testing.T) {
config := Config{ config := Config{
CAFile: "../test/hostname/CertAuth.crt", VerifyOutgoing: true,
CertFile: "../test/hostname/Alice.crt", CAFile: "../test/hostname/CertAuth.crt",
KeyFile: "../test/hostname/Alice.key", CertFile: "../test/hostname/Alice.crt",
Domain: "consul", KeyFile: "../test/hostname/Alice.key",
Domain: "consul",
} }
client, errc := startRPCTLSServer(&config) client, errc := startRPCTLSServer(&config)
@ -744,7 +745,7 @@ func TestConfigurator_OutgoingRPCTLSDisabled(t *testing.T) {
{false, true, nil, false}, {false, true, nil, false},
{true, true, nil, false}, {true, true, nil, false},
{false, false, &x509.CertPool{}, false}, // {false, false, &x509.CertPool{}, false},
{true, false, &x509.CertPool{}, false}, {true, false, &x509.CertPool{}, false},
{false, true, &x509.CertPool{}, false}, {false, true, &x509.CertPool{}, false},
{true, true, &x509.CertPool{}, false}, {true, true, &x509.CertPool{}, false},
@ -959,32 +960,42 @@ func TestConfigurator_OutgoingALPNRPCConfig(t *testing.T) {
func TestConfigurator_OutgoingRPCWrapper(t *testing.T) { func TestConfigurator_OutgoingRPCWrapper(t *testing.T) {
c := &Configurator{base: &Config{}, autoEncrypt: &autoEncrypt{}} c := &Configurator{base: &Config{}, autoEncrypt: &autoEncrypt{}}
require.Nil(t, c.OutgoingRPCWrapper()) wrapper := c.OutgoingRPCWrapper()
require.NotNil(t, wrapper)
conn := &net.TCPConn{}
cWrap, err := wrapper("", conn)
require.Equal(t, conn, cWrap)
c, err := NewConfigurator(Config{ c, err = NewConfigurator(Config{
VerifyOutgoing: true, VerifyOutgoing: true,
CAFile: "../test/ca/root.cer", CAFile: "../test/ca/root.cer",
}, nil) }, nil)
require.NoError(t, err) require.NoError(t, err)
wrap := c.OutgoingRPCWrapper() wrapper = c.OutgoingRPCWrapper()
require.NotNil(t, wrap) require.NotNil(t, wrapper)
t.Log("TODO: actually call wrap here eventually") cWrap, err = wrapper("", conn)
require.NotEqual(t, conn, cWrap)
} }
func TestConfigurator_OutgoingALPNRPCWrapper(t *testing.T) { func TestConfigurator_OutgoingALPNRPCWrapper(t *testing.T) {
c := &Configurator{base: &Config{}, autoEncrypt: &autoEncrypt{}} c := &Configurator{base: &Config{}, autoEncrypt: &autoEncrypt{}}
require.Nil(t, c.OutgoingRPCWrapper()) wrapper := c.OutgoingRPCWrapper()
require.NotNil(t, wrapper)
conn := &net.TCPConn{}
cWrap, err := wrapper("", conn)
require.Equal(t, conn, cWrap)
c, err := NewConfigurator(Config{ c, err = NewConfigurator(Config{
VerifyOutgoing: false, // ignored, assumed true VerifyOutgoing: true,
CAFile: "../test/ca/root.cer", CAFile: "../test/ca/root.cer",
}, nil) }, nil)
require.NoError(t, err) require.NoError(t, err)
wrap := c.OutgoingRPCWrapper() wrapper = c.OutgoingRPCWrapper()
require.NotNil(t, wrap) require.NotNil(t, wrapper)
t.Log("TODO: actually call wrap here eventually") cWrap, err = wrapper("", conn)
require.NotEqual(t, conn, cWrap)
} }
func TestConfigurator_UpdateChecks(t *testing.T) { func TestConfigurator_UpdateChecks(t *testing.T) {