mirror of https://github.com/hashicorp/consul
command/envoy: Refactor flag parsing/validation (#7504)
parent
33c7894123
commit
e5d6273a48
@ -0,0 +1,95 @@
|
|||||||
|
package envoy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul/api"
|
||||||
|
"github.com/hashicorp/go-sockaddr/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultMeshGatewayPort int = 443
|
||||||
|
|
||||||
|
// ServiceAddressValue implements a flag.Value that may be used to parse an
|
||||||
|
// addr:port string into an api.ServiceAddress.
|
||||||
|
type ServiceAddressValue struct {
|
||||||
|
value api.ServiceAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServiceAddressValue) String() string {
|
||||||
|
if s == nil {
|
||||||
|
return fmt.Sprintf(":%d", defaultMeshGatewayPort)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v:%d", s.value.Address, s.value.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServiceAddressValue) Value() api.ServiceAddress {
|
||||||
|
if s == nil || s.value.Port == 0 && s.value.Address == "" {
|
||||||
|
return api.ServiceAddress{Port: defaultMeshGatewayPort}
|
||||||
|
}
|
||||||
|
return s.value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServiceAddressValue) Set(raw string) error {
|
||||||
|
var err error
|
||||||
|
s.value, err = parseAddress(raw)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAddress(raw string) (api.ServiceAddress, error) {
|
||||||
|
result := api.ServiceAddress{}
|
||||||
|
x, err := template.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return result, fmt.Errorf("Error parsing address %q: %v", raw, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, portStr, err := net.SplitHostPort(x)
|
||||||
|
if err != nil {
|
||||||
|
return result, fmt.Errorf("Error parsing address %q: %v", x, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
port := defaultMeshGatewayPort
|
||||||
|
if portStr != "" {
|
||||||
|
port, err = strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return result, fmt.Errorf("Error parsing port %q: %v", portStr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Address = addr
|
||||||
|
result.Port = port
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ flag.Value = (*ServiceAddressValue)(nil)
|
||||||
|
|
||||||
|
type ServiceAddressMapValue struct {
|
||||||
|
value map[string]api.ServiceAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServiceAddressMapValue) String() string {
|
||||||
|
buf := new(strings.Builder)
|
||||||
|
for k, v := range s.value {
|
||||||
|
buf.WriteString(fmt.Sprintf("%v=%v:%d,", k, v.Address, v.Port))
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServiceAddressMapValue) Set(raw string) error {
|
||||||
|
if s.value == nil {
|
||||||
|
s.value = make(map[string]api.ServiceAddress)
|
||||||
|
}
|
||||||
|
idx := strings.Index(raw, "=")
|
||||||
|
if idx == -1 {
|
||||||
|
return fmt.Errorf(`Missing "=" in argument: %s`, raw)
|
||||||
|
}
|
||||||
|
key, value := raw[0:idx], raw[idx+1:]
|
||||||
|
var err error
|
||||||
|
s.value[key], err = parseAddress(value)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ flag.Value = (*ServiceAddressMapValue)(nil)
|
@ -0,0 +1,80 @@
|
|||||||
|
package envoy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul/api"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServiceAddressValue_Value(t *testing.T) {
|
||||||
|
t.Run("nil receiver", func(t *testing.T) {
|
||||||
|
var addr *ServiceAddressValue
|
||||||
|
require.Equal(t, addr.Value(), api.ServiceAddress{Port: defaultMeshGatewayPort})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("default value", func(t *testing.T) {
|
||||||
|
addr := &ServiceAddressValue{}
|
||||||
|
require.Equal(t, addr.Value(), api.ServiceAddress{Port: defaultMeshGatewayPort})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("set value", func(t *testing.T) {
|
||||||
|
addr := &ServiceAddressValue{}
|
||||||
|
require.NoError(t, addr.Set("localhost:3333"))
|
||||||
|
require.Equal(t, addr.Value(), api.ServiceAddress{
|
||||||
|
Address: "localhost",
|
||||||
|
Port: 3333,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServiceAddressValue_Set(t *testing.T) {
|
||||||
|
var testcases = []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expectedErr string
|
||||||
|
expectedValue api.ServiceAddress
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default port",
|
||||||
|
input: "8.8.8.8:",
|
||||||
|
expectedValue: api.ServiceAddress{
|
||||||
|
Address: "8.8.8.8",
|
||||||
|
Port: defaultMeshGatewayPort,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid address",
|
||||||
|
input: "8.8.8.8:1234",
|
||||||
|
expectedValue: api.ServiceAddress{Address: "8.8.8.8", Port: 1234},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid address",
|
||||||
|
input: "not-an-address",
|
||||||
|
expectedErr: "missing port in address",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid port",
|
||||||
|
input: "localhost:notaport",
|
||||||
|
expectedErr: `Error parsing port "notaport"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid address format",
|
||||||
|
input: "too:many:colons",
|
||||||
|
expectedErr: "address too:many:colons: too many colons",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range testcases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
addr := &ServiceAddressValue{}
|
||||||
|
err := addr.Set(tc.input)
|
||||||
|
if tc.expectedErr != "" {
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), tc.expectedErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, addr.Value(), tc.expectedValue)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in new issue