diff --git a/command/agent/agent.go b/command/agent/agent.go index 1621149bc8..e701c37ee0 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -3,6 +3,7 @@ package agent import ( "fmt" "github.com/hashicorp/consul/consul" + "github.com/hashicorp/serf/serf" "io" "log" "os" @@ -187,3 +188,59 @@ func (a *Agent) Shutdown() error { func (a *Agent) ShutdownCh() <-chan struct{} { return a.shutdownCh } + +// JoinLAN is used to have the agent join a LAN cluster +func (a *Agent) JoinLAN(addrs []string) (n int, err error) { + a.logger.Printf("[INFO] agent: (LAN) joining: %v", addrs) + if a.server != nil { + n, err = a.server.JoinLAN(addrs) + } else { + n, err = a.client.JoinLAN(addrs) + } + a.logger.Printf("[INFO] agent: (LAN) joined: %d Err: %v", n, err) + return +} + +// JoinWAN is used to have the agent join a WAN cluster +func (a *Agent) JoinWAN(addrs []string) (n int, err error) { + a.logger.Printf("[INFO] agent: (WAN) joining: %v", addrs) + if a.server != nil { + n, err = a.server.JoinWAN(addrs) + } else { + err = fmt.Errorf("Must be a server to join WAN cluster") + } + a.logger.Printf("[INFO] agent: (WAN) joined: %d Err: %v", n, err) + return +} + +// ForceLeave is used to remove a failed node from the cluster +func (a *Agent) ForceLeave(node string) (err error) { + a.logger.Printf("[INFO] Force leaving node: %v", node) + if a.server != nil { + err = a.server.RemoveFailedNode(node) + } else { + err = a.client.RemoveFailedNode(node) + } + if err != nil { + a.logger.Printf("[WARN] Failed to remove node: %v", err) + } + return err +} + +// Used to retrieve the LAN members +func (a *Agent) LANMembers() []serf.Member { + if a.server != nil { + return a.server.LANMembers() + } else { + return a.client.LANMembers() + } +} + +// Used to retrieve the WAN members +func (a *Agent) WANMembers() []serf.Member { + if a.server != nil { + return a.server.WANMembers() + } else { + return nil + } +} diff --git a/command/agent/agent_test.go b/command/agent/agent_test.go index 849c61201c..151b43c42e 100644 --- a/command/agent/agent_test.go +++ b/command/agent/agent_test.go @@ -3,6 +3,7 @@ package agent import ( "fmt" "github.com/hashicorp/consul/consul" + "io" "io/ioutil" "os" "sync/atomic" @@ -18,22 +19,24 @@ func nextConfig() *Config { conf.Bootstrap = true conf.Datacenter = "dc1" + conf.NodeName = fmt.Sprintf("Node %d", idx) conf.HTTPAddr = fmt.Sprintf("127.0.0.1:%d", 8500+10*idx) conf.RPCAddr = fmt.Sprintf("127.0.0.1:%d", 8400+10*idx) conf.SerfBindAddr = "127.0.0.1" conf.SerfLanPort = int(8301 + 10*idx) conf.SerfWanPort = int(8302 + 10*idx) conf.Server = true + conf.ServerAddr = fmt.Sprintf("127.0.0.1:%d", 8100+10*idx) cons := consul.DefaultConfig() conf.ConsulConfig = cons - cons.SerfLANConfig.MemberlistConfig.ProbeTimeout = 200 * time.Millisecond - cons.SerfLANConfig.MemberlistConfig.ProbeInterval = time.Second + cons.SerfLANConfig.MemberlistConfig.ProbeTimeout = 100 * time.Millisecond + cons.SerfLANConfig.MemberlistConfig.ProbeInterval = 100 * time.Millisecond cons.SerfLANConfig.MemberlistConfig.GossipInterval = 100 * time.Millisecond - cons.SerfWANConfig.MemberlistConfig.ProbeTimeout = 200 * time.Millisecond - cons.SerfWANConfig.MemberlistConfig.ProbeInterval = time.Second + cons.SerfWANConfig.MemberlistConfig.ProbeTimeout = 100 * time.Millisecond + cons.SerfWANConfig.MemberlistConfig.ProbeInterval = 100 * time.Millisecond cons.SerfWANConfig.MemberlistConfig.GossipInterval = 100 * time.Millisecond cons.RaftConfig.HeartbeatTimeout = 40 * time.Millisecond @@ -42,14 +45,14 @@ func nextConfig() *Config { return conf } -func makeAgent(t *testing.T, conf *Config) (string, *Agent) { +func makeAgentLog(t *testing.T, conf *Config, l io.Writer) (string, *Agent) { dir, err := ioutil.TempDir("", "agent") if err != nil { t.Fatalf(fmt.Sprintf("err: %v", err)) } conf.DataDir = dir - agent, err := Create(conf, nil) + agent, err := Create(conf, l) if err != nil { os.RemoveAll(dir) t.Fatalf(fmt.Sprintf("err: %v", err)) @@ -58,6 +61,10 @@ func makeAgent(t *testing.T, conf *Config) (string, *Agent) { return dir, agent } +func makeAgent(t *testing.T, conf *Config) (string, *Agent) { + return makeAgentLog(t, conf, nil) +} + func TestAgentStartStop(t *testing.T) { dir, agent := makeAgent(t, nextConfig()) defer os.RemoveAll(dir) diff --git a/command/agent/rpc.go b/command/agent/rpc.go new file mode 100644 index 0000000000..5057165597 --- /dev/null +++ b/command/agent/rpc.go @@ -0,0 +1,544 @@ +package agent + +/* + The agent exposes an RPC mechanism that is used for both controlling + Consul as well as providing a fast streaming mechanism for events. This + allows other applications to easily leverage Consul without embedding. + + We additionally make use of the RPC layer to also handle calls from + the CLI to unify the code paths. This results in a split Request/Response + as well as streaming mode of operation. + + The system is fairly simple, each client opens a TCP connection to the + agent. The connection is initialized with a handshake which establishes + the protocol version being used. This is to allow for future changes to + the protocol. + + Once initialized, clients send commands and wait for responses. Certain + commands will cause the client to subscribe to events, and those will be + pushed down the socket as they are received. This provides a low-latency + mechanism for applications to send and receive events, while also providing + a flexible control mechanism for Consul. +*/ + +import ( + "bufio" + "fmt" + "github.com/hashicorp/logutils" + "github.com/hashicorp/serf/serf" + "github.com/ugorji/go/codec" + "io" + "log" + "net" + "os" + "strings" + "sync" +) + +const ( + MinRPCVersion = 1 + MaxRPCVersion = 1 +) + +const ( + handshakeCommand = "handshake" + forceLeaveCommand = "force-leave" + joinCommand = "join" + membersLANCommand = "members-lan" + membersWANCommand = "members-wan" + stopCommand = "stop" + monitorCommand = "monitor" + leaveCommand = "leave" +) + +const ( + unsupportedCommand = "Unsupported command" + unsupportedRPCVersion = "Unsupported RPC version" + duplicateHandshake = "Handshake already performed" + handshakeRequired = "Handshake required" + monitorExists = "Monitor already exists" +) + +// Request header is sent before each request +type requestHeader struct { + Command string + Seq uint64 +} + +// Response header is sent before each response +type responseHeader struct { + Seq uint64 + Error string +} + +type handshakeRequest struct { + Version int32 +} + +type eventRequest struct { + Name string + Payload []byte + Coalesce bool +} + +type forceLeaveRequest struct { + Node string +} + +type joinRequest struct { + Existing []string + WAN bool +} + +type joinResponse struct { + Num int32 +} + +type membersResponse struct { + Members []Member +} + +type monitorRequest struct { + LogLevel string +} + +type streamRequest struct { + Type string +} + +type stopRequest struct { + Stop uint64 +} + +type logRecord struct { + Log string +} + +type userEventRecord struct { + Event string + LTime serf.LamportTime + Name string + Payload []byte + Coalesce bool +} + +type Member struct { + Name string + Addr net.IP + Port uint16 + Role string + Status string + ProtocolMin uint8 + ProtocolMax uint8 + ProtocolCur uint8 + DelegateMin uint8 + DelegateMax uint8 + DelegateCur uint8 +} + +type memberEventRecord struct { + Event string + Members []Member +} + +type AgentRPC struct { + sync.Mutex + agent *Agent + clients map[string]*rpcClient + listener net.Listener + logger *log.Logger + logWriter *logWriter + stop bool + stopCh chan struct{} +} + +type rpcClient struct { + name string + conn net.Conn + reader *bufio.Reader + writer *bufio.Writer + dec *codec.Decoder + enc *codec.Encoder + writeLock sync.Mutex + version int32 // From the handshake, 0 before + logStreamer *logStream +} + +// send is used to send an object using the MsgPack encoding. send +// is serialized to prevent write overlaps, while properly buffering. +func (c *rpcClient) Send(header *responseHeader, obj interface{}) error { + c.writeLock.Lock() + defer c.writeLock.Unlock() + + if err := c.enc.Encode(header); err != nil { + return err + } + + if obj != nil { + if err := c.enc.Encode(obj); err != nil { + return err + } + } + + if err := c.writer.Flush(); err != nil { + return err + } + + return nil +} + +func (c *rpcClient) String() string { + return fmt.Sprintf("rpc.client: %v", c.conn) +} + +// NewAgentRPC is used to create a new Agent RPC handler +func NewAgentRPC(agent *Agent, listener net.Listener, + logOutput io.Writer, logWriter *logWriter) *AgentRPC { + if logOutput == nil { + logOutput = os.Stderr + } + rpc := &AgentRPC{ + agent: agent, + clients: make(map[string]*rpcClient), + listener: listener, + logger: log.New(logOutput, "", log.LstdFlags), + logWriter: logWriter, + stopCh: make(chan struct{}), + } + go rpc.listen() + return rpc +} + +// Shutdown is used to shutdown the RPC layer +func (i *AgentRPC) Shutdown() { + i.Lock() + defer i.Unlock() + + if i.stop { + return + } + + i.stop = true + close(i.stopCh) + i.listener.Close() + + // Close the existing connections + for _, client := range i.clients { + client.conn.Close() + } +} + +// listen is a long running routine that listens for new clients +func (i *AgentRPC) listen() { + for { + conn, err := i.listener.Accept() + if err != nil { + if i.stop { + return + } + i.logger.Printf("[ERR] agent.rpc: Failed to accept client: %v", err) + continue + } + i.logger.Printf("[INFO] agent.rpc: Accepted client: %v", conn.RemoteAddr()) + + // Wrap the connection in a client + client := &rpcClient{ + name: conn.RemoteAddr().String(), + conn: conn, + reader: bufio.NewReader(conn), + writer: bufio.NewWriter(conn), + } + client.dec = codec.NewDecoder(client.reader, + &codec.MsgpackHandle{RawToString: true, WriteExt: true}) + client.enc = codec.NewEncoder(client.writer, + &codec.MsgpackHandle{RawToString: true, WriteExt: true}) + if err != nil { + i.logger.Printf("[ERR] agent.rpc: Failed to create decoder: %v", err) + conn.Close() + continue + } + + // Register the client + i.Lock() + if !i.stop { + i.clients[client.name] = client + go i.handleClient(client) + } else { + conn.Close() + } + i.Unlock() + } +} + +// deregisterClient is called to cleanup after a client disconnects +func (i *AgentRPC) deregisterClient(client *rpcClient) { + // Close the socket + client.conn.Close() + + // Remove from the clients list + i.Lock() + delete(i.clients, client.name) + i.Unlock() + + // Remove from the log writer + if client.logStreamer != nil { + i.logWriter.DeregisterHandler(client.logStreamer) + client.logStreamer.Stop() + } +} + +// handleClient is a long running routine that handles a single client +func (i *AgentRPC) handleClient(client *rpcClient) { + defer i.deregisterClient(client) + var reqHeader requestHeader + for { + // Decode the header + if err := client.dec.Decode(&reqHeader); err != nil { + if err != io.EOF && !i.stop { + i.logger.Printf("[ERR] agent.rpc: failed to decode request header: %v", err) + } + return + } + + // Evaluate the command + if err := i.handleRequest(client, &reqHeader); err != nil { + i.logger.Printf("[ERR] agent.rpc: Failed to evaluate request: %v", err) + return + } + } +} + +// handleRequest is used to evaluate a single client command +func (i *AgentRPC) handleRequest(client *rpcClient, reqHeader *requestHeader) error { + // Look for a command field + command := reqHeader.Command + seq := reqHeader.Seq + + // Ensure the handshake is performed before other commands + if command != handshakeCommand && client.version == 0 { + respHeader := responseHeader{Seq: seq, Error: handshakeRequired} + client.Send(&respHeader, nil) + return fmt.Errorf(handshakeRequired) + } + + // Dispatch command specific handlers + switch command { + case handshakeCommand: + return i.handleHandshake(client, seq) + + case membersLANCommand: + return i.handleMembersLAN(client, seq) + + case membersWANCommand: + return i.handleMembersWAN(client, seq) + + case monitorCommand: + return i.handleMonitor(client, seq) + + case stopCommand: + return i.handleStop(client, seq) + + case forceLeaveCommand: + return i.handleForceLeave(client, seq) + + case joinCommand: + return i.handleJoin(client, seq) + + case leaveCommand: + return i.handleLeave(client, seq) + + default: + respHeader := responseHeader{Seq: seq, Error: unsupportedCommand} + client.Send(&respHeader, nil) + return fmt.Errorf("command '%s' not recognized", command) + } +} + +func (i *AgentRPC) handleHandshake(client *rpcClient, seq uint64) error { + var req handshakeRequest + if err := client.dec.Decode(&req); err != nil { + return fmt.Errorf("decode failed: %v", err) + } + + resp := responseHeader{ + Seq: seq, + Error: "", + } + + // Check the version + if req.Version < MinRPCVersion || req.Version > MaxRPCVersion { + resp.Error = unsupportedRPCVersion + } else if client.version != 0 { + resp.Error = duplicateHandshake + } else { + client.version = req.Version + } + return client.Send(&resp, nil) +} + +func (i *AgentRPC) handleForceLeave(client *rpcClient, seq uint64) error { + var req forceLeaveRequest + if err := client.dec.Decode(&req); err != nil { + return fmt.Errorf("decode failed: %v", err) + } + + // Attempt leave + err := i.agent.ForceLeave(req.Node) + + // Respond + resp := responseHeader{ + Seq: seq, + Error: errToString(err), + } + return client.Send(&resp, nil) +} + +func (i *AgentRPC) handleJoin(client *rpcClient, seq uint64) error { + var req joinRequest + if err := client.dec.Decode(&req); err != nil { + return fmt.Errorf("decode failed: %v", err) + } + + // Attempt the join + var num int + var err error + if req.WAN { + num, err = i.agent.JoinWAN(req.Existing) + } else { + num, err = i.agent.JoinLAN(req.Existing) + } + + // Respond + header := responseHeader{ + Seq: seq, + Error: errToString(err), + } + resp := joinResponse{ + Num: int32(num), + } + return client.Send(&header, &resp) +} + +func (i *AgentRPC) handleMembersLAN(client *rpcClient, seq uint64) error { + raw := i.agent.LANMembers() + return formatMembers(raw, client, seq) +} + +func (i *AgentRPC) handleMembersWAN(client *rpcClient, seq uint64) error { + raw := i.agent.WANMembers() + return formatMembers(raw, client, seq) +} + +func formatMembers(raw []serf.Member, client *rpcClient, seq uint64) error { + members := make([]Member, 0, len(raw)) + for _, m := range raw { + sm := Member{ + Name: m.Name, + Addr: m.Addr, + Port: m.Port, + Role: m.Role, + Status: m.Status.String(), + ProtocolMin: m.ProtocolMin, + ProtocolMax: m.ProtocolMax, + ProtocolCur: m.ProtocolCur, + DelegateMin: m.DelegateMin, + DelegateMax: m.DelegateMax, + DelegateCur: m.DelegateCur, + } + members = append(members, sm) + } + + header := responseHeader{ + Seq: seq, + Error: "", + } + resp := membersResponse{ + Members: members, + } + return client.Send(&header, &resp) +} + +func (i *AgentRPC) handleMonitor(client *rpcClient, seq uint64) error { + var req monitorRequest + if err := client.dec.Decode(&req); err != nil { + return fmt.Errorf("decode failed: %v", err) + } + + resp := responseHeader{ + Seq: seq, + Error: "", + } + + // Upper case the log level + req.LogLevel = strings.ToUpper(req.LogLevel) + + // Create a level filter + filter := LevelFilter() + filter.MinLevel = logutils.LogLevel(req.LogLevel) + if !ValidateLevelFilter(filter.MinLevel, filter) { + resp.Error = fmt.Sprintf("Unknown log level: %s", filter.MinLevel) + goto SEND + } + + // Check if there is an existing monitor + if client.logStreamer != nil { + resp.Error = monitorExists + goto SEND + } + + // Create a log streamer + client.logStreamer = newLogStream(client, filter, seq, i.logger) + + // Register with the log writer. Defer so that we can respond before + // registration, avoids any possible race condition + defer i.logWriter.RegisterHandler(client.logStreamer) + +SEND: + return client.Send(&resp, nil) +} + +func (i *AgentRPC) handleStop(client *rpcClient, seq uint64) error { + var req stopRequest + if err := client.dec.Decode(&req); err != nil { + return fmt.Errorf("decode failed: %v", err) + } + + // Remove a log monitor if any + if client.logStreamer != nil && client.logStreamer.seq == req.Stop { + i.logWriter.DeregisterHandler(client.logStreamer) + client.logStreamer.Stop() + client.logStreamer = nil + } + + // Always succeed + resp := responseHeader{Seq: seq, Error: ""} + return client.Send(&resp, nil) +} + +func (i *AgentRPC) handleLeave(client *rpcClient, seq uint64) error { + i.logger.Printf("[INFO] agent.rpc: Graceful leave triggered") + + // Do the leave + err := i.agent.Leave() + if err != nil { + i.logger.Printf("[ERR] agent.rpc: leave failed: %v", err) + } + resp := responseHeader{Seq: seq, Error: errToString(err)} + + // Send and wait + err = client.Send(&resp, nil) + + // Trigger a shutdown! + if err := i.agent.Shutdown(); err != nil { + i.logger.Printf("[ERR] agent.rpc: shutdown failed: %v", err) + } + return err +} + +// Used to convert an error to a string representation +func errToString(err error) string { + if err == nil { + return "" + } + return err.Error() +} diff --git a/command/agent/rpc_client.go b/command/agent/rpc_client.go new file mode 100644 index 0000000000..c1f486f9b6 --- /dev/null +++ b/command/agent/rpc_client.go @@ -0,0 +1,399 @@ +package agent + +import ( + "bufio" + "fmt" + "github.com/hashicorp/logutils" + "github.com/ugorji/go/codec" + "log" + "net" + "sync" + "sync/atomic" +) + +var ( + clientClosed = fmt.Errorf("client closed") +) + +type seqCallback struct { + handler func(*responseHeader) +} + +func (sc *seqCallback) Handle(resp *responseHeader) { + sc.handler(resp) +} +func (sc *seqCallback) Cleanup() {} + +// seqHandler interface is used to handle responses +type seqHandler interface { + Handle(*responseHeader) + Cleanup() +} + +// RPCClient is the RPC client to make requests to the agent RPC. +type RPCClient struct { + seq uint64 + + conn *net.TCPConn + reader *bufio.Reader + writer *bufio.Writer + dec *codec.Decoder + enc *codec.Encoder + writeLock sync.Mutex + + dispatch map[uint64]seqHandler + dispatchLock sync.Mutex + + shutdown bool + shutdownCh chan struct{} + shutdownLock sync.Mutex +} + +// send is used to send an object using the MsgPack encoding. send +// is serialized to prevent write overlaps, while properly buffering. +func (c *RPCClient) send(header *requestHeader, obj interface{}) error { + c.writeLock.Lock() + defer c.writeLock.Unlock() + + if c.shutdown { + return clientClosed + } + + if err := c.enc.Encode(header); err != nil { + return err + } + + if obj != nil { + if err := c.enc.Encode(obj); err != nil { + return err + } + } + + if err := c.writer.Flush(); err != nil { + return err + } + + return nil +} + +// NewRPCClient is used to create a new RPC client given the address. +// This will properly dial, handshake, and start listening +func NewRPCClient(addr string) (*RPCClient, error) { + // Try to dial to agent + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + + // Create the client + client := &RPCClient{ + seq: 0, + conn: conn.(*net.TCPConn), + reader: bufio.NewReader(conn), + writer: bufio.NewWriter(conn), + dispatch: make(map[uint64]seqHandler), + shutdownCh: make(chan struct{}), + } + client.dec = codec.NewDecoder(client.reader, + &codec.MsgpackHandle{RawToString: true, WriteExt: true}) + client.enc = codec.NewEncoder(client.writer, + &codec.MsgpackHandle{RawToString: true, WriteExt: true}) + go client.listen() + + // Do the initial handshake + if err := client.handshake(); err != nil { + client.Close() + return nil, err + } + return client, err +} + +// StreamHandle is an opaque handle passed to stop to stop streaming +type StreamHandle uint64 + +// Close is used to free any resources associated with the client +func (c *RPCClient) Close() error { + c.shutdownLock.Lock() + defer c.shutdownLock.Unlock() + + if !c.shutdown { + c.shutdown = true + close(c.shutdownCh) + c.deregisterAll() + return c.conn.Close() + } + return nil +} + +// ForceLeave is used to ask the agent to issue a leave command for +// a given node +func (c *RPCClient) ForceLeave(node string) error { + header := requestHeader{ + Command: forceLeaveCommand, + Seq: c.getSeq(), + } + req := forceLeaveRequest{ + Node: node, + } + return c.genericRPC(&header, &req, nil) +} + +// Join is used to instruct the agent to attempt a join +func (c *RPCClient) Join(addrs []string, wan bool) (int, error) { + header := requestHeader{ + Command: joinCommand, + Seq: c.getSeq(), + } + req := joinRequest{ + Existing: addrs, + WAN: wan, + } + var resp joinResponse + + err := c.genericRPC(&header, &req, &resp) + return int(resp.Num), err +} + +// LANMembers is used to fetch a list of known members +func (c *RPCClient) LANMembers() ([]Member, error) { + header := requestHeader{ + Command: membersLANCommand, + Seq: c.getSeq(), + } + var resp membersResponse + + err := c.genericRPC(&header, nil, &resp) + return resp.Members, err +} + +// WANMembers is used to fetch a list of known members +func (c *RPCClient) WANMembers() ([]Member, error) { + header := requestHeader{ + Command: membersWANCommand, + Seq: c.getSeq(), + } + var resp membersResponse + + err := c.genericRPC(&header, nil, &resp) + return resp.Members, err +} + +// Leave is used to trigger a graceful leave and shutdown +func (c *RPCClient) Leave() error { + header := requestHeader{ + Command: leaveCommand, + Seq: c.getSeq(), + } + return c.genericRPC(&header, nil, nil) +} + +type monitorHandler struct { + client *RPCClient + closed bool + init bool + initCh chan<- error + logCh chan<- string + seq uint64 +} + +func (mh *monitorHandler) Handle(resp *responseHeader) { + // Initialize on the first response + if !mh.init { + mh.init = true + mh.initCh <- strToError(resp.Error) + return + } + + // Decode logs for all other responses + var rec logRecord + if err := mh.client.dec.Decode(&rec); err != nil { + log.Printf("[ERR] Failed to decode log: %v", err) + mh.client.deregisterHandler(mh.seq) + return + } + select { + case mh.logCh <- rec.Log: + default: + log.Printf("[ERR] Dropping log! Monitor channel full") + } +} + +func (mh *monitorHandler) Cleanup() { + if !mh.closed { + if !mh.init { + mh.init = true + mh.initCh <- fmt.Errorf("Stream closed") + } + close(mh.logCh) + mh.closed = true + } +} + +// Monitor is used to subscribe to the logs of the agent +func (c *RPCClient) Monitor(level logutils.LogLevel, ch chan<- string) (StreamHandle, error) { + // Setup the request + seq := c.getSeq() + header := requestHeader{ + Command: monitorCommand, + Seq: seq, + } + req := monitorRequest{ + LogLevel: string(level), + } + + // Create a monitor handler + initCh := make(chan error, 1) + handler := &monitorHandler{ + client: c, + initCh: initCh, + logCh: ch, + seq: seq, + } + c.handleSeq(seq, handler) + + // Send the request + if err := c.send(&header, &req); err != nil { + c.deregisterHandler(seq) + return 0, err + } + + // Wait for a response + select { + case err := <-initCh: + return StreamHandle(seq), err + case <-c.shutdownCh: + c.deregisterHandler(seq) + return 0, clientClosed + } +} + +// Stop is used to unsubscribe from logs or event streams +func (c *RPCClient) Stop(handle StreamHandle) error { + // Deregister locally first to stop delivery + c.deregisterHandler(uint64(handle)) + + header := requestHeader{ + Command: stopCommand, + Seq: c.getSeq(), + } + req := stopRequest{ + Stop: uint64(handle), + } + return c.genericRPC(&header, &req, nil) +} + +// handshake is used to perform the initial handshake on connect +func (c *RPCClient) handshake() error { + header := requestHeader{ + Command: handshakeCommand, + Seq: c.getSeq(), + } + req := handshakeRequest{ + Version: MaxRPCVersion, + } + return c.genericRPC(&header, &req, nil) +} + +// genericRPC is used to send a request and wait for an +// errorSequenceResponse, potentially returning an error +func (c *RPCClient) genericRPC(header *requestHeader, req interface{}, resp interface{}) error { + // Setup a response handler + errCh := make(chan error, 1) + handler := func(respHeader *responseHeader) { + if resp != nil { + err := c.dec.Decode(resp) + if err != nil { + errCh <- err + return + } + } + errCh <- strToError(respHeader.Error) + } + c.handleSeq(header.Seq, &seqCallback{handler: handler}) + defer c.deregisterHandler(header.Seq) + + // Send the request + if err := c.send(header, req); err != nil { + return err + } + + // Wait for a response + select { + case err := <-errCh: + return err + case <-c.shutdownCh: + return clientClosed + } +} + +// strToError converts a string to an error if not blank +func strToError(s string) error { + if s != "" { + return fmt.Errorf(s) + } + return nil +} + +// getSeq returns the next sequence number in a safe manner +func (c *RPCClient) getSeq() uint64 { + return atomic.AddUint64(&c.seq, 1) +} + +// deregisterAll is used to deregister all handlers +func (c *RPCClient) deregisterAll() { + c.dispatchLock.Lock() + defer c.dispatchLock.Unlock() + + for _, seqH := range c.dispatch { + seqH.Cleanup() + } + c.dispatch = make(map[uint64]seqHandler) +} + +// deregisterHandler is used to deregister a handler +func (c *RPCClient) deregisterHandler(seq uint64) { + c.dispatchLock.Lock() + seqH, ok := c.dispatch[seq] + delete(c.dispatch, seq) + c.dispatchLock.Unlock() + + if ok { + seqH.Cleanup() + } +} + +// handleSeq is used to setup a handlerto wait on a response for +// a given sequence number. +func (c *RPCClient) handleSeq(seq uint64, handler seqHandler) { + c.dispatchLock.Lock() + defer c.dispatchLock.Unlock() + c.dispatch[seq] = handler +} + +// respondSeq is used to respond to a given sequence number +func (c *RPCClient) respondSeq(seq uint64, respHeader *responseHeader) { + c.dispatchLock.Lock() + seqL, ok := c.dispatch[seq] + c.dispatchLock.Unlock() + + // Get a registered listener, ignore if none + if ok { + seqL.Handle(respHeader) + } +} + +// listen is used to processes data coming over the RPC channel, +// and wrote it to the correct destination based on seq no +func (c *RPCClient) listen() { + defer c.Close() + var respHeader responseHeader + for { + if err := c.dec.Decode(&respHeader); err != nil { + if !c.shutdown { + log.Printf("[ERR] agent.client: Failed to decode response header: %v", err) + } + break + } + c.respondSeq(respHeader.Seq, &respHeader) + } +} diff --git a/command/agent/rpc_client_test.go b/command/agent/rpc_client_test.go new file mode 100644 index 0000000000..a0ca5c641a --- /dev/null +++ b/command/agent/rpc_client_test.go @@ -0,0 +1,264 @@ +package agent + +import ( + "fmt" + "github.com/hashicorp/serf/serf" + "github.com/hashicorp/serf/testutil" + "io" + "net" + "os" + "strings" + "testing" + "time" +) + +type rpcParts struct { + dir string + client *RPCClient + agent *Agent + rpc *AgentRPC +} + +func (r *rpcParts) Close() { + r.client.Close() + r.rpc.Shutdown() + r.agent.Shutdown() + os.RemoveAll(r.dir) +} + +// testRPCClient returns an RPCClient connected to an RPC server that +// serves only this connection. +func testRPCClient(t *testing.T) *rpcParts { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + lw := NewLogWriter(512) + mult := io.MultiWriter(os.Stderr, lw) + + conf := nextConfig() + dir, agent := makeAgentLog(t, conf, mult) + rpc := NewAgentRPC(agent, l, mult, lw) + + rpcClient, err := NewRPCClient(l.Addr().String()) + if err != nil { + t.Fatalf("err: %s", err) + } + + return &rpcParts{ + dir: dir, + client: rpcClient, + agent: agent, + rpc: rpc, + } +} + +func TestRPCClientForceLeave(t *testing.T) { + p1 := testRPCClient(t) + p2 := testRPCClient(t) + defer p1.Close() + defer p2.Close() + testutil.Yield() + + s2Addr := fmt.Sprintf("127.0.0.1:%d", p2.agent.config.SerfLanPort) + if _, err := p1.agent.JoinLAN([]string{s2Addr}); err != nil { + t.Fatalf("err: %s", err) + } + + testutil.Yield() + + if err := p2.agent.Shutdown(); err != nil { + t.Fatalf("err: %s", err) + } + + time.Sleep(time.Second) + + if err := p1.client.ForceLeave(p2.agent.config.NodeName); err != nil { + t.Fatalf("err: %s", err) + } + + testutil.Yield() + + m := p1.agent.LANMembers() + if len(m) != 2 { + t.Fatalf("should have 2 members: %#v", m) + } + + if m[1].Status != serf.StatusLeft { + t.Fatalf("should be left: %#v %v", m[1], m[1].Status == serf.StatusLeft) + } +} + +func TestRPCClientJoinLAN(t *testing.T) { + p1 := testRPCClient(t) + p2 := testRPCClient(t) + defer p1.Close() + defer p2.Close() + testutil.Yield() + + s2Addr := fmt.Sprintf("127.0.0.1:%d", p2.agent.config.SerfLanPort) + n, err := p1.client.Join([]string{s2Addr}, false) + if err != nil { + t.Fatalf("err: %s", err) + } + + if n != 1 { + t.Fatalf("n != 1: %d", n) + } +} + +func TestRPCClientJoinWAN(t *testing.T) { + p1 := testRPCClient(t) + p2 := testRPCClient(t) + defer p1.Close() + defer p2.Close() + testutil.Yield() + + s2Addr := fmt.Sprintf("127.0.0.1:%d", p2.agent.config.SerfWanPort) + n, err := p1.client.Join([]string{s2Addr}, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + if n != 1 { + t.Fatalf("n != 1: %d", n) + } +} + +func TestRPCClientLANMembers(t *testing.T) { + p1 := testRPCClient(t) + p2 := testRPCClient(t) + defer p1.Close() + defer p2.Close() + testutil.Yield() + + mem, err := p1.client.LANMembers() + if err != nil { + t.Fatalf("err: %s", err) + } + + if len(mem) != 1 { + t.Fatalf("bad: %#v", mem) + } + + s2Addr := fmt.Sprintf("127.0.0.1:%d", p2.agent.config.SerfLanPort) + _, err = p1.client.Join([]string{s2Addr}, false) + if err != nil { + t.Fatalf("err: %s", err) + } + + testutil.Yield() + + mem, err = p1.client.LANMembers() + if err != nil { + t.Fatalf("err: %s", err) + } + + if len(mem) != 2 { + t.Fatalf("bad: %#v", mem) + } +} + +func TestRPCClientWANMembers(t *testing.T) { + p1 := testRPCClient(t) + p2 := testRPCClient(t) + defer p1.Close() + defer p2.Close() + testutil.Yield() + + mem, err := p1.client.WANMembers() + if err != nil { + t.Fatalf("err: %s", err) + } + + if len(mem) != 1 { + t.Fatalf("bad: %#v", mem) + } + + s2Addr := fmt.Sprintf("127.0.0.1:%d", p2.agent.config.SerfWanPort) + _, err = p1.client.Join([]string{s2Addr}, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + testutil.Yield() + + mem, err = p1.client.WANMembers() + if err != nil { + t.Fatalf("err: %s", err) + } + + if len(mem) != 2 { + t.Fatalf("bad: %#v", mem) + } +} + +func TestRPCClientLeave(t *testing.T) { + p1 := testRPCClient(t) + defer p1.Close() + testutil.Yield() + + if err := p1.client.Leave(); err != nil { + t.Fatalf("err: %s", err) + } + + testutil.Yield() + + select { + case <-p1.agent.ShutdownCh(): + default: + t.Fatalf("agent should be shutdown!") + } +} + +func TestRPCClientMonitor(t *testing.T) { + p1 := testRPCClient(t) + defer p1.Close() + testutil.Yield() + + eventCh := make(chan string, 64) + if handle, err := p1.client.Monitor("debug", eventCh); err != nil { + t.Fatalf("err: %s", err) + } else { + defer p1.client.Stop(handle) + } + + testutil.Yield() + + found := false +OUTER1: + for { + select { + case e := <-eventCh: + if strings.Contains(e, "Accepted client") { + found = true + } + default: + break OUTER1 + } + } + if !found { + t.Fatalf("should log client accept") + } + + // Join a bad thing to generate more events + p1.agent.JoinLAN(nil) + testutil.Yield() + + found = false +OUTER2: + for { + select { + case e := <-eventCh: + if strings.Contains(e, "joining") { + found = true + } + default: + break OUTER2 + } + } + if !found { + t.Fatalf("should log joining") + } +} diff --git a/command/agent/rpc_log_stream.go b/command/agent/rpc_log_stream.go new file mode 100644 index 0000000000..a561e082f0 --- /dev/null +++ b/command/agent/rpc_log_stream.go @@ -0,0 +1,68 @@ +package agent + +import ( + "github.com/hashicorp/logutils" + "log" +) + +type streamClient interface { + Send(*responseHeader, interface{}) error +} + +// logStream is used to stream logs to a client over RPC +type logStream struct { + client streamClient + filter *logutils.LevelFilter + logCh chan string + logger *log.Logger + seq uint64 +} + +func newLogStream(client streamClient, filter *logutils.LevelFilter, + seq uint64, logger *log.Logger) *logStream { + ls := &logStream{ + client: client, + filter: filter, + logCh: make(chan string, 512), + logger: logger, + seq: seq, + } + go ls.stream() + return ls +} + +func (ls *logStream) HandleLog(l string) { + // Check the log level + if !ls.filter.Check([]byte(l)) { + return + } + + // Do a non-blocking send + select { + case ls.logCh <- l: + default: + // We can't log syncronously, since we are already being invoked + // from the logWriter, and a log will need to invoke Write() which + // already holds the lock. We must therefor do the log async, so + // as to not deadlock + go ls.logger.Printf("[WARN] Dropping logs to %v", ls.client) + } +} + +func (ls *logStream) Stop() { + close(ls.logCh) +} + +func (ls *logStream) stream() { + header := responseHeader{Seq: ls.seq, Error: ""} + rec := logRecord{Log: ""} + + for line := range ls.logCh { + rec.Log = line + if err := ls.client.Send(&header, &rec); err != nil { + ls.logger.Printf("[ERR] Failed to stream log to %v: %v", + ls.client, err) + return + } + } +} diff --git a/command/agent/rpc_log_stream_test.go b/command/agent/rpc_log_stream_test.go new file mode 100644 index 0000000000..ea0412cf3f --- /dev/null +++ b/command/agent/rpc_log_stream_test.go @@ -0,0 +1,54 @@ +package agent + +import ( + "github.com/hashicorp/logutils" + "log" + "os" + "testing" + "time" +) + +type MockStreamClient struct { + headers []*responseHeader + objs []interface{} + err error +} + +func (m *MockStreamClient) Send(h *responseHeader, o interface{}) error { + m.headers = append(m.headers, h) + m.objs = append(m.objs, o) + return m.err +} + +func TestRPCLogStream(t *testing.T) { + sc := &MockStreamClient{} + filter := LevelFilter() + filter.MinLevel = logutils.LogLevel("INFO") + + ls := newLogStream(sc, filter, 42, log.New(os.Stderr, "", log.LstdFlags)) + defer ls.Stop() + + log := "[DEBUG] this is a test log" + log2 := "[INFO] This should pass" + ls.HandleLog(log) + ls.HandleLog(log2) + + time.Sleep(5 * time.Millisecond) + + if len(sc.headers) != 1 { + t.Fatalf("expected 1 messages!") + } + for _, h := range sc.headers { + if h.Seq != 42 { + t.Fatalf("bad seq") + } + if h.Error != "" { + t.Fatalf("bad err") + } + } + + obj1 := sc.objs[0].(*logRecord) + if obj1.Log != log2 { + t.Fatalf("bad event %#v", obj1) + } +}