diff --git a/command/rpc.go b/command/rpc.go index 84d41e822e..6f1562d216 100644 --- a/command/rpc.go +++ b/command/rpc.go @@ -8,9 +8,15 @@ import ( "github.com/hashicorp/consul/command/agent" ) -// RPCAddrEnvName defines an environment variable name which sets -// an RPC address if there is no -rpc-addr specified. -const RPCAddrEnvName = "CONSUL_RPC_ADDR" +const ( + // RPCAddrEnvName defines an environment variable name which sets + // an RPC address if there is no -rpc-addr specified. + RPCAddrEnvName = "CONSUL_RPC_ADDR" + + // HTTPAddrEnvName defines an environment variable name which sets + // the HTTP address if there is no -http-addr specified. + HTTPAddrEnvName = "CONSUL_HTTP_ADDR" +) // RPCAddrFlag returns a pointer to a string that will be populated // when the given flagset is parsed with the RPC address of the Consul. @@ -31,7 +37,11 @@ func RPCClient(addr string) (*agent.RPCClient, error) { // HTTPAddrFlag returns a pointer to a string that will be populated // when the given flagset is parsed with the HTTP address of the Consul. func HTTPAddrFlag(f *flag.FlagSet) *string { - return f.String("http-addr", "127.0.0.1:8500", + defaultHTTPAddr := os.Getenv(HTTPAddrEnvName) + if defaultHTTPAddr == "" { + defaultHTTPAddr = "127.0.0.1:8500" + } + return f.String("http-addr", defaultHTTPAddr, "HTTP address of the Consul agent") } @@ -43,7 +53,7 @@ func HTTPClient(addr string) (*consulapi.Client, error) { // HTTPClientDC returns a new Consul HTTP client with the given address and datacenter func HTTPClientDC(addr, dc string) (*consulapi.Client, error) { conf := consulapi.DefaultConfig() - if envAddr := os.Getenv("CONSUL_HTTP_ADDR"); envAddr != "" { + if envAddr := os.Getenv(HTTPAddrEnvName); addr == "" && envAddr != "" { addr = envAddr } conf.Address = addr diff --git a/command/rpc_test.go b/command/rpc_test.go index 01d6259487..3b5d64c16e 100644 --- a/command/rpc_test.go +++ b/command/rpc_test.go @@ -6,54 +6,82 @@ import ( "testing" ) -const defaultRPC = "127.0.0.1:8400" +const ( + defaultRPC = "127.0.0.1:8400" + defaultHTTP = "127.0.0.1:8500" +) + +type flagFunc func(f *flag.FlagSet) *string -func getParsedRPC(t *testing.T, cliRPC, envRPC string) string { +func getParsedAddr(t *testing.T, addrType, cliVal, envVal string) string { + var cliFlag, envVar string + var fn flagFunc args := []string{} - if cliRPC != "" { - args = append(args, "-rpc-addr="+cliRPC) + switch addrType { + case "rpc": + fn = RPCAddrFlag + envVar = RPCAddrEnvName + cliFlag = "-rpc-addr" + case "http": + fn = HTTPAddrFlag + envVar = HTTPAddrEnvName + cliFlag = "-http-addr" + default: + t.Fatalf("unknown address type %s", addrType) + } + + if cliVal != "" { + args = append(args, cliFlag+"="+cliVal) } os.Clearenv() - if envRPC != "" { - os.Setenv(RPCAddrEnvName, envRPC) + if envVal != "" { + os.Setenv(envVar, envVal) } - cmdFlags := flag.NewFlagSet("rpc", flag.ContinueOnError) - rpc := RPCAddrFlag(cmdFlags) + cmdFlags := flag.NewFlagSet(addrType, flag.ContinueOnError) + result := fn(cmdFlags) if err := cmdFlags.Parse(args); err != nil { t.Fatal("Parse error", err) } - return *rpc + return *result } -func TestRPCAddrFlag_default(t *testing.T) { - rpc := getParsedRPC(t, "", "") +func TestAddrFlag_default(t *testing.T) { + for a, def := range map[string]string{ + "rpc": defaultRPC, + "http": defaultHTTP, + } { + res := getParsedAddr(t, a, "", "") - if rpc != defaultRPC { - t.Fatalf("Expected rpc addr: %s, got: %s", defaultRPC, rpc) + if res != def { + t.Fatalf("Expected %s addr: %s, got: %s", def, res) + } } } -func TestRPCAddrFlag_onlyEnv(t *testing.T) { - envRPC := "4.4.4.4:8400" - rpc := getParsedRPC(t, "", envRPC) +func TestAddrFlag_onlyEnv(t *testing.T) { + envAddr := "4.4.4.4:1234" + for _, a := range []string{"rpc", "http"} { + res := getParsedAddr(t, a, "", envAddr) - if rpc != envRPC { - t.Fatalf("Expected rpc addr: %s, got: %s", envRPC, rpc) + if res != envAddr { + t.Fatalf("Expected %s addr: %s, got: %s", a, envAddr, res) + } } } -func TestRPCAddrFlag_precedence(t *testing.T) { - cliRPC := "8.8.8.8:8400" - envRPC := "4.4.4.4:8400" - - rpc := getParsedRPC(t, cliRPC, envRPC) +func TestAddrFlag_precedence(t *testing.T) { + cliAddr := "8.8.8.8:8400" + envAddr := "4.4.4.4:8400" + for _, a := range []string{"rpc", "http"} { + res := getParsedAddr(t, a, cliAddr, envAddr) - if rpc != cliRPC { - t.Fatalf("Expected rpc addr: %s, got: %s", cliRPC, rpc) + if res != cliAddr { + t.Fatalf("Expected %s addr: %s, got: %s", a, cliAddr, res) + } } }