diff --git a/consul/raft_rpc.go b/consul/raft_rpc.go index 4ed7437e9e..1221ce06f2 100644 --- a/consul/raft_rpc.go +++ b/consul/raft_rpc.go @@ -1,6 +1,7 @@ package consul import ( + "crypto/tls" "fmt" "net" "sync" @@ -16,6 +17,9 @@ type RaftLayer struct { // connCh is used to accept connections connCh chan net.Conn + // TLS configuration + tlsConfig *tls.Config + // Tracks if we are closed closed bool closeCh chan struct{} @@ -23,12 +27,14 @@ type RaftLayer struct { } // NewRaftLayer is used to initialize a new RaftLayer which can -// be used as a StreamLayer for Raft -func NewRaftLayer(addr net.Addr) *RaftLayer { +// be used as a StreamLayer for Raft. If a tlsConfig is provided, +// then the connection will use TLS. +func NewRaftLayer(addr net.Addr, tlsConfig *tls.Config) *RaftLayer { layer := &RaftLayer{ - addr: addr, - connCh: make(chan net.Conn), - closeCh: make(chan struct{}), + addr: addr, + connCh: make(chan net.Conn), + tlsConfig: tlsConfig, + closeCh: make(chan struct{}), } return layer } @@ -79,6 +85,18 @@ func (l *RaftLayer) Dial(address string, timeout time.Duration) (net.Conn, error return nil, err } + // Check for tls mode + if l.tlsConfig != nil { + // Switch the connection into TLS mode + if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil { + conn.Close() + return nil, err + } + + // Wrap the connection in a TLS client + conn = tls.Client(conn, l.tlsConfig) + } + // Write the Raft byte to set the mode _, err = conn.Write([]byte{byte(rpcRaft)}) if err != nil { diff --git a/consul/server.go b/consul/server.go index 3b97d0980a..640140ed1f 100644 --- a/consul/server.go +++ b/consul/server.go @@ -160,7 +160,7 @@ func NewServer(config *Config) (*Server, error) { } // Initialize the RPC layer - if err := s.setupRPC(); err != nil { + if err := s.setupRPC(tlsConfig); err != nil { s.Shutdown() return nil, fmt.Errorf("Failed to start RPC layer: %v", err) } @@ -290,7 +290,7 @@ func (s *Server) setupRaft() error { } // setupRPC is used to setup the RPC listener -func (s *Server) setupRPC() error { +func (s *Server) setupRPC(tlsConfig *tls.Config) error { // Create endpoints s.endpoints.Status = &Status{s} s.endpoints.Raft = &Raft{s} @@ -329,7 +329,7 @@ func (s *Server) setupRPC() error { return fmt.Errorf("RPC advertise address is not advertisable: %v", addr) } - s.raftLayer = NewRaftLayer(advertise) + s.raftLayer = NewRaftLayer(advertise, tlsConfig) go s.listen() return nil }