mirror of https://github.com/hashicorp/consul
Daniel Nephin
5 years ago
committed by
GitHub
3 changed files with 198 additions and 81 deletions
@ -0,0 +1,95 @@
|
||||
package envoy |
||||
|
||||
import ( |
||||
"flag" |
||||
"fmt" |
||||
"net" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"github.com/hashicorp/consul/api" |
||||
"github.com/hashicorp/go-sockaddr/template" |
||||
) |
||||
|
||||
const defaultMeshGatewayPort int = 443 |
||||
|
||||
// ServiceAddressValue implements a flag.Value that may be used to parse an
|
||||
// addr:port string into an api.ServiceAddress.
|
||||
type ServiceAddressValue struct { |
||||
value api.ServiceAddress |
||||
} |
||||
|
||||
func (s *ServiceAddressValue) String() string { |
||||
if s == nil { |
||||
return fmt.Sprintf(":%d", defaultMeshGatewayPort) |
||||
} |
||||
return fmt.Sprintf("%v:%d", s.value.Address, s.value.Port) |
||||
} |
||||
|
||||
func (s *ServiceAddressValue) Value() api.ServiceAddress { |
||||
if s == nil || s.value.Port == 0 && s.value.Address == "" { |
||||
return api.ServiceAddress{Port: defaultMeshGatewayPort} |
||||
} |
||||
return s.value |
||||
} |
||||
|
||||
func (s *ServiceAddressValue) Set(raw string) error { |
||||
var err error |
||||
s.value, err = parseAddress(raw) |
||||
return err |
||||
} |
||||
|
||||
func parseAddress(raw string) (api.ServiceAddress, error) { |
||||
result := api.ServiceAddress{} |
||||
x, err := template.Parse(raw) |
||||
if err != nil { |
||||
return result, fmt.Errorf("Error parsing address %q: %v", raw, err) |
||||
} |
||||
|
||||
addr, portStr, err := net.SplitHostPort(x) |
||||
if err != nil { |
||||
return result, fmt.Errorf("Error parsing address %q: %v", x, err) |
||||
} |
||||
|
||||
port := defaultMeshGatewayPort |
||||
if portStr != "" { |
||||
port, err = strconv.Atoi(portStr) |
||||
if err != nil { |
||||
return result, fmt.Errorf("Error parsing port %q: %v", portStr, err) |
||||
} |
||||
} |
||||
|
||||
result.Address = addr |
||||
result.Port = port |
||||
return result, nil |
||||
} |
||||
|
||||
var _ flag.Value = (*ServiceAddressValue)(nil) |
||||
|
||||
type ServiceAddressMapValue struct { |
||||
value map[string]api.ServiceAddress |
||||
} |
||||
|
||||
func (s *ServiceAddressMapValue) String() string { |
||||
buf := new(strings.Builder) |
||||
for k, v := range s.value { |
||||
buf.WriteString(fmt.Sprintf("%v=%v:%d,", k, v.Address, v.Port)) |
||||
} |
||||
return buf.String() |
||||
} |
||||
|
||||
func (s *ServiceAddressMapValue) Set(raw string) error { |
||||
if s.value == nil { |
||||
s.value = make(map[string]api.ServiceAddress) |
||||
} |
||||
idx := strings.Index(raw, "=") |
||||
if idx == -1 { |
||||
return fmt.Errorf(`Missing "=" in argument: %s`, raw) |
||||
} |
||||
key, value := raw[0:idx], raw[idx+1:] |
||||
var err error |
||||
s.value[key], err = parseAddress(value) |
||||
return err |
||||
} |
||||
|
||||
var _ flag.Value = (*ServiceAddressMapValue)(nil) |
@ -0,0 +1,80 @@
|
||||
package envoy |
||||
|
||||
import ( |
||||
"testing" |
||||
|
||||
"github.com/hashicorp/consul/api" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestServiceAddressValue_Value(t *testing.T) { |
||||
t.Run("nil receiver", func(t *testing.T) { |
||||
var addr *ServiceAddressValue |
||||
require.Equal(t, addr.Value(), api.ServiceAddress{Port: defaultMeshGatewayPort}) |
||||
}) |
||||
|
||||
t.Run("default value", func(t *testing.T) { |
||||
addr := &ServiceAddressValue{} |
||||
require.Equal(t, addr.Value(), api.ServiceAddress{Port: defaultMeshGatewayPort}) |
||||
}) |
||||
|
||||
t.Run("set value", func(t *testing.T) { |
||||
addr := &ServiceAddressValue{} |
||||
require.NoError(t, addr.Set("localhost:3333")) |
||||
require.Equal(t, addr.Value(), api.ServiceAddress{ |
||||
Address: "localhost", |
||||
Port: 3333, |
||||
}) |
||||
}) |
||||
} |
||||
|
||||
func TestServiceAddressValue_Set(t *testing.T) { |
||||
var testcases = []struct { |
||||
name string |
||||
input string |
||||
expectedErr string |
||||
expectedValue api.ServiceAddress |
||||
}{ |
||||
{ |
||||
name: "default port", |
||||
input: "8.8.8.8:", |
||||
expectedValue: api.ServiceAddress{ |
||||
Address: "8.8.8.8", |
||||
Port: defaultMeshGatewayPort, |
||||
}, |
||||
}, |
||||
{ |
||||
name: "valid address", |
||||
input: "8.8.8.8:1234", |
||||
expectedValue: api.ServiceAddress{Address: "8.8.8.8", Port: 1234}, |
||||
}, |
||||
{ |
||||
name: "invalid address", |
||||
input: "not-an-address", |
||||
expectedErr: "missing port in address", |
||||
}, |
||||
{ |
||||
name: "invalid port", |
||||
input: "localhost:notaport", |
||||
expectedErr: `Error parsing port "notaport"`, |
||||
}, |
||||
{ |
||||
name: "invalid address format", |
||||
input: "too:many:colons", |
||||
expectedErr: "address too:many:colons: too many colons", |
||||
}, |
||||
} |
||||
for _, tc := range testcases { |
||||
t.Run(tc.name, func(t *testing.T) { |
||||
addr := &ServiceAddressValue{} |
||||
err := addr.Set(tc.input) |
||||
if tc.expectedErr != "" { |
||||
require.Error(t, err) |
||||
require.Contains(t, err.Error(), tc.expectedErr) |
||||
return |
||||
} |
||||
|
||||
require.Equal(t, addr.Value(), tc.expectedValue) |
||||
}) |
||||
} |
||||
} |
Loading…
Reference in new issue