diff --git a/command/connect/envoy/envoy.go b/command/connect/envoy/envoy.go index 000fa713f1..0ca6429774 100644 --- a/command/connect/envoy/envoy.go +++ b/command/connect/envoy/envoy.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" + "github.com/mitchellh/cli" "github.com/mitchellh/mapstructure" "github.com/hashicorp/consul/agent/structs" @@ -19,9 +20,6 @@ import ( proxyCmd "github.com/hashicorp/consul/command/connect/proxy" "github.com/hashicorp/consul/command/flags" "github.com/hashicorp/consul/ipaddr" - "github.com/hashicorp/go-sockaddr/template" - - "github.com/mitchellh/cli" ) func New(ui cli.Ui) *cmd { @@ -60,10 +58,10 @@ type cmd struct { // mesh gateway registration information register bool - address string - wanAddress string + lanAddress ServiceAddressValue + wanAddress ServiceAddressValue deregAfterCritical string - bindAddresses map[string]string + bindAddresses ServiceAddressMapValue exposeServers bool meshGatewaySvcName string @@ -120,13 +118,13 @@ func (c *cmd) init() { c.flags.BoolVar(&c.register, "register", false, "Register a new Mesh Gateway service before configuring and starting Envoy") - c.flags.StringVar(&c.address, "address", "", + c.flags.Var(&c.lanAddress, "address", "LAN address to advertise in the Mesh Gateway service registration") - c.flags.StringVar(&c.wanAddress, "wan-address", "", + c.flags.Var(&c.wanAddress, "wan-address", "WAN address to advertise in the Mesh Gateway service registration") - c.flags.Var((*flags.FlagMapValue)(&c.bindAddresses), "bind-address", "Bind "+ + c.flags.Var(&c.bindAddresses, "bind-address", "Bind "+ "address to use instead of the default binding rules given as `=:` "+ "pairs. This flag may be specified multiple times to add multiple bind addresses.") @@ -145,38 +143,6 @@ func (c *cmd) init() { c.help = flags.Usage(help, c.flags) } -const ( - DefaultMeshGatewayPort int = 443 -) - -func parseAddress(addrStr string) (string, int, error) { - if addrStr == "" { - // defaulting the port to 443 - return "", DefaultMeshGatewayPort, nil - } - - x, err := template.Parse(addrStr) - if err != nil { - return "", DefaultMeshGatewayPort, fmt.Errorf("Error parsing address %q: %v", addrStr, err) - } - - addr, portStr, err := net.SplitHostPort(x) - if err != nil { - return "", DefaultMeshGatewayPort, fmt.Errorf("Error parsing address %q: %v", x, err) - } - - port := DefaultMeshGatewayPort - - if portStr != "" { - port, err = strconv.Atoi(portStr) - if err != nil { - return "", DefaultMeshGatewayPort, fmt.Errorf("Error parsing port %q: %v", portStr, err) - } - } - - return addr, port, nil -} - // canBindInternal is here mainly so we can unit test this with a constant net.Addr list func canBindInternal(addr string, ifAddrs []net.Addr) bool { if addr == "" { @@ -206,13 +172,13 @@ func canBindInternal(addr string, ifAddrs []net.Addr) bool { return false } -func canBind(addr string) bool { +func canBind(addr api.ServiceAddress) bool { ifAddrs, err := net.InterfaceAddrs() if err != nil { return false } - return canBindInternal(addr, ifAddrs) + return canBindInternal(addr.Address, ifAddrs) } func (c *cmd) Run(args []string) int { @@ -246,30 +212,18 @@ func (c *cmd) Run(args []string) int { return 1 } - lanAddr, lanPort, err := parseAddress(c.address) - if err != nil { - c.UI.Error(fmt.Sprintf("Failed to parse the -address parameter: %v", err)) - return 1 - } - taggedAddrs := make(map[string]api.ServiceAddress) - - if lanAddr != "" { - taggedAddrs[structs.TaggedAddressLAN] = api.ServiceAddress{Address: lanAddr, Port: lanPort} + lanAddr := c.lanAddress.Value() + if lanAddr.Address != "" { + taggedAddrs[structs.TaggedAddressLAN] = lanAddr } - wanAddr := "" - wanPort := lanPort - if c.wanAddress != "" { - wanAddr, wanPort, err = parseAddress(c.wanAddress) - if err != nil { - c.UI.Error(fmt.Sprintf("Failed to parse the -wan-address parameter: %v", err)) - return 1 - } - taggedAddrs[structs.TaggedAddressWAN] = api.ServiceAddress{Address: wanAddr, Port: wanPort} + wanAddr := c.wanAddress.Value() + if wanAddr.Address != "" { + taggedAddrs[structs.TaggedAddressWAN] = wanAddr } - tcpCheckAddr := lanAddr + tcpCheckAddr := lanAddr.Address if tcpCheckAddr == "" { // fallback to localhost as the gateway has to reside in the same network namespace // as the agent @@ -278,24 +232,12 @@ func (c *cmd) Run(args []string) int { var proxyConf *api.AgentServiceConnectProxyConfig - if len(c.bindAddresses) > 0 { + if len(c.bindAddresses.value) > 0 { // override all default binding rules and just bind to the user-supplied addresses - bindAddresses := make(map[string]api.ServiceAddress) - - for addrName, addrStr := range c.bindAddresses { - addr, port, err := parseAddress(addrStr) - if err != nil { - c.UI.Error(fmt.Sprintf("Failed to parse the bind address: %s=%s: %v", addrName, addrStr, err)) - return 1 - } - - bindAddresses[addrName] = api.ServiceAddress{Address: addr, Port: port} - } - proxyConf = &api.AgentServiceConnectProxyConfig{ Config: map[string]interface{}{ "envoy_mesh_gateway_no_default_bind": true, - "envoy_mesh_gateway_bind_addresses": bindAddresses, + "envoy_mesh_gateway_bind_addresses": c.bindAddresses.value, }, } } else if canBind(lanAddr) && canBind(wanAddr) { @@ -307,8 +249,8 @@ func (c *cmd) Run(args []string) int { "envoy_mesh_gateway_bind_tagged_addresses": true, }, } - } else if !canBind(lanAddr) && lanAddr != "" { - c.UI.Error(fmt.Sprintf("The LAN address %q will not be bindable. Either set a bindable address or override the bind addresses with -bind-address", lanAddr)) + } else if !canBind(lanAddr) && lanAddr.Address != "" { + c.UI.Error(fmt.Sprintf("The LAN address %q will not be bindable. Either set a bindable address or override the bind addresses with -bind-address", lanAddr.Address)) return 1 } @@ -320,14 +262,14 @@ func (c *cmd) Run(args []string) int { svc := api.AgentServiceRegistration{ Kind: api.ServiceKindMeshGateway, Name: c.meshGatewaySvcName, - Address: lanAddr, - Port: lanPort, + Address: lanAddr.Address, + Port: lanAddr.Port, Meta: meta, TaggedAddresses: taggedAddrs, Proxy: proxyConf, Check: &api.AgentServiceCheck{ Name: "Mesh Gateway Listening", - TCP: ipaddr.FormatAddressPort(tcpCheckAddr, lanPort), + TCP: ipaddr.FormatAddressPort(tcpCheckAddr, lanAddr.Port), Interval: "10s", DeregisterCriticalServiceAfter: c.deregAfterCritical, }, diff --git a/command/connect/envoy/flags.go b/command/connect/envoy/flags.go new file mode 100644 index 0000000000..5f6e2f4653 --- /dev/null +++ b/command/connect/envoy/flags.go @@ -0,0 +1,95 @@ +package envoy + +import ( + "flag" + "fmt" + "net" + "strconv" + "strings" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/go-sockaddr/template" +) + +const defaultMeshGatewayPort int = 443 + +// ServiceAddressValue implements a flag.Value that may be used to parse an +// addr:port string into an api.ServiceAddress. +type ServiceAddressValue struct { + value api.ServiceAddress +} + +func (s *ServiceAddressValue) String() string { + if s == nil { + return fmt.Sprintf(":%d", defaultMeshGatewayPort) + } + return fmt.Sprintf("%v:%d", s.value.Address, s.value.Port) +} + +func (s *ServiceAddressValue) Value() api.ServiceAddress { + if s == nil || s.value.Port == 0 && s.value.Address == "" { + return api.ServiceAddress{Port: defaultMeshGatewayPort} + } + return s.value +} + +func (s *ServiceAddressValue) Set(raw string) error { + var err error + s.value, err = parseAddress(raw) + return err +} + +func parseAddress(raw string) (api.ServiceAddress, error) { + result := api.ServiceAddress{} + x, err := template.Parse(raw) + if err != nil { + return result, fmt.Errorf("Error parsing address %q: %v", raw, err) + } + + addr, portStr, err := net.SplitHostPort(x) + if err != nil { + return result, fmt.Errorf("Error parsing address %q: %v", x, err) + } + + port := defaultMeshGatewayPort + if portStr != "" { + port, err = strconv.Atoi(portStr) + if err != nil { + return result, fmt.Errorf("Error parsing port %q: %v", portStr, err) + } + } + + result.Address = addr + result.Port = port + return result, nil +} + +var _ flag.Value = (*ServiceAddressValue)(nil) + +type ServiceAddressMapValue struct { + value map[string]api.ServiceAddress +} + +func (s *ServiceAddressMapValue) String() string { + buf := new(strings.Builder) + for k, v := range s.value { + buf.WriteString(fmt.Sprintf("%v=%v:%d,", k, v.Address, v.Port)) + } + return buf.String() +} + +func (s *ServiceAddressMapValue) Set(raw string) error { + if s.value == nil { + s.value = make(map[string]api.ServiceAddress) + } + idx := strings.Index(raw, "=") + if idx == -1 { + return fmt.Errorf(`Missing "=" in argument: %s`, raw) + } + key, value := raw[0:idx], raw[idx+1:] + var err error + s.value[key], err = parseAddress(value) + return err +} + +var _ flag.Value = (*ServiceAddressMapValue)(nil) diff --git a/command/connect/envoy/flags_test.go b/command/connect/envoy/flags_test.go new file mode 100644 index 0000000000..77dcf1c6c5 --- /dev/null +++ b/command/connect/envoy/flags_test.go @@ -0,0 +1,80 @@ +package envoy + +import ( + "testing" + + "github.com/hashicorp/consul/api" + "github.com/stretchr/testify/require" +) + +func TestServiceAddressValue_Value(t *testing.T) { + t.Run("nil receiver", func(t *testing.T) { + var addr *ServiceAddressValue + require.Equal(t, addr.Value(), api.ServiceAddress{Port: defaultMeshGatewayPort}) + }) + + t.Run("default value", func(t *testing.T) { + addr := &ServiceAddressValue{} + require.Equal(t, addr.Value(), api.ServiceAddress{Port: defaultMeshGatewayPort}) + }) + + t.Run("set value", func(t *testing.T) { + addr := &ServiceAddressValue{} + require.NoError(t, addr.Set("localhost:3333")) + require.Equal(t, addr.Value(), api.ServiceAddress{ + Address: "localhost", + Port: 3333, + }) + }) +} + +func TestServiceAddressValue_Set(t *testing.T) { + var testcases = []struct { + name string + input string + expectedErr string + expectedValue api.ServiceAddress + }{ + { + name: "default port", + input: "8.8.8.8:", + expectedValue: api.ServiceAddress{ + Address: "8.8.8.8", + Port: defaultMeshGatewayPort, + }, + }, + { + name: "valid address", + input: "8.8.8.8:1234", + expectedValue: api.ServiceAddress{Address: "8.8.8.8", Port: 1234}, + }, + { + name: "invalid address", + input: "not-an-address", + expectedErr: "missing port in address", + }, + { + name: "invalid port", + input: "localhost:notaport", + expectedErr: `Error parsing port "notaport"`, + }, + { + name: "invalid address format", + input: "too:many:colons", + expectedErr: "address too:many:colons: too many colons", + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + addr := &ServiceAddressValue{} + err := addr.Set(tc.input) + if tc.expectedErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectedErr) + return + } + + require.Equal(t, addr.Value(), tc.expectedValue) + }) + } +}