diff --git a/.gitignore b/.gitignore index 70c75111be..f373cf40e9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ *.swp *.test .DS_Store +.fseventsd .vagrant/ /pkg Thumbs.db diff --git a/agent/agent.go b/agent/agent.go index f11b717824..4410ff2935 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -73,6 +73,7 @@ type delegate interface { SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer, replyFn structs.SnapshotReplyFn) error Shutdown() error Stats() map[string]map[string]string + enterpriseDelegate } // notifier is called after a successful JoinLAN. diff --git a/agent/consul/client.go b/agent/consul/client.go index f3d5fc6bbb..a69e76160e 100644 --- a/agent/consul/client.go +++ b/agent/consul/client.go @@ -72,6 +72,9 @@ type Client struct { shutdown bool shutdownCh chan struct{} shutdownLock sync.Mutex + + // embedded struct to hold all the enterprise specific data + EnterpriseClient } // NewClient is used to construct a new Consul client from the @@ -131,6 +134,11 @@ func NewClientLogger(config *Config, logger *log.Logger) (*Client, error) { shutdownCh: make(chan struct{}), } + if err := c.initEnterprise(); err != nil { + c.Shutdown() + return nil, err + } + // Initialize the LAN Serf c.serf, err = c.setupSerf(config.SerfLANConfig, c.eventCh, serfLANSnapshot) @@ -147,6 +155,11 @@ func NewClientLogger(config *Config, logger *log.Logger) (*Client, error) { // handlers depend on the router and the router depends on Serf. go c.lanEventHandler() + if err := c.startEnterprise(); err != nil { + c.Shutdown() + return nil, err + } + return c, nil } @@ -342,6 +355,17 @@ func (c *Client) Stats() map[string]map[string]string { "serf_lan": c.serf.Stats(), "runtime": runtimeStats(), } + + for outerKey, outerValue := range c.enterpriseStats() { + if _, ok := stats[outerKey]; ok { + for innerKey, innerValue := range outerValue { + stats[outerKey][innerKey] = innerValue + } + } else { + stats[outerKey] = outerValue + } + } + return stats } diff --git a/agent/consul/client_serf.go b/agent/consul/client_serf.go index c47811bb5c..a133626d39 100644 --- a/agent/consul/client_serf.go +++ b/agent/consul/client_serf.go @@ -135,6 +135,8 @@ func (c *Client) localEvent(event serf.UserEvent) { c.config.UserEventHandler(event) } default: - c.logger.Printf("[WARN] consul: Unhandled local event: %v", event) + if !c.handleEnterpriseUserEvents(event) { + c.logger.Printf("[WARN] consul: Unhandled local event: %v", event) + } } } diff --git a/agent/consul/enterprise_client_oss.go b/agent/consul/enterprise_client_oss.go new file mode 100644 index 0000000000..9f3a4cfe6a --- /dev/null +++ b/agent/consul/enterprise_client_oss.go @@ -0,0 +1,25 @@ +// +build !ent + +package consul + +import ( + "github.com/hashicorp/serf/serf" +) + +type EnterpriseClient struct{} + +func (c *Client) initEnterprise() error { + return nil +} + +func (c *Client) startEnterprise() error { + return nil +} + +func (c *Client) handleEnterpriseUserEvents(event serf.UserEvent) bool { + return false +} + +func (c *Client) enterpriseStats() map[string]map[string]string { + return nil +} diff --git a/agent/consul/enterprise_server_oss.go b/agent/consul/enterprise_server_oss.go new file mode 100644 index 0000000000..84b49403b6 --- /dev/null +++ b/agent/consul/enterprise_server_oss.go @@ -0,0 +1,32 @@ +// +build !ent + +package consul + +import ( + "net" + + "github.com/hashicorp/consul/agent/pool" + "github.com/hashicorp/serf/serf" +) + +type EnterpriseServer struct{} + +func (s *Server) initEnterprise() error { + return nil +} + +func (s *Server) startEnterprise() error { + return nil +} + +func (s *Server) handleEnterpriseUserEvents(event serf.UserEvent) bool { + return false +} + +func (s *Server) handleEnterpriseRPCConn(rtype pool.RPCType, conn net.Conn, isTLS bool) bool { + return false +} + +func (s *Server) enterpriseStats() map[string]map[string]string { + return nil +} diff --git a/agent/consul/rpc.go b/agent/consul/rpc.go index fcde830b0c..b983a868c1 100644 --- a/agent/consul/rpc.go +++ b/agent/consul/rpc.go @@ -115,9 +115,10 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) { s.handleSnapshotConn(conn) default: - s.logger.Printf("[ERR] consul.rpc: unrecognized RPC byte: %v %s", typ, logConn(conn)) - conn.Close() - return + if !s.handleEnterpriseRPCConn(typ, conn, isTLS) { + s.logger.Printf("[ERR] consul.rpc: unrecognized RPC byte: %v %s", typ, logConn(conn)) + conn.Close() + } } } diff --git a/agent/consul/server.go b/agent/consul/server.go index 128f67081c..23fbf337c3 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -208,6 +208,9 @@ type Server struct { shutdown bool shutdownCh chan struct{} shutdownLock sync.Mutex + + // embedded struct to hold all the enterprise specific data + EnterpriseServer } func NewServer(config *Config) (*Server, error) { @@ -297,6 +300,12 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store) (* shutdownCh: shutdownCh, } + // Initialize enterprise specific server functionality + if err := s.initEnterprise(); err != nil { + s.Shutdown() + return nil, err + } + // Initialize the stats fetcher that autopilot will use. s.statsFetcher = NewStatsFetcher(logger, s.connPool, s.config.Datacenter) @@ -338,6 +347,12 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store) (* return nil, fmt.Errorf("Failed to start Raft: %v", err) } + // Start enterprise specific functionality + if err := s.startEnterprise(); err != nil { + s.Shutdown() + return nil, err + } + // Serf and dynamic bind ports // // The LAN serf cluster announces the port of the WAN serf cluster @@ -1019,6 +1034,17 @@ func (s *Server) Stats() map[string]map[string]string { if s.serfWAN != nil { stats["serf_wan"] = s.serfWAN.Stats() } + + for outerKey, outerValue := range s.enterpriseStats() { + if _, ok := stats[outerKey]; ok { + for innerKey, innerValue := range outerValue { + stats[outerKey][innerKey] = innerValue + } + } else { + stats[outerKey] = outerValue + } + } + return stats } diff --git a/agent/consul/server_serf.go b/agent/consul/server_serf.go index 5d46b74ee4..e0781d912f 100644 --- a/agent/consul/server_serf.go +++ b/agent/consul/server_serf.go @@ -198,7 +198,9 @@ func (s *Server) localEvent(event serf.UserEvent) { s.config.UserEventHandler(event) } default: - s.logger.Printf("[WARN] consul: Unhandled local event: %v", event) + if !s.handleEnterpriseUserEvents(event) { + s.logger.Printf("[WARN] consul: Unhandled local event: %v", event) + } } } diff --git a/agent/enterprise_delegate_oss.go b/agent/enterprise_delegate_oss.go new file mode 100644 index 0000000000..3fdf3fce33 --- /dev/null +++ b/agent/enterprise_delegate_oss.go @@ -0,0 +1,6 @@ +// +build !ent + +package agent + +// enterpriseDelegate has no functions in OSS +type enterpriseDelegate interface{} diff --git a/agent/http.go b/agent/http.go index e639d7d7b6..b9791232dd 100644 --- a/agent/http.go +++ b/agent/http.go @@ -31,6 +31,15 @@ func (e MethodNotAllowedError) Error() string { return fmt.Sprintf("method %s not allowed", e.Method) } +// BadRequestError should be returned by a handler when parameters or the payload are not valid +type BadRequestError struct { + Reason string +} + +func (e BadRequestError) Error() string { + return fmt.Sprintf("Bad request: %s", e.Reason) +} + // HTTPServer provides an HTTP api for an agent. type HTTPServer struct { *http.Server @@ -249,6 +258,11 @@ func (s *HTTPServer) wrap(handler endpoint, methods []string) http.HandlerFunc { return ok } + isBadRequest := func(err error) bool { + _, ok := err.(BadRequestError) + return ok + } + addAllowHeader := func(methods []string) { resp.Header().Add("Allow", strings.Join(methods, ",")) } @@ -269,6 +283,9 @@ func (s *HTTPServer) wrap(handler endpoint, methods []string) http.HandlerFunc { addAllowHeader(err.(MethodNotAllowedError).Allow) resp.WriteHeader(http.StatusMethodNotAllowed) // 405 fmt.Fprint(resp, err.Error()) + case isBadRequest(err): + resp.WriteHeader(http.StatusBadRequest) + fmt.Fprint(resp, err.Error()) default: resp.WriteHeader(http.StatusInternalServerError) fmt.Fprint(resp, err.Error()) diff --git a/command/helpers/helpers.go b/command/helpers/helpers.go new file mode 100644 index 0000000000..6ad7ed2b7d --- /dev/null +++ b/command/helpers/helpers.go @@ -0,0 +1,42 @@ +package helpers + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "os" +) + +func LoadDataSource(data string, testStdin io.Reader) (string, error) { + var stdin io.Reader = os.Stdin + if testStdin != nil { + stdin = testStdin + } + + // Handle empty quoted shell parameters + if len(data) == 0 { + return "", nil + } + + switch data[0] { + case '@': + data, err := ioutil.ReadFile(data[1:]) + if err != nil { + return "", fmt.Errorf("Failed to read file: %s", err) + } else { + return string(data), nil + } + case '-': + if len(data) > 1 { + return data, nil + } + var b bytes.Buffer + if _, err := io.Copy(&b, stdin); err != nil { + return "", fmt.Errorf("Failed to read stdin: %s", err) + } + return b.String(), nil + default: + return data, nil + } +} diff --git a/command/kv/put/kv_put.go b/command/kv/put/kv_put.go index 98be3c001e..abe51a5382 100644 --- a/command/kv/put/kv_put.go +++ b/command/kv/put/kv_put.go @@ -1,16 +1,14 @@ package put import ( - "bytes" "encoding/base64" "flag" "fmt" "io" - "io/ioutil" - "os" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/command/flags" + "github.com/hashicorp/consul/command/helpers" "github.com/mitchellh/cli" ) @@ -173,11 +171,6 @@ func (c *cmd) Run(args []string) int { } func (c *cmd) dataFromArgs(args []string) (string, string, error) { - var stdin io.Reader = os.Stdin - if c.testStdin != nil { - stdin = c.testStdin - } - switch len(args) { case 0: return "", "", fmt.Errorf("Missing KEY argument") @@ -189,30 +182,11 @@ func (c *cmd) dataFromArgs(args []string) (string, string, error) { } key := args[0] - data := args[1] + data, err := helpers.LoadDataSource(args[1], c.testStdin) - // Handle empty quoted shell parameters - if len(data) == 0 { - return key, "", nil - } - - switch data[0] { - case '@': - data, err := ioutil.ReadFile(data[1:]) - if err != nil { - return "", "", fmt.Errorf("Failed to read file: %s", err) - } - return key, string(data), nil - case '-': - if len(data) > 1 { - return key, data, nil - } - var b bytes.Buffer - if _, err := io.Copy(&b, stdin); err != nil { - return "", "", fmt.Errorf("Failed to read stdin: %s", err) - } - return key, b.String(), nil - default: + if err != nil { + return "", "", err + } else { return key, data, nil } }