diff --git a/api/api.go b/api/api.go index 5617293e44..0ff9a22c23 100644 --- a/api/api.go +++ b/api/api.go @@ -120,8 +120,8 @@ func DefaultConfig() *Config { HttpClient: http.DefaultClient, } - if len(os.Getenv("CONSUL_HTTP_ADDR")) > 0 { - config.Address = os.Getenv("CONSUL_HTTP_ADDR") + if addr := os.Getenv("CONSUL_HTTP_ADDR"); addr != "" { + config.Address = addr } return config @@ -137,11 +137,7 @@ func NewClient(config *Config) (*Client, error) { // bootstrap the config defConfig := DefaultConfig() - switch { - case len(config.Address) != 0: - case len(os.Getenv("CONSUL_HTTP_ADDR")) > 0: - config.Address = os.Getenv("CONSUL_HTTP_ADDR") - default: + if len(config.Address) == 0 { config.Address = defConfig.Address } @@ -153,14 +149,15 @@ func NewClient(config *Config) (*Client, error) { config.HttpClient = defConfig.HttpClient } - if strings.HasPrefix(config.Address, "unix://") { - shortStr := strings.TrimPrefix(config.Address, "unix://") - t := &http.Transport{} - t.Dial = func(_, _ string) (net.Conn, error) { - return net.Dial("unix", shortStr) + if parts := strings.SplitN(config.Address, "unix://", 2); len(parts) == 2 { + config.HttpClient = &http.Client{ + Transport: &http.Transport{ + Dial: func(_, _ string) (net.Conn, error) { + return net.Dial("unix", parts[1]) + }, + }, } - config.HttpClient.Transport = t - config.Address = shortStr + config.Address = parts[1] } client := &Client{ diff --git a/api/api_test.go b/api/api_test.go index 488fcb1ee0..a8c826ed20 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -42,6 +42,10 @@ type testServerConfig struct { Ports testPortConfig `json:"ports,omitempty"` } +// Callback functions for modifying config +type configCallback func(c *Config) +type serverConfigCallback func(c *testServerConfig) + func defaultConfig() *testServerConfig { return &testServerConfig{ Bootstrap: true, @@ -72,7 +76,7 @@ func newTestServer(t *testing.T) *testServer { return newTestServerWithConfig(t, func(c *testServerConfig) {}) } -func newTestServerWithConfig(t *testing.T, cb func(c *testServerConfig)) *testServer { +func newTestServerWithConfig(t *testing.T, cb serverConfigCallback) *testServer { if path, err := exec.LookPath("consul"); err != nil || path == "" { t.Log("consul not found on $PATH, skipping") t.SkipNow() @@ -131,15 +135,21 @@ func makeClient(t *testing.T) (*Client, *testServer) { }, func(c *testServerConfig) {}) } -func makeClientWithConfig(t *testing.T, clientConfig func(c *Config), serverConfig func(c *testServerConfig)) (*Client, *testServer) { - server := newTestServerWithConfig(t, serverConfig) +func makeClientWithConfig(t *testing.T, cb1 configCallback, cb2 serverConfigCallback) (*Client, *testServer) { + // Make client config conf := DefaultConfig() - clientConfig(conf) + cb1(conf) + fmt.Printf("%#v\n", conf.HttpClient.Transport) + + // Create client client, err := NewClient(conf) if err != nil { t.Fatalf("err: %v", err) } + // Create server + server := newTestServerWithConfig(t, cb2) + // Allow the server some time to start, and verify we have a leader. testutil.WaitForResult(func() (bool, error) { req := client.newRequest("GET", "/v1/catalog/nodes") diff --git a/api/status_test.go b/api/status_test.go index 5e7acd2740..61b0a435e9 100644 --- a/api/status_test.go +++ b/api/status_test.go @@ -1,8 +1,9 @@ package api import ( + "fmt" "io/ioutil" - "os/user" + "os" "runtime" "testing" ) @@ -29,25 +30,20 @@ func TestStatusLeaderUnix(t *testing.T) { tempdir, err := ioutil.TempDir("", "consul-test-") if err != nil { - t.Fatal("Could not create a working directory") + t.Fatalf("err: %s", err) } - - socket := "unix://" + tempdir + "/unix-http-test.sock" + defer os.RemoveAll(tempdir) + socket := fmt.Sprintf("unix://%s/test.sock", tempdir) clientConfig := func(c *Config) { c.Address = socket } serverConfig := func(c *testServerConfig) { - user, err := user.Current() - if err != nil { - t.Fatal("Could not get current user") - } - if c.Addresses == nil { c.Addresses = &testAddressConfig{} } - c.Addresses.HTTP = socket + ";" + user.Uid + ";" + user.Gid + ";640" + c.Addresses.HTTP = socket } c, s := makeClientWithConfig(t, clientConfig, serverConfig) diff --git a/command/agent/agent_test.go b/command/agent/agent_test.go index 91cb5c1ef5..d7b7d91539 100644 --- a/command/agent/agent_test.go +++ b/command/agent/agent_test.go @@ -7,7 +7,6 @@ import ( "io" "io/ioutil" "os" - "os/user" "path/filepath" "reflect" "runtime" @@ -125,7 +124,7 @@ func TestAgentStartStop(t *testing.T) { } } -func TestAgent_RPCPingTCP(t *testing.T) { +func TestAgent_RPCPing(t *testing.T) { dir, agent := makeAgent(t, nextConfig()) defer os.RemoveAll(dir) defer agent.Shutdown() @@ -136,35 +135,6 @@ func TestAgent_RPCPingTCP(t *testing.T) { } } -func TestAgent_RPCPingUnix(t *testing.T) { - if runtime.GOOS == "windows" { - t.SkipNow() - } - - nextConf := nextConfig() - - tempdir, err := ioutil.TempDir("", "consul-test-") - if err != nil { - t.Fatal("Could not create a working directory") - } - - user, err := user.Current() - if err != nil { - t.Fatal("Could not get current user") - } - - nextConf.Addresses.RPC = "unix://" + tempdir + "/unix-rpc-test.sock;" + user.Uid + ";" + user.Gid + ";640" - - dir, agent := makeAgent(t, nextConf) - defer os.RemoveAll(dir) - defer agent.Shutdown() - - var out struct{} - if err := agent.RPC("Status.Ping", struct{}{}, &out); err != nil { - t.Fatalf("err: %v", err) - } -} - func TestAgent_AddService(t *testing.T) { dir, agent := makeAgent(t, nextConfig()) defer os.RemoveAll(dir) diff --git a/command/agent/command.go b/command/agent/command.go index b9a82e19ae..a677b03b9c 100644 --- a/command/agent/command.go +++ b/command/agent/command.go @@ -295,9 +295,15 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log return err } - if _, ok := rpcAddr.(*net.UnixAddr); ok { - // Remove the socket if it exists, or we'll get a bind error - _ = os.Remove(rpcAddr.String()) + if path, ok := unixSocketAddr(config.Addresses.RPC); ok { + // Remove the socket if it exists, or we'll get a bind error. This + // is necessary to avoid situations where Consul cannot start if the + // socket file exists in case of unexpected termination. + if _, err := os.Stat(path); err == nil { + if err := os.Remove(path); err != nil { + return err + } + } } rpcListener, err := net.Listen(rpcAddr.Network(), rpcAddr.String()) @@ -307,14 +313,6 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log return err } - if _, ok := rpcAddr.(*net.UnixAddr); ok { - if err := adjustUnixSocketPermissions(config.Addresses.RPC); err != nil { - agent.Shutdown() - c.Ui.Error(fmt.Sprintf("Error adjusting Unix socket permissions: %s", err)) - return err - } - } - // Start the IPC layer c.Ui.Output("Starting Consul agent RPC...") c.rpcServer = NewAgentRPC(agent, rpcListener, logOutput, logWriter) diff --git a/command/agent/config.go b/command/agent/config.go index 74ed4c3ad9..92b9e64d37 100644 --- a/command/agent/config.go +++ b/command/agent/config.go @@ -7,11 +7,8 @@ import ( "io" "net" "os" - "os/user" "path/filepath" - "regexp" "sort" - "strconv" "strings" "time" @@ -348,89 +345,13 @@ type Config struct { WatchPlans []*watch.WatchPlan `mapstructure:"-" json:"-"` } -// UnixSocket contains the parameters for a Unix socket interface -type UnixSocket struct { - // Path to the socket on-disk - Path string - - // uid of the owner of the socket - Uid int - - // gid of the group of the socket - Gid int - - // Permissions for the socket file - Permissions os.FileMode -} - -func populateUnixSocket(addr string) (*UnixSocket, error) { +// unixSocketAddr tests if a given address describes a domain socket, +// and returns the relevant path part of the string if it is. +func unixSocketAddr(addr string) (string, bool) { if !strings.HasPrefix(addr, "unix://") { - return nil, fmt.Errorf("Failed to parse Unix address, format is unix://[path];[user];[group];[mode]: %v", addr) - } - - splitAddr := strings.Split(strings.TrimPrefix(addr, "unix://"), ";") - if len(splitAddr) != 4 { - return nil, fmt.Errorf("Failed to parse Unix address, format is unix://[path];[user];[group];[mode]: %v", addr) - } - - ret := &UnixSocket{Path: splitAddr[0]} - - var userVal *user.User - var err error - - regex := regexp.MustCompile("[\\d]+") - if regex.MatchString(splitAddr[1]) { - userVal, err = user.LookupId(splitAddr[1]) - } else { - userVal, err = user.Lookup(splitAddr[1]) - } - if err != nil { - return nil, fmt.Errorf("Invalid user given for Unix socket ownership: %v", splitAddr[1]) - } - - if uid64, err := strconv.ParseInt(userVal.Uid, 10, 32); err != nil { - return nil, fmt.Errorf("Failed to parse given user ID of %v into integer", userVal.Uid) - } else { - ret.Uid = int(uid64) - } - - // Go doesn't currently have a way to look up gid from group name, - // so require a numeric gid; see - // https://codereview.appspot.com/101310044 - if gid64, err := strconv.ParseInt(splitAddr[2], 10, 32); err != nil { - return nil, fmt.Errorf("Socket group must be given as numeric gid. Failed to parse given group ID of %v into integer", splitAddr[2]) - } else { - ret.Gid = int(gid64) - } - - if mode, err := strconv.ParseUint(splitAddr[3], 8, 32); err != nil { - return nil, fmt.Errorf("Failed to parse given mode of %v into integer", splitAddr[3]) - } else { - if mode > 0777 { - return nil, fmt.Errorf("Given mode is invalid; must be an octal number between 0 and 777") - } else { - ret.Permissions = os.FileMode(mode) - } - } - - return ret, nil -} - -func adjustUnixSocketPermissions(addr string) error { - sock, err := populateUnixSocket(addr) - if err != nil { - return err + return "", false } - - if err = os.Chown(sock.Path, sock.Uid, sock.Gid); err != nil { - return fmt.Errorf("Error attempting to change socket permissions to userid %v and groupid %v: %v", sock.Uid, sock.Gid, err) - } - - if err = os.Chmod(sock.Path, sock.Permissions); err != nil { - return fmt.Errorf("Error attempting to change socket permissions to mode %v: %v", sock.Permissions, err) - } - - return nil + return strings.TrimPrefix(addr, "unix://"), true } type dirEnts []os.FileInfo @@ -485,31 +406,14 @@ func (c *Config) ClientListener(override string, port int) (net.Addr, error) { addr = c.ClientAddr } - switch { - case strings.HasPrefix(addr, "unix://"): - sock, err := populateUnixSocket(addr) - if err != nil { - return nil, err - } - - return &net.UnixAddr{Name: sock.Path, Net: "unix"}, nil - - default: - ip := net.ParseIP(addr) - if ip == nil { - return nil, fmt.Errorf("Failed to parse IP: %v", addr) - } - - if ip.IsUnspecified() { - ip = net.ParseIP("127.0.0.1") - } - - if ip == nil { - return nil, fmt.Errorf("Failed to parse IP 127.0.0.1") - } - - return &net.TCPAddr{IP: ip, Port: port}, nil + if path, ok := unixSocketAddr(addr); ok { + return &net.UnixAddr{Name: path, Net: "unix"}, nil + } + ip := net.ParseIP(addr) + if ip == nil { + return nil, fmt.Errorf("Failed to parse IP: %v", addr) } + return &net.TCPAddr{IP: ip, Port: port}, nil } // DecodeConfig reads the configuration from the given reader in JSON diff --git a/command/agent/config_test.go b/command/agent/config_test.go index f10c5b7242..fa7bf6f274 100644 --- a/command/agent/config_test.go +++ b/command/agent/config_test.go @@ -4,12 +4,9 @@ import ( "bytes" "encoding/base64" "io/ioutil" - "net" "os" - "os/user" "path/filepath" "reflect" - "runtime" "strings" "testing" "time" @@ -1073,107 +1070,13 @@ func TestReadConfigPaths_dir(t *testing.T) { } func TestUnixSockets(t *testing.T) { - if runtime.GOOS == "windows" { - t.SkipNow() + path1, ok := unixSocketAddr("unix:///path/to/socket") + if !ok || path1 != "/path/to/socket" { + t.Fatalf("bad: %v %v", ok, path1) } - usr, err := user.Current() - if err != nil { - t.Fatal("Could not get current user: ", err) - } - - tempdir, err := ioutil.TempDir("", "consul-test-") - if err != nil { - t.Fatal("Could not create a working directory: ", err) - } - - type SocketTestData struct { - Path string - Uid string - Gid string - Mode string - } - - testUnixSocketPopulation := func(s SocketTestData) (*UnixSocket, error) { - return populateUnixSocket("unix://" + s.Path + ";" + s.Uid + ";" + s.Gid + ";" + s.Mode) - } - - testUnixSocketPermissions := func(s SocketTestData) error { - return adjustUnixSocketPermissions("unix://" + s.Path + ";" + s.Uid + ";" + s.Gid + ";" + s.Mode) - } - - _, err = populateUnixSocket("tcp://abc123") - if err == nil { - t.Fatal("Should have rejected invalid scheme") - } - - _, err = populateUnixSocket("unix://x;y;z") - if err == nil { - t.Fatal("Should have rejected invalid number of parameters in Unix socket definition") - } - - std := SocketTestData{ - Path: tempdir + "/unix-config-test.sock", - Uid: usr.Uid, - Gid: usr.Gid, - Mode: "640", - } - - std.Uid = "orasdfdsnfoinweroiu" - _, err = testUnixSocketPopulation(std) - if err == nil { - t.Fatal("Did not error on invalid username") - } - - std.Uid = usr.Username - std.Gid = "foinfphawepofhewof" - _, err = testUnixSocketPopulation(std) - if err == nil { - t.Fatal("Did not error on invalid group (a name, must be gid)") - } - - std.Gid = usr.Gid - std.Mode = "999" - _, err = testUnixSocketPopulation(std) - if err == nil { - t.Fatal("Did not error on invalid socket mode") - } - - std.Uid = usr.Username - std.Mode = "640" - _, err = testUnixSocketPopulation(std) - if err != nil { - t.Fatal("Unix socket test failed (using username): ", err) - } - - std.Uid = usr.Uid - sock, err := testUnixSocketPopulation(std) - if err != nil { - t.Fatal("Unix socket test failed (using uid): ", err) - } - - addr := &net.UnixAddr{Name: sock.Path, Net: "unix"} - _, err = net.Listen(addr.Network(), addr.String()) - if err != nil { - t.Fatal("Error creating socket for futher tests: ", err) - } - - std.Uid = "-999999" - err = testUnixSocketPermissions(std) - if err == nil { - t.Fatal("Did not error on invalid uid") - } - - std.Uid = usr.Uid - std.Gid = "-999999" - err = testUnixSocketPermissions(std) - if err == nil { - t.Fatal("Did not error on invalid uid") - } - - std.Gid = usr.Gid - err = testUnixSocketPermissions(std) - if err != nil { - t.Fatal("Adjusting socket permissions failed: ", err) + path2, ok := unixSocketAddr("notunix://blah") + if ok || path2 != "" { + t.Fatalf("bad: %v %v", ok, path2) } } diff --git a/command/agent/http.go b/command/agent/http.go index d480de816f..2d26e60b1e 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -59,9 +59,13 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS return nil, err } - if _, ok := httpAddr.(*net.UnixAddr); ok { - // Remove the socket if it exists, or we'll get a bind error - _ = os.Remove(httpAddr.String()) + if path, ok := unixSocketAddr(config.Addresses.HTTPS); ok { + // See command/agent/config.go + if _, err := os.Stat(path); err == nil { + if err := os.Remove(path); err != nil { + return nil, err + } + } } ln, err := net.Listen(httpAddr.Network(), httpAddr.String()) @@ -69,18 +73,10 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err) } - switch httpAddr.(type) { - case *net.UnixAddr: - if err := adjustUnixSocketPermissions(config.Addresses.HTTPS); err != nil { - return nil, err - } + if _, ok := unixSocketAddr(config.Addresses.HTTPS); ok { list = tls.NewListener(ln, tlsConfig) - - case *net.TCPAddr: + } else { list = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, tlsConfig) - - default: - return nil, fmt.Errorf("Error determining address type when attempting to get Listen on %s: %v", httpAddr.String(), err) } // Create the mux @@ -108,9 +104,13 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS return nil, fmt.Errorf("Failed to get ClientListener address:port: %v", err) } - if _, ok := httpAddr.(*net.UnixAddr); ok { - // Remove the socket if it exists, or we'll get a bind error - _ = os.Remove(httpAddr.String()) + if path, ok := unixSocketAddr(config.Addresses.HTTP); ok { + // See command/agent/config.go + if _, err := os.Stat(path); err == nil { + if err := os.Remove(path); err != nil { + return nil, err + } + } } ln, err := net.Listen(httpAddr.Network(), httpAddr.String()) @@ -118,18 +118,10 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err) } - switch httpAddr.(type) { - case *net.UnixAddr: - if err := adjustUnixSocketPermissions(config.Addresses.HTTP); err != nil { - return nil, err - } + if _, ok := unixSocketAddr(config.Addresses.HTTP); ok { list = ln - - case *net.TCPAddr: + } else { list = tcpKeepAliveListener{ln.(*net.TCPListener)} - - default: - return nil, fmt.Errorf("Error determining address type when attempting to get Listen on %s: %v", httpAddr.String(), err) } // Create the mux diff --git a/command/agent/rpc_client.go b/command/agent/rpc_client.go index 1490674f14..cbc9689cfb 100644 --- a/command/agent/rpc_client.go +++ b/command/agent/rpc_client.go @@ -81,24 +81,19 @@ func (c *RPCClient) send(header *requestHeader, obj interface{}) error { // NewRPCClient is used to create a new RPC client given the address. // This will properly dial, handshake, and start listening func NewRPCClient(addr string) (*RPCClient, error) { - sanedAddr := os.Getenv("CONSUL_RPC_ADDR") - if len(sanedAddr) == 0 { - sanedAddr = addr - } - - mode := "tcp" + var conn net.Conn + var err error - if strings.HasPrefix(sanedAddr, "unix://") { - sanedAddr = strings.TrimPrefix(sanedAddr, "unix://") + if envAddr := os.Getenv("CONSUL_RPC_ADDR"); envAddr != "" { + addr = envAddr } - if strings.HasPrefix(sanedAddr, "/") { + // Try to dial to agent + mode := "tcp" + if strings.HasPrefix(addr, "/") { mode = "unix" } - - // Try to dial to agent - conn, err := net.Dial(mode, sanedAddr) - if err != nil { + if conn, err = net.Dial(mode, addr); err != nil { return nil, err } diff --git a/command/agent/rpc_client_test.go b/command/agent/rpc_client_test.go index 2d8dfc9c08..8516943a18 100644 --- a/command/agent/rpc_client_test.go +++ b/command/agent/rpc_client_test.go @@ -9,7 +9,6 @@ import ( "io/ioutil" "net" "os" - "os/user" "runtime" "strings" "testing" @@ -223,19 +222,13 @@ func TestRPCClientStatsUnix(t *testing.T) { tempdir, err := ioutil.TempDir("", "consul-test-") if err != nil { - t.Fatal("Could not create a working directory: ", err) - } - - user, err := user.Current() - if err != nil { - t.Fatal("Could not get current user: ", err) - } - - cb := func(c *Config) { - c.Addresses.RPC = "unix://" + tempdir + "/unix-rpc-test.sock;" + user.Uid + ";" + user.Gid + ";640" + t.Fatalf("err: %s", err) } - p1 := testRPCClientWithConfig(t, cb) + p1 := testRPCClientWithConfig(t, func(c *Config) { + c.Addresses.RPC = fmt.Sprintf("unix://%s/test.sock", tempdir) + }) + defer p1.Close() stats, err := p1.client.Stats() if err != nil {