diff --git a/agent/consul/auto_encrypt.go b/agent/consul/auto_encrypt.go index 3acf15a614..21c5c149cc 100644 --- a/agent/consul/auto_encrypt.go +++ b/agent/consul/auto_encrypt.go @@ -4,7 +4,6 @@ import ( "fmt" "log" "net" - "strconv" "strings" "time" @@ -19,7 +18,7 @@ const ( retryJitterWindow = 30 * time.Second ) -func (c *Client) RequestAutoEncryptCerts(servers []string, defaultPort int, token string, interruptCh chan struct{}) (*structs.SignedResponse, string, error) { +func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token string, interruptCh chan struct{}) (*structs.SignedResponse, string, error) { errFn := func(err error) (*structs.SignedResponse, string, error) { return nil, "", err } @@ -82,7 +81,7 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, defaultPort int, toke // Translate host to net.TCPAddr to make life easier for // RPCInsecure. for _, s := range servers { - ips, port, err := resolveAddr(s, defaultPort, c.logger) + ips, err := resolveAddr(s, c.logger) if err != nil { c.logger.Printf("[WARN] agent: AutoEncrypt resolveAddr failed: %v", err) continue @@ -114,29 +113,26 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, defaultPort int, toke } } -// resolveAddr is used to resolve the host into IPs, port, and error. -// If no port is given, use the default -func resolveAddr(rawHost string, defaultPort int, logger *log.Logger) ([]net.IP, int, error) { - host, splitPort, err := net.SplitHostPort(rawHost) - if err != nil && err.Error() != fmt.Sprintf("address %s: missing port in address", rawHost) { - return nil, defaultPort, err - } +func missingPortError(host string, err error) bool { + return err != nil && err.Error() == fmt.Sprintf("address %s: missing port in address", host) +} - // SplitHostPort returns empty host and splitPort on missingPort err, - // so those are set to defaults - var port int +// resolveAddr is used to resolve the host into IPs and error. +func resolveAddr(rawHost string, logger *log.Logger) ([]net.IP, error) { + host, _, err := net.SplitHostPort(rawHost) if err != nil { - host = rawHost - port = defaultPort - } else { - port, err = strconv.Atoi(splitPort) - if err != nil { - port = defaultPort + // In case we encounter this error, we proceed with the + // rawHost. This is fine since -start-join and -retry-join + // take only hosts anyways and this is an expected case. + if missingPortError(rawHost, err) { + host = rawHost + } else { + return nil, err } } if ip := net.ParseIP(host); ip != nil { - return []net.IP{ip}, port, nil + return []net.IP{ip}, nil } // First try TCP so we have the best chance for the largest list of @@ -145,7 +141,7 @@ func resolveAddr(rawHost string, defaultPort int, logger *log.Logger) ([]net.IP, if ips, err := tcpLookupIP(host, logger); err != nil { logger.Printf("[DEBUG] agent: TCP-first lookup failed for '%s', falling back to UDP: %s", host, err) } else if len(ips) > 0 { - return ips, port, nil + return ips, nil } // If TCP didn't yield anything then use the normal Go resolver which @@ -153,9 +149,9 @@ func resolveAddr(rawHost string, defaultPort int, logger *log.Logger) ([]net.IP, // indicates it was truncated. ips, err := net.LookupIP(host) if err != nil { - return nil, port, err + return nil, err } - return ips, port, nil + return ips, nil } // tcpLookupIP is a helper to initiate a TCP-based DNS lookup for the given host. diff --git a/agent/consul/auto_encrypt_test.go b/agent/consul/auto_encrypt_test.go index 2a4daa012b..d27b2f9489 100644 --- a/agent/consul/auto_encrypt_test.go +++ b/agent/consul/auto_encrypt_test.go @@ -10,71 +10,70 @@ import ( func TestAutoEncrypt_resolveAddr(t *testing.T) { type args struct { - rawHost string - defaultPort int - logger *log.Logger + rawHost string + logger *log.Logger } tests := []struct { name string args args ips []net.IP - port int wantErr bool }{ { name: "host without port", args: args{ "127.0.0.1", - 8300, log.New(os.Stderr, "", log.LstdFlags), }, ips: []net.IP{net.IPv4(127, 0, 0, 1)}, - port: 8300, wantErr: false, }, { name: "host with port", args: args{ "127.0.0.1:1234", - 8300, log.New(os.Stderr, "", log.LstdFlags), }, ips: []net.IP{net.IPv4(127, 0, 0, 1)}, - port: 1234, wantErr: false, }, { name: "host with broken port", args: args{ "127.0.0.1:xyz", - 8300, log.New(os.Stderr, "", log.LstdFlags), }, ips: []net.IP{net.IPv4(127, 0, 0, 1)}, - port: 8300, wantErr: false, }, { name: "not an address", args: args{ "abc", - 8300, log.New(os.Stderr, "", log.LstdFlags), }, ips: nil, - port: 8300, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ips, port, err := resolveAddr(tt.args.rawHost, tt.args.defaultPort, tt.args.logger) + ips, err := resolveAddr(tt.args.rawHost, tt.args.logger) if (err != nil) != tt.wantErr { t.Errorf("resolveAddr error: %v, wantErr: %v", err, tt.wantErr) return } require.Equal(t, tt.ips, ips) - require.Equal(t, tt.port, port) }) } } + +func TestAutoEncrypt_missingPortError(t *testing.T) { + host := "127.0.0.1" + _, _, err := net.SplitHostPort(host) + require.True(t, missingPortError(host, err)) + + host = "127.0.0.1:1234" + _, _, err = net.SplitHostPort(host) + require.False(t, missingPortError(host, err)) +}