diff --git a/command/agent/agent_endpoint.go b/command/agent/agent_endpoint.go index 9c02611c3a..5a25f891e9 100644 --- a/command/agent/agent_endpoint.go +++ b/command/agent/agent_endpoint.go @@ -454,7 +454,7 @@ func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Re // Check the service address here and in the catalog RPC endpoint // since service registration isn't sychronous. - if args.Address == "0.0.0.0" { + if args.Address == "0.0.0.0" || args.Address == "::" || args.Address == "[::]" { resp.WriteHeader(400) fmt.Fprintf(resp, "Invalid service address") return nil, nil diff --git a/command/agent/agent_endpoint_test.go b/command/agent/agent_endpoint_test.go index dc40b32d56..fd15a8ad95 100644 --- a/command/agent/agent_endpoint_test.go +++ b/command/agent/agent_endpoint_test.go @@ -1503,24 +1503,28 @@ func TestAgent_RegisterService_InvalidAddress(t *testing.T) { defer srv.Shutdown() defer srv.agent.Shutdown() - req, err := http.NewRequest("GET", "/v1/agent/service/register?token=abc123", nil) - if err != nil { - t.Fatalf("err: %v", err) - } - args := &ServiceDefinition{ - Name: "test", - Address: "0.0.0.0", - Port: 8000, - } - req.Body = encodeReq(args) + for _, addr := range []string{"0.0.0.0", "::", "[::]"} { + t.Run("addr "+addr, func(t *testing.T) { + req, err := http.NewRequest("GET", "/v1/agent/service/register?token=abc123", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + args := &ServiceDefinition{ + Name: "test", + Address: addr, + Port: 8000, + } + req.Body = encodeReq(args) - resp := httptest.NewRecorder() - _, err = srv.AgentRegisterService(resp, req) - if got, want := resp.Code, 400; got != want { - t.Fatalf("got code %d want %d", got, want) - } - if got, want := resp.Body.String(), "Invalid service address"; got != want { - t.Fatalf("got body %q want %q", got, want) + resp := httptest.NewRecorder() + _, err = srv.AgentRegisterService(resp, req) + if got, want := resp.Code, 400; got != want { + t.Fatalf("got code %d want %d", got, want) + } + if got, want := resp.Body.String(), "Invalid service address"; got != want { + t.Fatalf("got body %q want %q", got, want) + } + }) } } diff --git a/command/agent/catalog_endpoint_test.go b/command/agent/catalog_endpoint_test.go index 6d04a94a13..f42a55dd87 100644 --- a/command/agent/catalog_endpoint_test.go +++ b/command/agent/catalog_endpoint_test.go @@ -63,25 +63,28 @@ func TestCatalogRegister_Service_InvalidAddress(t *testing.T) { testrpc.WaitForLeader(t, srv.agent.RPC, "dc1") - // Register node - req, err := http.NewRequest("GET", "/v1/catalog/register", nil) - if err != nil { - t.Fatalf("err: %v", err) - } - args := &structs.RegisterRequest{ - Node: "foo", - Address: "127.0.0.1", - Service: &structs.NodeService{ - Service: "test", - Address: "0.0.0.0", - Port: 8080, - }, - } - req.Body = encodeReq(args) + for _, addr := range []string{"0.0.0.0", "::", "[::]"} { + t.Run("addr "+addr, func(t *testing.T) { + req, err := http.NewRequest("GET", "/v1/catalog/register", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + args := &structs.RegisterRequest{ + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + Service: "test", + Address: addr, + Port: 8080, + }, + } + req.Body = encodeReq(args) - _, err = srv.CatalogRegister(nil, req) - if err == nil || err.Error() != "Invalid service address" { - t.Fatalf("err: %v", err) + _, err = srv.CatalogRegister(nil, req) + if err == nil || err.Error() != "Invalid service address" { + t.Fatalf("err: %v", err) + } + }) } } diff --git a/consul/catalog_endpoint.go b/consul/catalog_endpoint.go index 4e54245372..04b4eac146 100644 --- a/consul/catalog_endpoint.go +++ b/consul/catalog_endpoint.go @@ -54,7 +54,7 @@ func (c *Catalog) Register(args *structs.RegisterRequest, reply *struct{}) error // Check the service address here and in the agent endpoint // since service registration isn't sychronous. - if args.Service.Address == "0.0.0.0" { + if args.Service.Address == "0.0.0.0" || args.Service.Address == "::" || args.Service.Address == "[::]" { return fmt.Errorf("Invalid service address") } diff --git a/consul/catalog_endpoint_test.go b/consul/catalog_endpoint_test.go index 9737ce2fe2..86f1622fa3 100644 --- a/consul/catalog_endpoint_test.go +++ b/consul/catalog_endpoint_test.go @@ -53,21 +53,25 @@ func TestCatalog_RegisterService_InvalidAddress(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - arg := structs.RegisterRequest{ - Datacenter: "dc1", - Node: "foo", - Address: "127.0.0.1", - Service: &structs.NodeService{ - Service: "db", - Address: "0.0.0.0", - Port: 8000, - }, - } - var out struct{} + for _, addr := range []string{"0.0.0.0", "::", "[::]"} { + t.Run("addr "+addr, func(t *testing.T) { + arg := structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + Service: "db", + Address: addr, + Port: 8000, + }, + } + var out struct{} - err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &out) - if err == nil || err.Error() != "Invalid service address" { - t.Fatalf("got error %v want 'Invalid service address'", err) + err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &out) + if err == nil || err.Error() != "Invalid service address" { + t.Fatalf("got error %v want 'Invalid service address'", err) + } + }) } }