Working on the agent

pull/19/head
Armon Dadgar 2013-12-20 16:39:32 -08:00
parent aeccadd217
commit 8caf0034db
10 changed files with 700 additions and 13 deletions

View File

@ -3,6 +3,10 @@ package agent
import (
"fmt"
"github.com/hashicorp/consul/consul"
"io"
"log"
"os"
"sync"
)
/*
@ -16,17 +20,35 @@ import (
type Agent struct {
config *Config
// Used for writing our logs
logger *log.Logger
// Output sink for logs
logOutput io.Writer
// We have one of a client or a server, depending
// on our configuration
server *consul.Server
client *consul.Client
shutdown bool
shutdownCh chan struct{}
shutdownLock sync.Mutex
}
// Create is used to create a new Agent. Returns
// the agent or potentially an error.
func Create(config *Config) (*Agent, error) {
func Create(config *Config, logOutput io.Writer) (*Agent, error) {
// Ensure we have a log sink
if logOutput == nil {
logOutput = os.Stderr
}
agent := &Agent{
config: config,
config: config,
logger: log.New(logOutput, "", log.LstdFlags),
logOutput: logOutput,
shutdownCh: make(chan struct{}),
}
// Setup either the client or the server
@ -60,6 +82,11 @@ func (a *Agent) consulConfig() *consul.Config {
if a.config.DataDir != "" {
base.DataDir = a.config.DataDir
}
if a.config.EncryptKey != "" {
key, _ := a.config.EncryptBytes()
base.SerfLANConfig.MemberlistConfig.SecretKey = key
base.SerfWANConfig.MemberlistConfig.SecretKey = key
}
if a.config.NodeName != "" {
base.NodeName = a.config.NodeName
}
@ -73,10 +100,12 @@ func (a *Agent) consulConfig() *consul.Config {
if a.config.SerfWanPort != 0 {
base.SerfWANConfig.MemberlistConfig.Port = a.config.SerfWanPort
}
if a.config.ServerRPCAddr != "" {
base.RPCAddr = a.config.ServerRPCAddr
if a.config.ServerAddr != "" {
base.RPCAddr = a.config.ServerAddr
}
// Setup the loggers
base.LogOutput = a.logOutput
return base
}
@ -121,9 +150,29 @@ func (a *Agent) Leave() error {
// Shutdown is used to hard stop the agent. Should be preceeded
// by a call to Leave to do it gracefully.
func (a *Agent) Shutdown() error {
if a.server != nil {
return a.server.Shutdown()
} else {
return a.client.Shutdown()
a.shutdownLock.Lock()
defer a.shutdownLock.Unlock()
if a.shutdown {
return nil
}
a.logger.Println("[INFO] agent: requesting shutdown")
var err error
if a.server != nil {
err = a.server.Shutdown()
} else {
err = a.client.Shutdown()
}
a.logger.Println("[INFO] agent: shutdown complete")
a.shutdown = true
close(a.shutdownCh)
return err
}
// ShutdownCh returns a channel that can be selected to wait
// for the agent to perform a shutdown.
func (a *Agent) ShutdownCh() <-chan struct{} {
return a.shutdownCh
}

View File

@ -1,11 +1,21 @@
package agent
import (
"flag"
"fmt"
"github.com/hashicorp/logutils"
"github.com/mitchellh/cli"
"io"
"os"
"os/signal"
"strings"
"syscall"
"time"
)
// gracefulTimeout controls how long we wait before forcefully terminating
var gracefulTimeout = 5 * time.Second
// Command is a Command implementation that runs a Serf agent.
// The command will not end unless a shutdown message is sent on the
// ShutdownCh. If two messages are sent on the ShutdownCh it will forcibly
@ -13,6 +23,90 @@ import (
type Command struct {
Ui cli.Ui
ShutdownCh <-chan struct{}
args []string
logFilter *logutils.LevelFilter
}
// readConfig is responsible for setup of our configuration using
// the command line and any file configs
func (c *Command) readConfig() *Config {
var cmdConfig Config
var configFiles []string
cmdFlags := flag.NewFlagSet("agent", flag.ContinueOnError)
cmdFlags.Usage = func() { c.Ui.Output(c.Help()) }
cmdFlags.StringVar(&cmdConfig.SerfBindAddr, "serf-bind", "", "address to bind serf listeners to")
cmdFlags.StringVar(&cmdConfig.ServerAddr, "server-addr", "", "address to bind server listeners to")
cmdFlags.Var((*AppendSliceValue)(&configFiles), "config-file",
"json file to read config from")
cmdFlags.Var((*AppendSliceValue)(&configFiles), "config-dir",
"directory of json files to read")
cmdFlags.StringVar(&cmdConfig.EncryptKey, "encrypt", "", "encryption key")
cmdFlags.StringVar(&cmdConfig.LogLevel, "log-level", "", "log level")
cmdFlags.StringVar(&cmdConfig.NodeName, "node", "", "node name")
cmdFlags.StringVar(&cmdConfig.RPCAddr, "rpc-addr", "",
"address to bind RPC listener to")
cmdFlags.StringVar(&cmdConfig.DataDir, "data", "", "path to the data directory")
cmdFlags.StringVar(&cmdConfig.Datacenter, "dc", "", "node datacenter")
cmdFlags.BoolVar(&cmdConfig.Server, "server", false, "enable server mode")
if err := cmdFlags.Parse(c.args); err != nil {
return nil
}
config := DefaultConfig()
if len(configFiles) > 0 {
fileConfig, err := ReadConfigPaths(configFiles)
if err != nil {
c.Ui.Error(err.Error())
return nil
}
config = MergeConfig(config, fileConfig)
}
config = MergeConfig(config, &cmdConfig)
if config.NodeName == "" {
hostname, err := os.Hostname()
if err != nil {
c.Ui.Error(fmt.Sprintf("Error determining hostname: %s", err))
return nil
}
config.NodeName = hostname
}
if config.EncryptKey != "" {
if _, err := config.EncryptBytes(); err != nil {
c.Ui.Error(fmt.Sprintf("Invalid encryption key: %s", err))
return nil
}
}
return config
}
// setupLoggers is used to setup the logGate, logWriter, and our logOutput
func (c *Command) setupLoggers(config *Config) (*GatedWriter, *logWriter, io.Writer) {
// Setup logging. First create the gated log writer, which will
// store logs until we're ready to show them. Then create the level
// filter, filtering logs of the specified level.
logGate := &GatedWriter{
Writer: &cli.UiWriter{Ui: c.Ui},
}
c.logFilter = LevelFilter()
c.logFilter.MinLevel = logutils.LogLevel(strings.ToUpper(config.LogLevel))
c.logFilter.Writer = logGate
if !ValidateLevelFilter(c.logFilter.MinLevel, c.logFilter) {
c.Ui.Error(fmt.Sprintf(
"Invalid log level: %s. Valid log levels are: %v",
c.logFilter.MinLevel, c.logFilter.Levels))
return nil, nil, nil
}
// Create a log writer, and wrap a logOutput around it
logWriter := NewLogWriter(512)
logOutput := io.MultiWriter(c.logFilter, logWriter)
return logGate, logWriter, logOutput
}
func (c *Command) Run(args []string) int {
@ -23,19 +117,109 @@ func (c *Command) Run(args []string) int {
Ui: c.Ui,
}
conf := DefaultConfig()
agent, err := Create(conf)
// Parse our configs
c.args = args
config := c.readConfig()
if config == nil {
return 1
}
c.args = args
// Setup the log outputs
logGate, logWriter, logOutput := c.setupLoggers(config)
if logWriter == nil {
return 1
}
// Create the agent
c.Ui.Output("Starting Consul agent...")
agent, err := Create(config, logOutput)
if err != nil {
c.Ui.Error(fmt.Sprintf("Error starting agent: %s", err))
return 1
}
defer agent.Shutdown()
c.Ui.Output("Consul agent running!")
c.Ui.Info(fmt.Sprintf("Node name: '%s'", config.NodeName))
c.Ui.Info(fmt.Sprintf(" RPC addr: '%s'", config.RPCAddr))
c.Ui.Info(fmt.Sprintf("Encrypted: %#v", config.EncryptKey != ""))
c.Ui.Info(fmt.Sprintf(" Server: %v", config.Server))
// Enable log streaming
c.Ui.Info("")
c.Ui.Output("Log data will now stream in as it occurs:\n")
logGate.Flush()
// Wait for exit
return c.handleSignals(config, agent)
}
// handleSignals blocks until we get an exit-causing signal
func (c *Command) handleSignals(config *Config, agent *Agent) int {
signalCh := make(chan os.Signal, 4)
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP)
// Wait for a signal
WAIT:
var sig os.Signal
select {
case s := <-signalCh:
sig = s
case <-c.ShutdownCh:
sig = os.Interrupt
case <-agent.ShutdownCh():
// Agent is already shutdown!
return 0
}
return 0
c.Ui.Output(fmt.Sprintf("Caught signal: %v", sig))
// Check if this is a SIGHUP
if sig == syscall.SIGHUP {
config = c.handleReload(config, agent)
goto WAIT
}
// Check if we should do a graceful leave
graceful := false
if sig == os.Interrupt && !config.SkipLeaveOnInt {
graceful = true
} else if sig == syscall.SIGTERM && config.LeaveOnTerm {
graceful = true
}
// Bail fast if not doing a graceful leave
if !graceful {
return 1
}
// Attempt a graceful leave
gracefulCh := make(chan struct{})
c.Ui.Output("Gracefully shutting down agent...")
go func() {
if err := agent.Leave(); err != nil {
c.Ui.Error(fmt.Sprintf("Error: %s", err))
return
}
close(gracefulCh)
}()
// Wait for leave or another signal
select {
case <-signalCh:
return 1
case <-time.After(gracefulTimeout):
return 1
case <-gracefulCh:
return 0
}
}
// handleReload is invoked when we should reload our configs, e.g. SIGHUP
func (c *Command) handleReload(config *Config, agent *Agent) *Config {
c.Ui.Output("Reloading configuration...")
// TODO
return config
}
func (c *Command) Synopsis() string {

View File

@ -1,7 +1,15 @@
package agent
import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/hashicorp/consul/consul"
"github.com/mitchellh/mapstructure"
"io"
"os"
"path/filepath"
"strings"
)
// This is the default port we use for co
@ -17,6 +25,9 @@ type Config struct {
// DataDir is the directory to store our state in
DataDir string
// Encryption key to use for the Serf communication
EncryptKey string
// LogLevel is the level of the logs to putout
LogLevel string
@ -39,15 +50,23 @@ type Config struct {
// This is only for the Consul servers
SerfWanPort int
// ServerRPCAddr is the address we use for Consul server communication.
// ServerAddr is the address we use for Consul server communication.
// Defaults to 0.0.0.0:8300
ServerRPCAddr string
ServerAddr string
// Server controls if this agent acts like a Consul server,
// or merely as a client. Servers have more state, take part
// in leader election, etc.
Server bool
// LeaveOnTerm controls if Serf does a graceful leave when receiving
// the TERM signal. Defaults false. This can be changed on reload.
LeaveOnTerm bool `mapstructure:"leave_on_terminate"`
// SkipLeaveOnInt controls if Serf skips a graceful leave when receiving
// the INT signal. Defaults false. This can be changed on reload.
SkipLeaveOnInt bool `mapstructure:"skip_leave_on_interrupt"`
// ConsulConfig can either be provided or a default one created
ConsulConfig *consul.Config
}
@ -60,3 +79,147 @@ func DefaultConfig() *Config {
Server: false,
}
}
// EncryptBytes returns the encryption key configured.
func (c *Config) EncryptBytes() ([]byte, error) {
return base64.StdEncoding.DecodeString(c.EncryptKey)
}
// DecodeConfig reads the configuration from the given reader in JSON
// format and decodes it into a proper Config structure.
func DecodeConfig(r io.Reader) (*Config, error) {
var raw interface{}
dec := json.NewDecoder(r)
if err := dec.Decode(&raw); err != nil {
return nil, err
}
// Decode
var md mapstructure.Metadata
var result Config
msdec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Metadata: &md,
Result: &result,
})
if err != nil {
return nil, err
}
if err := msdec.Decode(raw); err != nil {
return nil, err
}
return &result, nil
}
// MergeConfig merges two configurations together to make a single new
// configuration.
func MergeConfig(a, b *Config) *Config {
var result Config = *a
// Copy the strings if they're set
if b.Datacenter != "" {
result.Datacenter = b.Datacenter
}
if b.DataDir != "" {
result.DataDir = b.DataDir
}
if b.EncryptKey != "" {
result.EncryptKey = b.EncryptKey
}
if b.LogLevel != "" {
result.LogLevel = b.LogLevel
}
if b.RPCAddr != "" {
result.RPCAddr = b.RPCAddr
}
if b.SerfBindAddr != "" {
result.SerfBindAddr = b.SerfBindAddr
}
if b.SerfLanPort > 0 {
result.SerfLanPort = b.SerfLanPort
}
if b.SerfWanPort > 0 {
result.SerfWanPort = b.SerfWanPort
}
if b.ServerAddr != "" {
result.ServerAddr = b.ServerAddr
}
if b.Server == true {
result.Server = b.Server
}
if b.LeaveOnTerm == true {
result.LeaveOnTerm = true
}
if b.SkipLeaveOnInt == true {
result.SkipLeaveOnInt = true
}
return &result
}
// ReadConfigPaths reads the paths in the given order to load configurations.
// The paths can be to files or directories. If the path is a directory,
// we read one directory deep and read any files ending in ".json" as
// configuration files.
func ReadConfigPaths(paths []string) (*Config, error) {
result := new(Config)
for _, path := range paths {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("Error reading '%s': %s", path, err)
}
fi, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("Error reading '%s': %s", path, err)
}
if !fi.IsDir() {
config, err := DecodeConfig(f)
f.Close()
if err != nil {
return nil, fmt.Errorf("Error decoding '%s': %s", path, err)
}
result = MergeConfig(result, config)
continue
}
contents, err := f.Readdir(-1)
f.Close()
if err != nil {
return nil, fmt.Errorf("Error reading '%s': %s", path, err)
}
for _, fi := range contents {
// Don't recursively read contents
if fi.IsDir() {
continue
}
// If it isn't a JSON file, ignore it
if !strings.HasSuffix(fi.Name(), ".json") {
continue
}
subpath := filepath.Join(path, fi.Name())
f, err := os.Open(subpath)
if err != nil {
return nil, fmt.Errorf("Error reading '%s': %s", subpath, err)
}
config, err := DecodeConfig(f)
f.Close()
if err != nil {
return nil, fmt.Errorf("Error decoding '%s': %s", subpath, err)
}
result = MergeConfig(result, config)
}
}
return result, nil
}

View File

@ -0,0 +1,20 @@
package agent
import "strings"
// AppendSliceValue implements the flag.Value interface and allows multiple
// calls to the same variable to append a list.
type AppendSliceValue []string
func (s *AppendSliceValue) String() string {
return strings.Join(*s, ",")
}
func (s *AppendSliceValue) Set(value string) error {
if *s == nil {
*s = make([]string, 0, 1)
}
*s = append(*s, value)
return nil
}

View File

@ -0,0 +1,33 @@
package agent
import (
"flag"
"reflect"
"testing"
)
func TestAppendSliceValue_implements(t *testing.T) {
var raw interface{}
raw = new(AppendSliceValue)
if _, ok := raw.(flag.Value); !ok {
t.Fatalf("AppendSliceValue should be a Value")
}
}
func TestAppendSliceValueSet(t *testing.T) {
sv := new(AppendSliceValue)
err := sv.Set("foo")
if err != nil {
t.Fatalf("err: %s", err)
}
err = sv.Set("bar")
if err != nil {
t.Fatalf("err: %s", err)
}
expected := []string{"foo", "bar"}
if !reflect.DeepEqual([]string(*sv), expected) {
t.Fatalf("Bad: %#v", sv)
}
}

View File

@ -0,0 +1,43 @@
package agent
import (
"io"
"sync"
)
// GatedWriter is an io.Writer implementation that buffers all of its
// data into an internal buffer until it is told to let data through.
type GatedWriter struct {
Writer io.Writer
buf [][]byte
flush bool
lock sync.RWMutex
}
// Flush tells the GatedWriter to flush any buffered data and to stop
// buffering.
func (w *GatedWriter) Flush() {
w.lock.Lock()
w.flush = true
w.lock.Unlock()
for _, p := range w.buf {
w.Write(p)
}
w.buf = nil
}
func (w *GatedWriter) Write(p []byte) (n int, err error) {
w.lock.RLock()
defer w.lock.RUnlock()
if w.flush {
return w.Writer.Write(p)
}
p2 := make([]byte, len(p))
copy(p2, p)
w.buf = append(w.buf, p2)
return len(p), nil
}

View File

@ -0,0 +1,34 @@
package agent
import (
"bytes"
"io"
"testing"
)
func TestGatedWriter_impl(t *testing.T) {
var _ io.Writer = new(GatedWriter)
}
func TestGatedWriter(t *testing.T) {
buf := new(bytes.Buffer)
w := &GatedWriter{Writer: buf}
w.Write([]byte("foo\n"))
w.Write([]byte("bar\n"))
if buf.String() != "" {
t.Fatalf("bad: %s", buf.String())
}
w.Flush()
if buf.String() != "foo\nbar\n" {
t.Fatalf("bad: %s", buf.String())
}
w.Write([]byte("baz\n"))
if buf.String() != "foo\nbar\nbaz\n" {
t.Fatalf("bad: %s", buf.String())
}
}

View File

@ -0,0 +1,27 @@
package agent
import (
"github.com/hashicorp/logutils"
"io/ioutil"
)
// LevelFilter returns a LevelFilter that is configured with the log
// levels that we use.
func LevelFilter() *logutils.LevelFilter {
return &logutils.LevelFilter{
Levels: []logutils.LogLevel{"TRACE", "DEBUG", "INFO", "WARN", "ERR"},
MinLevel: "INFO",
Writer: ioutil.Discard,
}
}
// ValidateLevelFilter verifies that the log levels within the filter
// are valid.
func ValidateLevelFilter(minLevel logutils.LogLevel, filter *logutils.LevelFilter) bool {
for _, level := range filter.Levels {
if level == minLevel {
return true
}
}
return false
}

View File

@ -0,0 +1,83 @@
package agent
import (
"sync"
)
// LogHandler interface is used for clients that want to subscribe
// to logs, for example to stream them over an IPC mechanism
type LogHandler interface {
HandleLog(string)
}
// logWriter implements io.Writer so it can be used as a log sink.
// It maintains a circular buffer of logs, and a set of handlers to
// which it can stream the logs to.
type logWriter struct {
sync.Mutex
logs []string
index int
handlers map[LogHandler]struct{}
}
// NewLogWriter creates a logWriter with the given buffer capacity
func NewLogWriter(buf int) *logWriter {
return &logWriter{
logs: make([]string, buf),
index: 0,
handlers: make(map[LogHandler]struct{}),
}
}
// RegisterHandler adds a log handler to recieve logs, and sends
// the last buffered logs to the handler
func (l *logWriter) RegisterHandler(lh LogHandler) {
l.Lock()
defer l.Unlock()
// Do nothing if already registered
if _, ok := l.handlers[lh]; ok {
return
}
// Register
l.handlers[lh] = struct{}{}
// Send the old logs
if l.logs[l.index] != "" {
for i := l.index; i < len(l.logs); i++ {
lh.HandleLog(l.logs[i])
}
}
for i := 0; i < l.index; i++ {
lh.HandleLog(l.logs[i])
}
}
// DeregisterHandler removes a LogHandler and prevents more invocations
func (l *logWriter) DeregisterHandler(lh LogHandler) {
l.Lock()
defer l.Unlock()
delete(l.handlers, lh)
}
// Write is used to accumulate new logs
func (l *logWriter) Write(p []byte) (n int, err error) {
l.Lock()
defer l.Unlock()
// Strip off newlines at the end if there are any since we store
// individual log lines in the agent.
n = len(p)
if p[n-1] == '\n' {
p = p[:n-1]
}
l.logs[l.index] = string(p)
l.index = (l.index + 1) % len(l.logs)
for lh, _ := range l.handlers {
lh.HandleLog(string(p))
}
return
}

View File

@ -0,0 +1,51 @@
package agent
import (
"testing"
)
type MockLogHandler struct {
logs []string
}
func (m *MockLogHandler) HandleLog(l string) {
m.logs = append(m.logs, l)
}
func TestLogWriter(t *testing.T) {
h := &MockLogHandler{}
w := NewLogWriter(4)
// Write some logs
w.Write([]byte("one")) // Gets dropped!
w.Write([]byte("two"))
w.Write([]byte("three"))
w.Write([]byte("four"))
w.Write([]byte("five"))
// Register a handler, sends old!
w.RegisterHandler(h)
w.Write([]byte("six"))
w.Write([]byte("seven"))
// Deregister
w.DeregisterHandler(h)
w.Write([]byte("eight"))
w.Write([]byte("nine"))
out := []string{
"two",
"three",
"four",
"five",
"six",
"seven",
}
for idx := range out {
if out[idx] != h.logs[idx] {
t.Fatalf("mismatch %v", h.logs)
}
}
}