diff --git a/agent/config_endpoint_test.go b/agent/config_endpoint_test.go index e6d0dd217f..9e5798089a 100644 --- a/agent/config_endpoint_test.go +++ b/agent/config_endpoint_test.go @@ -213,17 +213,36 @@ func TestConfig_Apply_TerminatingGateway(t *testing.T) { require.NoError(t, err) require.Equal(t, 200, resp.Code, "!200 Response Code: %s", resp.Body.String()) - // Get the remaining entry. + // Attempt to create an entry for a separate gateway that also routes to web + body = bytes.NewBuffer([]byte(` + { + "Kind": "terminating-gateway", + "Name": "east-gw-01", + "Services": [ + { + "Name": "web", + } + ] + }`)) + + req, _ = http.NewRequest("PUT", "/v1/config", body) + resp = httptest.NewRecorder() + _, err = a.srv.ConfigApply(resp, req) + require.Error(t, err, "service \"web\" is associated with a different gateway") + require.Equal(t, 200, resp.Code, "!200 Response Code: %s", resp.Body.String()) + + // List all entries, there should only be one { args := structs.ConfigEntryQuery{ Kind: structs.TerminatingGateway, - Name: "west-gw-01", Datacenter: "dc1", } - var out structs.ConfigEntryResponse - require.NoError(t, a.RPC("ConfigEntry.Get", &args, &out)) - require.NotNil(t, out.Entry) - got := out.Entry.(*structs.TerminatingGatewayConfigEntry) + var out structs.IndexedConfigEntries + require.NoError(t, a.RPC("ConfigEntry.List", &args, &out)) + require.NotNil(t, out) + require.Len(t, out.Entries, 1) + + got := out.Entries[0].(*structs.TerminatingGatewayConfigEntry) expect := []structs.LinkedService{ { Name: "web", diff --git a/agent/consul/acl.go b/agent/consul/acl.go index 9e256e95d1..e1e45056a3 100644 --- a/agent/consul/acl.go +++ b/agent/consul/acl.go @@ -1702,6 +1702,24 @@ func (f *aclFilter) filterServiceList(services *structs.ServiceList) { *services = ret } +// filterGatewayServices is used to filter gateway to service mappings based on ACL rules. +func (f *aclFilter) filterGatewayServices(mappings *structs.GatewayServices) { + ret := make(structs.GatewayServices, 0, len(*mappings)) + for _, s := range *mappings { + // This filter only checks ServiceRead on the linked service. + // ServiceRead on the gateway is checked in the GatewayServices endpoint before filtering. + var authzContext acl.AuthorizerContext + s.Service.FillAuthzContext(&authzContext) + + if f.authorizer.ServiceRead(s.Service.ID, &authzContext) != acl.Allow { + f.logger.Debug("dropping service from result due to ACLs", "service", s.Service.String()) + continue + } + ret = append(ret, s) + } + *mappings = ret +} + func (r *ACLResolver) filterACLWithAuthorizer(authorizer acl.Authorizer, subj interface{}) error { if authorizer == nil { return nil @@ -1786,6 +1804,10 @@ func (r *ACLResolver) filterACLWithAuthorizer(authorizer acl.Authorizer, subj in case *structs.IndexedServiceList: filt.filterServiceList(&v.Services) + + case *structs.GatewayServices: + filt.filterGatewayServices(v) + default: panic(fmt.Errorf("Unhandled type passed to ACL filter: %T %#v", subj, subj)) } diff --git a/agent/consul/catalog_endpoint_test.go b/agent/consul/catalog_endpoint_test.go index 3309785f01..55e39067f9 100644 --- a/agent/consul/catalog_endpoint_test.go +++ b/agent/consul/catalog_endpoint_test.go @@ -2071,6 +2071,122 @@ func TestCatalog_ListServiceNodes_ConnectProxy(t *testing.T) { assert.Equal(args.Service.Proxy.DestinationServiceName, v.ServiceProxy.DestinationServiceName) } +func TestCatalog_ServiceNodes_Gateway(t *testing.T) { + t.Parallel() + + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") + { + var out struct{} + + // Register a service "api" + args := structs.TestRegisterRequest(t) + args.Service.Service = "api" + args.Check = &structs.HealthCheck{ + Name: "api", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + } + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + // Register a proxy for api + args = structs.TestRegisterRequestProxy(t) + args.Service.Service = "api-proxy" + args.Service.Proxy.DestinationServiceName = "api" + args.Check = &structs.HealthCheck{ + Name: "api-proxy", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + } + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + // Register a service "web" + args = structs.TestRegisterRequest(t) + args.Check = &structs.HealthCheck{ + Name: "web", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + } + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + // Register a proxy for web + args = structs.TestRegisterRequestProxy(t) + args.Check = &structs.HealthCheck{ + Name: "web-proxy", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + } + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + // Register a gateway for web + args = &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + Kind: structs.ServiceKindTerminatingGateway, + Service: "gateway", + Port: 443, + }, + Check: &structs.HealthCheck{ + Name: "gateway", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + }, + } + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + entryArgs := &structs.ConfigEntryRequest{ + Op: structs.ConfigEntryUpsert, + Datacenter: "dc1", + Entry: &structs.TerminatingGatewayConfigEntry{ + Kind: "terminating-gateway", + Name: "gateway", + Services: []structs.LinkedService{ + { + Name: "web", + }, + }, + }, + } + var entryResp bool + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &entryArgs, &entryResp)) + } + + retry.Run(t, func(r *retry.R) { + // List should return both the terminating-gateway and the connect-proxy associated with web + req := structs.ServiceSpecificRequest{ + Connect: true, + Datacenter: "dc1", + ServiceName: "web", + } + var resp structs.IndexedServiceNodes + assert.Nil(r, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp)) + assert.Len(r, resp.ServiceNodes, 2) + + // Check sidecar + assert.Equal(r, structs.ServiceKindConnectProxy, resp.ServiceNodes[0].ServiceKind) + assert.Equal(r, "foo", resp.ServiceNodes[0].Node) + assert.Equal(r, "web-proxy", resp.ServiceNodes[0].ServiceName) + assert.Equal(r, "web-proxy", resp.ServiceNodes[0].ServiceID) + assert.Equal(r, "web", resp.ServiceNodes[0].ServiceProxy.DestinationServiceName) + assert.Equal(r, 2222, resp.ServiceNodes[0].ServicePort) + + // Check gateway + assert.Equal(r, structs.ServiceKindTerminatingGateway, resp.ServiceNodes[1].ServiceKind) + assert.Equal(r, "foo", resp.ServiceNodes[1].Node) + assert.Equal(r, "gateway", resp.ServiceNodes[1].ServiceName) + assert.Equal(r, "gateway", resp.ServiceNodes[1].ServiceID) + assert.Equal(r, 443, resp.ServiceNodes[1].ServicePort) + }) +} + func TestCatalog_ListServiceNodes_ConnectDestination(t *testing.T) { t.Parallel() diff --git a/agent/consul/health_endpoint_test.go b/agent/consul/health_endpoint_test.go index 9974bfc9f6..92b6290a80 100644 --- a/agent/consul/health_endpoint_test.go +++ b/agent/consul/health_endpoint_test.go @@ -8,6 +8,7 @@ import ( "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/lib" + "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/types" "github.com/hashicorp/net-rpc-msgpackrpc" @@ -1026,6 +1027,122 @@ service "foo" { assert.Len(resp.Nodes, 1) } +func TestHealth_ServiceNodes_Gateway(t *testing.T) { + t.Parallel() + + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") + { + var out struct{} + + // Register a service "api" + args := structs.TestRegisterRequest(t) + args.Service.Service = "api" + args.Check = &structs.HealthCheck{ + Name: "api", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + } + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + // Register a proxy for api + args = structs.TestRegisterRequestProxy(t) + args.Service.Service = "api-proxy" + args.Service.Proxy.DestinationServiceName = "api" + args.Check = &structs.HealthCheck{ + Name: "api-proxy", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + } + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + // Register a service "web" + args = structs.TestRegisterRequest(t) + args.Check = &structs.HealthCheck{ + Name: "web", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + } + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + // Register a proxy for web + args = structs.TestRegisterRequestProxy(t) + args.Check = &structs.HealthCheck{ + Name: "proxy", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + } + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + // Register a gateway for web + args = &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + Kind: structs.ServiceKindTerminatingGateway, + Service: "gateway", + Port: 443, + }, + Check: &structs.HealthCheck{ + Name: "gateway", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + }, + } + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + entryArgs := &structs.ConfigEntryRequest{ + Op: structs.ConfigEntryUpsert, + Datacenter: "dc1", + Entry: &structs.TerminatingGatewayConfigEntry{ + Kind: "terminating-gateway", + Name: "gateway", + Services: []structs.LinkedService{ + { + Name: "web", + }, + }, + }, + } + var entryResp bool + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &entryArgs, &entryResp)) + } + + retry.Run(t, func(r *retry.R) { + // List should return both the terminating-gateway and the connect-proxy associated with web + req := structs.ServiceSpecificRequest{ + Connect: true, + Datacenter: "dc1", + ServiceName: "web", + } + var resp structs.IndexedCheckServiceNodes + assert.Nil(r, msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp)) + assert.Len(r, resp.Nodes, 2) + + // Check sidecar + assert.Equal(r, structs.ServiceKindConnectProxy, resp.Nodes[0].Service.Kind) + assert.Equal(r, "foo", resp.Nodes[0].Node.Node) + assert.Equal(r, "web-proxy", resp.Nodes[0].Service.Service) + assert.Equal(r, "web-proxy", resp.Nodes[0].Service.ID) + assert.Equal(r, "web", resp.Nodes[0].Service.Proxy.DestinationServiceName) + assert.Equal(r, 2222, resp.Nodes[0].Service.Port) + + // Check gateway + assert.Equal(r, structs.ServiceKindTerminatingGateway, resp.Nodes[1].Service.Kind) + assert.Equal(r, "foo", resp.Nodes[1].Node.Node) + assert.Equal(r, "gateway", resp.Nodes[1].Service.Service) + assert.Equal(r, "gateway", resp.Nodes[1].Service.ID) + assert.Equal(r, 443, resp.Nodes[1].Service.Port) + }) +} + func TestHealth_NodeChecks_FilterACL(t *testing.T) { t.Parallel() dir, token, srv, codec := testACLFilterServer(t) diff --git a/agent/consul/internal_endpoint.go b/agent/consul/internal_endpoint.go index ff3e0d52a2..b577f42a3f 100644 --- a/agent/consul/internal_endpoint.go +++ b/agent/consul/internal_endpoint.go @@ -296,3 +296,46 @@ func (m *Internal) aclAccessorID(secretID string) string { } return ident.ID() } + +func (m *Internal) GatewayServices(args *structs.ServiceSpecificRequest, reply *structs.IndexedGatewayServices) error { + if done, err := m.srv.forward("Internal.GatewayServices", args, args, reply); done { + return err + } + + var authzContext acl.AuthorizerContext + authz, err := m.srv.ResolveTokenAndDefaultMeta(args.Token, &args.EnterpriseMeta, &authzContext) + if err != nil { + return err + } + + if err := m.srv.validateEnterpriseRequest(&args.EnterpriseMeta, false); err != nil { + return err + } + + if authz != nil && authz.ServiceRead(args.ServiceName, &authzContext) != acl.Allow { + return acl.ErrPermissionDenied + } + + return m.srv.blockingQuery( + &args.QueryOptions, + &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + var index uint64 + var services structs.GatewayServices + + switch args.ServiceKind { + case structs.ServiceKindTerminatingGateway: + index, services, err = state.TerminatingGatewayServices(ws, args.ServiceName, &args.EnterpriseMeta) + if err != nil { + return err + } + } + + if err := m.srv.filterACL(args.Token, &services); err != nil { + return err + } + + reply.Index, reply.Services = index, services + return nil + }) +} diff --git a/agent/consul/internal_endpoint_test.go b/agent/consul/internal_endpoint_test.go index bb61e91083..46fa10a40e 100644 --- a/agent/consul/internal_endpoint_test.go +++ b/agent/consul/internal_endpoint_test.go @@ -2,6 +2,7 @@ package consul import ( "encoding/base64" + "github.com/hashicorp/consul/sdk/testutil/retry" "os" "strings" "testing" @@ -12,7 +13,7 @@ import ( "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/testrpc" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" - + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -654,3 +655,322 @@ func TestInternal_ServiceDump_Kind(t *testing.T) { require.Equal(t, "web-proxy", nodes[0].Service.ID) }) } + +func TestInternal_TerminatingGatewayServices(t *testing.T) { + t.Parallel() + + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") + { + var out struct{} + + // Register a service "api" + args := structs.TestRegisterRequest(t) + args.Service.Service = "api" + args.Check = &structs.HealthCheck{ + Name: "api", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + } + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + // Register a service "db" + args = structs.TestRegisterRequest(t) + args.Service.Service = "db" + args.Check = &structs.HealthCheck{ + Name: "db", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + } + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + // Register a service "redis" + args = structs.TestRegisterRequest(t) + args.Service.Service = "redis" + args.Check = &structs.HealthCheck{ + Name: "redis", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + } + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + // Register a gateway + args = &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + Kind: structs.ServiceKindTerminatingGateway, + Service: "gateway", + Port: 443, + }, + Check: &structs.HealthCheck{ + Name: "gateway", + Status: api.HealthPassing, + ServiceID: "gateway", + }, + } + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + entryArgs := &structs.ConfigEntryRequest{ + Op: structs.ConfigEntryUpsert, + Datacenter: "dc1", + Entry: &structs.TerminatingGatewayConfigEntry{ + Kind: "terminating-gateway", + Name: "gateway", + Services: []structs.LinkedService{ + { + Name: "api", + CAFile: "api/ca.crt", + CertFile: "api/client.crt", + KeyFile: "api/client.key", + }, + { + Name: "db", + }, + { + Name: "*", + CAFile: "ca.crt", + CertFile: "client.crt", + KeyFile: "client.key", + }, + }, + }, + } + var entryResp bool + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &entryArgs, &entryResp)) + } + + retry.Run(t, func(r *retry.R) { + // List should return all three services + req := structs.ServiceSpecificRequest{ + Datacenter: "dc1", + ServiceName: "gateway", + ServiceKind: structs.ServiceKindTerminatingGateway, + } + var resp structs.IndexedGatewayServices + assert.Nil(r, msgpackrpc.CallWithCodec(codec, "Internal.GatewayServices", &req, &resp)) + assert.Len(r, resp.Services, 3) + + expect := structs.GatewayServices{ + { + Service: structs.NewServiceID("api", nil), + Gateway: structs.NewServiceID("gateway", nil), + GatewayKind: structs.ServiceKindTerminatingGateway, + CAFile: "api/ca.crt", + CertFile: "api/client.crt", + KeyFile: "api/client.key", + }, + { + Service: structs.NewServiceID("db", nil), + Gateway: structs.NewServiceID("gateway", nil), + GatewayKind: structs.ServiceKindTerminatingGateway, + CAFile: "", + CertFile: "", + KeyFile: "", + }, + { + Service: structs.NewServiceID("redis", nil), + Gateway: structs.NewServiceID("gateway", nil), + GatewayKind: structs.ServiceKindTerminatingGateway, + CAFile: "ca.crt", + CertFile: "client.crt", + KeyFile: "client.key", + }, + } + assert.Equal(r, expect, resp.Services) + }) +} + +func TestInternal_TerminatingGatewayServices_ACLFiltering(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLEnforceVersion8 = true + c.ACLMasterToken = "root" + c.ACLDefaultPolicy = "deny" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForTestAgent(t, s1.RPC, "dc1", testrpc.WithToken("root")) + + { + var out struct{} + + // Register a service "api" + args := structs.TestRegisterRequest(t) + args.Service.Service = "api" + args.Check = &structs.HealthCheck{ + Name: "api", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + } + args.Token = "root" + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + // Register a service "db" + args = structs.TestRegisterRequest(t) + args.Service.Service = "db" + args.Check = &structs.HealthCheck{ + Name: "db", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + } + args.Token = "root" + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + // Register a service "redis" + args = structs.TestRegisterRequest(t) + args.Service.Service = "redis" + args.Check = &structs.HealthCheck{ + Name: "redis", + Status: api.HealthPassing, + ServiceID: args.Service.Service, + } + args.Token = "root" + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + // Register a gateway + args = &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + Kind: structs.ServiceKindTerminatingGateway, + Service: "gateway", + Port: 443, + }, + Check: &structs.HealthCheck{ + Name: "gateway", + Status: api.HealthPassing, + ServiceID: "gateway", + }, + } + args.Token = "root" + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) + + entryArgs := &structs.ConfigEntryRequest{ + Op: structs.ConfigEntryUpsert, + Datacenter: "dc1", + Entry: &structs.TerminatingGatewayConfigEntry{ + Kind: "terminating-gateway", + Name: "gateway", + Services: []structs.LinkedService{ + { + Name: "api", + CAFile: "api/ca.crt", + CertFile: "api/client.crt", + KeyFile: "api/client.key", + }, + { + Name: "db", + }, + { + Name: "db_replica", + }, + { + Name: "*", + CAFile: "ca.crt", + CertFile: "client.crt", + KeyFile: "client.key", + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var entryResp bool + assert.Nil(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &entryArgs, &entryResp)) + } + + rules := ` +service_prefix "db" { + policy = "read" +} +` + svcToken, err := upsertTestTokenWithPolicyRules(codec, "root", "dc1", rules) + require.NoError(t, err) + + retry.Run(t, func(r *retry.R) { + // List should return an empty list, since we do not have read on the gateway + req := structs.ServiceSpecificRequest{ + Datacenter: "dc1", + ServiceName: "gateway", + ServiceKind: structs.ServiceKindTerminatingGateway, + QueryOptions: structs.QueryOptions{Token: svcToken.SecretID}, + } + var resp structs.IndexedGatewayServices + err := msgpackrpc.CallWithCodec(codec, "Internal.GatewayServices", &req, &resp) + require.True(r, acl.IsErrPermissionDenied(err)) + }) + + rules = ` +service "gateway" { + policy = "read" +} +` + gwToken, err := upsertTestTokenWithPolicyRules(codec, "root", "dc1", rules) + require.NoError(t, err) + + retry.Run(t, func(r *retry.R) { + // List should return an empty list, since we do not have read on db + req := structs.ServiceSpecificRequest{ + Datacenter: "dc1", + ServiceName: "gateway", + ServiceKind: structs.ServiceKindTerminatingGateway, + QueryOptions: structs.QueryOptions{Token: gwToken.SecretID}, + } + var resp structs.IndexedGatewayServices + assert.Nil(r, msgpackrpc.CallWithCodec(codec, "Internal.GatewayServices", &req, &resp)) + assert.Len(r, resp.Services, 0) + }) + + rules = ` +service_prefix "db" { + policy = "read" +} +service "gateway" { + policy = "read" +} +` + validToken, err := upsertTestTokenWithPolicyRules(codec, "root", "dc1", rules) + require.NoError(t, err) + + retry.Run(t, func(r *retry.R) { + // List should return db entry since we have read on db and gateway + req := structs.ServiceSpecificRequest{ + Datacenter: "dc1", + ServiceName: "gateway", + ServiceKind: structs.ServiceKindTerminatingGateway, + QueryOptions: structs.QueryOptions{Token: validToken.SecretID}, + } + var resp structs.IndexedGatewayServices + assert.Nil(r, msgpackrpc.CallWithCodec(codec, "Internal.GatewayServices", &req, &resp)) + assert.Len(r, resp.Services, 2) + + expect := structs.GatewayServices{ + { + Service: structs.NewServiceID("db", nil), + Gateway: structs.NewServiceID("gateway", nil), + GatewayKind: structs.ServiceKindTerminatingGateway, + }, + { + Service: structs.NewServiceID("db_replica", nil), + Gateway: structs.NewServiceID("gateway", nil), + GatewayKind: structs.ServiceKindTerminatingGateway, + }, + } + assert.Equal(r, expect, resp.Services) + }) +} diff --git a/agent/consul/state/catalog.go b/agent/consul/state/catalog.go index 3f045ad2dc..78bc0ab7d5 100644 --- a/agent/consul/state/catalog.go +++ b/agent/consul/state/catalog.go @@ -2,6 +2,7 @@ package state import ( "fmt" + "reflect" "strings" "github.com/hashicorp/consul/agent/structs" @@ -12,7 +13,8 @@ import ( ) const ( - servicesTableName = "services" + servicesTableName = "services" + terminatingGatewayServicesTableName = "terminating-gateway-services" // serviceLastExtinctionIndexName keeps track of the last raft index when the last instance // of any service was unregistered. This is used by blocking queries on missing services. @@ -55,10 +57,108 @@ func nodesTableSchema() *memdb.TableSchema { } } +// terminatingGatewayServicesTableSchema returns a new table schema used to store information +// about services associated with terminating gateways. +func terminatingGatewayServicesTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: terminatingGatewayServicesTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &ServiceIDIndex{ + Field: "Gateway", + }, + &ServiceIDIndex{ + Field: "Service", + }, + }, + }, + }, + "gateway": { + Name: "gateway", + AllowMissing: false, + Unique: false, + Indexer: &ServiceIDIndex{ + Field: "Gateway", + }, + }, + "service": { + Name: "service", + AllowMissing: true, + Unique: false, + Indexer: &ServiceIDIndex{ + Field: "Service", + }, + }, + }, + } +} + +type ServiceIDIndex struct { + Field string +} + +func (index *ServiceIDIndex) FromObject(obj interface{}) (bool, []byte, error) { + v := reflect.ValueOf(obj) + v = reflect.Indirect(v) // Dereference the pointer if any + + fv := v.FieldByName(index.Field) + isPtr := fv.Kind() == reflect.Ptr + fv = reflect.Indirect(fv) + if !isPtr && !fv.IsValid() || !fv.CanInterface() { + return false, nil, + fmt.Errorf("field '%s' for %#v is invalid %v ", index.Field, obj, isPtr) + } + + sid, ok := fv.Interface().(structs.ServiceID) + if !ok { + return false, nil, fmt.Errorf("Field 'ServiceID' is not of type structs.ServiceID") + } + + // Enforce lowercase and add null character as terminator + id := strings.ToLower(sid.String()) + "\x00" + + return true, []byte(id), nil +} + +func (index *ServiceIDIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + sid, ok := args[0].(structs.ServiceID) + if !ok { + return nil, fmt.Errorf("argument must be of type structs.ServiceID: %#v", args[0]) + } + + // Enforce lowercase and add null character as terminator + id := strings.ToLower(sid.String()) + "\x00" + + return []byte(strings.ToLower(id)), nil +} + +func (index *ServiceIDIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { + val, err := index.FromArgs(args...) + if err != nil { + return nil, err + } + + // Strip the null terminator, the rest is a prefix + n := len(val) + if n > 0 { + return val[:n-1], nil + } + return val, nil +} + func init() { registerSchema(nodesTableSchema) registerSchema(servicesTableSchema) registerSchema(checksTableSchema) + registerSchema(terminatingGatewayServicesTableSchema) } const ( @@ -674,6 +774,18 @@ func (s *Store) ensureServiceTxn(tx *memdb.Txn, idx uint64, node string, svc *st if err = structs.ValidateServiceMetadata(svc.Kind, svc.Meta, false); err != nil { return fmt.Errorf("Invalid Service Meta for node %s and serviceID %s: %v", node, svc.ID, err) } + + // Check if this service is covered by a terminating gateway's wildcard specifier + gateway, err := s.serviceTerminatingGateway(tx, structs.WildcardSpecifier, &svc.EnterpriseMeta) + if err != nil { + return fmt.Errorf("failed gateway lookup for %q: %s", svc.Service, err) + } + if gatewaySvc, ok := gateway.(*structs.GatewayService); ok && gatewaySvc != nil { + if err = s.updateTerminatingGatewayService(tx, idx, gatewaySvc.Gateway, svc.Service, &svc.EnterpriseMeta); err != nil { + return fmt.Errorf("Failed to associate service %q with gateway %q", gatewaySvc.Service.String(), gatewaySvc.Gateway.String()) + } + } + // Create the service node entry and populate the indexes. Note that // conversion doesn't populate any of the node-specific information. // That's always populated when we read from the state store. @@ -922,6 +1034,22 @@ func (s *Store) serviceNodes(ws memdb.WatchSet, serviceName string, connect bool results = append(results, service.(*structs.ServiceNode)) } + // If we are querying for Connect nodes, the associated proxy might be a gateway. + // Gateways are tracked in a separate table, and we append them to the result set. + // We append rather than replace since it allows users to migrate a service + // to the mesh with a mix of sidecars and gateways until all its instances have a sidecar. + if connect { + // Look up gateway nodes associated with the service + nodes, ch, err := s.serviceTerminatingGatewayNodes(tx, serviceName, entMeta) + if err != nil { + return 0, nil, fmt.Errorf("failed gateway nodes lookup: %v", err) + } + ws.Add(ch) + for i := 0; i < len(nodes); i++ { + results = append(results, nodes[i]) + } + } + // Fill in the node details. results, err = s.parseServiceNodes(tx, ws, results) if err != nil { @@ -1330,10 +1458,25 @@ func (s *Store) deleteServiceTxn(tx *memdb.Txn, idx uint64, nodeName, serviceID if err := s.catalogUpdateServiceExtinctionIndex(tx, idx, entMeta); err != nil { return err } + + // Clean up association between service name and gateway + gateway, err := s.serviceTerminatingGateway(tx, svc.ServiceName, &svc.EnterpriseMeta) + if err != nil { + return fmt.Errorf("failed gateway lookup for %q: %s", svc.ServiceName, err) + } + if gateway != nil { + if err := tx.Delete(terminatingGatewayServicesTableName, gateway); err != nil { + return fmt.Errorf("failed to delete gateway mapping for %q: %v", svc.ServiceName, err) + } + if err := indexUpdateMaxTxn(tx, idx, terminatingGatewayServicesTableName); err != nil { + return fmt.Errorf("failed updating terminating-gateway-services index: %v", err) + } + } } } else { return fmt.Errorf("Could not find any service %s: %s", svc.ServiceName, err) } + return nil } @@ -1836,6 +1979,22 @@ func (s *Store) checkServiceNodes(ws memdb.WatchSet, serviceName string, connect serviceNames[sn.ServiceName] = struct{}{} } + // If we are querying for Connect nodes, the associated proxy might be a gateway. + // Gateways are tracked in a separate table, and we append them to the result set. + // We append rather than replace since it allows users to migrate a service + // to the mesh with a mix of sidecars and gateways until all its instances have a sidecar. + if connect { + // Look up gateway nodes associated with the service + nodes, _, err := s.serviceTerminatingGatewayNodes(tx, serviceName, entMeta) + if err != nil { + return 0, nil, fmt.Errorf("failed gateway nodes lookup: %v", err) + } + for i := 0; i < len(nodes); i++ { + results = append(results, nodes[i]) + serviceNames[nodes[i].ServiceName] = struct{}{} + } + } + // watchOptimized tracks if we meet the necessary condition to optimize // WatchSet size. That is that every service name represented in the result // set must have a service-specific index we can watch instead of many radix @@ -1936,6 +2095,30 @@ func (s *Store) CheckServiceTagNodes(ws memdb.WatchSet, serviceName string, tags return s.parseCheckServiceNodes(tx, ws, idx, serviceName, results, err) } +// TerminatingGatewayServices is used to query all services associated with a terminating gateway +func (s *Store) TerminatingGatewayServices(ws memdb.WatchSet, gateway string, entMeta *structs.EnterpriseMeta) (uint64, structs.GatewayServices, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + iter, err := s.terminatingGatewayServices(tx, gateway, entMeta) + if err != nil { + return 0, nil, fmt.Errorf("failed gateway services lookup: %s", err) + } + ws.Add(iter.WatchCh()) + + var results structs.GatewayServices + for service := iter.Next(); service != nil; service = iter.Next() { + svc := service.(*structs.GatewayService) + + if svc.Service.ID != structs.WildcardSpecifier { + results = append(results, svc) + } + } + + idx := maxIndexTxn(tx, terminatingGatewayServicesTableName) + return idx, results, nil +} + // parseCheckServiceNodes is used to parse through a given set of services, // and query for an associated node and a set of checks. This is the inner // method used to return a rich set of results from a more simple query. @@ -2179,3 +2362,213 @@ func checkSessionsTxn(tx *memdb.Txn, hc *structs.HealthCheck) ([]*sessionCheck, } return sessions, nil } + +// updateGatewayService associates services with gateways as specified in a terminating-gateway config entry +func (s *Store) updateTerminatingGatewayServices(tx *memdb.Txn, idx uint64, conf structs.ConfigEntry, entMeta *structs.EnterpriseMeta) error { + entry, ok := conf.(*structs.TerminatingGatewayConfigEntry) + if !ok { + return fmt.Errorf("unexpected config entry type: %T", conf) + } + + // Check if service list matches the last known list for the config entry, if it does, skip the update + _, c, err := s.configEntryTxn(tx, nil, conf.GetKind(), conf.GetName(), entMeta) + if err != nil { + return fmt.Errorf("failed to get config entry: %v", err) + } + if cfg, ok := c.(*structs.TerminatingGatewayConfigEntry); ok && cfg != nil { + if reflect.DeepEqual(cfg.Services, entry.Services) { + // Services are the same, nothing to update + return nil + } + } + + // Delete all associated with gateway first, to avoid keeping mappings that were removed + if _, err := tx.DeleteAll(terminatingGatewayServicesTableName, "gateway", structs.NewServiceID(entry.Name, entMeta)); err != nil { + return fmt.Errorf("failed to truncate gateway services table: %v", err) + } + + gatewayID := structs.NewServiceID(entry.Name, &entry.EnterpriseMeta) + for _, svc := range entry.Services { + // If the service is a wildcard we need to target all services within the namespace + if svc.Name == structs.WildcardSpecifier { + if err := s.updateTerminatingGatewayNamespace(tx, gatewayID, svc, entMeta); err != nil { + return fmt.Errorf("failed to associate gateway %q with wildcard: %v", gatewayID.String(), err) + } + // Skip service-specific update below if there was a wildcard update + continue + } + + // Check if the non-wildcard service is already associated with a gateway + existing, err := s.serviceTerminatingGateway(tx, svc.Name, &svc.EnterpriseMeta) + if err != nil { + return fmt.Errorf("gateway service lookup failed: %s", err) + } + if gs, ok := existing.(*structs.GatewayService); ok && gs != nil { + // Only return an error if the stored gateway does not match the one from the config entry + if !gs.Gateway.Matches(&gatewayID) { + return fmt.Errorf("service %q is associated with different gateway, %q", gs.Service.String(), gs.Gateway.String()) + } + } + + // Since this service was specified on its own, and not with a wildcard, + // if there is an existing entry, we overwrite it. The service entry is the source of truth. + // + // By extension, if TLS creds are provided with a wildcard but are not provided in + // the service entry, the service does not inherit the creds from the wildcard. + mapping := &structs.GatewayService{ + Gateway: gatewayID, + Service: structs.NewServiceID(svc.Name, &svc.EnterpriseMeta), + GatewayKind: structs.ServiceKindTerminatingGateway, + KeyFile: svc.KeyFile, + CertFile: svc.CertFile, + CAFile: svc.CAFile, + } + if err := tx.Insert(terminatingGatewayServicesTableName, mapping); err != nil { + return fmt.Errorf("failed inserting gateway service mapping: %s", err) + } + } + + if err := indexUpdateMaxTxn(tx, idx, terminatingGatewayServicesTableName); err != nil { + return fmt.Errorf("failed updating terminating-gateway-services index: %v", err) + } + return nil +} + +// updateTerminatingGatewayNamespace is used to target all services within a namespace with a set of TLS certificates +func (s *Store) updateTerminatingGatewayNamespace(tx *memdb.Txn, gateway structs.ServiceID, service structs.LinkedService, entMeta *structs.EnterpriseMeta) error { + services, err := s.catalogServiceListByKind(tx, structs.ServiceKindTypical, entMeta) + if err != nil { + return fmt.Errorf("failed querying services: %s", err) + } + + // Iterate over services in namespace and insert mapping for each + for svc := services.Next(); svc != nil; svc = services.Next() { + sn := svc.(*structs.ServiceNode) + + // Only associate non-consul services with gateways + if sn.ServiceName == "consul" { + continue + } + + existing, err := s.serviceTerminatingGateway(tx, sn.ServiceName, &sn.EnterpriseMeta) + if err != nil { + return fmt.Errorf("gateway service lookup failed: %s", err) + } + + if gs, ok := existing.(*structs.GatewayService); ok && gs != nil { + // Return an error if the wildcard is attempting to cover a service specified by a different gateway's config entry + if !gs.Gateway.Matches(&gateway) { + return fmt.Errorf("service %q is associated with different gateway, %q", gs.Service.String(), gs.Gateway.String()) + } + + // If there's an existing service associated with this gateway then we skip it. + // This means the service was specified on its own, and the service entry overrides the wildcard entry. + continue + } + + mapping := &structs.GatewayService{ + Gateway: gateway, + Service: structs.NewServiceID(sn.ServiceName, &service.EnterpriseMeta), + GatewayKind: structs.ServiceKindTerminatingGateway, + KeyFile: service.KeyFile, + CertFile: service.CertFile, + CAFile: service.CAFile, + } + if err := tx.Insert(terminatingGatewayServicesTableName, mapping); err != nil { + return fmt.Errorf("failed inserting gateway service mapping: %s", err) + } + } + + // Also store a mapping for the wildcard so that the TLS creds can be pulled + // for new services registered in its namespace + mapping := &structs.GatewayService{ + Gateway: gateway, + Service: structs.NewServiceID(service.Name, &service.EnterpriseMeta), + GatewayKind: structs.ServiceKindTerminatingGateway, + KeyFile: service.KeyFile, + CertFile: service.CertFile, + CAFile: service.CAFile, + } + if err := tx.Insert(terminatingGatewayServicesTableName, mapping); err != nil { + return fmt.Errorf("failed inserting gateway service mapping: %s", err) + } + return nil +} + +// updateGatewayService associates services with gateways after an eligible event +// ie. Registering a service in a namespace targeted by a gateway +func (s *Store) updateTerminatingGatewayService(tx *memdb.Txn, idx uint64, gateway structs.ServiceID, service string, entMeta *structs.EnterpriseMeta) error { + mapping := &structs.GatewayService{ + Gateway: gateway, + Service: structs.NewServiceID(service, entMeta), + GatewayKind: structs.ServiceKindTerminatingGateway, + } + + // If a wildcard specifier is registered for that namespace, use its TLS config + wc, err := s.serviceTerminatingGateway(tx, structs.WildcardSpecifier, entMeta) + if err != nil { + return fmt.Errorf("gateway service lookup failed: %s", err) + } + if wc != nil { + cfg := wc.(*structs.GatewayService) + mapping.CAFile = cfg.CAFile + mapping.CertFile = cfg.CertFile + mapping.KeyFile = cfg.KeyFile + } + + // Check if mapping already exists in table if it's already in the table + // Avoid insert if nothing changed + existing, err := s.serviceTerminatingGateway(tx, service, entMeta) + if err != nil { + return fmt.Errorf("gateway service lookup failed: %s", err) + } + if gs, ok := existing.(*structs.GatewayService); ok && gs != nil { + if gs.IsSame(mapping) { + return nil + } + } + + if err := tx.Insert(terminatingGatewayServicesTableName, mapping); err != nil { + return fmt.Errorf("failed inserting gateway service mapping: %s", err) + } + + if err := indexUpdateMaxTxn(tx, idx, terminatingGatewayServicesTableName); err != nil { + return fmt.Errorf("failed updating terminating-gateway-services index: %v", err) + } + return nil +} + +func (s *Store) serviceTerminatingGateway(tx *memdb.Txn, name string, entMeta *structs.EnterpriseMeta) (interface{}, error) { + return tx.First(terminatingGatewayServicesTableName, "service", structs.NewServiceID(name, entMeta)) +} + +func (s *Store) terminatingGatewayServices(tx *memdb.Txn, name string, entMeta *structs.EnterpriseMeta) (memdb.ResultIterator, error) { + return tx.Get(terminatingGatewayServicesTableName, "gateway", structs.NewServiceID(name, entMeta)) +} + +func (s *Store) serviceTerminatingGatewayNodes(tx *memdb.Txn, service string, entMeta *structs.EnterpriseMeta) (structs.ServiceNodes, <-chan struct{}, error) { + // Look up gateway name associated with the service + gw, err := s.serviceTerminatingGateway(tx, service, entMeta) + if err != nil { + return nil, nil, fmt.Errorf("failed gateway lookup: %s", err) + } + + var ret structs.ServiceNodes + var watchChan <-chan struct{} + + if gw != nil { + mapping := gw.(*structs.GatewayService) + + // Look up nodes for gateway + gateways, err := s.catalogServiceNodeList(tx, mapping.Gateway.ID, "service", &mapping.Gateway.EnterpriseMeta) + if err != nil { + return nil, nil, fmt.Errorf("failed service lookup: %s", err) + } + for gateway := gateways.Next(); gateway != nil; gateway = gateways.Next() { + sn := gateway.(*structs.ServiceNode) + ret = append(ret, sn) + } + watchChan = gateways.WatchCh() + } + return ret, watchChan, nil +} diff --git a/agent/consul/state/catalog_oss.go b/agent/consul/state/catalog_oss.go index 8c89238bf7..b6f20db65d 100644 --- a/agent/consul/state/catalog_oss.go +++ b/agent/consul/state/catalog_oss.go @@ -4,7 +4,6 @@ package state import ( "fmt" - "github.com/hashicorp/consul/agent/structs" memdb "github.com/hashicorp/go-memdb" ) diff --git a/agent/consul/state/catalog_test.go b/agent/consul/state/catalog_test.go index e72cd3dc94..3fb9ab9e8c 100644 --- a/agent/consul/state/catalog_test.go +++ b/agent/consul/state/catalog_test.go @@ -2126,6 +2126,87 @@ func TestStateStore_ConnectServiceNodes(t *testing.T) { assert.True(watchFired(ws)) } +func TestStateStore_ConnectServiceNodes_Gateways(t *testing.T) { + assert := assert.New(t) + s := testStateStore(t) + + // Listing with no results returns an empty list. + ws := memdb.NewWatchSet() + idx, nodes, err := s.ConnectServiceNodes(ws, "db", nil) + assert.Nil(err) + assert.Equal(idx, uint64(0)) + assert.Len(nodes, 0) + + // Create some nodes and services. + assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"})) + assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"})) + + // Typical services + assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000})) + assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000})) + assert.Nil(s.EnsureService(14, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001})) + assert.False(watchFired(ws)) + + // Register a sidecar for db + assert.Nil(s.EnsureService(15, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000})) + assert.True(watchFired(ws)) + + // Associate gateway with db + assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443})) + assert.Nil(s.EnsureConfigEntry(17, &structs.TerminatingGatewayConfigEntry{ + Kind: "terminating-gateway", + Name: "gateway", + Services: []structs.LinkedService{ + { + Name: "db", + }, + }, + }, nil)) + assert.True(watchFired(ws)) + + // Read everything back. + ws = memdb.NewWatchSet() + idx, nodes, err = s.ConnectServiceNodes(ws, "db", nil) + assert.Nil(err) + assert.Equal(idx, uint64(14)) + assert.Len(nodes, 2) + + // Check sidecar + assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind) + assert.Equal("foo", nodes[0].Node) + assert.Equal("proxy", nodes[0].ServiceName) + assert.Equal("proxy", nodes[0].ServiceID) + assert.Equal("db", nodes[0].ServiceProxy.DestinationServiceName) + assert.Equal(8000, nodes[0].ServicePort) + + // Check gateway + assert.Equal(structs.ServiceKindTerminatingGateway, nodes[1].ServiceKind) + assert.Equal("bar", nodes[1].Node) + assert.Equal("gateway", nodes[1].ServiceName) + assert.Equal("gateway", nodes[1].ServiceID) + assert.Equal(443, nodes[1].ServicePort) + + // Watch should fire when another gateway instance is registered + assert.Nil(s.EnsureService(18, "foo", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway-2", Service: "gateway", Port: 443})) + assert.True(watchFired(ws)) + + // Watch should fire when a gateway instance is de-registered + assert.Nil(s.DeleteService(29, "bar", "gateway", nil)) + assert.True(watchFired(ws)) + + idx, nodes, err = s.ConnectServiceNodes(ws, "db", nil) + assert.Nil(err) + assert.Equal(idx, uint64(14)) + assert.Len(nodes, 2) + + // Check the new gateway + assert.Equal(structs.ServiceKindTerminatingGateway, nodes[1].ServiceKind) + assert.Equal("foo", nodes[1].Node) + assert.Equal("gateway", nodes[1].ServiceName) + assert.Equal("gateway-2", nodes[1].ServiceID) + assert.Equal(443, nodes[1].ServicePort) +} + func TestStateStore_Service_Snapshot(t *testing.T) { s := testStateStore(t) @@ -3464,6 +3545,97 @@ func TestStateStore_CheckConnectServiceNodes(t *testing.T) { } } +func TestStateStore_CheckConnectServiceNodes_Gateways(t *testing.T) { + assert := assert.New(t) + s := testStateStore(t) + + // Listing with no results returns an empty list. + ws := memdb.NewWatchSet() + idx, nodes, err := s.CheckConnectServiceNodes(ws, "db", nil) + assert.Nil(err) + assert.Equal(idx, uint64(0)) + assert.Len(nodes, 0) + + // Create some nodes and services. + assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"})) + assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"})) + + // Typical services + assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000})) + assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000})) + assert.Nil(s.EnsureService(14, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001})) + assert.False(watchFired(ws)) + + // Register a sidecar and a gateway for db + assert.Nil(s.EnsureService(15, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000})) + assert.True(watchFired(ws)) + + assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443})) + assert.True(watchFired(ws)) + + // Register node checks + testRegisterCheck(t, s, 17, "foo", "", "check1", api.HealthPassing) + testRegisterCheck(t, s, 18, "bar", "", "check2", api.HealthPassing) + + // Register checks against the services. + testRegisterCheck(t, s, 19, "foo", "db", "check3", api.HealthPassing) + testRegisterCheck(t, s, 20, "bar", "gateway", "check4", api.HealthPassing) + + // Associate gateway with db + assert.Nil(s.EnsureConfigEntry(21, &structs.TerminatingGatewayConfigEntry{ + Kind: "terminating-gateway", + Name: "gateway", + Services: []structs.LinkedService{ + { + Name: "db", + }, + }, + }, nil)) + assert.True(watchFired(ws)) + + // Read everything back. + ws = memdb.NewWatchSet() + idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil) + assert.Nil(err) + assert.Equal(idx, uint64(20)) + assert.Len(nodes, 2) + + // Check sidecar + assert.Equal(structs.ServiceKindConnectProxy, nodes[0].Service.Kind) + assert.Equal("foo", nodes[0].Node.Node) + assert.Equal("proxy", nodes[0].Service.Service) + assert.Equal("proxy", nodes[0].Service.ID) + assert.Equal("db", nodes[0].Service.Proxy.DestinationServiceName) + assert.Equal(8000, nodes[0].Service.Port) + + // Check gateway + assert.Equal(structs.ServiceKindTerminatingGateway, nodes[1].Service.Kind) + assert.Equal("bar", nodes[1].Node.Node) + assert.Equal("gateway", nodes[1].Service.Service) + assert.Equal("gateway", nodes[1].Service.ID) + assert.Equal(443, nodes[1].Service.Port) + + // Watch should fire when another gateway instance is registered + assert.Nil(s.EnsureService(22, "foo", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway-2", Service: "gateway", Port: 443})) + assert.True(watchFired(ws)) + + // Watch should fire when a gateway instance is de-registered + assert.Nil(s.DeleteService(23, "bar", "gateway", nil)) + assert.True(watchFired(ws)) + + idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil) + assert.Nil(err) + assert.Equal(idx, uint64(23)) + assert.Len(nodes, 2) + + // Check new gateway + assert.Equal(structs.ServiceKindTerminatingGateway, nodes[1].Service.Kind) + assert.Equal("foo", nodes[1].Node.Node) + assert.Equal("gateway", nodes[1].Service.Service) + assert.Equal("gateway-2", nodes[1].Service.ID) + assert.Equal(443, nodes[1].Service.Port) +} + func BenchmarkCheckServiceNodes(b *testing.B) { s, err := NewStateStore(nil) if err != nil { @@ -4210,3 +4382,239 @@ func TestStateStore_ensureServiceCASTxn(t *testing.T) { require.Equal(t, uint64(7), nsRead.ModifyIndex) tx.Commit() } + +func TestStateStore_TerminatingGatewayServices(t *testing.T) { + s := testStateStore(t) + + // Listing with no results returns an empty list. + ws := memdb.NewWatchSet() + idx, nodes, err := s.TerminatingGatewayServices(ws, "db", nil) + assert.Nil(t, err) + assert.Equal(t, idx, uint64(0)) + assert.Len(t, nodes, 0) + + // Create some nodes + assert.Nil(t, s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"})) + assert.Nil(t, s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"})) + assert.Nil(t, s.EnsureNode(12, &structs.Node{Node: "baz", Address: "127.0.0.2"})) + + // Typical services and some consul services spread across two nodes + assert.Nil(t, s.EnsureService(13, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000})) + assert.Nil(t, s.EnsureService(15, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000})) + assert.Nil(t, s.EnsureService(16, "bar", &structs.NodeService{ID: "consul", Service: "consul", Tags: nil})) + assert.Nil(t, s.EnsureService(17, "bar", &structs.NodeService{ID: "consul", Service: "consul", Tags: nil})) + + // Add ingress gateway and a connect proxy, neither should get picked up by terminating gateway + ingressNS := &structs.NodeService{ + Kind: structs.ServiceKindIngressGateway, + ID: "ingress", + Service: "ingress", + Port: 8443, + } + assert.Nil(t, s.EnsureService(18, "baz", ingressNS)) + + proxyNS := &structs.NodeService{ + Kind: structs.ServiceKindConnectProxy, + ID: "db proxy", + Service: "db proxy", + Proxy: structs.ConnectProxyConfig{ + DestinationServiceName: "db", + }, + Port: 8000, + } + assert.Nil(t, s.EnsureService(19, "foo", proxyNS)) + + // Register a gateway + assert.Nil(t, s.EnsureService(20, "baz", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443})) + + // Associate gateway with db and api + assert.Nil(t, s.EnsureConfigEntry(21, &structs.TerminatingGatewayConfigEntry{ + Kind: "terminating-gateway", + Name: "gateway", + Services: []structs.LinkedService{ + { + Name: "db", + }, + { + Name: "api", + }, + }, + }, nil)) + assert.True(t, watchFired(ws)) + + // Read everything back. + ws = memdb.NewWatchSet() + idx, out, err := s.TerminatingGatewayServices(ws, "gateway", nil) + assert.Nil(t, err) + assert.Equal(t, idx, uint64(21)) + assert.Len(t, out, 2) + + expect := structs.GatewayServices{ + { + Service: structs.NewServiceID("api", nil), + Gateway: structs.NewServiceID("gateway", nil), + GatewayKind: structs.ServiceKindTerminatingGateway, + }, + { + Service: structs.NewServiceID("db", nil), + Gateway: structs.NewServiceID("gateway", nil), + GatewayKind: structs.ServiceKindTerminatingGateway, + }, + } + assert.Equal(t, expect, out) + + // Associate gateway with a wildcard and add TLS config + assert.Nil(t, s.EnsureConfigEntry(22, &structs.TerminatingGatewayConfigEntry{ + Kind: "terminating-gateway", + Name: "gateway", + Services: []structs.LinkedService{ + { + Name: "api", + CAFile: "api/ca.crt", + CertFile: "api/client.crt", + KeyFile: "api/client.key", + }, + { + Name: "db", + }, + { + Name: "*", + CAFile: "ca.crt", + CertFile: "client.crt", + KeyFile: "client.key", + }, + }, + }, nil)) + assert.True(t, watchFired(ws)) + + // Read everything back. + ws = memdb.NewWatchSet() + idx, out, err = s.TerminatingGatewayServices(ws, "gateway", nil) + assert.Nil(t, err) + assert.Equal(t, idx, uint64(22)) + assert.Len(t, out, 2) + + expect = structs.GatewayServices{ + { + Service: structs.NewServiceID("api", nil), + Gateway: structs.NewServiceID("gateway", nil), + GatewayKind: structs.ServiceKindTerminatingGateway, + CAFile: "api/ca.crt", + CertFile: "api/client.crt", + KeyFile: "api/client.key", + }, + { + Service: structs.NewServiceID("db", nil), + Gateway: structs.NewServiceID("gateway", nil), + GatewayKind: structs.ServiceKindTerminatingGateway, + }, + } + assert.Equal(t, expect, out) + + // Add a service covered by wildcard + assert.Nil(t, s.EnsureService(23, "bar", &structs.NodeService{ID: "redis", Service: "redis", Tags: nil, Address: "", Port: 6379})) + assert.True(t, watchFired(ws)) + + idx, out, err = s.TerminatingGatewayServices(ws, "gateway", nil) + assert.Nil(t, err) + assert.Equal(t, idx, uint64(23)) + assert.Len(t, out, 3) + + expect = structs.GatewayServices{ + { + Service: structs.NewServiceID("api", nil), + Gateway: structs.NewServiceID("gateway", nil), + GatewayKind: structs.ServiceKindTerminatingGateway, + CAFile: "api/ca.crt", + CertFile: "api/client.crt", + KeyFile: "api/client.key", + }, + { + Service: structs.NewServiceID("db", nil), + Gateway: structs.NewServiceID("gateway", nil), + GatewayKind: structs.ServiceKindTerminatingGateway, + }, + { + Service: structs.NewServiceID("redis", nil), + Gateway: structs.NewServiceID("gateway", nil), + GatewayKind: structs.ServiceKindTerminatingGateway, + CAFile: "ca.crt", + CertFile: "client.crt", + KeyFile: "client.key", + }, + } + assert.Equal(t, expect, out) + + // Delete a service covered by wildcard + assert.Nil(t, s.DeleteService(24, "bar", "redis", nil)) + assert.True(t, watchFired(ws)) + + idx, out, err = s.TerminatingGatewayServices(ws, "gateway", nil) + assert.Nil(t, err) + assert.Equal(t, idx, uint64(24)) + assert.Len(t, out, 2) + + expect = structs.GatewayServices{ + { + Service: structs.NewServiceID("api", nil), + Gateway: structs.NewServiceID("gateway", nil), + GatewayKind: structs.ServiceKindTerminatingGateway, + CAFile: "api/ca.crt", + CertFile: "api/client.crt", + KeyFile: "api/client.key", + }, + { + Service: structs.NewServiceID("db", nil), + Gateway: structs.NewServiceID("gateway", nil), + GatewayKind: structs.ServiceKindTerminatingGateway, + }, + } + assert.Equal(t, expect, out) + + // Create a new entry that only leaves one service + assert.Nil(t, s.EnsureConfigEntry(25, &structs.TerminatingGatewayConfigEntry{ + Kind: "terminating-gateway", + Name: "gateway", + Services: []structs.LinkedService{ + { + Name: "db", + }, + }, + }, nil)) + assert.True(t, watchFired(ws)) + + idx, out, err = s.TerminatingGatewayServices(ws, "gateway", nil) + assert.Nil(t, err) + assert.Equal(t, idx, uint64(25)) + assert.Len(t, out, 1) + + // previously associated services should not be present + expect = structs.GatewayServices{ + { + Service: structs.NewServiceID("db", nil), + Gateway: structs.NewServiceID("gateway", nil), + GatewayKind: structs.ServiceKindTerminatingGateway, + }, + } + assert.Equal(t, expect, out) + + // Attempt to associate a different gateway with services that include db + assert.Error(t, s.EnsureConfigEntry(26, &structs.TerminatingGatewayConfigEntry{ + Kind: "terminating-gateway", + Name: "gateway2", + Services: []structs.LinkedService{ + { + Name: "*", + }, + }, + }, nil), "service \"db\" is associated with different gateway") + + // Deleting the config entry should remove existing mappings + assert.Nil(t, s.DeleteConfigEntry(26, "terminating-gateway", "gateway", nil)) + assert.True(t, watchFired(ws)) + + idx, out, err = s.TerminatingGatewayServices(ws, "gateway", nil) + assert.Nil(t, err) + assert.Equal(t, idx, uint64(26)) + assert.Len(t, out, 0) +} diff --git a/agent/consul/state/config_entry.go b/agent/consul/state/config_entry.go index fd5007c924..1b7ce5282f 100644 --- a/agent/consul/state/config_entry.go +++ b/agent/consul/state/config_entry.go @@ -214,6 +214,15 @@ func (s *Store) ensureConfigEntryTxn(tx *memdb.Txn, idx uint64, conf structs.Con return err // Err is already sufficiently decorated. } + // If the config entry is for terminating gateways we update the memdb table + // that associates gateways <-> services. + if conf.GetKind() == structs.TerminatingGateway { + err = s.updateTerminatingGatewayServices(tx, idx, conf, entMeta) + if err != nil { + return fmt.Errorf("failed to associate services to gateway: %v", err) + } + } + // Insert the config entry and update the index if err := s.insertConfigEntryWithTxn(tx, conf); err != nil { return fmt.Errorf("failed inserting config entry: %s", err) @@ -273,6 +282,17 @@ func (s *Store) DeleteConfigEntry(idx uint64, kind, name string, entMeta *struct return nil } + // If the config entry is for terminating gateways we delete entries from the memdb table + // that associates gateways <-> services. + if kind == structs.TerminatingGateway { + if _, err := tx.DeleteAll(terminatingGatewayServicesTableName, "gateway", structs.NewServiceID(name, entMeta)); err != nil { + return fmt.Errorf("failed to truncate gateway services table: %v", err) + } + if err := indexUpdateMaxTxn(tx, idx, terminatingGatewayServicesTableName); err != nil { + return fmt.Errorf("failed updating terminating-gateway-services index: %v", err) + } + } + err = s.validateProposedConfigEntryInGraph( tx, idx, diff --git a/agent/structs/config_entry_gateways.go b/agent/structs/config_entry_gateways.go index db41f2a7c3..43d11e45db 100644 --- a/agent/structs/config_entry_gateways.go +++ b/agent/structs/config_entry_gateways.go @@ -277,3 +277,24 @@ func (e *TerminatingGatewayConfigEntry) GetEnterpriseMeta() *EnterpriseMeta { return &e.EnterpriseMeta } + +// GatewayService is used to associate gateways with their linked services. +type GatewayService struct { + Gateway ServiceID + Service ServiceID + GatewayKind ServiceKind + CAFile string + CertFile string + KeyFile string +} + +type GatewayServices []*GatewayService + +func (g *GatewayService) IsSame(o *GatewayService) bool { + return g.Gateway.Matches(&o.Gateway) && + g.Service.Matches(&o.Service) && + g.GatewayKind == o.GatewayKind && + g.CAFile == o.CAFile && + g.CertFile == o.CertFile && + g.KeyFile == o.KeyFile +} diff --git a/agent/structs/structs.go b/agent/structs/structs.go index 0f4708e278..ea2bec108b 100644 --- a/agent/structs/structs.go +++ b/agent/structs/structs.go @@ -502,6 +502,7 @@ type ServiceSpecificRequest struct { Datacenter string NodeMetaFilters map[string]string ServiceName string + ServiceKind ServiceKind // DEPRECATED (singular-service-tag) - remove this when backwards RPC compat // with 1.2.x is not required. ServiceTag string @@ -1757,6 +1758,11 @@ type IndexedNodeDump struct { QueryMeta } +type IndexedGatewayServices struct { + Services GatewayServices + QueryMeta +} + // IndexedConfigEntries has its own encoding logic which differs from // ConfigEntryRequest as it has to send a slice of ConfigEntry. type IndexedConfigEntries struct { diff --git a/agent/structs/structs_test.go b/agent/structs/structs_test.go index 96cb8081b9..61d1113255 100644 --- a/agent/structs/structs_test.go +++ b/agent/structs/structs_test.go @@ -2132,6 +2132,58 @@ func TestSnapshotRequestResponse_MsgpackEncodeDecode(t *testing.T) { } +func TestGatewayService_IsSame(t *testing.T) { + gateway := NewServiceID("gateway", nil) + svc := NewServiceID("web", nil) + kind := ServiceKindTerminatingGateway + ca := "ca.pem" + cert := "client.pem" + key := "tls.key" + + g := &GatewayService{ + Gateway: gateway, + Service: svc, + GatewayKind: kind, + CAFile: ca, + CertFile: cert, + KeyFile: key, + } + other := &GatewayService{ + Gateway: gateway, + Service: svc, + GatewayKind: kind, + CAFile: ca, + CertFile: cert, + KeyFile: key, + } + check := func(twiddle, restore func()) { + t.Helper() + if !g.IsSame(other) || !other.IsSame(g) { + t.Fatalf("should be the same") + } + + twiddle() + if g.IsSame(other) || other.IsSame(g) { + t.Fatalf("should be different, was %#v VS %#v", g, other) + } + + restore() + if !g.IsSame(other) || !other.IsSame(g) { + t.Fatalf("should be the same") + } + } + check(func() { other.Gateway = NewServiceID("other", nil) }, func() { other.Gateway = gateway }) + check(func() { other.Service = NewServiceID("other", nil) }, func() { other.Service = svc }) + check(func() { other.GatewayKind = ServiceKindIngressGateway }, func() { other.GatewayKind = kind }) + check(func() { other.CAFile = "/certs/cert.pem" }, func() { other.CAFile = ca }) + check(func() { other.CertFile = "/certs/cert.pem" }, func() { other.CertFile = cert }) + check(func() { other.KeyFile = "/certs/cert.pem" }, func() { other.KeyFile = key }) + + if !g.IsSame(other) { + t.Fatalf("should be equal, was %#v VS %#v", g, other) + } +} + func requireErrorContains(t *testing.T, err error, expectedErrorMessage string) { t.Helper() if err == nil { diff --git a/agent/ui_endpoint_test.go b/agent/ui_endpoint_test.go index 0bffefc5d1..b9d95cc558 100644 --- a/agent/ui_endpoint_test.go +++ b/agent/ui_endpoint_test.go @@ -12,11 +12,10 @@ import ( "path/filepath" "testing" - "github.com/hashicorp/consul/testrpc" - "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" cleanhttp "github.com/hashicorp/go-cleanhttp" "github.com/stretchr/testify/require" ) diff --git a/api/config_entry_gateways_test.go b/api/config_entry_gateways_test.go index 9db4fa496f..dcb20a10a2 100644 --- a/api/config_entry_gateways_test.go +++ b/api/config_entry_gateways_test.go @@ -189,7 +189,29 @@ func TestAPI_ConfigEntries_TerminatingGateway(t *testing.T) { require.NotEqual(t, 0, wm.RequestTime) require.True(t, written) - // update no cas + // re-setting should not yield an error + _, wm, err = configEntries.Set(terminating1, nil) + require.NoError(t, err) + require.NotNil(t, wm) + require.NotEqual(t, 0, wm.RequestTime) + + // web is associated with the other gateway, should get an error + terminating2.Services = []LinkedService{ + { + Name: "*", + CAFile: "/etc/certs/ca.crt", + CertFile: "/etc/certs/client.crt", + KeyFile: "/etc/certs/tls.key", + }, + { + Name: "web", + }, + } + _, wm, err = configEntries.Set(terminating2, nil) + require.Error(t, err, "service \"web\" is associated with a different gateway") + require.Nil(t, wm) + + // try again without web terminating2.Services = []LinkedService{ { Name: "*",