diff --git a/agent/health_endpoint.go b/agent/health_endpoint.go index 49027ffffe..7ece2b06e4 100644 --- a/agent/health_endpoint.go +++ b/agent/health_endpoint.go @@ -3,6 +3,7 @@ package agent import ( "fmt" "net/http" + "net/url" "strconv" "strings" @@ -186,6 +187,20 @@ func (s *HTTPServer) healthServiceNodes(resp http.ResponseWriter, req *http.Requ prefix = "/v1/health/connect/" } + // Check for ingress request only when requesting connect services + if connect { + ingress, err := getBoolQueryParam(params, "ingress") + if err != nil { + resp.WriteHeader(http.StatusBadRequest) + fmt.Fprint(resp, "Invalid value for ?ingress") + return nil, nil + } + if ingress { + args.Connect = false + args.Ingress = true + } + } + // Pull out the service name args.ServiceName = strings.TrimPrefix(req.URL.Path, prefix) if args.ServiceName == "" { @@ -224,26 +239,15 @@ func (s *HTTPServer) healthServiceNodes(resp http.ResponseWriter, req *http.Requ out.ConsistencyLevel = args.QueryOptions.ConsistencyLevel() // Filter to only passing if specified - if _, ok := params[api.HealthPassing]; ok { - val := params.Get(api.HealthPassing) - // Backwards-compat to allow users to specify ?passing without a value. This - // should be removed in Consul 0.10. - var filter bool - if val == "" { - filter = true - } else { - var err error - filter, err = strconv.ParseBool(val) - if err != nil { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Invalid value for ?passing") - return nil, nil - } - } + filter, err := getBoolQueryParam(params, api.HealthPassing) + if err != nil { + resp.WriteHeader(http.StatusBadRequest) + fmt.Fprint(resp, "Invalid value for ?passing") + return nil, nil + } - if filter { - out.Nodes = filterNonPassing(out.Nodes) - } + if filter { + out.Nodes = filterNonPassing(out.Nodes) } // Translate addresses after filtering so we don't waste effort. @@ -273,6 +277,27 @@ func (s *HTTPServer) healthServiceNodes(resp http.ResponseWriter, req *http.Requ return out.Nodes, nil } +func getBoolQueryParam(params url.Values, key string) (bool, error) { + var param bool + if _, ok := params[key]; ok { + val := params.Get(key) + // Orginally a comment declared this check should be removed after Consul + // 0.10, to no longer support using ?passing without a value. However, I + // think this is a reasonable experience for a user and so am keeping it + // here. + if val == "" { + param = true + } else { + var err error + param, err = strconv.ParseBool(val) + if err != nil { + return false, err + } + } + } + return param, nil +} + // filterNonPassing is used to filter out any nodes that have check that are not passing func filterNonPassing(nodes structs.CheckServiceNodes) structs.CheckServiceNodes { n := len(nodes) diff --git a/agent/health_endpoint_test.go b/agent/health_endpoint_test.go index 59ef97679d..03fa9c7479 100644 --- a/agent/health_endpoint_test.go +++ b/agent/health_endpoint_test.go @@ -1139,6 +1139,105 @@ func TestHealthConnectServiceNodes(t *testing.T) { assert.Len(nodes[0].Checks, 0) } +func TestHealthConnectServiceNodes_Ingress(t *testing.T) { + t.Parallel() + + a := NewTestAgent(t, "") + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + // Register gateway + gatewayArgs := structs.TestRegisterIngressGateway(t) + gatewayArgs.Service.Address = "127.0.0.27" + var out struct{} + require.NoError(t, a.RPC("Catalog.Register", gatewayArgs, &out)) + + args := structs.TestRegisterRequest(t) + require.NoError(t, a.RPC("Catalog.Register", args, &out)) + + // Associate service to gateway + cfgArgs := &structs.IngressGatewayConfigEntry{ + Name: "ingress-gateway", + Kind: structs.IngressGateway, + Listeners: []structs.IngressListener{ + { + Port: 8888, + Protocol: "tcp", + Services: []structs.IngressService{ + {Name: args.Service.Service}, + }, + }, + }, + } + + req := structs.ConfigEntryRequest{ + Op: structs.ConfigEntryUpsert, + Datacenter: "dc1", + Entry: cfgArgs, + } + var outB bool + require.Nil(t, a.RPC("ConfigEntry.Apply", req, &outB)) + require.True(t, outB) + + t.Run("no_query_value", func(t *testing.T) { + assert := assert.New(t) + req, _ := http.NewRequest("GET", fmt.Sprintf( + "/v1/health/connect/%s?ingress", args.Service.Service), nil) + resp := httptest.NewRecorder() + obj, err := a.srv.HealthConnectServiceNodes(resp, req) + assert.Nil(err) + assertIndex(t, resp) + + nodes := obj.(structs.CheckServiceNodes) + require.Len(t, nodes, 1) + require.Equal(t, structs.ServiceKindIngressGateway, nodes[0].Service.Kind) + require.Equal(t, gatewayArgs.Service.Address, nodes[0].Service.Address) + require.Equal(t, gatewayArgs.Service.Proxy, nodes[0].Service.Proxy) + }) + + t.Run("true_value", func(t *testing.T) { + assert := assert.New(t) + req, _ := http.NewRequest("GET", fmt.Sprintf( + "/v1/health/connect/%s?ingress=true", args.Service.Service), nil) + resp := httptest.NewRecorder() + obj, err := a.srv.HealthConnectServiceNodes(resp, req) + assert.Nil(err) + assertIndex(t, resp) + + nodes := obj.(structs.CheckServiceNodes) + require.Len(t, nodes, 1) + require.Equal(t, structs.ServiceKindIngressGateway, nodes[0].Service.Kind) + require.Equal(t, gatewayArgs.Service.Address, nodes[0].Service.Address) + require.Equal(t, gatewayArgs.Service.Proxy, nodes[0].Service.Proxy) + }) + + t.Run("false_value", func(t *testing.T) { + assert := assert.New(t) + req, _ := http.NewRequest("GET", fmt.Sprintf( + "/v1/health/connect/%s?ingress=false", args.Service.Service), nil) + resp := httptest.NewRecorder() + obj, err := a.srv.HealthConnectServiceNodes(resp, req) + assert.Nil(err) + assertIndex(t, resp) + + nodes := obj.(structs.CheckServiceNodes) + require.Len(t, nodes, 0) + }) + + t.Run("invalid_value", func(t *testing.T) { + assert := assert.New(t) + req, _ := http.NewRequest("GET", fmt.Sprintf( + "/v1/health/connect/%s?ingress=notabool", args.Service.Service), nil) + resp := httptest.NewRecorder() + _, err := a.srv.HealthConnectServiceNodes(resp, req) + assert.Equal(400, resp.Code) + + body, err := ioutil.ReadAll(resp.Body) + assert.Nil(err) + assert.True(bytes.Contains(body, []byte("Invalid value for ?ingress"))) + }) +} + func TestHealthConnectServiceNodes_Filter(t *testing.T) { t.Parallel()