diff --git a/command/agent/prepared_query_endpoint.go b/command/agent/prepared_query_endpoint.go index 5d9c07212d..4c549c6b0f 100644 --- a/command/agent/prepared_query_endpoint.go +++ b/command/agent/prepared_query_endpoint.go @@ -6,6 +6,7 @@ import ( "strconv" "strings" + "github.com/hashicorp/consul/consul" "github.com/hashicorp/consul/consul/structs" ) @@ -101,6 +102,13 @@ func (s *HTTPServer) PreparedQuerySpecific(resp http.ResponseWriter, req *http.R var reply structs.PreparedQueryExecuteResponse if err := s.agent.RPC(endpoint+".Execute", &args, &reply); err != nil { + // We have to check the string since the RPC sheds + // the specific error type. + if err.Error() == consul.ErrQueryNotFound.Error() { + resp.WriteHeader(404) + resp.Write([]byte(err.Error())) + return nil, nil + } return nil, err } return reply, nil @@ -114,6 +122,13 @@ func (s *HTTPServer) PreparedQuerySpecific(resp http.ResponseWriter, req *http.R var reply structs.IndexedPreparedQueries if err := s.agent.RPC(endpoint+".Get", &args, &reply); err != nil { + // We have to check the string since the RPC sheds + // the specific error type. + if err.Error() == consul.ErrQueryNotFound.Error() { + resp.WriteHeader(404) + resp.Write([]byte(err.Error())) + return nil, nil + } return nil, err } return reply.Queries, nil diff --git a/command/agent/prepared_query_endpoint_test.go b/command/agent/prepared_query_endpoint_test.go index 23a39b196c..41b905249a 100644 --- a/command/agent/prepared_query_endpoint_test.go +++ b/command/agent/prepared_query_endpoint_test.go @@ -245,6 +245,23 @@ func TestPreparedQuery_Execute(t *testing.T) { t.Fatalf("bad: %v", r) } }) + + httpTest(t, func(srv *HTTPServer) { + body := bytes.NewBuffer(nil) + req, err := http.NewRequest("GET", "/v1/query/not-there/execute", body) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := httptest.NewRecorder() + _, err = srv.PreparedQuerySpecific(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp.Code != 404 { + t.Fatalf("bad code: %d", resp.Code) + } + }) } func TestPreparedQuery_Get(t *testing.T) { @@ -296,6 +313,23 @@ func TestPreparedQuery_Get(t *testing.T) { t.Fatalf("bad: %v", r) } }) + + httpTest(t, func(srv *HTTPServer) { + body := bytes.NewBuffer(nil) + req, err := http.NewRequest("GET", "/v1/query/f004177f-2c28-83b7-4229-eacc25fe55d1", body) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := httptest.NewRecorder() + _, err = srv.PreparedQuerySpecific(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp.Code != 404 { + t.Fatalf("bad code: %d", resp.Code) + } + }) } func TestPreparedQuery_Update(t *testing.T) { diff --git a/consul/prepared_query_endpoint.go b/consul/prepared_query_endpoint.go index 77b8e68bb2..3cdb9c213b 100644 --- a/consul/prepared_query_endpoint.go +++ b/consul/prepared_query_endpoint.go @@ -206,18 +206,17 @@ func (p *PreparedQuery) Get(args *structs.PreparedQuerySpecificRequest, if err != nil { return err } + if query == nil { + return ErrQueryNotFound + } - if (query != nil) && (query.Token != args.Token) && (acl != nil && !acl.QueryList()) { + if (query.Token != args.Token) && (acl != nil && !acl.QueryList()) { p.srv.logger.Printf("[WARN] consul.prepared_query: Request to get prepared query '%s' denied because ACL didn't match ACL used to create the query, and a management token wasn't supplied", args.QueryID) return permissionDeniedErr } reply.Index = index - if query != nil { - reply.Queries = structs.PreparedQueries{query} - } else { - reply.Queries = nil - } + reply.Queries = structs.PreparedQueries{query} return nil }) diff --git a/consul/prepared_query_endpoint_test.go b/consul/prepared_query_endpoint_test.go index adb7dceea6..59d2297ea4 100644 --- a/consul/prepared_query_endpoint_test.go +++ b/consul/prepared_query_endpoint_test.go @@ -184,7 +184,9 @@ func TestPreparedQuery_Apply(t *testing.T) { } var resp structs.IndexedPreparedQueries if err = msgpackrpc.CallWithCodec(codec, "PreparedQuery.Get", req, &resp); err != nil { - t.Fatalf("err: %v", err) + if err.Error() != ErrQueryNotFound.Error() { + t.Fatalf("err: %v", err) + } } if len(resp.Queries) != 0 { @@ -363,7 +365,9 @@ func TestPreparedQuery_Apply_ACLDeny(t *testing.T) { } var resp structs.IndexedPreparedQueries if err = msgpackrpc.CallWithCodec(codec, "PreparedQuery.Get", req, &resp); err != nil { - t.Fatalf("err: %v", err) + if err.Error() != ErrQueryNotFound.Error() { + t.Fatalf("err: %v", err) + } } if len(resp.Queries) != 0 { @@ -492,7 +496,9 @@ func TestPreparedQuery_Apply_ACLDeny(t *testing.T) { } var resp structs.IndexedPreparedQueries if err = msgpackrpc.CallWithCodec(codec, "PreparedQuery.Get", req, &resp); err != nil { - t.Fatalf("err: %v", err) + if err.Error() != ErrQueryNotFound.Error() { + t.Fatalf("err: %v", err) + } } if len(resp.Queries) != 0 { @@ -792,7 +798,9 @@ func TestPreparedQuery_Get(t *testing.T) { } var resp structs.IndexedPreparedQueries if err := msgpackrpc.CallWithCodec(codec, "PreparedQuery.Get", req, &resp); err != nil { - t.Fatalf("err: %v", err) + if err.Error() != ErrQueryNotFound.Error() { + t.Fatalf("err: %v", err) + } } if len(resp.Queries) != 0 {