From 2144bd7fbde71c419342f80a09db659f0cc0d9ff Mon Sep 17 00:00:00 2001 From: "R.B. Boyer" Date: Mon, 8 Apr 2019 12:05:51 -0500 Subject: [PATCH 1/5] acl: tokens can be created with an optional expiration time (#5353) --- agent/acl_endpoint.go | 30 +- agent/consul/acl.go | 11 +- agent/consul/acl_endpoint.go | 57 +- agent/consul/acl_endpoint_legacy.go | 15 +- agent/consul/acl_endpoint_test.go | 669 ++++++++++++++++++-- agent/consul/acl_replication.go | 3 +- agent/consul/acl_replication_legacy.go | 7 +- agent/consul/acl_server.go | 8 +- agent/consul/acl_token_exp.go | 144 +++++ agent/consul/acl_token_exp_test.go | 219 +++++++ agent/consul/config.go | 12 + agent/consul/fsm/commands_oss.go | 3 +- agent/consul/fsm/snapshot_oss.go | 2 + agent/consul/leader.go | 17 +- agent/consul/server.go | 6 + agent/consul/state/acl.go | 123 +++- agent/consul/state/acl_test.go | 262 ++++++-- agent/structs/acl.go | 88 ++- agent/structs/acl_test.go | 2 +- api/acl.go | 39 +- command/acl/acl_helpers.go | 38 +- command/acl/token/clone/token_clone_test.go | 10 +- command/acl/token/create/token_create.go | 23 +- command/acl/token/update/token_update.go | 6 +- 24 files changed, 1582 insertions(+), 212 deletions(-) create mode 100644 agent/consul/acl_token_exp.go create mode 100644 agent/consul/acl_token_exp_test.go diff --git a/agent/acl_endpoint.go b/agent/acl_endpoint.go index 64eaad0de8..d34690b328 100644 --- a/agent/acl_endpoint.go +++ b/agent/acl_endpoint.go @@ -268,15 +268,35 @@ func (s *HTTPServer) ACLPolicyCreate(resp http.ResponseWriter, req *http.Request return s.ACLPolicyWrite(resp, req, "") } -// fixCreateTimeAndHash is used to help in decoding the CreateTime and Hash +// fixTimeAndHashFields is used to help in decoding the ExpirationTTL, ExpirationTime, CreateTime, and Hash // attributes from the ACL Token/Policy create/update requests. It is needed // to help mapstructure decode things properly when decodeBody is used. -func fixCreateTimeAndHash(raw interface{}) error { +func fixTimeAndHashFields(raw interface{}) error { rawMap, ok := raw.(map[string]interface{}) if !ok { return nil } + if val, ok := rawMap["ExpirationTTL"]; ok { + if sval, ok := val.(string); ok { + d, err := time.ParseDuration(sval) + if err != nil { + return err + } + rawMap["ExpirationTTL"] = d + } + } + + if val, ok := rawMap["ExpirationTime"]; ok { + if sval, ok := val.(string); ok { + t, err := time.Parse(time.RFC3339, sval) + if err != nil { + return err + } + rawMap["ExpirationTime"] = t + } + } + if val, ok := rawMap["CreateTime"]; ok { if sval, ok := val.(string); ok { t, err := time.Parse(time.RFC3339, sval) @@ -301,7 +321,7 @@ func (s *HTTPServer) ACLPolicyWrite(resp http.ResponseWriter, req *http.Request, } s.parseToken(req, &args.Token) - if err := decodeBody(req, &args.Policy, fixCreateTimeAndHash); err != nil { + if err := decodeBody(req, &args.Policy, fixTimeAndHashFields); err != nil { return nil, BadRequestError{Reason: fmt.Sprintf("Policy decoding failed: %v", err)} } @@ -472,7 +492,7 @@ func (s *HTTPServer) ACLTokenSet(resp http.ResponseWriter, req *http.Request, to } s.parseToken(req, &args.Token) - if err := decodeBody(req, &args.ACLToken, fixCreateTimeAndHash); err != nil { + if err := decodeBody(req, &args.ACLToken, fixTimeAndHashFields); err != nil { return nil, BadRequestError{Reason: fmt.Sprintf("Token decoding failed: %v", err)} } @@ -513,7 +533,7 @@ func (s *HTTPServer) ACLTokenClone(resp http.ResponseWriter, req *http.Request, Datacenter: s.agent.config.Datacenter, } - if err := decodeBody(req, &args.ACLToken, fixCreateTimeAndHash); err != nil && err.Error() != "EOF" { + if err := decodeBody(req, &args.ACLToken, fixTimeAndHashFields); err != nil && err.Error() != "EOF" { return nil, BadRequestError{Reason: fmt.Sprintf("Token decoding failed: %v", err)} } s.parseToken(req, &args.Token) diff --git a/agent/consul/acl.go b/agent/consul/acl.go index 7553291fef..4f1061e5d1 100644 --- a/agent/consul/acl.go +++ b/agent/consul/acl.go @@ -31,9 +31,16 @@ const ( // with all tokens in it. aclUpgradeBatchSize = 128 - // aclUpgradeRateLimit is the number of batch upgrade requests per second. + // aclUpgradeRateLimit is the number of batch upgrade requests per second allowed. aclUpgradeRateLimit rate.Limit = 1.0 + // aclTokenReapingRateLimit is the number of batch token reaping requests per second allowed. + aclTokenReapingRateLimit rate.Limit = 1.0 + + // aclTokenReapingBurst is the number of batch token reaping requests per second + // that can burst after a period of idleness. + aclTokenReapingBurst = 5 + // aclBatchDeleteSize is the number of deletions to send in a single batch operation. 4096 should produce a batch that is <150KB // in size but should be sufficiently large to handle 1 replication round in a single batch aclBatchDeleteSize = 4096 @@ -608,6 +615,8 @@ func (r *ACLResolver) resolveTokenToIdentityAndPolicies(token string) (structs.A return nil, nil, err } else if identity == nil { return nil, nil, acl.ErrNotFound + } else if identity.IsExpired(time.Now()) { + return nil, nil, acl.ErrNotFound } lastIdentity = identity diff --git a/agent/consul/acl_endpoint.go b/agent/consul/acl_endpoint.go index 5f9738c9cc..c19ff29764 100644 --- a/agent/consul/acl_endpoint.go +++ b/agent/consul/acl_endpoint.go @@ -221,6 +221,10 @@ func (a *ACL) TokenRead(args *structs.ACLTokenGetRequest, reply *structs.ACLToke index, token, err = state.ACLTokenGetBySecret(ws, args.TokenID) } + if token != nil && token.IsExpired(time.Now()) { + token = nil + } + if err != nil { return err } @@ -256,7 +260,7 @@ func (a *ACL) TokenClone(args *structs.ACLTokenSetRequest, reply *structs.ACLTok _, token, err := a.srv.fsm.State().ACLTokenGetByAccessor(nil, args.ACLToken.AccessorID) if err != nil { return err - } else if token == nil { + } else if token == nil || token.IsExpired(time.Now()) { return acl.ErrNotFound } else if !a.srv.InACLDatacenter() && !token.Local { // global token writes must be forwarded to the primary DC @@ -271,9 +275,10 @@ func (a *ACL) TokenClone(args *structs.ACLTokenSetRequest, reply *structs.ACLTok cloneReq := structs.ACLTokenSetRequest{ Datacenter: args.Datacenter, ACLToken: structs.ACLToken{ - Policies: token.Policies, - Local: token.Local, - Description: token.Description, + Policies: token.Policies, + Local: token.Local, + Description: token.Description, + ExpirationTime: token.ExpirationTime, }, WriteRequest: args.WriteRequest, } @@ -342,6 +347,34 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. } token.CreateTime = time.Now() + + // Ensure an ExpirationTTL is valid if provided. + if token.ExpirationTTL != 0 { + if token.ExpirationTTL < 0 { + return fmt.Errorf("Token Expiration TTL '%s' should be > 0", token.ExpirationTTL) + } + if !token.ExpirationTime.IsZero() { + return fmt.Errorf("Token Expiration TTL and Expiration Time cannot both be set") + } + + token.ExpirationTime = token.CreateTime.Add(token.ExpirationTTL) + token.ExpirationTTL = 0 + } + + if !token.ExpirationTime.IsZero() { + if token.CreateTime.After(token.ExpirationTime) { + return fmt.Errorf("ExpirationTime cannot be before CreateTime") + } + + expiresIn := token.ExpirationTime.Sub(token.CreateTime) + if expiresIn > a.srv.config.ACLTokenMaxExpirationTTL { + return fmt.Errorf("ExpirationTime cannot be more than %s in the future (was %s)", + a.srv.config.ACLTokenMaxExpirationTTL, expiresIn) + } else if expiresIn < a.srv.config.ACLTokenMinExpirationTTL { + return fmt.Errorf("ExpirationTime cannot be less than %s in the future (was %s)", + a.srv.config.ACLTokenMinExpirationTTL, expiresIn) + } + } } else { // Token Update if _, err := uuid.ParseUUID(token.AccessorID); err != nil { @@ -365,7 +398,7 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. if err != nil { return fmt.Errorf("Failed to lookup the acl token %q: %v", token.AccessorID, err) } - if existing == nil { + if existing == nil || existing.IsExpired(time.Now()) { return fmt.Errorf("Cannot find token %q", token.AccessorID) } if token.SecretID == "" { @@ -379,6 +412,10 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. return fmt.Errorf("cannot toggle local mode of %s", token.AccessorID) } + if token.ExpirationTTL != 0 || !token.ExpirationTime.Equal(existing.ExpirationTime) { + return fmt.Errorf("Cannot change expiration time of %s", token.AccessorID) + } + if upgrade { token.CreateTime = time.Now() } else { @@ -440,6 +477,7 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. return respErr } + // Don't check expiration times here as it doesn't really matter. if _, updatedToken, err := a.srv.fsm.State().ACLTokenGetByAccessor(nil, token.AccessorID); err == nil && token != nil { *reply = *updatedToken } else { @@ -490,6 +528,8 @@ func (a *ACL) TokenDelete(args *structs.ACLTokenDeleteRequest, reply *string) er return fmt.Errorf("Deletion of the request's authorization token is not permitted") } + // No need to check expiration time because it's being deleted. + if !a.srv.InACLDatacenter() && !token.Local { args.Datacenter = a.srv.config.ACLDatacenter return a.srv.forwardDC("ACL.TokenDelete", a.srv.config.ACLDatacenter, args, reply) @@ -553,8 +593,13 @@ func (a *ACL) TokenList(args *structs.ACLTokenListRequest, reply *structs.ACLTok return err } + now := time.Now() + stubs := make([]*structs.ACLTokenListStub, 0, len(tokens)) for _, token := range tokens { + if token.IsExpired(now) { + continue + } stubs = append(stubs, token.Stub()) } reply.Index, reply.Tokens = index, stubs @@ -589,6 +634,8 @@ func (a *ACL) TokenBatchRead(args *structs.ACLTokenBatchGetRequest, reply *struc return err } + // This RPC is used for replication, so don't filter out expired tokens here. + a.srv.filterACLWithAuthorizer(rule, &tokens) reply.Index, reply.Tokens = index, tokens diff --git a/agent/consul/acl_endpoint_legacy.go b/agent/consul/acl_endpoint_legacy.go index 48867fdb3a..ad264e2dcb 100644 --- a/agent/consul/acl_endpoint_legacy.go +++ b/agent/consul/acl_endpoint_legacy.go @@ -93,8 +93,9 @@ func aclApplyInternal(srv *Server, args *structs.ACLRequest, reply *string) erro return fmt.Errorf("Invalid ACL Type") } + // No need to check expiration times as those did not exist in legacy tokens. _, existing, _ := srv.fsm.State().ACLTokenGetBySecret(nil, args.ACL.ID) - if existing != nil && len(existing.Policies) > 0 { + if existing != nil && existing.UsesNonLegacyFields() { return fmt.Errorf("Cannot use legacy endpoint to modify a non-legacy token") } @@ -210,8 +211,13 @@ func (a *ACL) Get(args *structs.ACLSpecificRequest, return err } - // converting an ACLToken to an ACL will return nil and an error + // Converting an ACLToken to an ACL will return nil and an error // (which we ignore) when it is unconvertible. + // + // This also means we won't have to check expiration times since + // any legacy tokens never had expiration times and no non-legacy + // tokens can be converted. + var acl *structs.ACL if token != nil { acl, _ = token.Convert() @@ -254,8 +260,13 @@ func (a *ACL) List(args *structs.DCSpecificRequest, return err } + now := time.Now() + var acls structs.ACLs for _, token := range tokens { + if token.IsExpired(now) { + continue + } if acl, err := token.Convert(); err == nil && acl != nil { acls = append(acls, acl) } diff --git a/agent/consul/acl_endpoint_test.go b/agent/consul/acl_endpoint_test.go index c88ba1563f..0ad3c5d3bb 100644 --- a/agent/consul/acl_endpoint_test.go +++ b/agent/consul/acl_endpoint_test.go @@ -630,6 +630,8 @@ func TestACLEndpoint_TokenRead(t *testing.T) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second }) defer os.RemoveAll(dir1) defer s1.Shutdown() @@ -638,14 +640,12 @@ func TestACLEndpoint_TokenRead(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") - token, err := upsertTestToken(codec, "root", "dc1") - if err != nil { - t.Fatalf("err: %v", err) - } - acl := ACL{srv: s1} t.Run("exists and matches what we created", func(t *testing.T) { + token, err := upsertTestToken(codec, "root", "dc1", nil) + require.NoError(t, err) + req := structs.ACLTokenGetRequest{ Datacenter: "dc1", TokenID: token.AccessorID, @@ -655,7 +655,7 @@ func TestACLEndpoint_TokenRead(t *testing.T) { resp := structs.ACLTokenResponse{} - err := acl.TokenRead(&req, &resp) + err = acl.TokenRead(&req, &resp) require.NoError(t, err) if !reflect.DeepEqual(resp.Token, token) { @@ -663,6 +663,44 @@ func TestACLEndpoint_TokenRead(t *testing.T) { } }) + t.Run("expired tokens are filtered", func(t *testing.T) { + // insert a token that will expire + token, err := upsertTestToken(codec, "root", "dc1", func(t *structs.ACLToken) { + t.ExpirationTTL = 20 * time.Millisecond + }) + require.NoError(t, err) + + t.Run("readable until expiration", func(t *testing.T) { + req := structs.ACLTokenGetRequest{ + Datacenter: "dc1", + TokenID: token.AccessorID, + TokenIDType: structs.ACLTokenAccessor, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLTokenResponse{} + + require.NoError(t, acl.TokenRead(&req, &resp)) + require.Equal(t, token, resp.Token) + }) + + time.Sleep(50 * time.Millisecond) + + t.Run("not returned when expired", func(t *testing.T) { + req := structs.ACLTokenGetRequest{ + Datacenter: "dc1", + TokenID: token.AccessorID, + TokenIDType: structs.ACLTokenAccessor, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLTokenResponse{} + + require.NoError(t, acl.TokenRead(&req, &resp)) + require.Nil(t, resp.Token) + }) + }) + t.Run("nil when token does not exist", func(t *testing.T) { fakeID, err := uuid.GenerateUUID() require.NoError(t, err) @@ -704,6 +742,8 @@ func TestACLEndpoint_TokenClone(t *testing.T) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second }) defer os.RemoveAll(dir1) defer s1.Shutdown() @@ -712,28 +752,52 @@ func TestACLEndpoint_TokenClone(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") - t1, err := upsertTestToken(codec, "root", "dc1") + t1, err := upsertTestToken(codec, "root", "dc1", nil) require.NoError(t, err) - acl := ACL{srv: s1} + endpoint := ACL{srv: s1} - req := structs.ACLTokenSetRequest{ - Datacenter: "dc1", - ACLToken: structs.ACLToken{AccessorID: t1.AccessorID}, - WriteRequest: structs.WriteRequest{Token: "root"}, - } + t.Run("normal", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{AccessorID: t1.AccessorID}, + WriteRequest: structs.WriteRequest{Token: "root"}, + } - t2 := structs.ACLToken{} + t2 := structs.ACLToken{} - err = acl.TokenClone(&req, &t2) - require.NoError(t, err) + err = endpoint.TokenClone(&req, &t2) + require.NoError(t, err) - require.Equal(t, t1.Description, t2.Description) - require.Equal(t, t1.Policies, t2.Policies) - require.Equal(t, t1.Rules, t2.Rules) - require.Equal(t, t1.Local, t2.Local) - require.NotEqual(t, t1.AccessorID, t2.AccessorID) - require.NotEqual(t, t1.SecretID, t2.SecretID) + require.Equal(t, t1.Description, t2.Description) + require.Equal(t, t1.Policies, t2.Policies) + require.Equal(t, t1.Rules, t2.Rules) + require.Equal(t, t1.Local, t2.Local) + require.NotEqual(t, t1.AccessorID, t2.AccessorID) + require.NotEqual(t, t1.SecretID, t2.SecretID) + }) + + t.Run("can't clone expired token", func(t *testing.T) { + // insert a token that will expire + t1, err := upsertTestToken(codec, "root", "dc1", func(t *structs.ACLToken) { + t.ExpirationTTL = 11 * time.Millisecond + }) + require.NoError(t, err) + + time.Sleep(30 * time.Millisecond) + + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{AccessorID: t1.AccessorID}, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + t2 := structs.ACLToken{} + + err = endpoint.TokenClone(&req, &t2) + require.Error(t, err) + require.Equal(t, acl.ErrNotFound, err) + }) } func TestACLEndpoint_TokenSet(t *testing.T) { @@ -743,6 +807,8 @@ func TestACLEndpoint_TokenSet(t *testing.T) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second }) defer os.RemoveAll(dir1) defer s1.Shutdown() @@ -752,6 +818,7 @@ func TestACLEndpoint_TokenSet(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") acl := ACL{srv: s1} + var tokenID string t.Run("Create it", func(t *testing.T) { @@ -806,6 +873,262 @@ func TestACLEndpoint_TokenSet(t *testing.T) { require.Equal(t, token.Description, "new-description") require.Equal(t, token.AccessorID, resp.AccessorID) }) + + t.Run("Create it using Policies linked by id and name", func(t *testing.T) { + policy1, err := upsertTestPolicy(codec, "root", "dc1") + require.NoError(t, err) + policy2, err := upsertTestPolicy(codec, "root", "dc1") + require.NoError(t, err) + + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: []structs.ACLTokenPolicyLink{ + structs.ACLTokenPolicyLink{ + ID: policy1.ID, + }, + structs.ACLTokenPolicyLink{ + Name: policy2.Name, + }, + }, + Local: false, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err = acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Delete both policies to ensure that we skip resolving ID->Name + // in the returned data. + require.NoError(t, deleteTestPolicy(codec, "root", "dc1", policy1.ID)) + require.NoError(t, deleteTestPolicy(codec, "root", "dc1", policy2.ID)) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + require.NotNil(t, token.AccessorID) + require.Equal(t, token.Description, "foobar") + require.Equal(t, token.AccessorID, resp.AccessorID) + + require.Len(t, token.Policies, 0) + }) + + for _, test := range []struct { + name string + offset time.Duration + errString string + errStringTTL string + }{ + {"before create time", -5 * time.Minute, "ExpirationTime cannot be before CreateTime", ""}, + {"too soon", 1 * time.Millisecond, "ExpirationTime cannot be less than", "ExpirationTime cannot be less than"}, + {"too distant", 25 * time.Hour, "ExpirationTime cannot be more than", "ExpirationTime cannot be more than"}, + } { + t.Run("Create it with an expiration time that is "+test.name, func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ExpirationTime: time.Now().Add(test.offset), + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + if test.errString != "" { + requireErrorContains(t, err, test.errString) + } else { + require.NotNil(t, err) + } + }) + + t.Run("Create it with an expiration TTL that is "+test.name, func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ExpirationTTL: test.offset, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + if test.errString != "" { + requireErrorContains(t, err, test.errStringTTL) + } else { + require.NotNil(t, err) + } + }) + } + + t.Run("Create it with expiration time AND expiration TTL set (error)", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ExpirationTime: time.Now().Add(4 * time.Second), + ExpirationTTL: 4 * time.Second, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "Expiration TTL and Expiration Time cannot both be set") + }) + + t.Run("Create it with expiration time using TTLs", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ExpirationTTL: 4 * time.Second, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + expectExpTime := resp.CreateTime.Add(4 * time.Second) + + require.NotNil(t, token.AccessorID) + require.Equal(t, token.Description, "foobar") + require.Equal(t, token.AccessorID, resp.AccessorID) + requireTimeEquals(t, expectExpTime, resp.ExpirationTime) + + tokenID = token.AccessorID + }) + + var expTime time.Time + t.Run("Create it with expiration time", func(t *testing.T) { + expTime = time.Now().Add(4 * time.Second) + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ExpirationTime: expTime, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + require.NotNil(t, token.AccessorID) + require.Equal(t, token.Description, "foobar") + require.Equal(t, token.AccessorID, resp.AccessorID) + requireTimeEquals(t, expTime, resp.ExpirationTime) + + tokenID = token.AccessorID + }) + + // do not insert another test at this point: these tests need to be serial + + t.Run("Update expiration time is not allowed", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "new-description", + AccessorID: tokenID, + ExpirationTime: expTime.Add(-1 * time.Second), + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "Cannot change expiration time") + }) + + // do not insert another test at this point: these tests need to be serial + + t.Run("Update anything except expiration time is ok", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "new-description", + AccessorID: tokenID, + ExpirationTime: expTime, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + require.NotNil(t, token.AccessorID) + require.Equal(t, token.Description, "new-description") + require.Equal(t, token.AccessorID, resp.AccessorID) + requireTimeEquals(t, expTime, resp.ExpirationTime) + }) + + t.Run("cannot update a token that is past its expiration time", func(t *testing.T) { + // create a token that will expire + expiringToken, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.ExpirationTTL = 11 * time.Millisecond + }) + require.NoError(t, err) + + time.Sleep(20 * time.Millisecond) // now 'expiringToken' is expired + + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "new-description", + AccessorID: expiringToken.AccessorID, + ExpirationTTL: 4 * time.Second, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err = acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "Cannot find token") + }) } func TestACLEndpoint_TokenSet_anon(t *testing.T) { @@ -857,6 +1180,8 @@ func TestACLEndpoint_TokenDelete(t *testing.T) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second }) defer os.RemoveAll(dir1) defer s1.Shutdown() @@ -869,6 +1194,8 @@ func TestACLEndpoint_TokenDelete(t *testing.T) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true c.Datacenter = "dc2" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second // token replication is required to test deleting non-local tokens in secondary dc c.ACLTokenReplication = true }) @@ -885,12 +1212,74 @@ func TestACLEndpoint_TokenDelete(t *testing.T) { // Try to join joinWAN(t, s2, s1) - existingToken, err := upsertTestToken(codec, "root", "dc1") - require.NoError(t, err) - acl := ACL{srv: s1} acl2 := ACL{srv: s2} + existingToken, err := upsertTestToken(codec, "root", "dc1", nil) + require.NoError(t, err) + + t.Run("deletes a token that has an expiration time in the future", func(t *testing.T) { + // create a token that will expire + testToken, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.ExpirationTTL = 4 * time.Second + }) + require.NoError(t, err) + + // Make sure the token is listable + tokenResp, err := retrieveTestToken(codec, "root", "dc1", testToken.AccessorID) + require.NoError(t, err) + require.NotNil(t, tokenResp.Token) + + // Now try to delete it (this should work). + req := structs.ACLTokenDeleteRequest{ + Datacenter: "dc1", + TokenID: testToken.AccessorID, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var resp string + + err = acl.TokenDelete(&req, &resp) + require.NoError(t, err) + + // Make sure the token is gone + tokenResp, err = retrieveTestToken(codec, "root", "dc1", testToken.AccessorID) + require.NoError(t, err) + require.Nil(t, tokenResp.Token) + }) + + t.Run("deletes a token that is past its expiration time", func(t *testing.T) { + // create a token that will expire + expiringToken, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.ExpirationTTL = 11 * time.Millisecond + }) + require.NoError(t, err) + + time.Sleep(20 * time.Millisecond) // now 'expiringToken' is expired + + // Make sure the token is not listable (filtered due to expiry) + tokenResp, err := retrieveTestToken(codec, "root", "dc1", expiringToken.AccessorID) + require.NoError(t, err) + require.Nil(t, tokenResp.Token) + + // Now try to delete it (this should work). + req := structs.ACLTokenDeleteRequest{ + Datacenter: "dc1", + TokenID: expiringToken.AccessorID, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var resp string + + err = acl.TokenDelete(&req, &resp) + require.NoError(t, err) + + // Make sure the token is still gone (this time it's actually gone) + tokenResp, err = retrieveTestToken(codec, "root", "dc1", expiringToken.AccessorID) + require.NoError(t, err) + require.Nil(t, tokenResp.Token) + }) + t.Run("deletes a token", func(t *testing.T) { req := structs.ACLTokenDeleteRequest{ Datacenter: "dc1", @@ -919,7 +1308,7 @@ func TestACLEndpoint_TokenDelete(t *testing.T) { var out structs.ACLTokenResponse - err := msgpackrpc.CallWithCodec(codec, "ACL.TokenRead", &readReq, &out) + err := acl.TokenRead(&readReq, &out) require.NoError(t, err) @@ -1019,6 +1408,8 @@ func TestACLEndpoint_TokenList(t *testing.T) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second }) defer os.RemoveAll(dir1) defer s1.Shutdown() @@ -1027,31 +1418,74 @@ func TestACLEndpoint_TokenList(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") - t1, err := upsertTestToken(codec, "root", "dc1") - require.NoError(t, err) - - t2, err := upsertTestToken(codec, "root", "dc1") - require.NoError(t, err) - acl := ACL{srv: s1} - req := structs.ACLTokenListRequest{ - Datacenter: "dc1", - QueryOptions: structs.QueryOptions{Token: "root"}, - } - - resp := structs.ACLTokenListResponse{} - - err = acl.TokenList(&req, &resp) + t1, err := upsertTestToken(codec, "root", "dc1", nil) require.NoError(t, err) - tokens := []string{t1.AccessorID, t2.AccessorID} - var retrievedTokens []string + t2, err := upsertTestToken(codec, "root", "dc1", nil) + require.NoError(t, err) - for _, v := range resp.Tokens { - retrievedTokens = append(retrievedTokens, v.AccessorID) - } - require.Subset(t, retrievedTokens, tokens) + t3, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.ExpirationTTL = 11 * time.Millisecond + }) + require.NoError(t, err) + + masterTokenAccessorID, err := retrieveTestTokenAccessorForSecret(codec, "root", "dc1", "root") + require.NoError(t, err) + + t.Run("normal", func(t *testing.T) { + req := structs.ACLTokenListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLTokenListResponse{} + + err = acl.TokenList(&req, &resp) + require.NoError(t, err) + + tokens := []string{ + masterTokenAccessorID, + structs.ACLTokenAnonymousID, + t1.AccessorID, + t2.AccessorID, + t3.AccessorID, + } + + var retrievedTokens []string + for _, v := range resp.Tokens { + retrievedTokens = append(retrievedTokens, v.AccessorID) + } + require.ElementsMatch(t, retrievedTokens, tokens) + }) + + time.Sleep(20 * time.Millisecond) // now 't3' is expired + + t.Run("filter expired", func(t *testing.T) { + req := structs.ACLTokenListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLTokenListResponse{} + + err = acl.TokenList(&req, &resp) + require.NoError(t, err) + + tokens := []string{ + masterTokenAccessorID, + structs.ACLTokenAnonymousID, + t1.AccessorID, + t2.AccessorID, + } + + var retrievedTokens []string + for _, v := range resp.Tokens { + retrievedTokens = append(retrievedTokens, v.AccessorID) + } + require.ElementsMatch(t, retrievedTokens, tokens) + }) } func TestACLEndpoint_TokenBatchRead(t *testing.T) { @@ -1061,6 +1495,8 @@ func TestACLEndpoint_TokenBatchRead(t *testing.T) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second }) defer os.RemoveAll(dir1) defer s1.Shutdown() @@ -1069,32 +1505,64 @@ func TestACLEndpoint_TokenBatchRead(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") - t1, err := upsertTestToken(codec, "root", "dc1") - require.NoError(t, err) - - t2, err := upsertTestToken(codec, "root", "dc1") - require.NoError(t, err) - acl := ACL{srv: s1} - tokens := []string{t1.AccessorID, t2.AccessorID} - req := structs.ACLTokenBatchGetRequest{ - Datacenter: "dc1", - AccessorIDs: tokens, - QueryOptions: structs.QueryOptions{Token: "root"}, - } - - resp := structs.ACLTokenBatchResponse{} - - err = acl.TokenBatchRead(&req, &resp) + t1, err := upsertTestToken(codec, "root", "dc1", nil) require.NoError(t, err) - var retrievedTokens []string + t2, err := upsertTestToken(codec, "root", "dc1", nil) + require.NoError(t, err) - for _, v := range resp.Tokens { - retrievedTokens = append(retrievedTokens, v.AccessorID) - } - require.EqualValues(t, retrievedTokens, tokens) + t3, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.ExpirationTTL = 4 * time.Second + }) + require.NoError(t, err) + + t.Run("normal", func(t *testing.T) { + tokens := []string{t1.AccessorID, t2.AccessorID, t3.AccessorID} + + req := structs.ACLTokenBatchGetRequest{ + Datacenter: "dc1", + AccessorIDs: tokens, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLTokenBatchResponse{} + + err = acl.TokenBatchRead(&req, &resp) + require.NoError(t, err) + + var retrievedTokens []string + + for _, v := range resp.Tokens { + retrievedTokens = append(retrievedTokens, v.AccessorID) + } + require.EqualValues(t, retrievedTokens, tokens) + }) + + time.Sleep(20 * time.Millisecond) // now 't3' is expired + + t.Run("returns expired tokens", func(t *testing.T) { + tokens := []string{t1.AccessorID, t2.AccessorID, t3.AccessorID} + + req := structs.ACLTokenBatchGetRequest{ + Datacenter: "dc1", + AccessorIDs: tokens, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLTokenBatchResponse{} + + err = acl.TokenBatchRead(&req, &resp) + require.NoError(t, err) + + var retrievedTokens []string + + for _, v := range resp.Tokens { + retrievedTokens = append(retrievedTokens, v.AccessorID) + } + require.EqualValues(t, retrievedTokens, tokens) + }) } func TestACLEndpoint_PolicyRead(t *testing.T) { @@ -1419,13 +1887,17 @@ func TestACLEndpoint_PolicyList(t *testing.T) { err = acl.PolicyList(&req, &resp) require.NoError(t, err) - policies := []string{p1.ID, p2.ID} + policies := []string{ + structs.ACLPolicyGlobalManagementID, + p1.ID, + p2.ID, + } var retrievedPolicies []string for _, v := range resp.Policies { retrievedPolicies = append(retrievedPolicies, v.ID) } - require.Subset(t, retrievedPolicies, policies) + require.ElementsMatch(t, retrievedPolicies, policies) } func TestACLEndpoint_PolicyResolve(t *testing.T) { @@ -1491,7 +1963,8 @@ func TestACLEndpoint_PolicyResolve(t *testing.T) { } // upsertTestToken creates a token for testing purposes -func upsertTestToken(codec rpc.ClientCodec, masterToken string, datacenter string) (*structs.ACLToken, error) { +func upsertTestToken(codec rpc.ClientCodec, masterToken string, datacenter string, + tokenModificationFn func(token *structs.ACLToken)) (*structs.ACLToken, error) { arg := structs.ACLTokenSetRequest{ Datacenter: datacenter, ACLToken: structs.ACLToken{ @@ -1502,6 +1975,10 @@ func upsertTestToken(codec rpc.ClientCodec, masterToken string, datacenter strin WriteRequest: structs.WriteRequest{Token: masterToken}, } + if tokenModificationFn != nil { + tokenModificationFn(&arg.ACLToken) + } + var out structs.ACLToken err := msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg, &out) @@ -1517,6 +1994,29 @@ func upsertTestToken(codec rpc.ClientCodec, masterToken string, datacenter strin return &out, nil } +func retrieveTestTokenAccessorForSecret(codec rpc.ClientCodec, masterToken string, datacenter string, id string) (string, error) { + arg := structs.ACLTokenGetRequest{ + TokenID: "root", + TokenIDType: structs.ACLTokenSecret, + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + var out structs.ACLTokenResponse + + err := msgpackrpc.CallWithCodec(codec, "ACL.TokenRead", &arg, &out) + + if err != nil { + return "", err + } + + if out.Token == nil { + return "", nil + } + + return out.Token.AccessorID, nil +} + // retrieveTestToken returns a policy for testing purposes func retrieveTestToken(codec rpc.ClientCodec, masterToken string, datacenter string, id string) (*structs.ACLTokenResponse, error) { arg := structs.ACLTokenGetRequest{ @@ -1537,6 +2037,18 @@ func retrieveTestToken(codec rpc.ClientCodec, masterToken string, datacenter str return &out, nil } +func deleteTestPolicy(codec rpc.ClientCodec, masterToken string, datacenter string, policyID string) error { + arg := structs.ACLPolicyDeleteRequest{ + Datacenter: datacenter, + PolicyID: policyID, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var ignored string + err := msgpackrpc.CallWithCodec(codec, "ACL.PolicyDelete", &arg, &ignored) + return err +} + // upsertTestPolicy creates a policy for testing purposes func upsertTestPolicy(codec rpc.ClientCodec, masterToken string, datacenter string) (*structs.ACLPolicy, error) { // Make sure test policies can't collide @@ -1586,3 +2098,20 @@ func retrieveTestPolicy(codec rpc.ClientCodec, masterToken string, datacenter st return &out, nil } + +func requireTimeEquals(t *testing.T, expect, got time.Time) { + t.Helper() + if !expect.Equal(got) { + t.Fatalf("expected=%q != got=%q", expect, got) + } +} + +func requireErrorContains(t *testing.T, err error, expectedErrorMessage string) { + t.Helper() + if err == nil { + t.Fatal("An error is expected but got nil.") + } + if !strings.Contains(err.Error(), expectedErrorMessage) { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/agent/consul/acl_replication.go b/agent/consul/acl_replication.go index d691895b67..047705a60f 100644 --- a/agent/consul/acl_replication.go +++ b/agent/consul/acl_replication.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "github.com/armon/go-metrics" + metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/agent/structs" ) @@ -468,6 +468,7 @@ func (s *Server) replicateACLTokens(lastRemoteIndex uint64, ctx context.Context) if err != nil { return 0, false, fmt.Errorf("failed to retrieve local ACL tokens: %v", err) } + // Do not filter by expiration times. Wait until the tokens are explicitly deleted. // If the remote index ever goes backwards, it's a good indication that // the remote side was rebuilt and we should do a full sync since we diff --git a/agent/consul/acl_replication_legacy.go b/agent/consul/acl_replication_legacy.go index 182e206208..3fc2ce5eb3 100644 --- a/agent/consul/acl_replication_legacy.go +++ b/agent/consul/acl_replication_legacy.go @@ -6,7 +6,7 @@ import ( "sort" "time" - "github.com/armon/go-metrics" + metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/agent/structs" ) @@ -143,8 +143,13 @@ func (s *Server) fetchLocalLegacyACLs() (structs.ACLs, error) { return nil, err } + now := time.Now() + var acls structs.ACLs for _, token := range local { + if token.IsExpired(now) { + continue + } if acl, err := token.Convert(); err == nil && acl != nil { acls = append(acls, acl) } diff --git a/agent/consul/acl_server.go b/agent/consul/acl_server.go index 1eaf474c2b..e8213af4fc 100644 --- a/agent/consul/acl_server.go +++ b/agent/consul/acl_server.go @@ -2,6 +2,7 @@ package consul import ( "sync/atomic" + "time" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/structs" @@ -29,6 +30,11 @@ var serverACLCacheConfig *structs.ACLCachesConfig = &structs.ACLCachesConfig{ func (s *Server) checkTokenUUID(id string) (bool, error) { state := s.fsm.State() + + // We won't check expiration times here. If we generate a UUID that matches + // a token that hasn't been reaped yet, then we won't be able to insert the + // new token due to a collision. + if _, token, err := state.ACLTokenGetByAccessor(nil, id); err != nil { return false, err } else if token != nil { @@ -145,7 +151,7 @@ func (s *Server) ResolveIdentityFromToken(token string) (bool, structs.ACLIdenti index, aclToken, err := s.fsm.State().ACLTokenGetBySecret(nil, token) if err != nil { return true, nil, err - } else if aclToken != nil { + } else if aclToken != nil && !aclToken.IsExpired(time.Now()) { return true, aclToken, nil } diff --git a/agent/consul/acl_token_exp.go b/agent/consul/acl_token_exp.go new file mode 100644 index 0000000000..f1d5cb52d5 --- /dev/null +++ b/agent/consul/acl_token_exp.go @@ -0,0 +1,144 @@ +package consul + +import ( + "context" + "fmt" + "time" + + "github.com/hashicorp/consul/agent/structs" + "golang.org/x/time/rate" +) + +func (s *Server) startACLTokenReaping() { + s.aclTokenReapLock.Lock() + defer s.aclTokenReapLock.Unlock() + + if s.aclTokenReapEnabled { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + s.aclTokenReapCancel = cancel + + // Do a quick check for config settings that would imply the goroutine + // below will just spin forever. + // + // We can only check the config settings here that cannot change without a + // restart, so we omit the check for a non-empty replication token as that + // can be changed at runtime. + if !s.InACLDatacenter() && !s.config.ACLTokenReplication { + return + } + + go func() { + limiter := rate.NewLimiter(aclTokenReapingRateLimit, aclTokenReapingBurst) + + for { + if err := limiter.Wait(ctx); err != nil { + return + } + + if s.LocalTokensEnabled() { + if _, err := s.reapExpiredLocalACLTokens(); err != nil { + s.logger.Printf("[ERR] acl: error reaping expired local ACL tokens: %v", err) + } + } + if s.InACLDatacenter() { + if _, err := s.reapExpiredGlobalACLTokens(); err != nil { + s.logger.Printf("[ERR] acl: error reaping expired global ACL tokens: %v", err) + } + } + } + }() + + s.aclTokenReapEnabled = true +} + +func (s *Server) stopACLTokenReaping() { + s.aclTokenReapLock.Lock() + defer s.aclTokenReapLock.Unlock() + + if !s.aclTokenReapEnabled { + return + } + + s.aclTokenReapCancel() + s.aclTokenReapCancel = nil + s.aclTokenReapEnabled = false +} + +func (s *Server) reapExpiredGlobalACLTokens() (int, error) { + return s.reapExpiredACLTokens(false, true) +} +func (s *Server) reapExpiredLocalACLTokens() (int, error) { + return s.reapExpiredACLTokens(true, false) +} +func (s *Server) reapExpiredACLTokens(local, global bool) (int, error) { + if !s.ACLsEnabled() { + return 0, nil + } + if s.UseLegacyACLs() { + return 0, nil + } + if local == global { + return 0, fmt.Errorf("cannot reap both local and global tokens in the same request") + } + + locality := localityName(local) + + minExpiredTime, err := s.fsm.State().ACLTokenMinExpirationTime(local) + if err != nil { + return 0, err + } + + now := time.Now() + + if minExpiredTime.After(now) { + return 0, nil // nothing to do + } + + tokens, _, err := s.fsm.State().ACLTokenListExpired(local, now, aclBatchDeleteSize) + if err != nil { + return 0, err + } + + if len(tokens) == 0 { + return 0, nil + } + + var ( + secretIDs []string + req structs.ACLTokenBatchDeleteRequest + ) + for _, token := range tokens { + if token.Local != local { + return 0, fmt.Errorf("expired index for local=%v returned a mismatched token with local=%v: %s", local, token.Local, token.AccessorID) + } + req.TokenIDs = append(req.TokenIDs, token.AccessorID) + secretIDs = append(secretIDs, token.SecretID) + } + + s.logger.Printf("[INFO] acl: deleting %d expired %s tokens", len(req.TokenIDs), locality) + resp, err := s.raftApply(structs.ACLTokenDeleteRequestType, &req) + if err != nil { + return 0, fmt.Errorf("Failed to apply token expiration deletions: %v", err) + } + + // Purge the identities from the cache + for _, secretID := range secretIDs { + s.acls.cache.RemoveIdentity(secretID) + } + + if respErr, ok := resp.(error); ok { + return 0, respErr + } + + return len(req.TokenIDs), nil +} + +func localityName(local bool) string { + if local { + return "local" + } + return "global" +} diff --git a/agent/consul/acl_token_exp_test.go b/agent/consul/acl_token_exp_test.go new file mode 100644 index 0000000000..20ae878afc --- /dev/null +++ b/agent/consul/acl_token_exp_test.go @@ -0,0 +1,219 @@ +package consul + +import ( + "os" + "testing" + "time" + + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/testrpc" + "github.com/stretchr/testify/require" +) + +func TestACLTokenReap_Primary(t *testing.T) { + t.Parallel() + + t.Run("global", func(t *testing.T) { + t.Parallel() + testACLTokenReap_Primary(t, false, true) + }) + t.Run("local", func(t *testing.T) { + t.Parallel() + testACLTokenReap_Primary(t, true, false) + }) +} + +func testACLTokenReap_Primary(t *testing.T, local, global bool) { + // ------------------------------------------- + // A word of caution when testing reapExpiredACLTokens(): + // + // The underlying memdb index used for reaping has a minimum granularity of + // 1 second as it delegates to `time.Unix()`. This test will have to be + // deliberately slow to allow for necessary sleeps. If you try to make it + // operate faster (using expiration ttls of milliseconds) it will be flaky. + // ------------------------------------------- + + t.Helper() + require.NotEqual(t, local, global) + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 8 * time.Second + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + codec := rpcClient(t, s1) + defer codec.Close() + + acl := ACL{s1} + + masterTokenAccessorID, err := retrieveTestTokenAccessorForSecret(codec, "root", "dc1", "root") + require.NoError(t, err) + + listTokens := func() (localTokens, globalTokens []string, err error) { + req := structs.ACLTokenListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + var res structs.ACLTokenListResponse + err = acl.TokenList(&req, &res) + if err != nil { + return nil, nil, err + } + + for _, tok := range res.Tokens { + if tok.Local { + localTokens = append(localTokens, tok.AccessorID) + } else { + globalTokens = append(globalTokens, tok.AccessorID) + } + } + + return localTokens, globalTokens, nil + } + + requireTokenMatch := func(t *testing.T, expect []string) { + t.Helper() + + var expectLocal, expectGlobal []string + // The master token and the anonymous token are always going to be + // present and global. + expectGlobal = append(expectGlobal, masterTokenAccessorID) + expectGlobal = append(expectGlobal, structs.ACLTokenAnonymousID) + + if local { + expectLocal = append(expectLocal, expect...) + } else { + expectGlobal = append(expectGlobal, expect...) + } + + localTokens, globalTokens, err := listTokens() + require.NoError(t, err) + require.ElementsMatch(t, expectLocal, localTokens) + require.ElementsMatch(t, expectGlobal, globalTokens) + } + + // initial sanity check + requireTokenMatch(t, []string{}) + + t.Run("no tokens", func(t *testing.T) { + n, err := s1.reapExpiredACLTokens(local, global) + require.NoError(t, err) + require.Equal(t, 0, n) + + requireTokenMatch(t, []string{}) + }) + + // 2 normal + token1, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.Local = local + }) + require.NoError(t, err) + token2, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.Local = local + }) + require.NoError(t, err) + + requireTokenMatch(t, []string{ + token1.AccessorID, + token2.AccessorID, + }) + + t.Run("only normal tokens", func(t *testing.T) { + n, err := s1.reapExpiredACLTokens(local, global) + require.NoError(t, err) + require.Equal(t, 0, n) + + requireTokenMatch(t, []string{ + token1.AccessorID, + token2.AccessorID, + }) + }) + + // 2 expiring + token3, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.ExpirationTTL = 1 * time.Second + token.Local = local + }) + require.NoError(t, err) + token4, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.ExpirationTTL = 5 * time.Second + token.Local = local + }) + require.NoError(t, err) + + // 2 more normal + token5, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.Local = local + }) + require.NoError(t, err) + token6, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.Local = local + }) + require.NoError(t, err) + + requireTokenMatch(t, []string{ + token1.AccessorID, + token2.AccessorID, + token3.AccessorID, + token4.AccessorID, + token5.AccessorID, + token6.AccessorID, + }) + + t.Run("mixed but nothing expired yet", func(t *testing.T) { + n, err := s1.reapExpiredACLTokens(local, global) + require.NoError(t, err) + require.Equal(t, 0, n) + + requireTokenMatch(t, []string{ + token1.AccessorID, + token2.AccessorID, + token3.AccessorID, + token4.AccessorID, + token5.AccessorID, + token6.AccessorID, + }) + }) + + time.Sleep(token3.ExpirationTime.Sub(time.Now()) + 10*time.Millisecond) + + t.Run("one should be reaped", func(t *testing.T) { + n, err := s1.reapExpiredACLTokens(local, global) + require.NoError(t, err) + require.Equal(t, 1, n) + + requireTokenMatch(t, []string{ + token1.AccessorID, + token2.AccessorID, + // token3.AccessorID, + token4.AccessorID, + token5.AccessorID, + token6.AccessorID, + }) + }) + + time.Sleep(token4.ExpirationTime.Sub(time.Now()) + 10*time.Millisecond) + + t.Run("two should be reaped", func(t *testing.T) { + n, err := s1.reapExpiredACLTokens(local, global) + require.NoError(t, err) + require.Equal(t, 1, n) + + requireTokenMatch(t, []string{ + token1.AccessorID, + token2.AccessorID, + // token3.AccessorID, + // token4.AccessorID, + token5.AccessorID, + token6.AccessorID, + }) + }) +} diff --git a/agent/consul/config.go b/agent/consul/config.go index baf9e63a29..cce70f392b 100644 --- a/agent/consul/config.go +++ b/agent/consul/config.go @@ -313,6 +313,16 @@ type Config struct { // Minimum Session TTL SessionTTLMin time.Duration + // maxTokenExpirationDuration is the maximum difference allowed between + // ACLToken CreateTime and ExpirationTime values if ExpirationTime is set + // on a token. + ACLTokenMaxExpirationTTL time.Duration + + // ACLTokenMinExpirationTTL is the minimum difference allowed between + // ACLToken CreateTime and ExpirationTime values if ExpirationTime is set + // on a token. + ACLTokenMinExpirationTTL time.Duration + // ServerUp callback can be used to trigger a notification that // a Consul server is now up and known about. ServerUp func() @@ -473,6 +483,8 @@ func DefaultConfig() *Config { TombstoneTTL: 15 * time.Minute, TombstoneTTLGranularity: 30 * time.Second, SessionTTLMin: 10 * time.Second, + ACLTokenMinExpirationTTL: 1 * time.Minute, + ACLTokenMaxExpirationTTL: 24 * time.Hour, // These are tuned to provide a total throughput of 128 updates // per second. If you update these, you should update the client- diff --git a/agent/consul/fsm/commands_oss.go b/agent/consul/fsm/commands_oss.go index 1e3651cba4..f9d75e83c8 100644 --- a/agent/consul/fsm/commands_oss.go +++ b/agent/consul/fsm/commands_oss.go @@ -4,7 +4,7 @@ import ( "fmt" "time" - "github.com/armon/go-metrics" + metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" ) @@ -165,6 +165,7 @@ func (c *FSM) applyACLOperation(buf []byte, index uint64) interface{} { return err } + // No need to check expiration times as those did not exist in legacy tokens. if _, token, err := c.state.ACLTokenGetBySecret(nil, req.ACL.ID); err != nil { return err } else { diff --git a/agent/consul/fsm/snapshot_oss.go b/agent/consul/fsm/snapshot_oss.go index 0c7713753e..4195b8c422 100644 --- a/agent/consul/fsm/snapshot_oss.go +++ b/agent/consul/fsm/snapshot_oss.go @@ -178,6 +178,8 @@ func (s *snapshot) persistACLs(sink raft.SnapshotSink, return err } + // Don't check expiration times. Wait for explicit deletions. + for token := tokens.Next(); token != nil; token = tokens.Next() { if _, err := sink.Write([]byte{byte(structs.ACLTokenSetRequestType)}); err != nil { return err diff --git a/agent/consul/leader.go b/agent/consul/leader.go index 3264892095..d42fddd798 100644 --- a/agent/consul/leader.go +++ b/agent/consul/leader.go @@ -10,7 +10,7 @@ import ( "sync/atomic" "time" - "github.com/armon/go-metrics" + metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/connect" ca "github.com/hashicorp/consul/agent/connect/ca" @@ -22,7 +22,7 @@ import ( "github.com/hashicorp/consul/types" memdb "github.com/hashicorp/go-memdb" uuid "github.com/hashicorp/go-uuid" - "github.com/hashicorp/go-version" + version "github.com/hashicorp/go-version" "github.com/hashicorp/raft" "github.com/hashicorp/serf/serf" "golang.org/x/time/rate" @@ -308,6 +308,8 @@ func (s *Server) revokeLeadership() error { s.setCAProvider(nil, nil) + s.stopACLTokenReaping() + s.stopACLUpgrade() s.resetConsistentReadReady() @@ -329,6 +331,7 @@ func (s *Server) initializeLegacyACL() error { if err != nil { return fmt.Errorf("failed to get anonymous token: %v", err) } + // Ignoring expiration times to avoid an insertion collision. if token == nil { req := structs.ACLRequest{ Datacenter: authDC, @@ -352,6 +355,7 @@ func (s *Server) initializeLegacyACL() error { if err != nil { return fmt.Errorf("failed to get master token: %v", err) } + // Ignoring expiration times to avoid an insertion collision. if token == nil { req := structs.ACLRequest{ Datacenter: authDC, @@ -482,6 +486,7 @@ func (s *Server) initializeACLs(upgrade bool) error { if err != nil { return fmt.Errorf("failed to get master token: %v", err) } + // Ignoring expiration times to avoid an insertion collision. if token == nil { accessor, err := lib.GenerateUUID(s.checkTokenUUID) if err != nil { @@ -543,6 +548,7 @@ func (s *Server) initializeACLs(upgrade bool) error { if err != nil { return fmt.Errorf("failed to get anonymous token: %v", err) } + // Ignoring expiration times to avoid an insertion collision. if token == nil { // DEPRECATED (ACL-Legacy-Compat) - Don't need to query for previous "anonymous" token // check for legacy token that needs an upgrade @@ -550,6 +556,7 @@ func (s *Server) initializeACLs(upgrade bool) error { if err != nil { return fmt.Errorf("failed to get anonymous token: %v", err) } + // Ignoring expiration times to avoid an insertion collision. // the token upgrade routine will take care of upgrading the token if a legacy version exists if legacyToken == nil { @@ -572,6 +579,7 @@ func (s *Server) initializeACLs(upgrade bool) error { s.logger.Printf("[INFO] consul: Created ACL anonymous token from configuration") } } + // launch the upgrade go routine to generate accessors for everything s.startACLUpgrade() } else { if s.UseLegacyACLs() && !upgrade { @@ -588,7 +596,7 @@ func (s *Server) initializeACLs(upgrade bool) error { s.startACLReplication() } - // launch the upgrade go routine to generate accessors for everything + s.startACLTokenReaping() return nil } @@ -617,6 +625,7 @@ func (s *Server) startACLUpgrade() { if err != nil { s.logger.Printf("[WARN] acl: encountered an error while searching for tokens without accessor ids: %v", err) } + // No need to check expiration time here, as that only exists for v2 tokens. if len(tokens) == 0 { ws := memdb.NewWatchSet() @@ -797,10 +806,10 @@ func (s *Server) startACLReplication() { if s.config.ACLTokenReplication { replicationType = structs.ACLReplicateTokens - go func() { var failedAttempts uint limiter := rate.NewLimiter(rate.Limit(s.config.ACLReplicationRate), s.config.ACLReplicationBurst) + var lastRemoteIndex uint64 for { if err := limiter.Wait(ctx); err != nil { diff --git a/agent/consul/server.go b/agent/consul/server.go index f19a907455..b44b417816 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -109,6 +109,12 @@ type Server struct { aclReplicationLock sync.RWMutex aclReplicationEnabled bool + // aclTokenReapCancel is used to shut down the ACL Token expiration reap + // goroutine when we lose leadership. + aclTokenReapCancel context.CancelFunc + aclTokenReapLock sync.RWMutex + aclTokenReapEnabled bool + // DEPRECATED (ACL-Legacy-Compat) - only needed while we support both // useNewACLs is used to determine whether we can use new ACLs or not useNewACLs int32 diff --git a/agent/consul/state/acl.go b/agent/consul/state/acl.go index 5c2e83d57f..36c0a03e46 100644 --- a/agent/consul/state/acl.go +++ b/agent/consul/state/acl.go @@ -1,10 +1,12 @@ package state import ( + "encoding/binary" "fmt" + "time" "github.com/hashicorp/consul/agent/structs" - "github.com/hashicorp/go-memdb" + memdb "github.com/hashicorp/go-memdb" ) type TokenPoliciesIndex struct { @@ -58,6 +60,54 @@ func (s *TokenPoliciesIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) return val, nil } +type TokenExpirationIndex struct { + LocalFilter bool +} + +func (s *TokenExpirationIndex) encodeTime(t time.Time) []byte { + val := t.Unix() + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(val)) + return buf +} + +func (s *TokenExpirationIndex) FromObject(obj interface{}) (bool, []byte, error) { + token, ok := obj.(*structs.ACLToken) + if !ok { + return false, nil, fmt.Errorf("object is not an ACLToken") + } + if s.LocalFilter != token.Local { + return false, nil, nil + } + if token.ExpirationTime.IsZero() { + return false, nil, nil + } + if token.ExpirationTime.Unix() < 0 { + return false, nil, fmt.Errorf("token expiration time cannot be before the unix epoch: %s", token.ExpirationTime) + } + + buf := s.encodeTime(token.ExpirationTime) + + return true, buf, nil +} + +func (s *TokenExpirationIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + arg, ok := args[0].(time.Time) + if !ok { + return nil, fmt.Errorf("argument must be a time.Time: %#v", args[0]) + } + if arg.Unix() < 0 { + return nil, fmt.Errorf("argument must be a time.Time after the unix epoch: %s", args[0]) + } + + buf := s.encodeTime(arg) + + return buf, nil +} + func tokensTableSchema() *memdb.TableSchema { return &memdb.TableSchema{ Name: "acl-tokens", @@ -100,6 +150,18 @@ func tokensTableSchema() *memdb.TableSchema { }, }, }, + "expires-global": { + Name: "expires-global", + AllowMissing: true, + Unique: false, + Indexer: &TokenExpirationIndex{LocalFilter: false}, + }, + "expires-local": { + Name: "expires-local", + AllowMissing: true, + Unique: false, + Indexer: &TokenExpirationIndex{LocalFilter: true}, + }, //DEPRECATED (ACL-Legacy-Compat) - This index is only needed while we support upgrading v1 to v2 acls // This table indexes all the ACL tokens that do not have an AccessorID @@ -405,7 +467,7 @@ func (s *Store) aclTokenSetTxn(tx *memdb.Txn, idx uint64, token *structs.ACLToke } if legacy && original != nil { - if len(original.Policies) > 0 || original.Type == "" { + if original.UsesNonLegacyFields() { return fmt.Errorf("failed inserting acl token: cannot use legacy endpoint to modify a non-legacy token") } @@ -586,6 +648,63 @@ func (s *Store) ACLTokenListUpgradeable(max int) (structs.ACLTokens, <-chan stru return tokens, iter.WatchCh(), nil } +func (s *Store) ACLTokenMinExpirationTime(local bool) (time.Time, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + item, err := tx.First("acl-tokens", s.expiresIndexName(local)) + if err != nil { + return time.Time{}, fmt.Errorf("failed acl token listing: %v", err) + } + + if item == nil { + return time.Time{}, nil + } + + token := item.(*structs.ACLToken) + + return token.ExpirationTime, nil +} + +// ACLTokenListExpires lists tokens that are expires as of the provided time. +// The returned set will be no larger than the max value provided. +func (s *Store) ACLTokenListExpired(local bool, asOf time.Time, max int) (structs.ACLTokens, <-chan struct{}, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + iter, err := tx.Get("acl-tokens", s.expiresIndexName(local)) + if err != nil { + return nil, nil, fmt.Errorf("failed acl token listing: %v", err) + } + + var ( + tokens structs.ACLTokens + i int + ) + for raw := iter.Next(); raw != nil; raw = iter.Next() { + token := raw.(*structs.ACLToken) + + if !token.ExpirationTime.Before(asOf) { + return tokens, nil, nil + } + + tokens = append(tokens, token) + i += 1 + if i >= max { + return tokens, nil, nil + } + } + + return tokens, iter.WatchCh(), nil +} + +func (s *Store) expiresIndexName(local bool) string { + if local { + return "expires-local" + } + return "expires-global" +} + // ACLTokenDeleteBySecret is used to remove an existing ACL from the state store. If // the ACL does not exist this is a no-op and no error is returned. func (s *Store) ACLTokenDeleteBySecret(idx uint64, secret string) error { diff --git a/agent/consul/state/acl_test.go b/agent/consul/state/acl_test.go index b6546b5bda..0e01fbcfe3 100644 --- a/agent/consul/state/acl_test.go +++ b/agent/consul/state/acl_test.go @@ -1,11 +1,16 @@ package state import ( + "math/rand" + "strconv" "testing" "time" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/lib" + memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/go-uuid" "github.com/stretchr/testify/require" ) @@ -47,6 +52,13 @@ func setupExtraPolicies(t *testing.T, s *Store) { Rules: `node_prefix "" { policy = "read" }`, Syntax: acl.SyntaxCurrent, }, + &structs.ACLPolicy{ + ID: "9386ecae-6677-4686-bcd4-5ab9d86cca1d", + Name: "agent-read", + Description: "Allows reading all node information", + Rules: `agent_prefix "" { policy = "read" }`, + Syntax: acl.SyntaxCurrent, + }, } for _, policy := range policies { @@ -94,23 +106,6 @@ func TestStateStore_ACLBootstrap(t *testing.T) { Type: structs.ACLTokenTypeManagement, } - stripIrrelevantFields := func(token *structs.ACLToken) *structs.ACLToken { - tokenCopy := token.Clone() - // When comparing the tokens disregard the policy link names. This - // data is not cleanly updated in a variety of scenarios and should not - // be relied upon. - for i, _ := range tokenCopy.Policies { - tokenCopy.Policies[i].Name = "" - } - // The raft indexes won't match either because the requester will not - // have access to that. - tokenCopy.RaftIndex = structs.RaftIndex{} - return tokenCopy - } - compareTokens := func(expected, actual *structs.ACLToken) { - require.Equal(t, stripIrrelevantFields(expected), stripIrrelevantFields(actual)) - } - s := testStateStore(t) setupGlobalManagement(t, s) @@ -143,7 +138,7 @@ func TestStateStore_ACLBootstrap(t *testing.T) { _, tokens, err := s.ACLTokenList(nil, true, true, "") require.NoError(t, err) require.Len(t, tokens, 1) - compareTokens(token1, tokens[0]) + compareTokens(t, token1, tokens[0]) // bootstrap reset err = s.ACLBootstrap(32, index-1, token2.Clone(), false) @@ -175,10 +170,10 @@ func TestStateStore_ACLToken_SetGet_Legacy(t *testing.T) { }, } - require.NoError(t, s.ACLTokenSet(2, token, false)) + require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) // legacy flag is set so it should disallow setting this token - err := s.ACLTokenSet(3, token, true) + err := s.ACLTokenSet(3, token.Clone(), true) require.Error(t, err) }) @@ -190,10 +185,10 @@ func TestStateStore_ACLToken_SetGet_Legacy(t *testing.T) { SecretID: "c0056225-5785-43b3-9b77-3954f06d6aee", } - require.NoError(t, s.ACLTokenSet(2, token, false)) + require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) // legacy flag is set so it should disallow setting this token - err := s.ACLTokenSet(3, token, true) + err := s.ACLTokenSet(3, token.Clone(), true) require.Error(t, err) }) @@ -206,7 +201,7 @@ func TestStateStore_ACLToken_SetGet_Legacy(t *testing.T) { Rules: `service "" { policy = "read" }`, } - require.NoError(t, s.ACLTokenSet(2, token, true)) + require.NoError(t, s.ACLTokenSet(2, token.Clone(), true)) idx, rtoken, err := s.ACLTokenGetBySecret(nil, token.SecretID) require.NoError(t, err) @@ -230,7 +225,7 @@ func TestStateStore_ACLToken_SetGet_Legacy(t *testing.T) { Rules: `service "" { policy = "read" }`, } - require.NoError(t, s.ACLTokenSet(2, original, true)) + require.NoError(t, s.ACLTokenSet(2, original.Clone(), true)) updatedRules := `service "" { policy = "read" } service "foo" { policy = "deny"}` update := &structs.ACLToken{ @@ -239,7 +234,7 @@ func TestStateStore_ACLToken_SetGet_Legacy(t *testing.T) { Rules: updatedRules, } - require.NoError(t, s.ACLTokenSet(3, update, true)) + require.NoError(t, s.ACLTokenSet(3, update.Clone(), true)) idx, rtoken, err := s.ACLTokenGetBySecret(nil, original.SecretID) require.NoError(t, err) @@ -265,7 +260,7 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { AccessorID: "39171632-6f34-4411-827f-9416403687f4", } - err := s.ACLTokenSet(2, token, false) + err := s.ACLTokenSet(2, token.Clone(), false) require.Error(t, err) require.Equal(t, ErrMissingACLTokenSecret, err) }) @@ -277,7 +272,7 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { SecretID: "39171632-6f34-4411-827f-9416403687f4", } - err := s.ACLTokenSet(2, token, false) + err := s.ACLTokenSet(2, token.Clone(), false) require.Error(t, err) require.Equal(t, ErrMissingACLTokenAccessor, err) }) @@ -295,7 +290,7 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { }, } - err := s.ACLTokenSet(2, token, false) + err := s.ACLTokenSet(2, token.Clone(), false) require.Error(t, err) }) @@ -312,7 +307,7 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { }, } - err := s.ACLTokenSet(2, token, false) + err := s.ACLTokenSet(2, token.Clone(), false) require.Error(t, err) }) @@ -329,13 +324,12 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { }, } - require.NoError(t, s.ACLTokenSet(2, token, false)) + require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) idx, rtoken, err := s.ACLTokenGetByAccessor(nil, "daf37c07-d04d-4fd5-9678-a8206a57d61a") require.NoError(t, err) require.Equal(t, uint64(2), idx) - // pointer equality - require.True(t, rtoken == token) + compareTokens(t, token, rtoken) require.Equal(t, uint64(2), rtoken.CreateIndex) require.Equal(t, uint64(2), rtoken.ModifyIndex) require.Len(t, rtoken.Policies, 1) @@ -355,7 +349,7 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { }, } - require.NoError(t, s.ACLTokenSet(2, token, false)) + require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) updated := &structs.ACLToken{ AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", @@ -367,13 +361,12 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { }, } - require.NoError(t, s.ACLTokenSet(3, updated, false)) + require.NoError(t, s.ACLTokenSet(3, updated.Clone(), false)) idx, rtoken, err := s.ACLTokenGetByAccessor(nil, "daf37c07-d04d-4fd5-9678-a8206a57d61a") require.NoError(t, err) require.Equal(t, uint64(3), idx) - // pointer equality - require.True(t, rtoken == updated) + compareTokens(t, updated, rtoken) require.Equal(t, uint64(2), rtoken.CreateIndex) require.Equal(t, uint64(3), rtoken.ModifyIndex) require.Len(t, rtoken.Policies, 1) @@ -945,7 +938,7 @@ func TestStateStore_ACLToken_Delete(t *testing.T) { Local: true, } - require.NoError(t, s.ACLTokenSet(2, token, false)) + require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) _, rtoken, err := s.ACLTokenGetByAccessor(nil, "f1093997-b6c7-496d-bfb8-6b1b1895641b") require.NoError(t, err) @@ -973,7 +966,7 @@ func TestStateStore_ACLToken_Delete(t *testing.T) { Local: true, } - require.NoError(t, s.ACLTokenSet(2, token, false)) + require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) _, rtoken, err := s.ACLTokenGetByAccessor(nil, "f1093997-b6c7-496d-bfb8-6b1b1895641b") require.NoError(t, err) @@ -1183,7 +1176,7 @@ func TestStateStore_ACLPolicy_SetGet(t *testing.T) { // this creates the node read policy which we can update s := testACLTokensStateStore(t) - update := structs.ACLPolicy{ + update := &structs.ACLPolicy{ ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", Name: "node-read-modified", Description: "Modified", @@ -1192,19 +1185,29 @@ func TestStateStore_ACLPolicy_SetGet(t *testing.T) { Datacenters: []string{"dc1", "dc2"}, } - require.NoError(t, s.ACLPolicySet(3, &update)) + require.NoError(t, s.ACLPolicySet(3, update.Clone())) + expect := update.Clone() + expect.CreateIndex = 2 + expect.ModifyIndex = 3 + + // policy found via id idx, rpolicy, err := s.ACLPolicyGetByID(nil, "a0625e95-9b3e-42de-a8d6-ceef5b6f3286") + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.Equal(t, expect, rpolicy) + + // policy no longer found via old name + idx, rpolicy, err = s.ACLPolicyGetByName(nil, "node-read") require.Equal(t, uint64(3), idx) require.NoError(t, err) - require.NotNil(t, rpolicy) - require.Equal(t, "node-read-modified", rpolicy.Name) - require.Equal(t, "Modified", rpolicy.Description) - require.Equal(t, `node_prefix "" { policy = "read" } node "secret" { policy = "deny" }`, rpolicy.Rules) - require.Equal(t, acl.SyntaxCurrent, rpolicy.Syntax) - require.ElementsMatch(t, []string{"dc1", "dc2"}, rpolicy.Datacenters) - require.Equal(t, uint64(2), rpolicy.CreateIndex) - require.Equal(t, uint64(3), rpolicy.ModifyIndex) + require.Nil(t, rpolicy) + + // policy is found via new name + idx, rpolicy, err = s.ACLPolicyGetByName(nil, "node-read-modified") + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.Equal(t, expect, rpolicy) }) } @@ -1632,3 +1635,166 @@ func TestStateStore_ACLPolicies_Snapshot_Restore(t *testing.T) { require.Equal(t, uint64(2), s.maxIndex("acl-policies")) }() } + +func TestTokenPoliciesIndex(t *testing.T) { + lib.SeedMathRand() + + idIndex := &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "AccessorID", Lowercase: false}, + } + globalIndex := &memdb.IndexSchema{ + Name: "global", + AllowMissing: true, + Unique: false, + Indexer: &TokenExpirationIndex{LocalFilter: false}, + } + localIndex := &memdb.IndexSchema{ + Name: "local", + AllowMissing: true, + Unique: false, + Indexer: &TokenExpirationIndex{LocalFilter: true}, + } + schema := &memdb.DBSchema{ + Tables: map[string]*memdb.TableSchema{ + "test": &memdb.TableSchema{ + Name: "test", + Indexes: map[string]*memdb.IndexSchema{ + "id": idIndex, + "global": globalIndex, + "local": localIndex, + }, + }, + }, + } + + knownUUIDs := make(map[string]struct{}) + newUUID := func() string { + for { + ret, err := uuid.GenerateUUID() + require.NoError(t, err) + if _, ok := knownUUIDs[ret]; !ok { + knownUUIDs[ret] = struct{}{} + return ret + } + } + } + + baseTime := time.Date(2010, 12, 31, 11, 30, 7, 0, time.UTC) + + newToken := func(local bool, desc string, expTime time.Time) *structs.ACLToken { + return &structs.ACLToken{ + AccessorID: newUUID(), + SecretID: newUUID(), + Description: desc, + Local: local, + ExpirationTime: expTime, + CreateTime: baseTime, + RaftIndex: structs.RaftIndex{ + CreateIndex: 9, + ModifyIndex: 10, + }, + } + } + + db, err := memdb.NewMemDB(schema) + require.NoError(t, err) + + dumpItems := func(index string) ([]string, error) { + tx := db.Txn(false) + defer tx.Abort() + + iter, err := tx.Get("test", index) + if err != nil { + return nil, err + } + + var out []string + for raw := iter.Next(); raw != nil; raw = iter.Next() { + tok := raw.(*structs.ACLToken) + out = append(out, tok.Description) + } + + return out, nil + } + + { // insert things with no expiration time + tx := db.Txn(true) + for i := 0; i < 10; i++ { + tok := newToken(i%2 != 1, "tok["+strconv.Itoa(i)+"]", time.Time{}) + + require.NoError(t, tx.Insert("test", tok)) + } + tx.Commit() + } + + t.Run("no expiration", func(t *testing.T) { + dump, err := dumpItems("local") + require.NoError(t, err) + require.Len(t, dump, 0) + + dump, err = dumpItems("global") + require.NoError(t, err) + require.Len(t, dump, 0) + }) + + { // insert things with laddered expiration time, inserted in random order + var tokens []*structs.ACLToken + for i := 0; i < 10; i++ { + expTime := baseTime.Add(time.Duration(i+1) * time.Minute) + tok := newToken(i%2 == 0, "exp-tok["+strconv.Itoa(i)+"]", expTime) + tokens = append(tokens, tok) + } + rand.Shuffle(len(tokens), func(i, j int) { + tokens[i], tokens[j] = tokens[j], tokens[i] + }) + + tx := db.Txn(true) + for _, tok := range tokens { + require.NoError(t, tx.Insert("test", tok)) + } + tx.Commit() + } + + t.Run("mixed expiration", func(t *testing.T) { + dump, err := dumpItems("local") + require.NoError(t, err) + require.ElementsMatch(t, []string{ + "exp-tok[0]", + "exp-tok[2]", + "exp-tok[4]", + "exp-tok[6]", + "exp-tok[8]", + }, dump) + + dump, err = dumpItems("global") + require.NoError(t, err) + require.ElementsMatch(t, []string{ + "exp-tok[1]", + "exp-tok[3]", + "exp-tok[5]", + "exp-tok[7]", + "exp-tok[9]", + }, dump) + }) +} + +func stripIrrelevantTokenFields(token *structs.ACLToken) *structs.ACLToken { + tokenCopy := token.Clone() + // When comparing the tokens disregard the policy link names. This + // data is not cleanly updated in a variety of scenarios and should not + // be relied upon. + for i, _ := range tokenCopy.Policies { + tokenCopy.Policies[i].Name = "" + } + // The raft indexes won't match either because the requester will not + // have access to that. + tokenCopy.RaftIndex = structs.RaftIndex{} + return tokenCopy +} + +func compareTokens(t *testing.T, expected, actual *structs.ACLToken) { + require.Equal(t, stripIrrelevantTokenFields(expected), stripIrrelevantTokenFields(actual)) +} diff --git a/agent/structs/acl.go b/agent/structs/acl.go index 5e052b414f..84e1793661 100644 --- a/agent/structs/acl.go +++ b/agent/structs/acl.go @@ -113,6 +113,7 @@ type ACLIdentity interface { SecretToken() string PolicyIDs() []string EmbeddedPolicy() *ACLPolicy + IsExpired(asOf time.Time) bool } type ACLTokenPolicyLink struct { @@ -150,6 +151,19 @@ type ACLToken struct { // to the ACL datacenter and replicated to others. Local bool + // ExpirationTime represents the point after which a token should be + // considered revoked and is eligible for destruction. The zero value + // represents NO expiration. + ExpirationTime time.Time `json:",omitempty"` + + // ExpirationTTL is a convenience field for helping set ExpirationTime to a + // value of CreateTime+ExpirationTTL. This can only be set during + // TokenCreate and is cleared and used to initialize the ExpirationTime + // field before being persisted to the state store or raft log. + // + // This is a string version of a time.Duration like "2m". + ExpirationTTL time.Duration `json:",omitempty"` + // The time when this token was created CreateTime time.Time `json:",omitempty"` @@ -191,6 +205,20 @@ func (t *ACLToken) PolicyIDs() []string { return ids } +func (t *ACLToken) IsExpired(asOf time.Time) bool { + if asOf.IsZero() || t.ExpirationTime.IsZero() { + return false + } + return t.ExpirationTime.Before(asOf) +} + +func (t *ACLToken) UsesNonLegacyFields() bool { + return len(t.Policies) > 0 || + t.Type == "" || + !t.ExpirationTime.IsZero() || + t.ExpirationTTL != 0 +} + func (t *ACLToken) EmbeddedPolicy() *ACLPolicy { // DEPRECATED (ACL-Legacy-Compat) // @@ -229,6 +257,14 @@ func (t *ACLToken) SetHash(force bool) []byte { panic(err) } + // Any non-immutable "content" fields should be involved with the + // overall hash. The IDs are immutable which is why they aren't here. + // The raft indices are metadata similar to the hash which is why they + // aren't incorporated. CreateTime is similarly immutable + // + // The Hash is really only used for replication to determine if a token + // has changed and should be updated locally. + // Write all the user set fields hash.Write([]byte(t.Description)) hash.Write([]byte(t.Type)) @@ -254,8 +290,8 @@ func (t *ACLToken) SetHash(force bool) []byte { } func (t *ACLToken) EstimateSize() int { - // 33 = 16 (RaftIndex) + 8 (Hash) + 8 (CreateTime) + 1 (Local) - size := 33 + len(t.AccessorID) + len(t.SecretID) + len(t.Description) + len(t.Type) + len(t.Rules) + // 41 = 16 (RaftIndex) + 8 (Hash) + 8 (ExpirationTime) + 8 (CreateTime) + 1 (Local) + size := 41 + len(t.AccessorID) + len(t.SecretID) + len(t.Description) + len(t.Type) + len(t.Rules) for _, link := range t.Policies { size += len(link.ID) + len(link.Name) } @@ -266,30 +302,32 @@ func (t *ACLToken) EstimateSize() int { type ACLTokens []*ACLToken type ACLTokenListStub struct { - AccessorID string - Description string - Policies []ACLTokenPolicyLink - Local bool - CreateTime time.Time `json:",omitempty"` - Hash []byte - CreateIndex uint64 - ModifyIndex uint64 - Legacy bool `json:",omitempty"` + AccessorID string + Description string + Policies []ACLTokenPolicyLink + Local bool + ExpirationTime time.Time `json:",omitempty"` + CreateTime time.Time `json:",omitempty"` + Hash []byte + CreateIndex uint64 + ModifyIndex uint64 + Legacy bool `json:",omitempty"` } type ACLTokenListStubs []*ACLTokenListStub func (token *ACLToken) Stub() *ACLTokenListStub { return &ACLTokenListStub{ - AccessorID: token.AccessorID, - Description: token.Description, - Policies: token.Policies, - Local: token.Local, - CreateTime: token.CreateTime, - Hash: token.Hash, - CreateIndex: token.CreateIndex, - ModifyIndex: token.ModifyIndex, - Legacy: token.Rules != "", + AccessorID: token.AccessorID, + Description: token.Description, + Policies: token.Policies, + Local: token.Local, + ExpirationTime: token.ExpirationTime, + CreateTime: token.CreateTime, + Hash: token.Hash, + CreateIndex: token.CreateIndex, + ModifyIndex: token.ModifyIndex, + Legacy: token.Rules != "", } } @@ -384,6 +422,14 @@ func (p *ACLPolicy) SetHash(force bool) []byte { panic(err) } + // Any non-immutable "content" fields should be involved with the + // overall hash. The ID is immutable which is why it isn't here. The + // raft indices are metadata similar to the hash which is why they + // aren't incorporated. CreateTime is similarly immutable + // + // The Hash is really only used for replication to determine if a policy + // has changed and should be updated locally. + // Write all the user set fields hash.Write([]byte(p.Name)) hash.Write([]byte(p.Description)) @@ -414,7 +460,7 @@ func (p *ACLPolicy) EstimateSize() int { return size } -// ACLPolicyListHash returns a consistent hash for a set of policies. +// HashKey returns a consistent hash for a set of policies. func (policies ACLPolicies) HashKey() string { cacheKeyHash, err := blake2b.New256(nil) if err != nil { diff --git a/agent/structs/acl_test.go b/agent/structs/acl_test.go index fba38545bf..2e7e9edcdb 100644 --- a/agent/structs/acl_test.go +++ b/agent/structs/acl_test.go @@ -208,7 +208,7 @@ func TestStructs_ACLToken_EstimateSize(t *testing.T) { // this test is very contrived. Basically just tests that the // math is okay and returns the value. - require.Equal(t, 120, token.EstimateSize()) + require.Equal(t, 128, token.EstimateSize()) } func TestStructs_ACLToken_Stub(t *testing.T) { diff --git a/api/acl.go b/api/acl.go index 53a052363e..667d215482 100644 --- a/api/acl.go +++ b/api/acl.go @@ -22,15 +22,17 @@ type ACLTokenPolicyLink struct { // ACLToken represents an ACL Token type ACLToken struct { - CreateIndex uint64 - ModifyIndex uint64 - AccessorID string - SecretID string - Description string - Policies []*ACLTokenPolicyLink - Local bool - CreateTime time.Time `json:",omitempty"` - Hash []byte `json:",omitempty"` + CreateIndex uint64 + ModifyIndex uint64 + AccessorID string + SecretID string + Description string + Policies []*ACLTokenPolicyLink + Local bool + ExpirationTTL time.Duration `json:",omitempty"` + ExpirationTime time.Time `json:",omitempty"` + CreateTime time.Time `json:",omitempty"` + Hash []byte `json:",omitempty"` // DEPRECATED (ACL-Legacy-Compat) // Rules will only be present for legacy tokens returned via the new APIs @@ -38,15 +40,16 @@ type ACLToken struct { } type ACLTokenListEntry struct { - CreateIndex uint64 - ModifyIndex uint64 - AccessorID string - Description string - Policies []*ACLTokenPolicyLink - Local bool - CreateTime time.Time - Hash []byte - Legacy bool + CreateIndex uint64 + ModifyIndex uint64 + AccessorID string + Description string + Policies []*ACLTokenPolicyLink + Local bool + ExpirationTime time.Time `json:",omitempty"` + CreateTime time.Time + Hash []byte + Legacy bool } // ACLEntry is used to represent a legacy ACL token diff --git a/command/acl/acl_helpers.go b/command/acl/acl_helpers.go index 96d8ec57c9..1b9f9cee74 100644 --- a/command/acl/acl_helpers.go +++ b/command/acl/acl_helpers.go @@ -10,15 +10,18 @@ import ( ) func PrintToken(token *api.ACLToken, ui cli.Ui, showMeta bool) { - ui.Info(fmt.Sprintf("AccessorID: %s", token.AccessorID)) - ui.Info(fmt.Sprintf("SecretID: %s", token.SecretID)) - ui.Info(fmt.Sprintf("Description: %s", token.Description)) - ui.Info(fmt.Sprintf("Local: %t", token.Local)) - ui.Info(fmt.Sprintf("Create Time: %v", token.CreateTime)) + ui.Info(fmt.Sprintf("AccessorID: %s", token.AccessorID)) + ui.Info(fmt.Sprintf("SecretID: %s", token.SecretID)) + ui.Info(fmt.Sprintf("Description: %s", token.Description)) + ui.Info(fmt.Sprintf("Local: %t", token.Local)) + ui.Info(fmt.Sprintf("Create Time: %v", token.CreateTime)) + if !token.ExpirationTime.IsZero() { + ui.Info(fmt.Sprintf("Expiration Time: %v", token.ExpirationTime)) + } if showMeta { - ui.Info(fmt.Sprintf("Hash: %x", token.Hash)) - ui.Info(fmt.Sprintf("Create Index: %d", token.CreateIndex)) - ui.Info(fmt.Sprintf("Modify Index: %d", token.ModifyIndex)) + ui.Info(fmt.Sprintf("Hash: %x", token.Hash)) + ui.Info(fmt.Sprintf("Create Index: %d", token.CreateIndex)) + ui.Info(fmt.Sprintf("Modify Index: %d", token.ModifyIndex)) } ui.Info(fmt.Sprintf("Policies:")) for _, policy := range token.Policies { @@ -31,15 +34,18 @@ func PrintToken(token *api.ACLToken, ui cli.Ui, showMeta bool) { } func PrintTokenListEntry(token *api.ACLTokenListEntry, ui cli.Ui, showMeta bool) { - ui.Info(fmt.Sprintf("AccessorID: %s", token.AccessorID)) - ui.Info(fmt.Sprintf("Description: %s", token.Description)) - ui.Info(fmt.Sprintf("Local: %t", token.Local)) - ui.Info(fmt.Sprintf("Create Time: %v", token.CreateTime)) - ui.Info(fmt.Sprintf("Legacy: %t", token.Legacy)) + ui.Info(fmt.Sprintf("AccessorID: %s", token.AccessorID)) + ui.Info(fmt.Sprintf("Description: %s", token.Description)) + ui.Info(fmt.Sprintf("Local: %t", token.Local)) + ui.Info(fmt.Sprintf("Create Time: %v", token.CreateTime)) + if !token.ExpirationTime.IsZero() { + ui.Info(fmt.Sprintf("Expiration Time: %v", token.ExpirationTime)) + } + ui.Info(fmt.Sprintf("Legacy: %t", token.Legacy)) if showMeta { - ui.Info(fmt.Sprintf("Hash: %x", token.Hash)) - ui.Info(fmt.Sprintf("Create Index: %d", token.CreateIndex)) - ui.Info(fmt.Sprintf("Modify Index: %d", token.ModifyIndex)) + ui.Info(fmt.Sprintf("Hash: %x", token.Hash)) + ui.Info(fmt.Sprintf("Create Index: %d", token.CreateIndex)) + ui.Info(fmt.Sprintf("Modify Index: %d", token.ModifyIndex)) } ui.Info(fmt.Sprintf("Policies:")) for _, policy := range token.Policies { diff --git a/command/acl/token/clone/token_clone_test.go b/command/acl/token/clone/token_clone_test.go index 0baa0c4f2e..55a81a682d 100644 --- a/command/acl/token/clone/token_clone_test.go +++ b/command/acl/token/clone/token_clone_test.go @@ -19,11 +19,11 @@ import ( func parseCloneOutput(t *testing.T, output string) *api.ACLToken { // This will only work for non-legacy tokens re := regexp.MustCompile("Token cloned successfully.\n" + - "AccessorID: ([a-zA-Z0-9\\-]{36})\n" + - "SecretID: ([a-zA-Z0-9\\-]{36})\n" + - "Description: ([^\n]*)\n" + - "Local: (true|false)\n" + - "Create Time: ([^\n]+)\n" + + "AccessorID: ([a-zA-Z0-9\\-]{36})\n" + + "SecretID: ([a-zA-Z0-9\\-]{36})\n" + + "Description: ([^\n]*)\n" + + "Local: (true|false)\n" + + "Create Time: ([^\n]+)\n" + "Policies:\n" + "( [a-zA-Z0-9\\-]{36} - [^\n]+\n)*") diff --git a/command/acl/token/create/token_create.go b/command/acl/token/create/token_create.go index 83a17cd612..624ec8e647 100644 --- a/command/acl/token/create/token_create.go +++ b/command/acl/token/create/token_create.go @@ -3,6 +3,7 @@ package tokencreate import ( "flag" "fmt" + "time" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/command/acl" @@ -22,11 +23,12 @@ type cmd struct { http *flags.HTTPFlags help string - policyIDs []string - policyNames []string - description string - local bool - showMeta bool + policyIDs []string + policyNames []string + expirationTTL time.Duration + description string + local bool + showMeta bool } func (c *cmd) init() { @@ -39,6 +41,8 @@ func (c *cmd) init() { "policy to use for this token. May be specified multiple times") c.flags.Var((*flags.AppendSliceValue)(&c.policyNames), "policy-name", "Name of a "+ "policy to use for this token. May be specified multiple times") + c.flags.DurationVar(&c.expirationTTL, "expires-ttl", 0, "Duration of time this "+ + "token should be valid for") c.http = &flags.HTTPFlags{} flags.Merge(c.flags, c.http.ClientFlags()) flags.Merge(c.flags, c.http.ServerFlags()) @@ -65,6 +69,9 @@ func (c *cmd) Run(args []string) int { Description: c.description, Local: c.local, } + if c.expirationTTL > 0 { + newToken.ExpirationTTL = c.expirationTTL + } for _, policyName := range c.policyNames { // We could resolve names to IDs here but there isn't any reason why its would be better @@ -109,7 +116,7 @@ Usage: consul acl token create [options] Create a new token: - $ consul acl token create -description "Replication token" - -policy-id b52fc3de-5 - -policy-name "acl-replication" + $ consul acl token create -description "Replication token" \ + -policy-id b52fc3de-5 \ + -policy-name "acl-replication" ` diff --git a/command/acl/token/update/token_update.go b/command/acl/token/update/token_update.go index 09df170b5a..0849808796 100644 --- a/command/acl/token/update/token_update.go +++ b/command/acl/token/update/token_update.go @@ -192,7 +192,9 @@ Usage: consul acl token update [options] $ consul acl token update -id abcd -description "replication" -merge-policies - Update all editable fields of the token: + Update all editable fields of the token: - $ consul acl token update -id abcd -description "replication" -policy-name "token-replication" + $ consul acl token update -id abcd \ + -description "replication" \ + -policy-name "token-replication" ` From db43fc3a20afdf9924164c3c39e9ce99c640f116 Mon Sep 17 00:00:00 2001 From: "R.B. Boyer" Date: Mon, 8 Apr 2019 13:19:09 -0500 Subject: [PATCH 2/5] acl: ACL Tokens can now be assigned an optional set of service identities (#5390) These act like a special cased version of a Policy Template for granting a token the privileges necessary to register a service and its connect proxy, and read upstreams from the catalog. --- agent/consul/acl.go | 104 ++++++++++++++- agent/consul/acl_endpoint.go | 44 +++++-- agent/consul/acl_endpoint_test.go | 118 +++++++++++++++++ agent/consul/acl_test.go | 89 +++++++++++++ agent/consul/leader.go | 4 +- agent/consul/state/acl.go | 6 + agent/consul/state/acl_test.go | 51 +++++++ agent/structs/acl.go | 161 +++++++++++++++++++---- agent/structs/acl_legacy.go | 17 +-- agent/structs/acl_test.go | 63 +++++++++ api/acl.go | 52 +++++--- command/acl/acl_helpers.go | 37 ++++++ command/acl/token/create/token_create.go | 19 ++- command/acl/token/update/token_update.go | 47 ++++++- 14 files changed, 730 insertions(+), 82 deletions(-) diff --git a/agent/consul/acl.go b/agent/consul/acl.go index 4f1061e5d1..2e89eba2f2 100644 --- a/agent/consul/acl.go +++ b/agent/consul/acl.go @@ -4,10 +4,11 @@ import ( "fmt" "log" "os" + "sort" "sync" "time" - "github.com/armon/go-metrics" + metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" @@ -133,7 +134,7 @@ type ACLResolverConfig struct { // - Resolving policies remotely via an ACL.PolicyResolve RPC // // Remote Resolution: -// Remote resolution can be done syncrhonously or asynchronously depending +// Remote resolution can be done synchronously or asynchronously depending // on the ACLDownPolicy in the Config passed to the resolver. // // When the down policy is set to async-cache and we have already cached values @@ -503,7 +504,9 @@ func (r *ACLResolver) filterPoliciesByScope(policies structs.ACLPolicies) struct func (r *ACLResolver) resolvePoliciesForIdentity(identity structs.ACLIdentity) (structs.ACLPolicies, error) { policyIDs := identity.PolicyIDs() - if len(policyIDs) == 0 { + serviceIdentities := identity.ServiceIdentityList() + + if len(policyIDs) == 0 && len(serviceIdentities) == 0 { policy := identity.EmbeddedPolicy() if policy != nil { return []*structs.ACLPolicy{policy}, nil @@ -513,9 +516,96 @@ func (r *ACLResolver) resolvePoliciesForIdentity(identity structs.ACLIdentity) ( return nil, nil } + syntheticPolicies := r.synthesizePoliciesForServiceIdentities(serviceIdentities) + // For the new ACLs policy replication is mandatory for correct operation on servers. Therefore // we only attempt to resolve policies locally - policies := make([]*structs.ACLPolicy, 0, len(policyIDs)) + policies, err := r.collectPoliciesForIdentity(identity, policyIDs, len(syntheticPolicies)) + if err != nil { + return nil, err + } + + policies = append(policies, syntheticPolicies...) + filtered := r.filterPoliciesByScope(policies) + return filtered, nil +} + +func (r *ACLResolver) synthesizePoliciesForServiceIdentities(serviceIdentities []*structs.ACLServiceIdentity) []*structs.ACLPolicy { + if len(serviceIdentities) == 0 { + return nil + } + + // Collect and dedupe service identities. Prefer increasing datacenter scope. + serviceIdentities = dedupeServiceIdentities(serviceIdentities) + + syntheticPolicies := make([]*structs.ACLPolicy, 0, len(serviceIdentities)) + for _, s := range serviceIdentities { + syntheticPolicies = append(syntheticPolicies, s.SyntheticPolicy()) + } + + return syntheticPolicies +} + +func dedupeServiceIdentities(in []*structs.ACLServiceIdentity) []*structs.ACLServiceIdentity { + // From: https://github.com/golang/go/wiki/SliceTricks#in-place-deduplicate-comparable + + if len(in) <= 1 { + return in + } + + sort.Slice(in, func(i, j int) bool { + return in[i].ServiceName < in[j].ServiceName + }) + + j := 0 + for i := 1; i < len(in); i++ { + if in[j].ServiceName == in[i].ServiceName { + // Prefer increasing scope. + if len(in[j].Datacenters) == 0 || len(in[i].Datacenters) == 0 { + in[j].Datacenters = nil + } else { + in[j].Datacenters = mergeStringSlice(in[j].Datacenters, in[i].Datacenters) + } + continue + } + j++ + in[j] = in[i] + } + + // Discard the skipped items. + for i := j + 1; i < len(in); i++ { + in[i] = nil + } + + return in[:j+1] +} + +func mergeStringSlice(a, b []string) []string { + out := make([]string, 0, len(a)+len(b)) + out = append(out, a...) + out = append(out, b...) + return dedupeStringSlice(out) +} + +func dedupeStringSlice(in []string) []string { + // From: https://github.com/golang/go/wiki/SliceTricks#in-place-deduplicate-comparable + + sort.Strings(in) + + j := 0 + for i := 1; i < len(in); i++ { + if in[j] == in[i] { + continue + } + j++ + in[j] = in[i] + } + + return in[:j+1] +} + +func (r *ACLResolver) collectPoliciesForIdentity(identity structs.ACLIdentity, policyIDs []string, extraCap int) ([]*structs.ACLPolicy, error) { + policies := make([]*structs.ACLPolicy, 0, len(policyIDs)+extraCap) // Get all associated policies var missing []string @@ -559,7 +649,7 @@ func (r *ACLResolver) resolvePoliciesForIdentity(identity structs.ACLIdentity) ( // Hot-path if we have no missing or expired policies if len(missing)+len(expired) == 0 { - return r.filterPoliciesByScope(policies), nil + return policies, nil } hasMissing := len(missing) > 0 @@ -579,7 +669,7 @@ func (r *ACLResolver) resolvePoliciesForIdentity(identity structs.ACLIdentity) ( if !waitForResult { // waitForResult being false requires that all the policies were cached already policies = append(policies, expired...) - return r.filterPoliciesByScope(policies), nil + return policies, nil } res := <-waitChan @@ -596,7 +686,7 @@ func (r *ACLResolver) resolvePoliciesForIdentity(identity structs.ACLIdentity) ( } } - return r.filterPoliciesByScope(policies), nil + return policies, nil } func (r *ACLResolver) resolveTokenToPolicies(token string) (structs.ACLPolicies, error) { diff --git a/agent/consul/acl_endpoint.go b/agent/consul/acl_endpoint.go index c19ff29764..65be9b9e8a 100644 --- a/agent/consul/acl_endpoint.go +++ b/agent/consul/acl_endpoint.go @@ -8,13 +8,13 @@ import ( "regexp" "time" - "github.com/armon/go-metrics" + metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/lib" - "github.com/hashicorp/go-memdb" - "github.com/hashicorp/go-uuid" + memdb "github.com/hashicorp/go-memdb" + uuid "github.com/hashicorp/go-uuid" ) const ( @@ -24,7 +24,11 @@ const ( ) // Regex for matching -var validPolicyName = regexp.MustCompile(`^[A-Za-z0-9\-_]{1,128}$`) +var ( + validPolicyName = regexp.MustCompile(`^[A-Za-z0-9\-_]{1,128}$`) + validServiceIdentityName = regexp.MustCompile(`^[a-z0-9]([a-z0-9\-_]*[a-z0-9])?$`) + serviceIdentityNameMaxLength = 256 +) // ACL endpoint is used to manipulate ACLs type ACL struct { @@ -275,10 +279,11 @@ func (a *ACL) TokenClone(args *structs.ACLTokenSetRequest, reply *structs.ACLTok cloneReq := structs.ACLTokenSetRequest{ Datacenter: args.Datacenter, ACLToken: structs.ACLToken{ - Policies: token.Policies, - Local: token.Local, - Description: token.Description, - ExpirationTime: token.ExpirationTime, + Policies: token.Policies, + ServiceIdentities: token.ServiceIdentities, + Local: token.Local, + Description: token.Description, + ExpirationTime: token.ExpirationTime, }, WriteRequest: args.WriteRequest, } @@ -450,6 +455,18 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. } token.Policies = policies + for _, svcid := range token.ServiceIdentities { + if svcid.ServiceName == "" { + return fmt.Errorf("Service identity is missing the service name field on this token") + } + if token.Local && len(svcid.Datacenters) > 0 { + return fmt.Errorf("Service identity %q cannot specify a list of datacenters on a local token", svcid.ServiceName) + } + if !isValidServiceIdentityName(svcid.ServiceName) { + return fmt.Errorf("Service identity %q has an invalid name. Only alphanumeric characters, '-' and '_' are allowed", svcid.ServiceName) + } + } + if token.Rules != "" { return fmt.Errorf("Rules cannot be specified for this token") } @@ -487,6 +504,17 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. return nil } +// isValidServiceIdentityName returns true if the provided name can be used as +// an ACLServiceIdentity ServiceName. This is more restrictive than standard +// catalog registration, which basically takes the view that "everything is +// valid". +func isValidServiceIdentityName(name string) bool { + if len(name) < 1 || len(name) > serviceIdentityNameMaxLength { + return false + } + return validServiceIdentityName.MatchString(name) +} + func (a *ACL) TokenDelete(args *structs.ACLTokenDeleteRequest, reply *string) error { if err := a.aclPreCheck(); err != nil { return err diff --git a/agent/consul/acl_endpoint_test.go b/agent/consul/acl_endpoint_test.go index 0ad3c5d3bb..f2d17fc49a 100644 --- a/agent/consul/acl_endpoint_test.go +++ b/agent/consul/acl_endpoint_test.go @@ -919,6 +919,124 @@ func TestACLEndpoint_TokenSet(t *testing.T) { require.Len(t, token.Policies, 0) }) + t.Run("Create it with invalid service identity (empty)", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: ""}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "Service identity is missing the service name field") + }) + + t.Run("Create it with invalid service identity (too large)", func(t *testing.T) { + long := strings.Repeat("x", serviceIdentityNameMaxLength+1) + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: long}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NotNil(t, err) + }) + + for _, test := range []struct { + name string + ok bool + }{ + {"-abc", false}, + {"abc-", false}, + {"a-bc", true}, + {"_abc", false}, + {"abc_", false}, + {"a_bc", true}, + {":abc", false}, + {"abc:", false}, + {"a:bc", false}, + {"Abc", false}, + {"aBc", false}, + {"abC", false}, + {"0abc", true}, + {"abc0", true}, + {"a0bc", true}, + } { + var testName string + if test.ok { + testName = "Create it with valid service identity (by regex): " + test.name + } else { + testName = "Create it with invalid service identity (by regex): " + test.name + } + t.Run(testName, func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: test.name}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + if test.ok { + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + require.ElementsMatch(t, req.ACLToken.ServiceIdentities, token.ServiceIdentities) + } else { + require.NotNil(t, err) + } + }) + } + + t.Run("Create it with invalid service identity (datacenters set on local token)", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: true, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: "foo", Datacenters: []string{"dc2"}}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "cannot specify a list of datacenters on a local token") + }) + for _, test := range []struct { name string offset time.Duration diff --git a/agent/consul/acl_test.go b/agent/consul/acl_test.go index 38b723cbdb..9ff773ca41 100644 --- a/agent/consul/acl_test.go +++ b/agent/consul/acl_test.go @@ -2861,3 +2861,92 @@ service "service" { t.Fatalf("err: %v", err) } } + +func TestDedupeServiceIdentities(t *testing.T) { + srvid := func(name string, datacenters ...string) *structs.ACLServiceIdentity { + return &structs.ACLServiceIdentity{ + ServiceName: name, + Datacenters: datacenters, + } + } + + tests := []struct { + name string + in []*structs.ACLServiceIdentity + expect []*structs.ACLServiceIdentity + }{ + { + name: "empty", + in: nil, + expect: nil, + }, + { + name: "one", + in: []*structs.ACLServiceIdentity{ + srvid("foo"), + }, + expect: []*structs.ACLServiceIdentity{ + srvid("foo"), + }, + }, + { + name: "just names", + in: []*structs.ACLServiceIdentity{ + srvid("fooZ"), + srvid("fooA"), + srvid("fooY"), + srvid("fooB"), + }, + expect: []*structs.ACLServiceIdentity{ + srvid("fooA"), + srvid("fooB"), + srvid("fooY"), + srvid("fooZ"), + }, + }, + { + name: "just names with dupes", + in: []*structs.ACLServiceIdentity{ + srvid("fooZ"), + srvid("fooA"), + srvid("fooY"), + srvid("fooB"), + srvid("fooA"), + srvid("fooB"), + srvid("fooY"), + srvid("fooZ"), + }, + expect: []*structs.ACLServiceIdentity{ + srvid("fooA"), + srvid("fooB"), + srvid("fooY"), + srvid("fooZ"), + }, + }, + { + name: "names with dupes and datacenters", + in: []*structs.ACLServiceIdentity{ + srvid("fooZ", "dc2", "dc4"), + srvid("fooA"), + srvid("fooY", "dc1"), + srvid("fooB"), + srvid("fooA", "dc9", "dc8"), + srvid("fooB"), + srvid("fooY", "dc1"), + srvid("fooZ", "dc3", "dc4"), + }, + expect: []*structs.ACLServiceIdentity{ + srvid("fooA"), + srvid("fooB"), + srvid("fooY", "dc1"), + srvid("fooZ", "dc2", "dc3", "dc4"), + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := dedupeServiceIdentities(test.in) + require.ElementsMatch(t, test.expect, got) + }) + } +} diff --git a/agent/consul/leader.go b/agent/consul/leader.go index d42fddd798..fe58c3ad7e 100644 --- a/agent/consul/leader.go +++ b/agent/consul/leader.go @@ -658,7 +658,9 @@ func (s *Server) startACLUpgrade() { } // Assign the global-management policy to legacy management tokens - if len(newToken.Policies) == 0 && newToken.Type == structs.ACLTokenTypeManagement { + if len(newToken.Policies) == 0 && + len(newToken.ServiceIdentities) == 0 && + newToken.Type == structs.ACLTokenTypeManagement { newToken.Policies = append(newToken.Policies, structs.ACLTokenPolicyLink{ID: structs.ACLPolicyGlobalManagementID}) } diff --git a/agent/consul/state/acl.go b/agent/consul/state/acl.go index 36c0a03e46..abeb1f7e22 100644 --- a/agent/consul/state/acl.go +++ b/agent/consul/state/acl.go @@ -478,6 +478,12 @@ func (s *Store) aclTokenSetTxn(tx *memdb.Txn, idx uint64, token *structs.ACLToke return err } + for _, svcid := range token.ServiceIdentities { + if svcid.ServiceName == "" { + return fmt.Errorf("Encountered a Token with an empty service identity name in the state store") + } + } + // Set the indexes if original != nil { if original.AccessorID != "" && token.AccessorID != original.AccessorID { diff --git a/agent/consul/state/acl_test.go b/agent/consul/state/acl_test.go index 0e01fbcfe3..9ec57674bc 100644 --- a/agent/consul/state/acl_test.go +++ b/agent/consul/state/acl_test.go @@ -277,6 +277,38 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { require.Equal(t, ErrMissingACLTokenAccessor, err) }) + t.Run("Missing Service Identity Fields", func(t *testing.T) { + t.Parallel() + s := testACLTokensStateStore(t) + token := &structs.ACLToken{ + AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", + SecretID: "39171632-6f34-4411-827f-9416403687f4", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{}, + }, + } + + err := s.ACLTokenSet(2, token, false) + require.Error(t, err) + }) + + t.Run("Missing Service Identity Name", func(t *testing.T) { + t.Parallel() + s := testACLTokensStateStore(t) + token := &structs.ACLToken{ + AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", + SecretID: "39171632-6f34-4411-827f-9416403687f4", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + Datacenters: []string{"dc1"}, + }, + }, + } + + err := s.ACLTokenSet(2, token, false) + require.Error(t, err) + }) + t.Run("Missing Policy ID", func(t *testing.T) { t.Parallel() s := testACLTokensStateStore(t) @@ -322,6 +354,11 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", }, }, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "web", + }, + }, } require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) @@ -334,6 +371,8 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { require.Equal(t, uint64(2), rtoken.ModifyIndex) require.Len(t, rtoken.Policies, 1) require.Equal(t, "node-read", rtoken.Policies[0].Name) + require.Len(t, rtoken.ServiceIdentities, 1) + require.Equal(t, "web", rtoken.ServiceIdentities[0].ServiceName) }) t.Run("Update", func(t *testing.T) { @@ -347,6 +386,11 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", }, }, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "web", + }, + }, } require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) @@ -359,6 +403,11 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { ID: structs.ACLPolicyGlobalManagementID, }, }, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "db", + }, + }, } require.NoError(t, s.ACLTokenSet(3, updated.Clone(), false)) @@ -372,6 +421,8 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { require.Len(t, rtoken.Policies, 1) require.Equal(t, structs.ACLPolicyGlobalManagementID, rtoken.Policies[0].ID) require.Equal(t, "global-management", rtoken.Policies[0].Name) + require.Len(t, rtoken.ServiceIdentities, 1) + require.Equal(t, "db", rtoken.ServiceIdentities[0].ServiceName) }) } diff --git a/agent/structs/acl.go b/agent/structs/acl.go index 84e1793661..fee2cd85ea 100644 --- a/agent/structs/acl.go +++ b/agent/structs/acl.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "errors" "fmt" + "hash" "hash/fnv" "sort" "strings" @@ -84,6 +85,22 @@ session_prefix "" { // This is the policy ID for anonymous access. This is configurable by the // user. ACLTokenAnonymousID = "00000000-0000-0000-0000-000000000002" + + // aclPolicyTemplateServiceIdentity is the template used for synthesizing + // policies for service identities. + aclPolicyTemplateServiceIdentity = ` +service "%s" { + policy = "write" +} +service "%s-sidecar-proxy" { + policy = "write" +} +service_prefix "" { + policy = "read" +} +node_prefix "" { + policy = "read" +}` ) func ACLIDReserved(id string) bool { @@ -113,6 +130,7 @@ type ACLIdentity interface { SecretToken() string PolicyIDs() []string EmbeddedPolicy() *ACLPolicy + ServiceIdentityList() []*ACLServiceIdentity IsExpired(asOf time.Time) bool } @@ -121,6 +139,49 @@ type ACLTokenPolicyLink struct { Name string `hash:"ignore"` } +// ACLServiceIdentity represents a high-level grant of all necessary privileges +// to assume the identity of the named Service in the Catalog and within +// Connect. +type ACLServiceIdentity struct { + ServiceName string + + // Datacenters that the synthetic policy will be valid within. + // - No wildcards allowed + // - If empty then the synthetic policy is valid within all datacenters + // + // Only valid for global tokens. It is an error to specify this for local tokens. + Datacenters []string `json:",omitempty"` +} + +func (s *ACLServiceIdentity) Clone() *ACLServiceIdentity { + s2 := *s + s2.Datacenters = cloneStringSlice(s.Datacenters) + return &s2 +} + +func (s *ACLServiceIdentity) AddToHash(h hash.Hash) { + h.Write([]byte(s.ServiceName)) + for _, dc := range s.Datacenters { + h.Write([]byte(dc)) + } +} + +func (s *ACLServiceIdentity) SyntheticPolicy() *ACLPolicy { + // Given that we validate this string name before persisting, we do not + // have to escape it before doing the following interpolation. + rules := fmt.Sprintf(aclPolicyTemplateServiceIdentity, s.ServiceName, s.ServiceName) + + hasher := fnv.New128a() + policy := &ACLPolicy{} + policy.ID = fmt.Sprintf("%x", hasher.Sum([]byte(rules))) + policy.Name = fmt.Sprintf("synthetic-policy-%s", policy.ID) + policy.Rules = rules + policy.Syntax = acl.SyntaxCurrent + policy.Datacenters = s.Datacenters + policy.SetHash(true) + return policy +} + type ACLToken struct { // This is the UUID used for tracking and management purposes AccessorID string @@ -131,10 +192,13 @@ type ACLToken struct { // Human readable string to display for the token (Optional) Description string - // List of policy links - nil/empty for legacy tokens + // List of policy links - nil/empty for legacy tokens or if service identities are in use. // Note this is the list of IDs and not the names. Prior to token creation // the list of policy names gets validated and the policy IDs get stored herein - Policies []ACLTokenPolicyLink + Policies []ACLTokenPolicyLink `json:",omitempty"` + + // List of services to generate synthetic policies for. + ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` // Type is the V1 Token Type // DEPRECATED (ACL-Legacy-Compat) - remove once we no longer support v1 ACL compat @@ -181,11 +245,18 @@ type ACLToken struct { func (t *ACLToken) Clone() *ACLToken { t2 := *t t2.Policies = nil + t2.ServiceIdentities = nil if len(t.Policies) > 0 { t2.Policies = make([]ACLTokenPolicyLink, len(t.Policies)) copy(t2.Policies, t.Policies) } + if len(t.ServiceIdentities) > 0 { + t2.ServiceIdentities = make([]*ACLServiceIdentity, len(t.ServiceIdentities)) + for i, s := range t.ServiceIdentities { + t2.ServiceIdentities[i] = s.Clone() + } + } return &t2 } @@ -198,13 +269,29 @@ func (t *ACLToken) SecretToken() string { } func (t *ACLToken) PolicyIDs() []string { - var ids []string + if len(t.Policies) == 0 { + return nil + } + + ids := make([]string, 0, len(t.Policies)) for _, link := range t.Policies { ids = append(ids, link.ID) } return ids } +func (t *ACLToken) ServiceIdentityList() []*ACLServiceIdentity { + if len(t.ServiceIdentities) == 0 { + return nil + } + + out := make([]*ACLServiceIdentity, 0, len(t.ServiceIdentities)) + for _, s := range t.ServiceIdentities { + out = append(out, s.Clone()) + } + return out +} + func (t *ACLToken) IsExpired(asOf time.Time) bool { if asOf.IsZero() || t.ExpirationTime.IsZero() { return false @@ -214,6 +301,7 @@ func (t *ACLToken) IsExpired(asOf time.Time) bool { func (t *ACLToken) UsesNonLegacyFields() bool { return len(t.Policies) > 0 || + len(t.ServiceIdentities) > 0 || t.Type == "" || !t.ExpirationTime.IsZero() || t.ExpirationTTL != 0 @@ -280,6 +368,10 @@ func (t *ACLToken) SetHash(force bool) []byte { hash.Write([]byte(link.ID)) } + for _, srvid := range t.ServiceIdentities { + srvid.AddToHash(hash) + } + // Finalize the hash hashVal := hash.Sum(nil) @@ -295,6 +387,12 @@ func (t *ACLToken) EstimateSize() int { for _, link := range t.Policies { size += len(link.ID) + len(link.Name) } + for _, srvid := range t.ServiceIdentities { + size += len(srvid.ServiceName) + for _, dc := range srvid.Datacenters { + size += len(dc) + } + } return size } @@ -302,32 +400,34 @@ func (t *ACLToken) EstimateSize() int { type ACLTokens []*ACLToken type ACLTokenListStub struct { - AccessorID string - Description string - Policies []ACLTokenPolicyLink - Local bool - ExpirationTime time.Time `json:",omitempty"` - CreateTime time.Time `json:",omitempty"` - Hash []byte - CreateIndex uint64 - ModifyIndex uint64 - Legacy bool `json:",omitempty"` + AccessorID string + Description string + Policies []ACLTokenPolicyLink `json:",omitempty"` + ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` + Local bool + ExpirationTime time.Time `json:",omitempty"` + CreateTime time.Time `json:",omitempty"` + Hash []byte + CreateIndex uint64 + ModifyIndex uint64 + Legacy bool `json:",omitempty"` } type ACLTokenListStubs []*ACLTokenListStub func (token *ACLToken) Stub() *ACLTokenListStub { return &ACLTokenListStub{ - AccessorID: token.AccessorID, - Description: token.Description, - Policies: token.Policies, - Local: token.Local, - ExpirationTime: token.ExpirationTime, - CreateTime: token.CreateTime, - Hash: token.Hash, - CreateIndex: token.CreateIndex, - ModifyIndex: token.ModifyIndex, - Legacy: token.Rules != "", + AccessorID: token.AccessorID, + Description: token.Description, + Policies: token.Policies, + ServiceIdentities: token.ServiceIdentities, + Local: token.Local, + ExpirationTime: token.ExpirationTime, + CreateTime: token.CreateTime, + Hash: token.Hash, + CreateIndex: token.CreateIndex, + ModifyIndex: token.ModifyIndex, + Legacy: token.Rules != "", } } @@ -381,11 +481,7 @@ type ACLPolicy struct { func (p *ACLPolicy) Clone() *ACLPolicy { p2 := *p - p2.Datacenters = nil - if len(p.Datacenters) > 0 { - p2.Datacenters = make([]string, len(p.Datacenters)) - copy(p2.Datacenters, p.Datacenters) - } + p2.Datacenters = cloneStringSlice(p.Datacenters) return &p2 } @@ -765,3 +861,12 @@ type ACLPolicyBatchSetRequest struct { type ACLPolicyBatchDeleteRequest struct { PolicyIDs []string } + +func cloneStringSlice(s []string) []string { + if len(s) == 0 { + return nil + } + out := make([]string, len(s)) + copy(out, s) + return out +} diff --git a/agent/structs/acl_legacy.go b/agent/structs/acl_legacy.go index ebd8ece82b..fa182d8884 100644 --- a/agent/structs/acl_legacy.go +++ b/agent/structs/acl_legacy.go @@ -73,14 +73,15 @@ func (a *ACL) Convert() *ACLToken { } return &ACLToken{ - AccessorID: "", - SecretID: a.ID, - Description: a.Name, - Policies: nil, - Type: a.Type, - Rules: a.Rules, - Local: false, - RaftIndex: a.RaftIndex, + AccessorID: "", + SecretID: a.ID, + Description: a.Name, + Policies: nil, + ServiceIdentities: nil, + Type: a.Type, + Rules: a.Rules, + Local: false, + RaftIndex: a.RaftIndex, } } diff --git a/agent/structs/acl_test.go b/agent/structs/acl_test.go index 2e7e9edcdb..bfc585a55d 100644 --- a/agent/structs/acl_test.go +++ b/agent/structs/acl_test.go @@ -140,6 +140,69 @@ func TestStructs_ACLToken_EmbeddedPolicy(t *testing.T) { }) } +func TestStructs_ACLServiceIdentity_SyntheticPolicy(t *testing.T) { + t.Parallel() + + for _, test := range []struct { + serviceName string + datacenters []string + expectRules string + }{ + {"web", nil, ` +service "web" { + policy = "write" +} +service "web-sidecar-proxy" { + policy = "write" +} +service_prefix "" { + policy = "read" +} +node_prefix "" { + policy = "read" +}`}, + {"companion-cube-99", []string{"dc1", "dc2"}, ` +service "companion-cube-99" { + policy = "write" +} +service "companion-cube-99-sidecar-proxy" { + policy = "write" +} +service_prefix "" { + policy = "read" +} +node_prefix "" { + policy = "read" +}`}, + } { + name := test.serviceName + if len(test.datacenters) > 0 { + name += " [" + strings.Join(test.datacenters, ", ") + "]" + } + t.Run(name, func(t *testing.T) { + svcid := &ACLServiceIdentity{ + ServiceName: test.serviceName, + Datacenters: test.datacenters, + } + + expect := &ACLPolicy{ + Syntax: acl.SyntaxCurrent, + Datacenters: test.datacenters, + Rules: test.expectRules, + } + + got := svcid.SyntheticPolicy() + require.NotEmpty(t, got.ID) + require.Equal(t, got.Name, "synthetic-policy-"+got.ID) + // strip irrelevant fields before equality + got.ID = "" + got.Name = "" + got.Hash = nil + require.Equal(t, expect, got) + }) + } +} + func TestStructs_ACLToken_SetHash(t *testing.T) { t.Parallel() diff --git a/api/acl.go b/api/acl.go index 667d215482..3cf0227621 100644 --- a/api/acl.go +++ b/api/acl.go @@ -22,17 +22,18 @@ type ACLTokenPolicyLink struct { // ACLToken represents an ACL Token type ACLToken struct { - CreateIndex uint64 - ModifyIndex uint64 - AccessorID string - SecretID string - Description string - Policies []*ACLTokenPolicyLink - Local bool - ExpirationTTL time.Duration `json:",omitempty"` - ExpirationTime time.Time `json:",omitempty"` - CreateTime time.Time `json:",omitempty"` - Hash []byte `json:",omitempty"` + CreateIndex uint64 + ModifyIndex uint64 + AccessorID string + SecretID string + Description string + Policies []*ACLTokenPolicyLink `json:",omitempty"` + ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` + Local bool + ExpirationTTL time.Duration `json:",omitempty"` + ExpirationTime time.Time `json:",omitempty"` + CreateTime time.Time `json:",omitempty"` + Hash []byte `json:",omitempty"` // DEPRECATED (ACL-Legacy-Compat) // Rules will only be present for legacy tokens returned via the new APIs @@ -40,16 +41,17 @@ type ACLToken struct { } type ACLTokenListEntry struct { - CreateIndex uint64 - ModifyIndex uint64 - AccessorID string - Description string - Policies []*ACLTokenPolicyLink - Local bool - ExpirationTime time.Time `json:",omitempty"` - CreateTime time.Time - Hash []byte - Legacy bool + CreateIndex uint64 + ModifyIndex uint64 + AccessorID string + Description string + Policies []*ACLTokenPolicyLink `json:",omitempty"` + ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` + Local bool + ExpirationTime time.Time `json:",omitempty"` + CreateTime time.Time + Hash []byte + Legacy bool } // ACLEntry is used to represent a legacy ACL token @@ -75,6 +77,14 @@ type ACLReplicationStatus struct { LastError time.Time } +// ACLServiceIdentity represents a high-level grant of all necessary privileges +// to assume the identity of the named Service in the Catalog and within +// Connect. +type ACLServiceIdentity struct { + ServiceName string + Datacenters []string `json:",omitempty"` +} + // ACLPolicy represents an ACL Policy. type ACLPolicy struct { ID string diff --git a/command/acl/acl_helpers.go b/command/acl/acl_helpers.go index 1b9f9cee74..d1e364acd6 100644 --- a/command/acl/acl_helpers.go +++ b/command/acl/acl_helpers.go @@ -27,6 +27,14 @@ func PrintToken(token *api.ACLToken, ui cli.Ui, showMeta bool) { for _, policy := range token.Policies { ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) } + ui.Info(fmt.Sprintf("Service Identities:")) + for _, svcid := range token.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } + } if token.Rules != "" { ui.Info(fmt.Sprintf("Rules:")) ui.Info(token.Rules) @@ -51,6 +59,14 @@ func PrintTokenListEntry(token *api.ACLTokenListEntry, ui cli.Ui, showMeta bool) for _, policy := range token.Policies { ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) } + ui.Info(fmt.Sprintf("Service Identities:")) + for _, svcid := range token.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } + } } func PrintPolicy(policy *api.ACLPolicy, ui cli.Ui, showMeta bool) { @@ -191,3 +207,24 @@ func GetRulesFromLegacyToken(client *api.Client, tokenID string, isSecret bool) return token.Rules, nil } + +func ExtractServiceIdentities(serviceIdents []string) ([]*api.ACLServiceIdentity, error) { + var out []*api.ACLServiceIdentity + for _, svcidRaw := range serviceIdents { + parts := strings.Split(svcidRaw, ":") + switch len(parts) { + case 2: + out = append(out, &api.ACLServiceIdentity{ + ServiceName: parts[0], + Datacenters: strings.Split(parts[1], ","), + }) + case 1: + out = append(out, &api.ACLServiceIdentity{ + ServiceName: parts[0], + }) + default: + return nil, fmt.Errorf("Malformed -service-identity argument: %q", svcidRaw) + } + } + return out, nil +} diff --git a/command/acl/token/create/token_create.go b/command/acl/token/create/token_create.go index 624ec8e647..0760c18f07 100644 --- a/command/acl/token/create/token_create.go +++ b/command/acl/token/create/token_create.go @@ -25,6 +25,7 @@ type cmd struct { policyIDs []string policyNames []string + serviceIdents []string expirationTTL time.Duration description string local bool @@ -41,6 +42,9 @@ func (c *cmd) init() { "policy to use for this token. May be specified multiple times") c.flags.Var((*flags.AppendSliceValue)(&c.policyNames), "policy-name", "Name of a "+ "policy to use for this token. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.serviceIdents), "service-identity", "Name of a "+ + "service identity to use for this token. May be specified multiple times. Format is "+ + "the SERVICENAME or SERVICENAME:DATACENTER1,DATACENTER2,...") c.flags.DurationVar(&c.expirationTTL, "expires-ttl", 0, "Duration of time this "+ "token should be valid for") c.http = &flags.HTTPFlags{} @@ -54,8 +58,9 @@ func (c *cmd) Run(args []string) int { return 1 } - if len(c.policyNames) == 0 && len(c.policyIDs) == 0 { - c.UI.Error(fmt.Sprintf("Cannot create a token without specifying -policy-name or -policy-id at least once")) + if len(c.policyNames) == 0 && len(c.policyIDs) == 0 && + len(c.serviceIdents) == 0 { + c.UI.Error(fmt.Sprintf("Cannot create a token without specifying -policy-name, -policy-id, or -service-identity at least once")) return 1 } @@ -73,6 +78,13 @@ func (c *cmd) Run(args []string) int { newToken.ExpirationTTL = c.expirationTTL } + parsedServiceIdents, err := acl.ExtractServiceIdentities(c.serviceIdents) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + newToken.ServiceIdentities = parsedServiceIdents + for _, policyName := range c.policyNames { // We could resolve names to IDs here but there isn't any reason why its would be better // than allowing the agent to do it. @@ -119,4 +131,7 @@ Usage: consul acl token create [options] $ consul acl token create -description "Replication token" \ -policy-id b52fc3de-5 \ -policy-name "acl-replication" + -policy-name "acl-replication" \ + -service-identity "web" \ + -service-identity "db:east,west" ` diff --git a/command/acl/token/update/token_update.go b/command/acl/token/update/token_update.go index 0849808796..c2ad084688 100644 --- a/command/acl/token/update/token_update.go +++ b/command/acl/token/update/token_update.go @@ -22,13 +22,15 @@ type cmd struct { http *flags.HTTPFlags help string - tokenID string - policyIDs []string - policyNames []string - description string - mergePolicies bool - showMeta bool - upgradeLegacy bool + tokenID string + policyIDs []string + policyNames []string + serviceIdents []string + description string + mergePolicies bool + mergeServiceIdents bool + showMeta bool + upgradeLegacy bool } func (c *cmd) init() { @@ -37,6 +39,8 @@ func (c *cmd) init() { "as the content hash and raft indices should be shown for each entry") c.flags.BoolVar(&c.mergePolicies, "merge-policies", false, "Merge the new policies "+ "with the existing policies") + c.flags.BoolVar(&c.mergeServiceIdents, "merge-service-identities", false, "Merge the new service identities "+ + "with the existing service identities") c.flags.StringVar(&c.tokenID, "id", "", "The Accessor ID of the token to read. "+ "It may be specified as a unique ID prefix but will error if the prefix "+ "matches multiple token Accessor IDs") @@ -45,6 +49,9 @@ func (c *cmd) init() { "policy to use for this token. May be specified multiple times") c.flags.Var((*flags.AppendSliceValue)(&c.policyNames), "policy-name", "Name of a "+ "policy to use for this token. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.serviceIdents), "service-identity", "Name of a "+ + "service identity to use for this token. May be specified multiple times. Format is "+ + "the SERVICENAME or SERVICENAME:DATACENTER1,DATACENTER2,...") c.flags.BoolVar(&c.upgradeLegacy, "upgrade-legacy", false, "Add new polices "+ "to a legacy token replacing all existing rules. This will cause the legacy "+ "token to behave exactly like a new token but keep the same Secret.\n"+ @@ -107,6 +114,12 @@ func (c *cmd) Run(args []string) int { token.Description = c.description } + parsedServiceIdents, err := acl.ExtractServiceIdentities(c.serviceIdents) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + if c.mergePolicies { for _, policyName := range c.policyNames { found := false @@ -162,6 +175,26 @@ func (c *cmd) Run(args []string) int { } } + if c.mergeServiceIdents { + for _, svcid := range parsedServiceIdents { + found := -1 + for i, link := range token.ServiceIdentities { + if link.ServiceName == svcid.ServiceName { + found = i + break + } + } + + if found != -1 { + token.ServiceIdentities[found] = svcid + } else { + token.ServiceIdentities = append(token.ServiceIdentities, svcid) + } + } + } else { + token.ServiceIdentities = parsedServiceIdents + } + token, _, err = client.ACL().TokenUpdate(token, nil) if err != nil { c.UI.Error(fmt.Sprintf("Failed to update token %s: %v", tokenID, err)) From 7928305279788def909b2f7b54a9d394774a138a Mon Sep 17 00:00:00 2001 From: "R.B. Boyer" Date: Mon, 15 Apr 2019 13:35:55 -0500 Subject: [PATCH 3/5] making ACLToken.ExpirationTime a *time.Time value instead of time.Time (#5663) This is mainly to avoid having the API return "0001-01-01T00:00:00Z" as a value for the ExpirationTime field when it is not set. Unfortunately time.Time doesn't respect the json marshalling "omitempty" directive. --- agent/consul/acl_endpoint.go | 22 +++++++++++++++++----- agent/consul/acl_endpoint_test.go | 28 +++++++++++++++++----------- agent/consul/state/acl.go | 11 +++++------ agent/consul/state/acl_test.go | 2 +- agent/structs/acl.go | 18 +++++++++++++----- api/acl.go | 4 ++-- command/acl/acl_helpers.go | 8 ++++---- 7 files changed, 59 insertions(+), 34 deletions(-) diff --git a/agent/consul/acl_endpoint.go b/agent/consul/acl_endpoint.go index 65be9b9e8a..cd3abecb1a 100644 --- a/agent/consul/acl_endpoint.go +++ b/agent/consul/acl_endpoint.go @@ -358,16 +358,16 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. if token.ExpirationTTL < 0 { return fmt.Errorf("Token Expiration TTL '%s' should be > 0", token.ExpirationTTL) } - if !token.ExpirationTime.IsZero() { + if token.HasExpirationTime() { return fmt.Errorf("Token Expiration TTL and Expiration Time cannot both be set") } - token.ExpirationTime = token.CreateTime.Add(token.ExpirationTTL) + token.ExpirationTime = timePointer(token.CreateTime.Add(token.ExpirationTTL)) token.ExpirationTTL = 0 } - if !token.ExpirationTime.IsZero() { - if token.CreateTime.After(token.ExpirationTime) { + if token.HasExpirationTime() { + if token.CreateTime.After(*token.ExpirationTime) { return fmt.Errorf("ExpirationTime cannot be before CreateTime") } @@ -417,7 +417,15 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. return fmt.Errorf("cannot toggle local mode of %s", token.AccessorID) } - if token.ExpirationTTL != 0 || !token.ExpirationTime.Equal(existing.ExpirationTime) { + if token.ExpirationTTL != 0 { + return fmt.Errorf("Cannot change expiration time of %s", token.AccessorID) + } + + if !token.HasExpirationTime() { + token.ExpirationTime = existing.ExpirationTime + } else if !existing.HasExpirationTime() { + return fmt.Errorf("Cannot change expiration time of %s", token.AccessorID) + } else if !token.ExpirationTime.Equal(*existing.ExpirationTime) { return fmt.Errorf("Cannot change expiration time of %s", token.AccessorID) } @@ -1041,3 +1049,7 @@ func (a *ACL) ReplicationStatus(args *structs.DCSpecificRequest, a.srv.aclReplicationStatusLock.RUnlock() return nil } + +func timePointer(t time.Time) *time.Time { + return &t +} diff --git a/agent/consul/acl_endpoint_test.go b/agent/consul/acl_endpoint_test.go index f2d17fc49a..49943c9113 100644 --- a/agent/consul/acl_endpoint_test.go +++ b/agent/consul/acl_endpoint_test.go @@ -1054,7 +1054,7 @@ func TestACLEndpoint_TokenSet(t *testing.T) { Description: "foobar", Policies: nil, Local: false, - ExpirationTime: time.Now().Add(test.offset), + ExpirationTime: timePointer(time.Now().Add(test.offset)), }, WriteRequest: structs.WriteRequest{Token: "root"}, } @@ -1099,7 +1099,7 @@ func TestACLEndpoint_TokenSet(t *testing.T) { Description: "foobar", Policies: nil, Local: false, - ExpirationTime: time.Now().Add(4 * time.Second), + ExpirationTime: timePointer(time.Now().Add(4 * time.Second)), ExpirationTTL: 4 * time.Second, }, WriteRequest: structs.WriteRequest{Token: "root"}, @@ -1138,7 +1138,7 @@ func TestACLEndpoint_TokenSet(t *testing.T) { require.NotNil(t, token.AccessorID) require.Equal(t, token.Description, "foobar") require.Equal(t, token.AccessorID, resp.AccessorID) - requireTimeEquals(t, expectExpTime, resp.ExpirationTime) + requireTimeEquals(t, &expectExpTime, resp.ExpirationTime) tokenID = token.AccessorID }) @@ -1152,7 +1152,7 @@ func TestACLEndpoint_TokenSet(t *testing.T) { Description: "foobar", Policies: nil, Local: false, - ExpirationTime: expTime, + ExpirationTime: &expTime, }, WriteRequest: structs.WriteRequest{Token: "root"}, } @@ -1170,7 +1170,7 @@ func TestACLEndpoint_TokenSet(t *testing.T) { require.NotNil(t, token.AccessorID) require.Equal(t, token.Description, "foobar") require.Equal(t, token.AccessorID, resp.AccessorID) - requireTimeEquals(t, expTime, resp.ExpirationTime) + requireTimeEquals(t, &expTime, resp.ExpirationTime) tokenID = token.AccessorID }) @@ -1183,7 +1183,7 @@ func TestACLEndpoint_TokenSet(t *testing.T) { ACLToken: structs.ACLToken{ Description: "new-description", AccessorID: tokenID, - ExpirationTime: expTime.Add(-1 * time.Second), + ExpirationTime: timePointer(expTime.Add(-1 * time.Second)), }, WriteRequest: structs.WriteRequest{Token: "root"}, } @@ -1202,7 +1202,7 @@ func TestACLEndpoint_TokenSet(t *testing.T) { ACLToken: structs.ACLToken{ Description: "new-description", AccessorID: tokenID, - ExpirationTime: expTime, + ExpirationTime: &expTime, }, WriteRequest: structs.WriteRequest{Token: "root"}, } @@ -1220,7 +1220,7 @@ func TestACLEndpoint_TokenSet(t *testing.T) { require.NotNil(t, token.AccessorID) require.Equal(t, token.Description, "new-description") require.Equal(t, token.AccessorID, resp.AccessorID) - requireTimeEquals(t, expTime, resp.ExpirationTime) + requireTimeEquals(t, &expTime, resp.ExpirationTime) }) t.Run("cannot update a token that is past its expiration time", func(t *testing.T) { @@ -2217,10 +2217,16 @@ func retrieveTestPolicy(codec rpc.ClientCodec, masterToken string, datacenter st return &out, nil } -func requireTimeEquals(t *testing.T, expect, got time.Time) { +func requireTimeEquals(t *testing.T, expect, got *time.Time) { t.Helper() - if !expect.Equal(got) { - t.Fatalf("expected=%q != got=%q", expect, got) + if expect == nil && got == nil { + return + } else if expect == nil && got != nil { + t.Fatalf("expected=NIL != got=%q", *got) + } else if expect != nil && got == nil { + t.Fatalf("expected=%q != got=NIL", *expect) + } else if !expect.Equal(*got) { + t.Fatalf("expected=%q != got=%q", *expect, *got) } } diff --git a/agent/consul/state/acl.go b/agent/consul/state/acl.go index abeb1f7e22..1533dcbef7 100644 --- a/agent/consul/state/acl.go +++ b/agent/consul/state/acl.go @@ -79,14 +79,14 @@ func (s *TokenExpirationIndex) FromObject(obj interface{}) (bool, []byte, error) if s.LocalFilter != token.Local { return false, nil, nil } - if token.ExpirationTime.IsZero() { + if !token.HasExpirationTime() { return false, nil, nil } if token.ExpirationTime.Unix() < 0 { return false, nil, fmt.Errorf("token expiration time cannot be before the unix epoch: %s", token.ExpirationTime) } - buf := s.encodeTime(token.ExpirationTime) + buf := s.encodeTime(*token.ExpirationTime) return true, buf, nil } @@ -669,10 +669,10 @@ func (s *Store) ACLTokenMinExpirationTime(local bool) (time.Time, error) { token := item.(*structs.ACLToken) - return token.ExpirationTime, nil + return *token.ExpirationTime, nil } -// ACLTokenListExpires lists tokens that are expires as of the provided time. +// ACLTokenListExpires lists tokens that are expired as of the provided time. // The returned set will be no larger than the max value provided. func (s *Store) ACLTokenListExpired(local bool, asOf time.Time, max int) (structs.ACLTokens, <-chan struct{}, error) { tx := s.db.Txn(false) @@ -689,8 +689,7 @@ func (s *Store) ACLTokenListExpired(local bool, asOf time.Time, max int) (struct ) for raw := iter.Next(); raw != nil; raw = iter.Next() { token := raw.(*structs.ACLToken) - - if !token.ExpirationTime.Before(asOf) { + if token.ExpirationTime != nil && !token.ExpirationTime.Before(asOf) { return tokens, nil, nil } diff --git a/agent/consul/state/acl_test.go b/agent/consul/state/acl_test.go index 9ec57674bc..d39c71d099 100644 --- a/agent/consul/state/acl_test.go +++ b/agent/consul/state/acl_test.go @@ -1741,7 +1741,7 @@ func TestTokenPoliciesIndex(t *testing.T) { SecretID: newUUID(), Description: desc, Local: local, - ExpirationTime: expTime, + ExpirationTime: &expTime, CreateTime: baseTime, RaftIndex: structs.RaftIndex{ CreateIndex: 9, diff --git a/agent/structs/acl.go b/agent/structs/acl.go index fee2cd85ea..50d4df0b52 100644 --- a/agent/structs/acl.go +++ b/agent/structs/acl.go @@ -218,7 +218,11 @@ type ACLToken struct { // ExpirationTime represents the point after which a token should be // considered revoked and is eligible for destruction. The zero value // represents NO expiration. - ExpirationTime time.Time `json:",omitempty"` + // + // This is a pointer value so that the zero value is omitted properly + // during json serialization. time.Time does not respect json omitempty + // directives unfortunately. + ExpirationTime *time.Time `json:",omitempty"` // ExpirationTTL is a convenience field for helping set ExpirationTime to a // value of CreateTime+ExpirationTTL. This can only be set during @@ -293,17 +297,21 @@ func (t *ACLToken) ServiceIdentityList() []*ACLServiceIdentity { } func (t *ACLToken) IsExpired(asOf time.Time) bool { - if asOf.IsZero() || t.ExpirationTime.IsZero() { + if asOf.IsZero() || !t.HasExpirationTime() { return false } return t.ExpirationTime.Before(asOf) } +func (t *ACLToken) HasExpirationTime() bool { + return t.ExpirationTime != nil && !t.ExpirationTime.IsZero() +} + func (t *ACLToken) UsesNonLegacyFields() bool { return len(t.Policies) > 0 || len(t.ServiceIdentities) > 0 || t.Type == "" || - !t.ExpirationTime.IsZero() || + t.HasExpirationTime() || t.ExpirationTTL != 0 } @@ -405,8 +413,8 @@ type ACLTokenListStub struct { Policies []ACLTokenPolicyLink `json:",omitempty"` ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` Local bool - ExpirationTime time.Time `json:",omitempty"` - CreateTime time.Time `json:",omitempty"` + ExpirationTime *time.Time `json:",omitempty"` + CreateTime time.Time `json:",omitempty"` Hash []byte CreateIndex uint64 ModifyIndex uint64 diff --git a/api/acl.go b/api/acl.go index 3cf0227621..e920c46d6e 100644 --- a/api/acl.go +++ b/api/acl.go @@ -31,7 +31,7 @@ type ACLToken struct { ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` Local bool ExpirationTTL time.Duration `json:",omitempty"` - ExpirationTime time.Time `json:",omitempty"` + ExpirationTime *time.Time `json:",omitempty"` CreateTime time.Time `json:",omitempty"` Hash []byte `json:",omitempty"` @@ -48,7 +48,7 @@ type ACLTokenListEntry struct { Policies []*ACLTokenPolicyLink `json:",omitempty"` ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` Local bool - ExpirationTime time.Time `json:",omitempty"` + ExpirationTime *time.Time `json:",omitempty"` CreateTime time.Time Hash []byte Legacy bool diff --git a/command/acl/acl_helpers.go b/command/acl/acl_helpers.go index d1e364acd6..843a91487a 100644 --- a/command/acl/acl_helpers.go +++ b/command/acl/acl_helpers.go @@ -15,8 +15,8 @@ func PrintToken(token *api.ACLToken, ui cli.Ui, showMeta bool) { ui.Info(fmt.Sprintf("Description: %s", token.Description)) ui.Info(fmt.Sprintf("Local: %t", token.Local)) ui.Info(fmt.Sprintf("Create Time: %v", token.CreateTime)) - if !token.ExpirationTime.IsZero() { - ui.Info(fmt.Sprintf("Expiration Time: %v", token.ExpirationTime)) + if token.ExpirationTime != nil && !token.ExpirationTime.IsZero() { + ui.Info(fmt.Sprintf("Expiration Time: %v", *token.ExpirationTime)) } if showMeta { ui.Info(fmt.Sprintf("Hash: %x", token.Hash)) @@ -46,8 +46,8 @@ func PrintTokenListEntry(token *api.ACLTokenListEntry, ui cli.Ui, showMeta bool) ui.Info(fmt.Sprintf("Description: %s", token.Description)) ui.Info(fmt.Sprintf("Local: %t", token.Local)) ui.Info(fmt.Sprintf("Create Time: %v", token.CreateTime)) - if !token.ExpirationTime.IsZero() { - ui.Info(fmt.Sprintf("Expiration Time: %v", token.ExpirationTime)) + if token.ExpirationTime != nil && !token.ExpirationTime.IsZero() { + ui.Info(fmt.Sprintf("Expiration Time: %v", *token.ExpirationTime)) } ui.Info(fmt.Sprintf("Legacy: %t", token.Legacy)) if showMeta { From cc1aa3f97384618a9b82227e7b1757a3f9a169fc Mon Sep 17 00:00:00 2001 From: "R.B. Boyer" Date: Mon, 15 Apr 2019 15:43:19 -0500 Subject: [PATCH 4/5] acl: adding Roles to Tokens (#5514) Roles are named and can express the same bundle of permissions that can currently be assigned to a Token (lists of Policies and Service Identities). The difference with a Role is that it not itself a bearer token, but just another entity that can be tied to a Token. This lets an operator potentially curate a set of smaller reusable Policies and compose them together into reusable Roles, rather than always exploding that same list of Policies on any Token that needs similar permissions. This also refactors the acl replication code to be semi-generic to avoid 3x copypasta. --- agent/acl_endpoint.go | 153 +++ agent/agent.go | 3 + agent/config/builder.go | 1 + agent/config/config.go | 1 + agent/config/runtime.go | 6 + agent/config/runtime_test.go | 4 + agent/consul/acl.go | 257 +++- agent/consul/acl_client.go | 7 + agent/consul/acl_endpoint.go | 361 +++++- agent/consul/acl_endpoint_legacy.go | 2 +- agent/consul/acl_endpoint_test.go | 631 +++++++++- agent/consul/acl_replication.go | 531 ++++---- agent/consul/acl_replication_legacy.go | 2 +- agent/consul/acl_replication_legacy_test.go | 15 +- agent/consul/acl_replication_test.go | 313 ++++- agent/consul/acl_replication_types.go | 370 ++++++ agent/consul/acl_server.go | 28 +- agent/consul/acl_test.go | 775 +++++++++--- agent/consul/config.go | 6 + agent/consul/fsm/commands_oss.go | 24 + agent/consul/fsm/snapshot_oss.go | 23 + agent/consul/helper_test.go | 17 +- agent/consul/leader.go | 80 +- agent/consul/state/acl.go | 591 ++++++++- agent/consul/state/acl_test.go | 1230 ++++++++++++++++++- agent/consul/state/state_store.go | 14 +- agent/http_oss.go | 4 + agent/structs/acl.go | 284 ++++- agent/structs/acl_cache.go | 51 +- agent/structs/acl_cache_test.go | 19 +- agent/structs/acl_test.go | 2 + agent/structs/structs.go | 2 + api/acl.go | 164 +++ api/api.go | 32 +- command/acl/acl_helpers.go | 101 ++ command/acl/role/create/role_create.go | 134 ++ command/acl/role/create/role_create_test.go | 116 ++ command/acl/role/delete/role_delete.go | 91 ++ command/acl/role/delete/role_delete_test.go | 141 +++ command/acl/role/list/role_list.go | 79 ++ command/acl/role/list/role_list_test.go | 83 ++ command/acl/role/read/role_read.go | 115 ++ command/acl/role/read/role_read_test.go | 194 +++ command/acl/role/role.go | 56 + command/acl/role/update/role_update.go | 225 ++++ command/acl/role/update/role_update_test.go | 398 ++++++ command/acl/token/create/token_create.go | 27 +- command/acl/token/update/token_update.go | 67 +- command/commands_oss.go | 12 + 49 files changed, 7193 insertions(+), 649 deletions(-) create mode 100644 agent/consul/acl_replication_types.go create mode 100644 command/acl/role/create/role_create.go create mode 100644 command/acl/role/create/role_create_test.go create mode 100644 command/acl/role/delete/role_delete.go create mode 100644 command/acl/role/delete/role_delete_test.go create mode 100644 command/acl/role/list/role_list.go create mode 100644 command/acl/role/list/role_list_test.go create mode 100644 command/acl/role/read/role_read.go create mode 100644 command/acl/role/read/role_read_test.go create mode 100644 command/acl/role/role.go create mode 100644 command/acl/role/update/role_update.go create mode 100644 command/acl/role/update/role_update_test.go diff --git a/agent/acl_endpoint.go b/agent/acl_endpoint.go index d34690b328..cafe6e11c3 100644 --- a/agent/acl_endpoint.go +++ b/agent/acl_endpoint.go @@ -254,6 +254,7 @@ func (s *HTTPServer) ACLPolicyRead(resp http.ResponseWriter, req *http.Request, } if out.Policy == nil { + // TODO(rb): should this return a normal 404? return nil, acl.ErrNotFound } @@ -374,6 +375,7 @@ func (s *HTTPServer) ACLTokenList(resp http.ResponseWriter, req *http.Request) ( } args.Policy = req.URL.Query().Get("policy") + args.Role = req.URL.Query().Get("role") var out structs.ACLTokenListResponse defer setMeta(resp, &out.QueryMeta) @@ -548,3 +550,154 @@ func (s *HTTPServer) ACLTokenClone(resp http.ResponseWriter, req *http.Request, return &out, nil } + +func (s *HTTPServer) ACLRoleList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var args structs.ACLRoleListRequest + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + args.Policy = req.URL.Query().Get("policy") + + var out structs.ACLRoleListResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.RoleList", &args, &out); err != nil { + return nil, err + } + + // make sure we return an array and not nil + if out.Roles == nil { + out.Roles = make(structs.ACLRoles, 0) + } + + return out.Roles, nil +} + +func (s *HTTPServer) ACLRoleCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var fn func(resp http.ResponseWriter, req *http.Request, roleID string) (interface{}, error) + + switch req.Method { + case "GET": + fn = s.ACLRoleReadByID + + case "PUT": + fn = s.ACLRoleWrite + + case "DELETE": + fn = s.ACLRoleDelete + + default: + return nil, MethodNotAllowedError{req.Method, []string{"GET", "PUT", "DELETE"}} + } + + roleID := strings.TrimPrefix(req.URL.Path, "/v1/acl/role/") + if roleID == "" && req.Method != "PUT" { + return nil, BadRequestError{Reason: "Missing role ID"} + } + + return fn(resp, req, roleID) +} + +func (s *HTTPServer) ACLRoleReadByName(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + roleName := strings.TrimPrefix(req.URL.Path, "/v1/acl/role/name/") + if roleName == "" { + return nil, BadRequestError{Reason: "Missing role Name"} + } + + return s.ACLRoleRead(resp, req, "", roleName) +} + +func (s *HTTPServer) ACLRoleReadByID(resp http.ResponseWriter, req *http.Request, roleID string) (interface{}, error) { + return s.ACLRoleRead(resp, req, roleID, "") +} + +func (s *HTTPServer) ACLRoleRead(resp http.ResponseWriter, req *http.Request, roleID, roleName string) (interface{}, error) { + args := structs.ACLRoleGetRequest{ + Datacenter: s.agent.config.Datacenter, + RoleID: roleID, + RoleName: roleName, + } + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + var out structs.ACLRoleResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.RoleRead", &args, &out); err != nil { + return nil, err + } + + if out.Role == nil { + resp.WriteHeader(http.StatusNotFound) + return nil, nil + } + + return out.Role, nil +} + +func (s *HTTPServer) ACLRoleCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + return s.ACLRoleWrite(resp, req, "") +} + +func (s *HTTPServer) ACLRoleWrite(resp http.ResponseWriter, req *http.Request, roleID string) (interface{}, error) { + args := structs.ACLRoleSetRequest{ + Datacenter: s.agent.config.Datacenter, + } + s.parseToken(req, &args.Token) + + if err := decodeBody(req, &args.Role, fixTimeAndHashFields); err != nil { + return nil, BadRequestError{Reason: fmt.Sprintf("Role decoding failed: %v", err)} + } + + if args.Role.ID != "" && args.Role.ID != roleID { + return nil, BadRequestError{Reason: "Role ID in URL and payload do not match"} + } else if args.Role.ID == "" { + args.Role.ID = roleID + } + + var out structs.ACLRole + if err := s.agent.RPC("ACL.RoleSet", args, &out); err != nil { + return nil, err + } + + return &out, nil +} + +func (s *HTTPServer) ACLRoleDelete(resp http.ResponseWriter, req *http.Request, roleID string) (interface{}, error) { + args := structs.ACLRoleDeleteRequest{ + Datacenter: s.agent.config.Datacenter, + RoleID: roleID, + } + s.parseToken(req, &args.Token) + + var ignored string + if err := s.agent.RPC("ACL.RoleDelete", args, &ignored); err != nil { + return nil, err + } + + return true, nil +} diff --git a/agent/agent.go b/agent/agent.go index 36ca1a2958..c1455a78d8 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1001,6 +1001,9 @@ func (a *Agent) consulConfig() (*consul.Config, error) { if a.config.ACLPolicyTTL != 0 { base.ACLPolicyTTL = a.config.ACLPolicyTTL } + if a.config.ACLRoleTTL != 0 { + base.ACLRoleTTL = a.config.ACLRoleTTL + } if a.config.ACLDefaultPolicy != "" { base.ACLDefaultPolicy = a.config.ACLDefaultPolicy } diff --git a/agent/config/builder.go b/agent/config/builder.go index 5ef7e35044..9cd5956ed0 100644 --- a/agent/config/builder.go +++ b/agent/config/builder.go @@ -702,6 +702,7 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) { ACLReplicationToken: b.stringValWithDefault(c.ACL.Tokens.Replication, b.stringVal(c.ACLReplicationToken)), ACLTokenTTL: b.durationValWithDefault("acl.token_ttl", c.ACL.TokenTTL, b.durationVal("acl_ttl", c.ACLTTL)), ACLPolicyTTL: b.durationVal("acl.policy_ttl", c.ACL.PolicyTTL), + ACLRoleTTL: b.durationVal("acl.role_ttl", c.ACL.RoleTTL), ACLToken: b.stringValWithDefault(c.ACL.Tokens.Default, b.stringVal(c.ACLToken)), ACLTokenReplication: b.boolValWithDefault(c.ACL.TokenReplication, b.boolValWithDefault(c.EnableACLReplication, enableTokenReplication)), ACLEnableTokenPersistence: b.boolValWithDefault(c.ACL.EnableTokenPersistence, false), diff --git a/agent/config/config.go b/agent/config/config.go index 958f53f26c..1e769118d6 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -635,6 +635,7 @@ type ACL struct { Enabled *bool `json:"enabled,omitempty" hcl:"enabled" mapstructure:"enabled"` TokenReplication *bool `json:"enable_token_replication,omitempty" hcl:"enable_token_replication" mapstructure:"enable_token_replication"` PolicyTTL *string `json:"policy_ttl,omitempty" hcl:"policy_ttl" mapstructure:"policy_ttl"` + RoleTTL *string `json:"role_ttl,omitempty" hcl:"role_ttl" mapstructure:"role_ttl"` TokenTTL *string `json:"token_ttl,omitempty" hcl:"token_ttl" mapstructure:"token_ttl"` DownPolicy *string `json:"down_policy,omitempty" hcl:"down_policy" mapstructure:"down_policy"` DefaultPolicy *string `json:"default_policy,omitempty" hcl:"default_policy" mapstructure:"default_policy"` diff --git a/agent/config/runtime.go b/agent/config/runtime.go index 96bbe1e3e4..dc9d567f81 100644 --- a/agent/config/runtime.go +++ b/agent/config/runtime.go @@ -155,6 +155,12 @@ type RuntimeConfig struct { // hcl: acl.token_ttl = "duration" ACLPolicyTTL time.Duration + // ACLRoleTTL is used to control the time-to-live of cached ACL roles. This has + // a major impact on performance. By default, it is set to 30 seconds. + // + // hcl: acl.role_ttl = "duration" + ACLRoleTTL time.Duration + // ACLToken is the default token used to make requests if a per-request // token is not provided. If not configured the 'anonymous' token is used. // diff --git a/agent/config/runtime_test.go b/agent/config/runtime_test.go index 0e22b95d73..18d9a7e354 100644 --- a/agent/config/runtime_test.go +++ b/agent/config/runtime_test.go @@ -2901,6 +2901,7 @@ func TestFullConfig(t *testing.T) { "enable_key_list_policy": false, "enable_token_persistence": true, "policy_ttl": "1123s", + "role_ttl": "9876s", "token_ttl": "3321s", "enable_token_replication" : true, "tokens" : { @@ -3464,6 +3465,7 @@ func TestFullConfig(t *testing.T) { enable_key_list_policy = false enable_token_persistence = true policy_ttl = "1123s" + role_ttl = "9876s" token_ttl = "3321s" enable_token_replication = true tokens = { @@ -4145,6 +4147,7 @@ func TestFullConfig(t *testing.T) { ACLReplicationToken: "5795983a", ACLTokenTTL: 3321 * time.Second, ACLPolicyTTL: 1123 * time.Second, + ACLRoleTTL: 9876 * time.Second, ACLToken: "418fdff1", ACLTokenReplication: true, AdvertiseAddrLAN: ipAddr("17.99.29.16"), @@ -4975,6 +4978,7 @@ func TestSanitize(t *testing.T) { "ACLMasterToken": "hidden", "ACLPolicyTTL": "0s", "ACLReplicationToken": "hidden", + "ACLRoleTTL": "0s", "ACLTokenReplication": false, "ACLTokenTTL": "0s", "ACLToken": "hidden", diff --git a/agent/consul/acl.go b/agent/consul/acl.go index 2e89eba2f2..74ebb90385 100644 --- a/agent/consul/acl.go +++ b/agent/consul/acl.go @@ -65,6 +65,10 @@ const ( // Maximum number of re-resolution requests to be made if the token is modified between // resolving the token and resolving its policies that would remove one of its policies. tokenPolicyResolutionMaxRetries = 5 + + // Maximum number of re-resolution requests to be made if the token is modified between + // resolving the token and resolving its roles that would remove one of its roles. + tokenRoleResolutionMaxRetries = 5 ) func minTTL(a time.Duration, b time.Duration) time.Duration { @@ -93,15 +97,16 @@ type ACLResolverDelegate interface { UseLegacyACLs() bool ResolveIdentityFromToken(token string) (bool, structs.ACLIdentity, error) ResolvePolicyFromID(policyID string) (bool, *structs.ACLPolicy, error) + ResolveRoleFromID(roleID string) (bool, *structs.ACLRole, error) RPC(method string, args interface{}, reply interface{}) error } -type policyTokenError struct { +type policyOrRoleTokenError struct { Err error token string } -func (e policyTokenError) Error() string { +func (e policyOrRoleTokenError) Error() string { return e.Err.Error() } @@ -129,9 +134,11 @@ type ACLResolverConfig struct { // Supports: // - Resolving tokens locally via the ACLResolverDelegate // - Resolving policies locally via the ACLResolverDelegate +// - Resolving roles locally via the ACLResolverDelegate // - Resolving legacy tokens remotely via a ACL.GetPolicy RPC // - Resolving tokens remotely via an ACL.TokenRead RPC // - Resolving policies remotely via an ACL.PolicyResolve RPC +// - Resolving roles remotely via an ACL.RoleResolve RPC // // Remote Resolution: // Remote resolution can be done synchronously or asynchronously depending @@ -141,7 +148,7 @@ type ACLResolverConfig struct { // then go routines will be spawned to perform the RPCs in the background // and then will update the cache with either the positive or negative result. // -// When the down policy is set to extend-cache or the token/policy is not already +// When the down policy is set to extend-cache or the token/policy/role is not already // cached then the same go routines are spawned to do the RPCs in the background. // However in this mode channels are created to receive the results of the RPC // and are registered with the resolver. Those channels are immediately read/blocked @@ -157,6 +164,7 @@ type ACLResolver struct { cache *structs.ACLCaches identityGroup singleflight.Group policyGroup singleflight.Group + roleGroup singleflight.Group legacyGroup singleflight.Group down acl.Authorizer @@ -447,7 +455,7 @@ func (r *ACLResolver) fetchAndCachePoliciesForIdentity(identity structs.ACLIdent // Do not touch the policy cache. Getting a top level ACL not found error // only indicates that the secret token used in the request // no longer exists - return nil, &policyTokenError{acl.ErrNotFound, identity.SecretToken()} + return nil, &policyOrRoleTokenError{acl.ErrNotFound, identity.SecretToken()} } if acl.IsErrPermissionDenied(err) { @@ -457,7 +465,7 @@ func (r *ACLResolver) fetchAndCachePoliciesForIdentity(identity structs.ACLIdent // Do not remove from the policy cache for permission denied // what this does indicate is that our view of the token is out of date - return nil, &policyTokenError{acl.ErrPermissionDenied, identity.SecretToken()} + return nil, &policyOrRoleTokenError{acl.ErrPermissionDenied, identity.SecretToken()} } // other RPC error - use cache if available @@ -483,6 +491,78 @@ func (r *ACLResolver) fetchAndCachePoliciesForIdentity(identity structs.ACLIdent return out, nil } +func (r *ACLResolver) fetchAndCacheRolesForIdentity(identity structs.ACLIdentity, roleIDs []string, cached map[string]*structs.RoleCacheEntry) (map[string]*structs.ACLRole, error) { + req := structs.ACLRoleBatchGetRequest{ + Datacenter: r.delegate.ACLDatacenter(false), + RoleIDs: roleIDs, + QueryOptions: structs.QueryOptions{ + Token: identity.SecretToken(), + AllowStale: true, + }, + } + + var resp structs.ACLRoleBatchResponse + err := r.delegate.RPC("ACL.RoleResolve", &req, &resp) + if err == nil { + out := make(map[string]*structs.ACLRole) + for _, role := range resp.Roles { + out[role.ID] = role + } + + for _, roleID := range roleIDs { + if role, ok := out[roleID]; ok { + r.cache.PutRole(roleID, role) + } else { + r.cache.PutRole(roleID, nil) + } + } + return out, nil + } + + if acl.IsErrNotFound(err) { + // make sure to indicate that this identity is no longer valid within + // the cache + r.cache.PutIdentity(identity.SecretToken(), nil) + + // Do not touch the cache. Getting a top level ACL not found error + // only indicates that the secret token used in the request + // no longer exists + return nil, &policyOrRoleTokenError{acl.ErrNotFound, identity.SecretToken()} + } + + if acl.IsErrPermissionDenied(err) { + // invalidate our ID cache so that identity resolution will take place + // again in the future + r.cache.RemoveIdentity(identity.SecretToken()) + + // Do not remove from the cache for permission denied + // what this does indicate is that our view of the token is out of date + return nil, &policyOrRoleTokenError{acl.ErrPermissionDenied, identity.SecretToken()} + } + + // other RPC error - use cache if available + + extendCache := r.config.ACLDownPolicy == "extend-cache" || r.config.ACLDownPolicy == "async-cache" + + out := make(map[string]*structs.ACLRole) + insufficientCache := false + for _, roleID := range roleIDs { + if entry, ok := cached[roleID]; extendCache && ok { + r.cache.PutRole(roleID, entry.Role) + if entry.Role != nil { + out[roleID] = entry.Role + } + } else { + r.cache.PutRole(roleID, nil) + insufficientCache = true + } + } + if insufficientCache { + return nil, ACLRemoteError{Err: err} + } + return out, nil +} + func (r *ACLResolver) filterPoliciesByScope(policies structs.ACLPolicies) structs.ACLPolicies { var out structs.ACLPolicies for _, policy := range policies { @@ -504,9 +584,10 @@ func (r *ACLResolver) filterPoliciesByScope(policies structs.ACLPolicies) struct func (r *ACLResolver) resolvePoliciesForIdentity(identity structs.ACLIdentity) (structs.ACLPolicies, error) { policyIDs := identity.PolicyIDs() + roleIDs := identity.RoleIDs() serviceIdentities := identity.ServiceIdentityList() - if len(policyIDs) == 0 && len(serviceIdentities) == 0 { + if len(policyIDs) == 0 && len(serviceIdentities) == 0 && len(roleIDs) == 0 { policy := identity.EmbeddedPolicy() if policy != nil { return []*structs.ACLPolicy{policy}, nil @@ -516,6 +597,25 @@ func (r *ACLResolver) resolvePoliciesForIdentity(identity structs.ACLIdentity) ( return nil, nil } + // Collect all of the roles tied to this token. + roles, err := r.collectRolesForIdentity(identity, roleIDs) + if err != nil { + return nil, err + } + + // Merge the policies and service identities across Token and Role fields. + for _, role := range roles { + for _, link := range role.Policies { + policyIDs = append(policyIDs, link.ID) + } + serviceIdentities = append(serviceIdentities, role.ServiceIdentities...) + } + + // Now deduplicate any policies or service identities that occur more than once. + policyIDs = dedupeStringSlice(policyIDs) + serviceIdentities = dedupeServiceIdentities(serviceIdentities) + + // Generate synthetic policies for all service identities in effect. syntheticPolicies := r.synthesizePoliciesForServiceIdentities(serviceIdentities) // For the new ACLs policy replication is mandatory for correct operation on servers. Therefore @@ -535,9 +635,6 @@ func (r *ACLResolver) synthesizePoliciesForServiceIdentities(serviceIdentities [ return nil } - // Collect and dedupe service identities. Prefer increasing datacenter scope. - serviceIdentities = dedupeServiceIdentities(serviceIdentities) - syntheticPolicies := make([]*structs.ACLPolicy, 0, len(serviceIdentities)) for _, s := range serviceIdentities { syntheticPolicies = append(syntheticPolicies, s.SyntheticPolicy()) @@ -590,6 +687,10 @@ func mergeStringSlice(a, b []string) []string { func dedupeStringSlice(in []string) []string { // From: https://github.com/golang/go/wiki/SliceTricks#in-place-deduplicate-comparable + if len(in) <= 1 { + return in + } + sort.Strings(in) j := 0 @@ -635,7 +736,7 @@ func (r *ACLResolver) collectPoliciesForIdentity(identity structs.ACLIdentity, p } if entry.Policy == nil { - // this happens when we cache a negative response for the policies existence + // this happens when we cache a negative response for the policy's existence continue } @@ -689,6 +790,99 @@ func (r *ACLResolver) collectPoliciesForIdentity(identity structs.ACLIdentity, p return policies, nil } +func (r *ACLResolver) resolveRolesForIdentity(identity structs.ACLIdentity) (structs.ACLRoles, error) { + return r.collectRolesForIdentity(identity, identity.RoleIDs()) +} + +func (r *ACLResolver) collectRolesForIdentity(identity structs.ACLIdentity, roleIDs []string) (structs.ACLRoles, error) { + if len(roleIDs) == 0 { + return nil, nil + } + + // For the new ACLs policy & role replication is mandatory for correct operation + // on servers. Therefore we only attempt to resolve roles locally + roles := make([]*structs.ACLRole, 0, len(roleIDs)) + + var missing []string + var expired []*structs.ACLRole + expCacheMap := make(map[string]*structs.RoleCacheEntry) + + for _, roleID := range roleIDs { + if done, role, err := r.delegate.ResolveRoleFromID(roleID); done { + if err != nil && !acl.IsErrNotFound(err) { + return nil, err + } + + if role != nil { + roles = append(roles, role) + } else { + r.logger.Printf("[WARN] acl: role %q not found for identity %q", roleID, identity.ID()) + } + + continue + } + + // create the missing list which we can execute an RPC to get all the missing roles at once + entry := r.cache.GetRole(roleID) + if entry == nil { + missing = append(missing, roleID) + continue + } + + if entry.Role == nil { + // this happens when we cache a negative response for the role's existence + continue + } + + if entry.Age() >= r.config.ACLRoleTTL { + expired = append(expired, entry.Role) + expCacheMap[roleID] = entry + } else { + roles = append(roles, entry.Role) + } + } + + // Hot-path if we have no missing or expired roles + if len(missing)+len(expired) == 0 { + return roles, nil + } + + hasMissing := len(missing) > 0 + + fetchIDs := missing + for _, role := range expired { + fetchIDs = append(fetchIDs, role.ID) + } + + waitChan := r.roleGroup.DoChan(identity.SecretToken(), func() (interface{}, error) { + roles, err := r.fetchAndCacheRolesForIdentity(identity, fetchIDs, expCacheMap) + return roles, err + }) + + waitForResult := hasMissing || r.config.ACLDownPolicy != "async-cache" + if !waitForResult { + // waitForResult being false requires that all the roles were cached already + roles = append(roles, expired...) + return roles, nil + } + + res := <-waitChan + + if res.Err != nil { + return nil, res.Err + } + + if res.Val != nil { + foundRoles := res.Val.(map[string]*structs.ACLRole) + + for _, role := range foundRoles { + roles = append(roles, role) + } + } + + return roles, nil +} + func (r *ACLResolver) resolveTokenToPolicies(token string) (structs.ACLPolicies, error) { _, policies, err := r.resolveTokenToIdentityAndPolicies(token) return policies, err @@ -717,13 +911,52 @@ func (r *ACLResolver) resolveTokenToIdentityAndPolicies(token string) (structs.A } lastErr = err - if tokenErr, ok := err.(*policyTokenError); ok { + if tokenErr, ok := err.(*policyOrRoleTokenError); ok { if acl.IsErrNotFound(err) && tokenErr.token == identity.SecretToken() { // token was deleted while resolving policies return nil, nil, acl.ErrNotFound } - // other types of policyTokenErrors should cause retrying the whole token + // other types of policyOrRoleTokenErrors should cause retrying the whole token + // resolution process + } else { + return identity, nil, err + } + } + + return lastIdentity, nil, lastErr +} + +func (r *ACLResolver) resolveTokenToIdentityAndRoles(token string) (structs.ACLIdentity, structs.ACLRoles, error) { + var lastErr error + var lastIdentity structs.ACLIdentity + + for i := 0; i < tokenRoleResolutionMaxRetries; i++ { + // Resolve the token to an ACLIdentity + identity, err := r.resolveIdentityFromToken(token) + if err != nil { + return nil, nil, err + } else if identity == nil { + return nil, nil, acl.ErrNotFound + } else if identity.IsExpired(time.Now()) { + return nil, nil, acl.ErrNotFound + } + + lastIdentity = identity + + roles, err := r.resolveRolesForIdentity(identity) + if err == nil { + return identity, roles, nil + } + lastErr = err + + if tokenErr, ok := err.(*policyOrRoleTokenError); ok { + if acl.IsErrNotFound(err) && tokenErr.token == identity.SecretToken() { + // token was deleted while resolving roles + return nil, nil, acl.ErrNotFound + } + + // other types of policyOrRoleTokenErrors should cause retrying the whole token // resolution process } else { return identity, nil, err diff --git a/agent/consul/acl_client.go b/agent/consul/acl_client.go index 06b9d78b5e..1951d43412 100644 --- a/agent/consul/acl_client.go +++ b/agent/consul/acl_client.go @@ -25,6 +25,8 @@ var clientACLCacheConfig *structs.ACLCachesConfig = &structs.ACLCachesConfig{ ParsedPolicies: 128, // Authorizers - number of compiled multi-policy effective policies that can be cached Authorizers: 256, + // Roles - number of ACL roles that can be cached + Roles: 128, } func (c *Client) UseLegacyACLs() bool { @@ -96,6 +98,11 @@ func (c *Client) ResolvePolicyFromID(policyID string) (bool, *structs.ACLPolicy, return false, nil, nil } +func (c *Client) ResolveRoleFromID(roleID string) (bool, *structs.ACLRole, error) { + // clients do no local role resolution at the moment + return false, nil, nil +} + func (c *Client) ResolveToken(token string) (acl.Authorizer, error) { return c.acls.ResolveToken(token) } diff --git a/agent/consul/acl_endpoint.go b/agent/consul/acl_endpoint.go index cd3abecb1a..62c17496f9 100644 --- a/agent/consul/acl_endpoint.go +++ b/agent/consul/acl_endpoint.go @@ -28,6 +28,7 @@ var ( validPolicyName = regexp.MustCompile(`^[A-Za-z0-9\-_]{1,128}$`) validServiceIdentityName = regexp.MustCompile(`^[a-z0-9]([a-z0-9\-_]*[a-z0-9])?$`) serviceIdentityNameMaxLength = 256 + validRoleName = regexp.MustCompile(`^[A-Za-z0-9\-_]{1,256}$`) ) // ACL endpoint is used to manipulate ACLs @@ -463,6 +464,33 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. } token.Policies = policies + roleIDs := make(map[string]struct{}) + var roles []structs.ACLTokenRoleLink + + // Validate all the role names and convert them to role IDs + for _, link := range token.Roles { + if link.ID == "" { + _, role, err := state.ACLRoleGetByName(nil, link.Name) + if err != nil { + return fmt.Errorf("Error looking up role for name %q: %v", link.Name, err) + } + if role == nil { + return fmt.Errorf("No such ACL role with name %q", link.Name) + } + link.ID = role.ID + } + + // Do not store the role name within raft/memdb as the role could be renamed in the future. + link.Name = "" + + // dedup role links by id + if _, ok := roleIDs[link.ID]; !ok { + roles = append(roles, link) + roleIDs[link.ID] = struct{}{} + } + } + token.Roles = roles + for _, svcid := range token.ServiceIdentities { if svcid.ServiceName == "" { return fmt.Errorf("Service identity is missing the service name field on this token") @@ -624,7 +652,7 @@ func (a *ACL) TokenList(args *structs.ACLTokenListRequest, reply *structs.ACLTok return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, tokens, err := state.ACLTokenList(ws, args.IncludeLocal, args.IncludeGlobal, args.Policy) + index, tokens, err := state.ACLTokenList(ws, args.IncludeLocal, args.IncludeGlobal, args.Policy, args.Role) if err != nil { return err } @@ -1053,3 +1081,334 @@ func (a *ACL) ReplicationStatus(args *structs.DCSpecificRequest, func timePointer(t time.Time) *time.Time { return &t } + +func (a *ACL) RoleRead(args *structs.ACLRoleGetRequest, reply *structs.ACLRoleResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if done, err := a.srv.forward("ACL.RoleRead", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + var ( + index uint64 + role *structs.ACLRole + err error + ) + if args.RoleID != "" { + index, role, err = state.ACLRoleGetByID(ws, args.RoleID) + } else { + index, role, err = state.ACLRoleGetByName(ws, args.RoleName) + } + + if err != nil { + return err + } + + reply.Index, reply.Role = index, role + return nil + }) +} + +func (a *ACL) RoleBatchRead(args *structs.ACLRoleBatchGetRequest, reply *structs.ACLRoleBatchResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if done, err := a.srv.forward("ACL.RoleBatchRead", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, roles, err := state.ACLRoleBatchGet(ws, args.RoleIDs) + if err != nil { + return err + } + + reply.Index, reply.Roles = index, roles + return nil + }) +} + +func (a *ACL) RoleSet(args *structs.ACLRoleSetRequest, reply *structs.ACLRole) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.InACLDatacenter() { + args.Datacenter = a.srv.config.ACLDatacenter + } + + if done, err := a.srv.forward("ACL.RoleSet", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "role", "upsert"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + role := &args.Role + state := a.srv.fsm.State() + + // Almost all of the checks here are also done in the state store. However, + // we want to prevent the raft operations when we know they are going to fail + // so we still do them here. + + // ensure a name is set + if role.Name == "" { + return fmt.Errorf("Invalid Role: no Name is set") + } + + if !validRoleName.MatchString(role.Name) { + return fmt.Errorf("Invalid Role: invalid Name. Only alphanumeric characters, '-' and '_' are allowed") + } + + if role.ID == "" { + // with no role ID one will be generated + var err error + + role.ID, err = lib.GenerateUUID(a.srv.checkRoleUUID) + if err != nil { + return err + } + + // validate the name is unique + if _, existing, err := state.ACLRoleGetByName(nil, role.Name); err != nil { + return fmt.Errorf("acl role lookup by name failed: %v", err) + } else if existing != nil { + return fmt.Errorf("Invalid Role: A Role with Name %q already exists", role.Name) + } + } else { + if _, err := uuid.ParseUUID(role.ID); err != nil { + return fmt.Errorf("Role ID invalid UUID") + } + + // Verify the role exists + _, existing, err := state.ACLRoleGetByID(nil, role.ID) + if err != nil { + return fmt.Errorf("acl role lookup failed: %v", err) + } else if existing == nil { + return fmt.Errorf("cannot find role %s", role.ID) + } + + if existing.Name != role.Name { + if _, nameMatch, err := state.ACLRoleGetByName(nil, role.Name); err != nil { + return fmt.Errorf("acl role lookup by name failed: %v", err) + } else if nameMatch != nil { + return fmt.Errorf("Invalid Role: A role with name %q already exists", role.Name) + } + } + } + + policyIDs := make(map[string]struct{}) + var policies []structs.ACLRolePolicyLink + + // Validate all the policy names and convert them to policy IDs + for _, link := range role.Policies { + if link.ID == "" { + _, policy, err := state.ACLPolicyGetByName(nil, link.Name) + if err != nil { + return fmt.Errorf("Error looking up policy for name %q: %v", link.Name, err) + } + if policy == nil { + return fmt.Errorf("No such ACL policy with name %q", link.Name) + } + link.ID = policy.ID + } + + // Do not store the policy name within raft/memdb as the policy could be renamed in the future. + link.Name = "" + + // dedup policy links by id + if _, ok := policyIDs[link.ID]; !ok { + policies = append(policies, link) + policyIDs[link.ID] = struct{}{} + } + } + role.Policies = policies + + for _, svcid := range role.ServiceIdentities { + if svcid.ServiceName == "" { + return fmt.Errorf("Service identity is missing the service name field on this role") + } + // TODO(rb): ugh if a local token gets a role that has a service + // identity that has datacenters set, we won't be anble to enforce this + // next blob here. This makes me lean more towards nuking ServiceIdentity.Datacenters again + // + // if token.Local && len(svcid.Datacenters) > 0 { + // return fmt.Errorf("Service identity %q cannot specify a list of datacenters on a local token", svcid.ServiceName) + // } + if !isValidServiceIdentityName(svcid.ServiceName) { + return fmt.Errorf("Service identity %q has an invalid name. Only alphanumeric characters, '-' and '_' are allowed", svcid.ServiceName) + } + } + + // calculate the hash for this role + role.SetHash(true) + + req := &structs.ACLRoleBatchSetRequest{ + Roles: structs.ACLRoles{role}, + } + + resp, err := a.srv.raftApply(structs.ACLRoleSetRequestType, req) + if err != nil { + return fmt.Errorf("Failed to apply role upsert request: %v", err) + } + + // Remove from the cache to prevent stale cache usage + a.srv.acls.cache.RemoveRole(role.ID) + + if respErr, ok := resp.(error); ok { + return respErr + } + + if _, role, err := a.srv.fsm.State().ACLRoleGetByID(nil, role.ID); err == nil && role != nil { + *reply = *role + } + + return nil +} + +func (a *ACL) RoleDelete(args *structs.ACLRoleDeleteRequest, reply *string) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.InACLDatacenter() { + args.Datacenter = a.srv.config.ACLDatacenter + } + + if done, err := a.srv.forward("ACL.RoleDelete", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "role", "delete"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + _, role, err := a.srv.fsm.State().ACLRoleGetByID(nil, args.RoleID) + if err != nil { + return err + } + + if role == nil { + return nil + } + + req := structs.ACLRoleBatchDeleteRequest{ + RoleIDs: []string{args.RoleID}, + } + + resp, err := a.srv.raftApply(structs.ACLRoleDeleteRequestType, &req) + if err != nil { + return fmt.Errorf("Failed to apply role delete request: %v", err) + } + + a.srv.acls.cache.RemoveRole(role.ID) + + if respErr, ok := resp.(error); ok { + return respErr + } + + if role != nil { + *reply = role.Name + } + + return nil +} + +func (a *ACL) RoleList(args *structs.ACLRoleListRequest, reply *structs.ACLRoleListResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if done, err := a.srv.forward("ACL.RoleList", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, roles, err := state.ACLRoleList(ws, args.Policy) + if err != nil { + return err + } + + reply.Index, reply.Roles = index, roles + return nil + }) +} + +// RoleResolve is used to retrieve a subset of the roles associated with a given token +// The role ids in the args simply act as a filter on the role set assigned to the token +func (a *ACL) RoleResolve(args *structs.ACLRoleBatchGetRequest, reply *structs.ACLRoleBatchResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if done, err := a.srv.forward("ACL.RoleResolve", args, args, reply); done { + return err + } + + // get full list of roles for this token + identity, roles, err := a.srv.acls.resolveTokenToIdentityAndRoles(args.Token) + if err != nil { + return err + } + + idMap := make(map[string]*structs.ACLRole) + for _, roleID := range identity.RoleIDs() { + idMap[roleID] = nil + } + for _, role := range roles { + idMap[role.ID] = role + } + + for _, roleID := range args.RoleIDs { + if role, ok := idMap[roleID]; ok { + // only add non-deleted roles + if role != nil { + reply.Roles = append(reply.Roles, role) + } + } else { + // send a permission denied to indicate that the request included + // role ids not associated with this token + return acl.ErrPermissionDenied + } + } + + a.srv.setQueryMeta(&reply.QueryMeta) + + return nil +} diff --git a/agent/consul/acl_endpoint_legacy.go b/agent/consul/acl_endpoint_legacy.go index ad264e2dcb..3b5ee22c6e 100644 --- a/agent/consul/acl_endpoint_legacy.go +++ b/agent/consul/acl_endpoint_legacy.go @@ -255,7 +255,7 @@ func (a *ACL) List(args *structs.DCSpecificRequest, return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, tokens, err := state.ACLTokenList(ws, false, true, "") + index, tokens, err := state.ACLTokenList(ws, false, true, "", "") if err != nil { return err } diff --git a/agent/consul/acl_endpoint_test.go b/agent/consul/acl_endpoint_test.go index 49943c9113..99de10be32 100644 --- a/agent/consul/acl_endpoint_test.go +++ b/agent/consul/acl_endpoint_test.go @@ -919,6 +919,51 @@ func TestACLEndpoint_TokenSet(t *testing.T) { require.Len(t, token.Policies, 0) }) + t.Run("Create it using Roles linked by id and name", func(t *testing.T) { + role1, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + role2, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: role1.ID, + }, + structs.ACLTokenRoleLink{ + Name: role2.Name, + }, + }, + Local: false, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err = acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Delete both roles to ensure that we skip resolving ID->Name + // in the returned data. + require.NoError(t, deleteTestRole(codec, "root", "dc1", role1.ID)) + require.NoError(t, deleteTestRole(codec, "root", "dc1", role2.ID)) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + require.NotNil(t, token.AccessorID) + require.Equal(t, token.Description, "foobar") + require.Equal(t, token.AccessorID, resp.AccessorID) + + require.Len(t, token.Roles, 0) + }) + t.Run("Create it with invalid service identity (empty)", func(t *testing.T) { req := structs.ACLTokenSetRequest{ Datacenter: "dc1", @@ -1783,8 +1828,7 @@ func TestACLEndpoint_PolicySet(t *testing.T) { acl := ACL{srv: s1} var policyID string - // Create it - { + t.Run("Create it", func(t *testing.T) { req := structs.ACLPolicySetRequest{ Datacenter: "dc1", Policy: structs.ACLPolicy{ @@ -1811,10 +1855,9 @@ func TestACLEndpoint_PolicySet(t *testing.T) { require.Equal(t, policy.Rules, "service \"\" { policy = \"read\" }") policyID = policy.ID - } + }) - // Update it - { + t.Run("Update it", func(t *testing.T) { req := structs.ACLPolicySetRequest{ Datacenter: "dc1", Policy: structs.ACLPolicy{ @@ -1840,7 +1883,7 @@ func TestACLEndpoint_PolicySet(t *testing.T) { require.Equal(t, policy.Description, "bat") require.Equal(t, policy.Name, "bar") require.Equal(t, policy.Rules, "service \"\" { policy = \"write\" }") - } + }) } func TestACLEndpoint_PolicySet_globalManagement(t *testing.T) { @@ -2080,6 +2123,482 @@ func TestACLEndpoint_PolicyResolve(t *testing.T) { require.EqualValues(t, retrievedPolicies, policies) } +func TestACLEndpoint_RoleRead(t *testing.T) { + t.Parallel() + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + role, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + acl := ACL{srv: s1} + + req := structs.ACLRoleGetRequest{ + Datacenter: "dc1", + RoleID: role.ID, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLRoleResponse{} + + err = acl.RoleRead(&req, &resp) + require.NoError(t, err) + require.Equal(t, role, resp.Role) +} + +func TestACLEndpoint_RoleBatchRead(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + r1, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + r2, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + acl := ACL{srv: s1} + roles := []string{r1.ID, r2.ID} + + req := structs.ACLRoleBatchGetRequest{ + Datacenter: "dc1", + RoleIDs: roles, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLRoleBatchResponse{} + + err = acl.RoleBatchRead(&req, &resp) + require.NoError(t, err) + + var retrievedRoles []string + + for _, v := range resp.Roles { + retrievedRoles = append(retrievedRoles, v.ID) + } + require.EqualValues(t, retrievedRoles, roles) +} + +func TestACLEndpoint_RoleSet(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + var roleID string + + testPolicy1, err := upsertTestPolicy(codec, "root", "dc1") + require.NoError(t, err) + testPolicy2, err := upsertTestPolicy(codec, "root", "dc1") + require.NoError(t, err) + + t.Run("Create it", func(t *testing.T) { + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: "baz", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicy1.ID, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + require.NoError(t, err) + require.NotNil(t, resp.ID) + + // Get the role directly to validate that it exists + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + + require.NotNil(t, role.ID) + require.Equal(t, role.Description, "foobar") + require.Equal(t, role.Name, "baz") + require.Len(t, role.Policies, 1) + require.Equal(t, testPolicy1.ID, role.Policies[0].ID) + + roleID = role.ID + }) + + t.Run("Update it", func(t *testing.T) { + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + ID: roleID, + Description: "bat", + Name: "bar", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicy2.ID, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + require.NoError(t, err) + require.NotNil(t, resp.ID) + + // Get the role directly to validate that it exists + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + + require.NotNil(t, role.ID) + require.Equal(t, role.Description, "bat") + require.Equal(t, role.Name, "bar") + require.Len(t, role.Policies, 1) + require.Equal(t, testPolicy2.ID, role.Policies[0].ID) + }) + + t.Run("Create it using Policies linked by id and name", func(t *testing.T) { + policy1, err := upsertTestPolicy(codec, "root", "dc1") + require.NoError(t, err) + policy2, err := upsertTestPolicy(codec, "root", "dc1") + require.NoError(t, err) + + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: "baz", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: policy1.ID, + }, + structs.ACLRolePolicyLink{ + Name: policy2.Name, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLRole{} + + err = acl.RoleSet(&req, &resp) + require.NoError(t, err) + require.NotNil(t, resp.ID) + + // Delete both policies to ensure that we skip resolving ID->Name + // in the returned data. + require.NoError(t, deleteTestPolicy(codec, "root", "dc1", policy1.ID)) + require.NoError(t, deleteTestPolicy(codec, "root", "dc1", policy2.ID)) + + // Get the role directly to validate that it exists + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + + require.NotNil(t, role.ID) + require.Equal(t, role.Description, "foobar") + require.Equal(t, role.Name, "baz") + + require.Len(t, role.Policies, 0) + }) + + roleNameGen := func(t *testing.T) string { + t.Helper() + name, err := uuid.GenerateUUID() + require.NoError(t, err) + return name + } + + t.Run("Create it with invalid service identity (empty)", func(t *testing.T) { + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: roleNameGen(t), + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: ""}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + requireErrorContains(t, err, "Service identity is missing the service name field") + }) + + t.Run("Create it with invalid service identity (too large)", func(t *testing.T) { + long := strings.Repeat("x", serviceIdentityNameMaxLength+1) + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: roleNameGen(t), + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: long}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + require.NotNil(t, err) + }) + + for _, test := range []struct { + name string + ok bool + }{ + {"-abc", false}, + {"abc-", false}, + {"a-bc", true}, + {"_abc", false}, + {"abc_", false}, + {"a_bc", true}, + {":abc", false}, + {"abc:", false}, + {"a:bc", false}, + {"Abc", false}, + {"aBc", false}, + {"abC", false}, + {"0abc", true}, + {"abc0", true}, + {"a0bc", true}, + } { + var testName string + if test.ok { + testName = "Create it with valid service identity (by regex): " + test.name + } else { + testName = "Create it with invalid service identity (by regex): " + test.name + } + t.Run(testName, func(t *testing.T) { + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: roleNameGen(t), + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: test.name}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + if test.ok { + require.NoError(t, err) + + // Get the token directly to validate that it exists + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + require.ElementsMatch(t, req.Role.ServiceIdentities, role.ServiceIdentities) + } else { + require.NotNil(t, err) + } + }) + } +} + +func TestACLEndpoint_RoleSet_names(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + + testPolicy1, err := upsertTestPolicy(codec, "root", "dc1") + require.NoError(t, err) + + for _, test := range []struct { + name string + ok bool + }{ + {"", false}, + {"-bad", true}, + {"bad-", true}, + {"bad?bad", false}, + {strings.Repeat("x", 257), false}, + {strings.Repeat("x", 256), true}, + {"-abc", true}, + {"abc-", true}, + {"a-bc", true}, + {"_abc", true}, + {"abc_", true}, + {"a_bc", true}, + {":abc", false}, + {"abc:", false}, + {"a:bc", false}, + {"Abc", true}, + {"aBc", true}, + {"abC", true}, + {"0abc", true}, + {"abc0", true}, + {"a0bc", true}, + } { + var testName string + if test.ok { + testName = "create with valid name: " + test.name + } else { + testName = "create with invalid name: " + test.name + } + + t.Run(testName, func(t *testing.T) { + // cleanup from a prior insertion that may have succeeded + require.NoError(t, deleteTestRoleByName(codec, "root", "dc1", test.name)) + + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Name: test.name, + Description: "foobar", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicy1.ID, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + if test.ok { + require.NoError(t, err) + + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + require.Equal(t, test.name, role.Name) + } else { + require.Error(t, err) + } + }) + } +} + +func TestACLEndpoint_RoleDelete(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + existingRole, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + acl := ACL{srv: s1} + + req := structs.ACLRoleDeleteRequest{ + Datacenter: "dc1", + RoleID: existingRole.ID, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var resp string + + err = acl.RoleDelete(&req, &resp) + require.NoError(t, err) + + // Make sure the role is gone + roleResp, err := retrieveTestRole(codec, "root", "dc1", existingRole.ID) + require.Nil(t, roleResp.Role) +} + +func TestACLEndpoint_RoleList(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + r1, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + r2, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + acl := ACL{srv: s1} + + req := structs.ACLRoleListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLRoleListResponse{} + + err = acl.RoleList(&req, &resp) + require.NoError(t, err) + + roles := []string{r1.ID, r2.ID} + var retrievedRoles []string + + for _, v := range resp.Roles { + retrievedRoles = append(retrievedRoles, v.ID) + } + require.ElementsMatch(t, retrievedRoles, roles) +} + // upsertTestToken creates a token for testing purposes func upsertTestToken(codec rpc.ClientCodec, masterToken string, datacenter string, tokenModificationFn func(token *structs.ACLToken)) (*structs.ACLToken, error) { @@ -2217,6 +2736,106 @@ func retrieveTestPolicy(codec rpc.ClientCodec, masterToken string, datacenter st return &out, nil } +func deleteTestRole(codec rpc.ClientCodec, masterToken string, datacenter string, roleID string) error { + arg := structs.ACLRoleDeleteRequest{ + Datacenter: datacenter, + RoleID: roleID, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var ignored string + err := msgpackrpc.CallWithCodec(codec, "ACL.RoleDelete", &arg, &ignored) + return err +} + +func deleteTestRoleByName(codec rpc.ClientCodec, masterToken string, datacenter string, roleName string) error { + resp, err := retrieveTestRoleByName(codec, masterToken, datacenter, roleName) + if err != nil { + return err + } + if resp.Role == nil { + return nil + } + + return deleteTestRole(codec, masterToken, datacenter, resp.Role.ID) +} + +// upsertTestRole creates a role for testing purposes +func upsertTestRole(codec rpc.ClientCodec, masterToken string, datacenter string) (*structs.ACLRole, error) { + // Make sure test roles can't collide + roleUnq, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + policyID, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + + arg := structs.ACLRoleSetRequest{ + Datacenter: datacenter, + Role: structs.ACLRole{ + Name: fmt.Sprintf("test-role-%s", roleUnq), + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: policyID, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var out structs.ACLRole + + err = msgpackrpc.CallWithCodec(codec, "ACL.RoleSet", &arg, &out) + + if err != nil { + return nil, err + } + + if out.ID == "" { + return nil, fmt.Errorf("ID is nil: %v", out) + } + + return &out, nil +} + +func retrieveTestRole(codec rpc.ClientCodec, masterToken string, datacenter string, id string) (*structs.ACLRoleResponse, error) { + arg := structs.ACLRoleGetRequest{ + Datacenter: datacenter, + RoleID: id, + QueryOptions: structs.QueryOptions{Token: masterToken}, + } + + var out structs.ACLRoleResponse + + err := msgpackrpc.CallWithCodec(codec, "ACL.RoleRead", &arg, &out) + + if err != nil { + return nil, err + } + + return &out, nil +} + +func retrieveTestRoleByName(codec rpc.ClientCodec, masterToken string, datacenter string, name string) (*structs.ACLRoleResponse, error) { + arg := structs.ACLRoleGetRequest{ + Datacenter: datacenter, + RoleName: name, + QueryOptions: structs.QueryOptions{Token: masterToken}, + } + + var out structs.ACLRoleResponse + + err := msgpackrpc.CallWithCodec(codec, "ACL.RoleRead", &arg, &out) + + if err != nil { + return nil, err + } + + return &out, nil +} + func requireTimeEquals(t *testing.T, expect, got *time.Time) { t.Helper() if expect == nil && got == nil { diff --git a/agent/consul/acl_replication.go b/agent/consul/acl_replication.go index 047705a60f..4cec1d81a3 100644 --- a/agent/consul/acl_replication.go +++ b/agent/consul/acl_replication.go @@ -3,6 +3,7 @@ package consul import ( "bytes" "context" + "errors" "fmt" "time" @@ -15,123 +16,108 @@ const ( aclReplicationMaxRetryBackoff = 64 ) -func diffACLPolicies(local structs.ACLPolicies, remote structs.ACLPolicyListStubs, lastRemoteIndex uint64) ([]string, []string) { - local.Sort() - remote.Sort() +// aclTypeReplicator allows the machinery of acl replication to be shared between +// types with minimal code duplication (barring generics magically popping into +// existence). +// +// Concrete implementations of this interface should internally contain a +// pointer to the server so that data lookups can occur, and they should +// maintain the smallest quantity of type-specific state they can. +// +// Implementations of this interface are short-lived and recreated on every +// iteration. +type aclTypeReplicator interface { + // Type is variant of replication in use. Used for updating the replication + // status tracker. + Type() structs.ACLReplicationType - var deletions []string - var updates []string - var localIdx int - var remoteIdx int - for localIdx, remoteIdx = 0, 0; localIdx < len(local) && remoteIdx < len(remote); { - if local[localIdx].ID == remote[remoteIdx].ID { - // policy is in both the local and remote state - need to check raft indices and the Hash - if remote[remoteIdx].ModifyIndex > lastRemoteIndex && !bytes.Equal(remote[remoteIdx].Hash, local[localIdx].Hash) { - updates = append(updates, remote[remoteIdx].ID) - } - // increment both indices when equal - localIdx += 1 - remoteIdx += 1 - } else if local[localIdx].ID < remote[remoteIdx].ID { - // policy no longer in remoted state - needs deleting - deletions = append(deletions, local[localIdx].ID) + // SingularNoun is the singular form of the item being replicated. + SingularNoun() string - // increment just the local index - localIdx += 1 - } else { - // local state doesn't have this policy - needs updating - updates = append(updates, remote[remoteIdx].ID) + // PluralNoun is the plural form of the item being replicated. + PluralNoun() string - // increment just the remote index - remoteIdx += 1 - } - } + // FetchRemote retrieves items newer than the provided index from the + // remote datacenter (for diffing purposes). + FetchRemote(srv *Server, lastRemoteIndex uint64) (int, uint64, error) - for ; localIdx < len(local); localIdx += 1 { - deletions = append(deletions, local[localIdx].ID) - } + // FetchLocal retrieves items from the current datacenter (for diffing + // purposes). + FetchLocal(srv *Server) (int, uint64, error) - for ; remoteIdx < len(remote); remoteIdx += 1 { - updates = append(updates, remote[remoteIdx].ID) - } + // SortState sorts the internal working state output of FetchRemote and + // FetchLocal so that a sane diff can be performed. + SortState() (lenLocal, lenRemote int) - return deletions, updates + // LocalMeta allows for type-agnostic metadata from the sorted local state + // can be retrieved for the purposes of diffing. + LocalMeta(i int) (id string, modIndex uint64, hash []byte) + + // RemoteMeta allows for type-agnostic metadata from the sorted remote + // state can be retrieved for the purposes of diffing. + RemoteMeta(i int) (id string, modIndex uint64, hash []byte) + + // FetchUpdated retrieves the specific items from the remote (during the + // correction phase). + FetchUpdated(srv *Server, updates []string) (int, error) + + // LenPendingUpdates should be the size of the data retrieved in + // FetchUpdated. + LenPendingUpdates() int + + // PendingUpdateIsRedacted returns true if the update contains redacted + // data. Really only valid for tokens. + PendingUpdateIsRedacted(i int) bool + + // PendingUpdateEstimatedSize is the item's EstimatedSize in the state + // populated by FetchUpdated. + PendingUpdateEstimatedSize(i int) int + + // UpdateLocalBatch applies a portion of the state populated by + // FetchUpdated to the current datacenter. + UpdateLocalBatch(ctx context.Context, srv *Server, start, end int) error + + // DeleteLocalBatch removes items from the current datacenter. + DeleteLocalBatch(srv *Server, batch []string) error } -func (s *Server) deleteLocalACLPolicies(deletions []string, ctx context.Context) (bool, error) { - ticker := time.NewTicker(time.Second / time.Duration(s.config.ACLReplicationApplyLimit)) - defer ticker.Stop() +var errContainsRedactedData = errors.New("replication results contain redacted data") - for i := 0; i < len(deletions); i += aclBatchDeleteSize { - req := structs.ACLPolicyBatchDeleteRequest{} - - if i+aclBatchDeleteSize > len(deletions) { - req.PolicyIDs = deletions[i:] - } else { - req.PolicyIDs = deletions[i : i+aclBatchDeleteSize] - } - - resp, err := s.raftApply(structs.ACLPolicyDeleteRequestType, &req) - if err != nil { - return false, fmt.Errorf("Failed to apply policy deletions: %v", err) - } - if respErr, ok := resp.(error); ok && err != nil { - return false, fmt.Errorf("Failed to apply policy deletions: %v", respErr) - } - - if i+aclBatchDeleteSize < len(deletions) { - select { - case <-ctx.Done(): - return true, nil - case <-ticker.C: - // do nothing - ready for the next batch - } - } +func (s *Server) fetchACLRolesBatch(roleIDs []string) (*structs.ACLRoleBatchResponse, error) { + req := structs.ACLRoleBatchGetRequest{ + Datacenter: s.config.ACLDatacenter, + RoleIDs: roleIDs, + QueryOptions: structs.QueryOptions{ + AllowStale: true, + Token: s.tokens.ReplicationToken(), + }, } - return false, nil + var response structs.ACLRoleBatchResponse + if err := s.RPC("ACL.RoleBatchRead", &req, &response); err != nil { + return nil, err + } + + return &response, nil } -func (s *Server) updateLocalACLPolicies(policies structs.ACLPolicies, ctx context.Context) (bool, error) { - ticker := time.NewTicker(time.Second / time.Duration(s.config.ACLReplicationApplyLimit)) - defer ticker.Stop() +func (s *Server) fetchACLRoles(lastRemoteIndex uint64) (*structs.ACLRoleListResponse, error) { + defer metrics.MeasureSince([]string{"leader", "replication", "acl", "role", "fetch"}, time.Now()) - // outer loop handles submitting a batch - for batchStart := 0; batchStart < len(policies); { - // inner loop finds the last element to include in this batch. - batchSize := 0 - batchEnd := batchStart - for ; batchEnd < len(policies) && batchSize < aclBatchUpsertSize; batchEnd += 1 { - batchSize += policies[batchEnd].EstimateSize() - } - - req := structs.ACLPolicyBatchSetRequest{ - Policies: policies[batchStart:batchEnd], - } - - resp, err := s.raftApply(structs.ACLPolicySetRequestType, &req) - if err != nil { - return false, fmt.Errorf("Failed to apply policy upserts: %v", err) - } - if respErr, ok := resp.(error); ok && respErr != nil { - return false, fmt.Errorf("Failed to apply policy upsert: %v", respErr) - } - s.logger.Printf("[DEBUG] acl: policy replication - upserted 1 batch with %d policies of size %d", batchEnd-batchStart, batchSize) - - // policies[batchEnd] wasn't include as the slicing doesn't include the element at the stop index - batchStart = batchEnd - - // prevent waiting if we are done - if batchEnd < len(policies) { - select { - case <-ctx.Done(): - return true, nil - case <-ticker.C: - // nothing to do - just rate limiting - } - } + req := structs.ACLRoleListRequest{ + Datacenter: s.config.ACLDatacenter, + QueryOptions: structs.QueryOptions{ + AllowStale: true, + MinQueryIndex: lastRemoteIndex, + Token: s.tokens.ReplicationToken(), + }, } - return false, nil + + var response structs.ACLRoleListResponse + if err := s.RPC("ACL.RoleList", &req, &response); err != nil { + return nil, err + } + return &response, nil } func (s *Server) fetchACLPoliciesBatch(policyIDs []string) (*structs.ACLPolicyBatchResponse, error) { @@ -171,66 +157,72 @@ func (s *Server) fetchACLPolicies(lastRemoteIndex uint64) (*structs.ACLPolicyLis return &response, nil } -type tokenDiffResults struct { +type itemDiffResults struct { LocalDeletes []string LocalUpserts []string LocalSkipped int RemoteSkipped int } -func diffACLTokens(local structs.ACLTokens, remote structs.ACLTokenListStubs, lastRemoteIndex uint64) tokenDiffResults { - // Note: items with empty AccessorIDs will bubble up to the top. - local.Sort() - remote.Sort() +func diffACLType(tr aclTypeReplicator, lastRemoteIndex uint64) itemDiffResults { + // Note: items with empty IDs will bubble up to the top (like legacy, unmigrated Tokens) - var res tokenDiffResults + lenLocal, lenRemote := tr.SortState() + + var res itemDiffResults var localIdx int var remoteIdx int - for localIdx, remoteIdx = 0, 0; localIdx < len(local) && remoteIdx < len(remote); { - if local[localIdx].AccessorID == "" { + for localIdx, remoteIdx = 0, 0; localIdx < lenLocal && remoteIdx < lenRemote; { + localID, _, localHash := tr.LocalMeta(localIdx) + remoteID, remoteMod, remoteHash := tr.RemoteMeta(remoteIdx) + + if localID == "" { res.LocalSkipped++ localIdx += 1 continue } - if remote[remoteIdx].AccessorID == "" { + if remoteID == "" { res.RemoteSkipped++ remoteIdx += 1 continue } - if local[localIdx].AccessorID == remote[remoteIdx].AccessorID { - // policy is in both the local and remote state - need to check raft indices and Hash - if remote[remoteIdx].ModifyIndex > lastRemoteIndex && !bytes.Equal(remote[remoteIdx].Hash, local[localIdx].Hash) { - res.LocalUpserts = append(res.LocalUpserts, remote[remoteIdx].AccessorID) + + if localID == remoteID { + // item is in both the local and remote state - need to check raft indices and the Hash + if remoteMod > lastRemoteIndex && !bytes.Equal(remoteHash, localHash) { + res.LocalUpserts = append(res.LocalUpserts, remoteID) } // increment both indices when equal localIdx += 1 remoteIdx += 1 - } else if local[localIdx].AccessorID < remote[remoteIdx].AccessorID { - // policy no longer in remoted state - needs deleting - res.LocalDeletes = append(res.LocalDeletes, local[localIdx].AccessorID) + } else if localID < remoteID { + // item no longer in remote state - needs deleting + res.LocalDeletes = append(res.LocalDeletes, localID) // increment just the local index localIdx += 1 } else { - // local state doesn't have this policy - needs updating - res.LocalUpserts = append(res.LocalUpserts, remote[remoteIdx].AccessorID) + // local state doesn't have this item - needs updating + res.LocalUpserts = append(res.LocalUpserts, remoteID) // increment just the remote index remoteIdx += 1 } } - for ; localIdx < len(local); localIdx += 1 { - if local[localIdx].AccessorID != "" { - res.LocalDeletes = append(res.LocalDeletes, local[localIdx].AccessorID) + for ; localIdx < lenLocal; localIdx += 1 { + localID, _, _ := tr.LocalMeta(localIdx) + if localID != "" { + res.LocalDeletes = append(res.LocalDeletes, localID) } else { res.LocalSkipped++ } } - for ; remoteIdx < len(remote); remoteIdx += 1 { - if remote[remoteIdx].AccessorID != "" { - res.LocalUpserts = append(res.LocalUpserts, remote[remoteIdx].AccessorID) + for ; remoteIdx < lenRemote; remoteIdx += 1 { + remoteID, _, _ := tr.RemoteMeta(remoteIdx) + if remoteID != "" { + res.LocalUpserts = append(res.LocalUpserts, remoteID) } else { res.RemoteSkipped++ } @@ -239,25 +231,21 @@ func diffACLTokens(local structs.ACLTokens, remote structs.ACLTokenListStubs, la return res } -func (s *Server) deleteLocalACLTokens(deletions []string, ctx context.Context) (bool, error) { +func (s *Server) deleteLocalACLType(ctx context.Context, tr aclTypeReplicator, deletions []string) (bool, error) { ticker := time.NewTicker(time.Second / time.Duration(s.config.ACLReplicationApplyLimit)) defer ticker.Stop() for i := 0; i < len(deletions); i += aclBatchDeleteSize { - req := structs.ACLTokenBatchDeleteRequest{} + var batch []string if i+aclBatchDeleteSize > len(deletions) { - req.TokenIDs = deletions[i:] + batch = deletions[i:] } else { - req.TokenIDs = deletions[i : i+aclBatchDeleteSize] + batch = deletions[i : i+aclBatchDeleteSize] } - resp, err := s.raftApply(structs.ACLTokenDeleteRequestType, &req) - if err != nil { - return false, fmt.Errorf("Failed to apply token deletions: %v", err) - } - if respErr, ok := resp.(error); ok && err != nil { - return false, fmt.Errorf("Failed to apply token deletions: %v", respErr) + if err := tr.DeleteLocalBatch(s, batch); err != nil { + return false, fmt.Errorf("Failed to apply %s deletions: %v", tr.SingularNoun(), err) } if i+aclBatchDeleteSize < len(deletions) { @@ -273,47 +261,50 @@ func (s *Server) deleteLocalACLTokens(deletions []string, ctx context.Context) ( return false, nil } -func (s *Server) updateLocalACLTokens(tokens structs.ACLTokens, ctx context.Context) (bool, error) { +func (s *Server) updateLocalACLType(ctx context.Context, tr aclTypeReplicator) (bool, error) { ticker := time.NewTicker(time.Second / time.Duration(s.config.ACLReplicationApplyLimit)) defer ticker.Stop() + lenPending := tr.LenPendingUpdates() + // outer loop handles submitting a batch - for batchStart := 0; batchStart < len(tokens); { + for batchStart := 0; batchStart < lenPending; { // inner loop finds the last element to include in this batch. batchSize := 0 batchEnd := batchStart - for ; batchEnd < len(tokens) && batchSize < aclBatchUpsertSize; batchEnd += 1 { - if tokens[batchEnd].SecretID == redactedToken { - return false, fmt.Errorf("Detected redacted token secrets: stopping token update round - verify that the replication token in use has acl:write permissions.") + for ; batchEnd < lenPending && batchSize < aclBatchUpsertSize; batchEnd += 1 { + if tr.PendingUpdateIsRedacted(batchEnd) { + return false, fmt.Errorf( + "Detected redacted %s secrets: stopping %s update round - verify that the replication token in use has acl:write permissions.", + tr.SingularNoun(), + tr.SingularNoun(), + ) } - batchSize += tokens[batchEnd].EstimateSize() + batchSize += tr.PendingUpdateEstimatedSize(batchEnd) } - req := structs.ACLTokenBatchSetRequest{ - Tokens: tokens[batchStart:batchEnd], - CAS: false, - } - - resp, err := s.raftApply(structs.ACLTokenSetRequestType, &req) + err := tr.UpdateLocalBatch(ctx, s, batchStart, batchEnd) if err != nil { - return false, fmt.Errorf("Failed to apply token upserts: %v", err) - } - if respErr, ok := resp.(error); ok && respErr != nil { - return false, fmt.Errorf("Failed to apply token upserts: %v", respErr) + return false, fmt.Errorf("Failed to apply %s upserts: %v", tr.SingularNoun(), err) } + s.logger.Printf( + "[DEBUG] acl: %s replication - upserted 1 batch with %d %s of size %d", + tr.SingularNoun(), + batchEnd-batchStart, + tr.PluralNoun(), + batchSize, + ) - s.logger.Printf("[DEBUG] acl: token replication - upserted 1 batch with %d tokens of size %d", batchEnd-batchStart, batchSize) - - // tokens[batchEnd] wasn't include as the slicing doesn't include the element at the stop index + // items[batchEnd] wasn't include as the slicing doesn't include the element at the stop index batchStart = batchEnd // prevent waiting if we are done - if batchEnd < len(tokens) { + if batchEnd < lenPending { select { case <-ctx.Done(): return true, nil case <-ticker.C: - // nothing to do - just rate limiting here + // nothing to do - just rate limiting } } } @@ -359,95 +350,28 @@ func (s *Server) fetchACLTokens(lastRemoteIndex uint64) (*structs.ACLTokenListRe return &response, nil } -func (s *Server) replicateACLPolicies(lastRemoteIndex uint64, ctx context.Context) (uint64, bool, error) { - remote, err := s.fetchACLPolicies(lastRemoteIndex) - if err != nil { - return 0, false, fmt.Errorf("failed to retrieve remote ACL policies: %v", err) - } - - s.logger.Printf("[DEBUG] acl: finished fetching policies tokens: %d", len(remote.Policies)) - - // Need to check if we should be stopping. This will be common as the fetching process is a blocking - // RPC which could have been hanging around for a long time and during that time leadership could - // have been lost. - select { - case <-ctx.Done(): - return 0, true, nil - default: - // do nothing - } - - // Measure everything after the remote query, which can block for long - // periods of time. This metric is a good measure of how expensive the - // replication process is. - defer metrics.MeasureSince([]string{"leader", "replication", "acl", "policy", "apply"}, time.Now()) - - _, local, err := s.fsm.State().ACLPolicyList(nil) - if err != nil { - return 0, false, fmt.Errorf("failed to retrieve local ACL policies: %v", err) - } - - // If the remote index ever goes backwards, it's a good indication that - // the remote side was rebuilt and we should do a full sync since we - // can't make any assumptions about what's going on. - if remote.QueryMeta.Index < lastRemoteIndex { - s.logger.Printf("[WARN] consul: ACL policy replication remote index moved backwards (%d to %d), forcing a full ACL policy sync", lastRemoteIndex, remote.QueryMeta.Index) - lastRemoteIndex = 0 - } - - s.logger.Printf("[DEBUG] acl: policy replication - local: %d, remote: %d", len(local), len(remote.Policies)) - // Calculate the changes required to bring the state into sync and then - // apply them. - deletions, updates := diffACLPolicies(local, remote.Policies, lastRemoteIndex) - - s.logger.Printf("[DEBUG] acl: policy replication - deletions: %d, updates: %d", len(deletions), len(updates)) - - var policies *structs.ACLPolicyBatchResponse - if len(updates) > 0 { - policies, err = s.fetchACLPoliciesBatch(updates) - if err != nil { - return 0, false, fmt.Errorf("failed to retrieve ACL policy updates: %v", err) - } - s.logger.Printf("[DEBUG] acl: policy replication - downloaded %d policies", len(policies.Policies)) - } - - if len(deletions) > 0 { - s.logger.Printf("[DEBUG] acl: policy replication - performing deletions") - - exit, err := s.deleteLocalACLPolicies(deletions, ctx) - if exit { - return 0, true, nil - } - if err != nil { - return 0, false, fmt.Errorf("failed to delete local ACL policies: %v", err) - } - s.logger.Printf("[DEBUG] acl: policy replication - finished deletions") - } - - if len(updates) > 0 { - s.logger.Printf("[DEBUG] acl: policy replication - performing updates") - exit, err := s.updateLocalACLPolicies(policies.Policies, ctx) - if exit { - return 0, true, nil - } - if err != nil { - return 0, false, fmt.Errorf("failed to update local ACL policies: %v", err) - } - s.logger.Printf("[DEBUG] acl: policy replication - finished updates") - } - - // Return the index we got back from the remote side, since we've synced - // up with the remote state as of that index. - return remote.QueryMeta.Index, false, nil +func (s *Server) replicateACLTokens(ctx context.Context, lastRemoteIndex uint64) (uint64, bool, error) { + tr := &aclTokenReplicator{} + return s.replicateACLType(ctx, tr, lastRemoteIndex) } -func (s *Server) replicateACLTokens(lastRemoteIndex uint64, ctx context.Context) (uint64, bool, error) { - remote, err := s.fetchACLTokens(lastRemoteIndex) +func (s *Server) replicateACLPolicies(ctx context.Context, lastRemoteIndex uint64) (uint64, bool, error) { + tr := &aclPolicyReplicator{} + return s.replicateACLType(ctx, tr, lastRemoteIndex) +} + +func (s *Server) replicateACLRoles(ctx context.Context, lastRemoteIndex uint64) (uint64, bool, error) { + tr := &aclRoleReplicator{} + return s.replicateACLType(ctx, tr, lastRemoteIndex) +} + +func (s *Server) replicateACLType(ctx context.Context, tr aclTypeReplicator, lastRemoteIndex uint64) (uint64, bool, error) { + lenRemote, remoteIndex, err := tr.FetchRemote(s, lastRemoteIndex) if err != nil { - return 0, false, fmt.Errorf("failed to retrieve remote ACL tokens: %v", err) + return 0, false, fmt.Errorf("failed to retrieve remote ACL %s: %v", tr.PluralNoun(), err) } - s.logger.Printf("[DEBUG] acl: finished fetching remote tokens: %d", len(remote.Tokens)) + s.logger.Printf("[DEBUG] acl: finished fetching %s: %d", tr.PluralNoun(), lenRemote) // Need to check if we should be stopping. This will be common as the fetching process is a blocking // RPC which could have been hanging around for a long time and during that time leadership could @@ -462,74 +386,99 @@ func (s *Server) replicateACLTokens(lastRemoteIndex uint64, ctx context.Context) // Measure everything after the remote query, which can block for long // periods of time. This metric is a good measure of how expensive the // replication process is. - defer metrics.MeasureSince([]string{"leader", "replication", "acl", "token", "apply"}, time.Now()) + defer metrics.MeasureSince([]string{"leader", "replication", "acl", tr.SingularNoun(), "apply"}, time.Now()) - _, local, err := s.fsm.State().ACLTokenList(nil, false, true, "") + lenLocal, _, err := tr.FetchLocal(s) if err != nil { - return 0, false, fmt.Errorf("failed to retrieve local ACL tokens: %v", err) + return 0, false, fmt.Errorf("failed to retrieve local ACL %s: %v", tr.PluralNoun(), err) } - // Do not filter by expiration times. Wait until the tokens are explicitly deleted. // If the remote index ever goes backwards, it's a good indication that // the remote side was rebuilt and we should do a full sync since we // can't make any assumptions about what's going on. - if remote.QueryMeta.Index < lastRemoteIndex { - s.logger.Printf("[WARN] consul: ACL token replication remote index moved backwards (%d to %d), forcing a full ACL token sync", lastRemoteIndex, remote.QueryMeta.Index) + if remoteIndex < lastRemoteIndex { + s.logger.Printf( + "[WARN] consul: ACL %s replication remote index moved backwards (%d to %d), forcing a full ACL %s sync", + tr.SingularNoun(), + lastRemoteIndex, + remoteIndex, + tr.SingularNoun(), + ) lastRemoteIndex = 0 } - s.logger.Printf("[DEBUG] acl: token replication - local: %d, remote: %d", len(local), len(remote.Tokens)) - - // Calculate the changes required to bring the state into sync and then - // apply them. - res := diffACLTokens(local, remote.Tokens, lastRemoteIndex) + s.logger.Printf( + "[DEBUG] acl: %s replication - local: %d, remote: %d", + tr.SingularNoun(), + lenLocal, + lenRemote, + ) + // Calculate the changes required to bring the state into sync and then apply them. + res := diffACLType(tr, lastRemoteIndex) if res.LocalSkipped > 0 || res.RemoteSkipped > 0 { - s.logger.Printf("[DEBUG] acl: token replication - deletions: %d, updates: %d, skipped: %d, skippedRemote: %d", - len(res.LocalDeletes), len(res.LocalUpserts), res.LocalSkipped, res.RemoteSkipped) + s.logger.Printf( + "[DEBUG] acl: %s replication - deletions: %d, updates: %d, skipped: %d, skippedRemote: %d", + tr.SingularNoun(), + len(res.LocalDeletes), + len(res.LocalUpserts), + res.LocalSkipped, + res.RemoteSkipped, + ) } else { - s.logger.Printf("[DEBUG] acl: token replication - deletions: %d, updates: %d", len(res.LocalDeletes), len(res.LocalUpserts)) + s.logger.Printf( + "[DEBUG] acl: %s replication - deletions: %d, updates: %d", + tr.SingularNoun(), + len(res.LocalDeletes), + len(res.LocalUpserts), + ) } - var tokens *structs.ACLTokenBatchResponse if len(res.LocalUpserts) > 0 { - tokens, err = s.fetchACLTokensBatch(res.LocalUpserts) - if err != nil { - return 0, false, fmt.Errorf("failed to retrieve ACL token updates: %v", err) - } else if tokens.Redacted { - return 0, false, fmt.Errorf("failed to retrieve unredacted tokens - replication token in use does not grant acl:write") + lenUpdated, err := tr.FetchUpdated(s, res.LocalUpserts) + if err == errContainsRedactedData { + return 0, false, fmt.Errorf("failed to retrieve unredacted %s - replication token in use does not grant acl:write", tr.PluralNoun()) + } else if err != nil { + return 0, false, fmt.Errorf("failed to retrieve ACL %s updates: %v", tr.SingularNoun(), err) } - - s.logger.Printf("[DEBUG] acl: token replication - downloaded %d tokens", len(tokens.Tokens)) + s.logger.Printf( + "[DEBUG] acl: %s replication - downloaded %d %s", + tr.SingularNoun(), + lenUpdated, + tr.PluralNoun(), + ) } if len(res.LocalDeletes) > 0 { - s.logger.Printf("[DEBUG] acl: token replication - performing deletions") + s.logger.Printf( + "[DEBUG] acl: %s replication - performing deletions", + tr.SingularNoun(), + ) - exit, err := s.deleteLocalACLTokens(res.LocalDeletes, ctx) + exit, err := s.deleteLocalACLType(ctx, tr, res.LocalDeletes) if exit { return 0, true, nil } if err != nil { - return 0, false, fmt.Errorf("failed to delete local ACL tokens: %v", err) + return 0, false, fmt.Errorf("failed to delete local ACL %s: %v", tr.PluralNoun(), err) } - s.logger.Printf("[DEBUG] acl: token replication - finished deletions") + s.logger.Printf("[DEBUG] acl: %s replication - finished deletions", tr.SingularNoun()) } if len(res.LocalUpserts) > 0 { - s.logger.Printf("[DEBUG] acl: token replication - performing updates") - exit, err := s.updateLocalACLTokens(tokens.Tokens, ctx) + s.logger.Printf("[DEBUG] acl: %s replication - performing updates", tr.SingularNoun()) + exit, err := s.updateLocalACLType(ctx, tr) if exit { return 0, true, nil } if err != nil { - return 0, false, fmt.Errorf("failed to update local ACL tokens: %v", err) + return 0, false, fmt.Errorf("failed to update local ACL %s: %v", tr.PluralNoun(), err) } - s.logger.Printf("[DEBUG] acl: token replication - finished updates") + s.logger.Printf("[DEBUG] acl: %s replication - finished updates", tr.SingularNoun()) } // Return the index we got back from the remote side, since we've synced // up with the remote state as of that index. - return remote.QueryMeta.Index, false, nil + return remoteIndex, false, nil } // IsACLReplicationEnabled returns true if ACL replication is enabled. @@ -547,20 +496,23 @@ func (s *Server) updateACLReplicationStatusError() { s.aclReplicationStatus.LastError = time.Now().Round(time.Second).UTC() } -func (s *Server) updateACLReplicationStatusIndex(index uint64) { +func (s *Server) updateACLReplicationStatusIndex(replicationType structs.ACLReplicationType, index uint64) { s.aclReplicationStatusLock.Lock() defer s.aclReplicationStatusLock.Unlock() s.aclReplicationStatus.LastSuccess = time.Now().Round(time.Second).UTC() - s.aclReplicationStatus.ReplicatedIndex = index -} - -func (s *Server) updateACLReplicationStatusTokenIndex(index uint64) { - s.aclReplicationStatusLock.Lock() - defer s.aclReplicationStatusLock.Unlock() - - s.aclReplicationStatus.LastSuccess = time.Now().Round(time.Second).UTC() - s.aclReplicationStatus.ReplicatedTokenIndex = index + switch replicationType { + case structs.ACLReplicateLegacy: + s.aclReplicationStatus.ReplicatedIndex = index + case structs.ACLReplicateTokens: + s.aclReplicationStatus.ReplicatedTokenIndex = index + case structs.ACLReplicatePolicies: + s.aclReplicationStatus.ReplicatedIndex = index + case structs.ACLReplicateRoles: + s.aclReplicationStatus.ReplicatedRoleIndex = index + default: + panic("unknown replication type: " + replicationType.SingularNoun()) + } } func (s *Server) initReplicationStatus() { @@ -583,6 +535,21 @@ func (s *Server) updateACLReplicationStatusRunning(replicationType structs.ACLRe s.aclReplicationStatusLock.Lock() defer s.aclReplicationStatusLock.Unlock() + // The running state represents which type of overall replication has been + // configured. Though there are various types of internal plumbing for acl + // replication, to the end user there are only 3 distinctly configurable + // variants: legacy, policy, token. Roles replicate with policies so we + // round that up here. + if replicationType == structs.ACLReplicateRoles { + replicationType = structs.ACLReplicatePolicies + } + s.aclReplicationStatus.Running = true s.aclReplicationStatus.ReplicationType = replicationType } + +func (s *Server) getACLReplicationStatusRunningType() (structs.ACLReplicationType, bool) { + s.aclReplicationStatusLock.RLock() + defer s.aclReplicationStatusLock.RUnlock() + return s.aclReplicationStatus.ReplicationType, s.aclReplicationStatus.Running +} diff --git a/agent/consul/acl_replication_legacy.go b/agent/consul/acl_replication_legacy.go index 3fc2ce5eb3..010c220a96 100644 --- a/agent/consul/acl_replication_legacy.go +++ b/agent/consul/acl_replication_legacy.go @@ -138,7 +138,7 @@ func reconcileLegacyACLs(local, remote structs.ACLs, lastRemoteIndex uint64) str // FetchLocalACLs returns the ACLs in the local state store. func (s *Server) fetchLocalLegacyACLs() (structs.ACLs, error) { - _, local, err := s.fsm.State().ACLTokenList(nil, false, true, "") + _, local, err := s.fsm.State().ACLTokenList(nil, false, true, "", "") if err != nil { return nil, err } diff --git a/agent/consul/acl_replication_legacy_test.go b/agent/consul/acl_replication_legacy_test.go index a1eea646f2..f5a2601d54 100644 --- a/agent/consul/acl_replication_legacy_test.go +++ b/agent/consul/acl_replication_legacy_test.go @@ -335,6 +335,10 @@ func TestACLReplication_IsACLReplicationEnabled(t *testing.T) { } } +// Note that this test is testing that legacy token data is replicated, NOT +// directly testing the legacy acl replication goroutine code. +// +// Actually testing legacy replication is difficult to do without old binaries. func TestACLReplication_LegacyTokens(t *testing.T) { t.Parallel() dir1, s1 := testServerWithConfig(t, func(c *Config) { @@ -367,6 +371,12 @@ func TestACLReplication_LegacyTokens(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") testrpc.WaitForLeader(t, s1.RPC, "dc2") + // Wait for legacy acls to be disabled so we are clear that + // legacy replication isn't meddling. + waitForNewACLs(t, s1) + waitForNewACLs(t, s2) + waitForNewACLReplication(t, s2, structs.ACLReplicateTokens) + // Create a bunch of new tokens. var id string for i := 0; i < 50; i++ { @@ -386,14 +396,15 @@ func TestACLReplication_LegacyTokens(t *testing.T) { } checkSame := func() error { - index, remote, err := s1.fsm.State().ACLTokenList(nil, true, true, "") + index, remote, err := s1.fsm.State().ACLTokenList(nil, true, true, "", "") if err != nil { return err } - _, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "") + _, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "") if err != nil { return err } + if got, want := len(remote), len(local); got != want { return fmt.Errorf("got %d remote ACLs want %d", got, want) } diff --git a/agent/consul/acl_replication_test.go b/agent/consul/acl_replication_test.go index 86c939c6ae..730527eedc 100644 --- a/agent/consul/acl_replication_test.go +++ b/agent/consul/acl_replication_test.go @@ -3,6 +3,7 @@ package consul import ( "fmt" "os" + "strconv" "testing" "time" @@ -15,6 +16,11 @@ import ( ) func TestACLReplication_diffACLPolicies(t *testing.T) { + diffACLPolicies := func(local structs.ACLPolicies, remote structs.ACLPolicyListStubs, lastRemoteIndex uint64) ([]string, []string) { + tr := &aclPolicyReplicator{local: local, remote: remote} + res := diffACLType(tr, lastRemoteIndex) + return res.LocalDeletes, res.LocalUpserts + } local := structs.ACLPolicies{ &structs.ACLPolicy{ ID: "44ef9aec-7654-4401-901b-4d4a8b3c80fc", @@ -127,6 +133,15 @@ func TestACLReplication_diffACLPolicies(t *testing.T) { } func TestACLReplication_diffACLTokens(t *testing.T) { + diffACLTokens := func( + local structs.ACLTokens, + remote structs.ACLTokenListStubs, + lastRemoteIndex uint64, + ) itemDiffResults { + tr := &aclTokenReplicator{local: local, remote: remote} + return diffACLType(tr, lastRemoteIndex) + } + local := structs.ACLTokens{ // When a just-upgraded (1.3->1.4+) secondary DC is replicating from an // upgraded primary DC (1.4+), the local state for tokens predating the @@ -307,6 +322,12 @@ func TestACLReplication_Tokens(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") testrpc.WaitForLeader(t, s1.RPC, "dc2") + // Wait for legacy acls to be disabled so we are clear that + // legacy replication isn't meddling. + waitForNewACLs(t, s1) + waitForNewACLs(t, s2) + waitForNewACLReplication(t, s2, structs.ACLReplicateTokens) + // Create a bunch of new tokens and policies var tokens structs.ACLTokens for i := 0; i < 50; i++ { @@ -328,11 +349,11 @@ func TestACLReplication_Tokens(t *testing.T) { tokens = append(tokens, &token) } - checkSame := func(t *retry.R) error { + checkSame := func(t *retry.R) { // only account for global tokens - local tokens shouldn't be replicated - index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "") + index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "") require.NoError(t, err) - _, local, err := s2.fsm.State().ACLTokenList(nil, false, true, "") + _, local, err := s2.fsm.State().ACLTokenList(nil, false, true, "", "") require.NoError(t, err) require.Len(t, local, len(remote)) @@ -340,18 +361,15 @@ func TestACLReplication_Tokens(t *testing.T) { require.Equal(t, token.Hash, local[i].Hash) } - var status structs.ACLReplicationStatus s2.aclReplicationStatusLock.RLock() - status = s2.aclReplicationStatus + status := s2.aclReplicationStatus s2.aclReplicationStatusLock.RUnlock() - if !status.Enabled || !status.Running || - status.ReplicationType != structs.ACLReplicateTokens || - status.ReplicatedTokenIndex != index || - status.SourceDatacenter != "dc1" { - return fmt.Errorf("ACL replication status differs") - } - return nil + require.True(t, status.Enabled) + require.True(t, status.Running) + require.Equal(t, status.ReplicationType, structs.ACLReplicateTokens) + require.Equal(t, status.ReplicatedTokenIndex, index) + require.Equal(t, status.SourceDatacenter, "dc1") } // Wait for the replica to converge. retry.Run(t, func(r *retry.R) { @@ -426,7 +444,7 @@ func TestACLReplication_Tokens(t *testing.T) { }) // verify dc2 local tokens didn't get blown away - _, local, err := s2.fsm.State().ACLTokenList(nil, true, false, "") + _, local, err := s2.fsm.State().ACLTokenList(nil, true, false, "", "") require.NoError(t, err) require.Len(t, local, 50) @@ -479,6 +497,12 @@ func TestACLReplication_Policies(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") testrpc.WaitForLeader(t, s1.RPC, "dc2") + // Wait for legacy acls to be disabled so we are clear that + // legacy replication isn't meddling. + waitForNewACLs(t, s1) + waitForNewACLs(t, s2) + waitForNewACLReplication(t, s2, structs.ACLReplicatePolicies) + // Create a bunch of new policies var policies structs.ACLPolicies for i := 0; i < 50; i++ { @@ -496,7 +520,7 @@ func TestACLReplication_Policies(t *testing.T) { policies = append(policies, &policy) } - checkSame := func(t *retry.R) error { + checkSame := func(t *retry.R) { // only account for global tokens - local tokens shouldn't be replicated index, remote, err := s1.fsm.State().ACLPolicyList(nil) require.NoError(t, err) @@ -508,18 +532,15 @@ func TestACLReplication_Policies(t *testing.T) { require.Equal(t, policy.Hash, local[i].Hash) } - var status structs.ACLReplicationStatus s2.aclReplicationStatusLock.RLock() - status = s2.aclReplicationStatus + status := s2.aclReplicationStatus s2.aclReplicationStatusLock.RUnlock() - if !status.Enabled || !status.Running || - status.ReplicationType != structs.ACLReplicatePolicies || - status.ReplicatedIndex != index || - status.SourceDatacenter != "dc1" { - return fmt.Errorf("ACL replication status differs") - } - return nil + require.True(t, status.Enabled) + require.True(t, status.Running) + require.Equal(t, status.ReplicationType, structs.ACLReplicatePolicies) + require.Equal(t, status.ReplicatedIndex, index) + require.Equal(t, status.SourceDatacenter, "dc1") } // Wait for the replica to converge. retry.Run(t, func(r *retry.R) { @@ -709,3 +730,249 @@ func TestACLReplication_TokensRedacted(t *testing.T) { require.True(r, status.LastError.After(minErrorTime), "Replication LastError not after the minErrorTime") }) } + +func TestACLReplication_AllTypes(t *testing.T) { + t.Parallel() + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + testrpc.WaitForLeader(t, s1.RPC, "dc1") + client := rpcClient(t, s1) + defer client.Close() + + dir2, s2 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc2" + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLTokenReplication = true + c.ACLReplicationRate = 100 + c.ACLReplicationBurst = 100 + c.ACLReplicationApplyLimit = 1000000 + }) + s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig) + testrpc.WaitForLeader(t, s2.RPC, "dc2") + defer os.RemoveAll(dir2) + defer s2.Shutdown() + + // Try to join. + joinWAN(t, s2, s1) + testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForLeader(t, s1.RPC, "dc2") + + // Wait for legacy acls to be disabled so we are clear that + // legacy replication isn't meddling. + waitForNewACLs(t, s1) + waitForNewACLs(t, s2) + waitForNewACLReplication(t, s2, structs.ACLReplicateTokens) + + const ( + numItems = 50 + numItemsThatAreLocal = 10 + ) + + // Create some data. + policyIDs, roleIDs, tokenIDs := createACLTestData(t, s1, "b1", numItems, numItemsThatAreLocal) + + checkSameTokens := func(t *retry.R) { + // only account for global tokens - local tokens shouldn't be replicated + index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "") + require.NoError(t, err) + // Query for all of them, so that we can prove that no globals snuck in. + _, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "") + require.NoError(t, err) + + require.Len(t, remote, len(local)) + for i, token := range remote { + require.Equal(t, token.Hash, local[i].Hash) + } + + s2.aclReplicationStatusLock.RLock() + status := s2.aclReplicationStatus + s2.aclReplicationStatusLock.RUnlock() + + require.True(t, status.Enabled) + require.True(t, status.Running) + require.Equal(t, status.ReplicationType, structs.ACLReplicateTokens) + require.Equal(t, status.ReplicatedTokenIndex, index) + require.Equal(t, status.SourceDatacenter, "dc1") + } + checkSamePolicies := func(t *retry.R) { + index, remote, err := s1.fsm.State().ACLPolicyList(nil) + require.NoError(t, err) + _, local, err := s2.fsm.State().ACLPolicyList(nil) + require.NoError(t, err) + + require.Len(t, remote, len(local)) + for i, policy := range remote { + require.Equal(t, policy.Hash, local[i].Hash) + } + + s2.aclReplicationStatusLock.RLock() + status := s2.aclReplicationStatus + s2.aclReplicationStatusLock.RUnlock() + + require.True(t, status.Enabled) + require.True(t, status.Running) + require.Equal(t, status.ReplicationType, structs.ACLReplicateTokens) + require.Equal(t, status.ReplicatedIndex, index) + require.Equal(t, status.SourceDatacenter, "dc1") + } + checkSameRoles := func(t *retry.R) { + index, remote, err := s1.fsm.State().ACLRoleList(nil, "") + require.NoError(t, err) + _, local, err := s2.fsm.State().ACLRoleList(nil, "") + require.NoError(t, err) + + require.Len(t, remote, len(local)) + for i, role := range remote { + require.Equal(t, role.Hash, local[i].Hash) + } + + s2.aclReplicationStatusLock.RLock() + status := s2.aclReplicationStatus + s2.aclReplicationStatusLock.RUnlock() + + require.True(t, status.Enabled) + require.True(t, status.Running) + require.Equal(t, status.ReplicationType, structs.ACLReplicateTokens) + require.Equal(t, status.ReplicatedRoleIndex, index) + require.Equal(t, status.SourceDatacenter, "dc1") + } + checkSame := func(t *retry.R) { + checkSameTokens(t) + checkSamePolicies(t) + checkSameRoles(t) + } + // Wait for the replica to converge. + retry.Run(t, func(r *retry.R) { + checkSame(r) + }) + + // Create additional data to replicate. + _, _, _ = createACLTestData(t, s1, "b2", numItems, numItemsThatAreLocal) + + // Wait for the replica to converge. + retry.Run(t, func(r *retry.R) { + checkSame(r) + }) + + // Delete one piece of each type of data from batch 1. + const itemToDelete = numItems - 1 + { + id := tokenIDs[itemToDelete] + + arg := structs.ACLTokenDeleteRequest{ + Datacenter: "dc1", + TokenID: id, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var dontCare string + if err := s1.RPC("ACL.TokenDelete", &arg, &dontCare); err != nil { + t.Fatalf("err: %v", err) + } + } + { + id := roleIDs[itemToDelete] + + arg := structs.ACLRoleDeleteRequest{ + Datacenter: "dc1", + RoleID: id, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var dontCare string + if err := s1.RPC("ACL.RoleDelete", &arg, &dontCare); err != nil { + t.Fatalf("err: %v", err) + } + } + { + id := policyIDs[itemToDelete] + + arg := structs.ACLPolicyDeleteRequest{ + Datacenter: "dc1", + PolicyID: id, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var dontCare string + if err := s1.RPC("ACL.PolicyDelete", &arg, &dontCare); err != nil { + t.Fatalf("err: %v", err) + } + } + // Wait for the replica to converge. + retry.Run(t, func(r *retry.R) { + checkSame(r) + }) +} + +func createACLTestData(t *testing.T, srv *Server, namePrefix string, numObjects, numItemsThatAreLocal int) (policyIDs, roleIDs, tokenIDs []string) { + require.True(t, numItemsThatAreLocal <= numObjects, 0, "numItemsThatAreLocal <= numObjects") + + // Create some policies. + for i := 0; i < numObjects; i++ { + str := strconv.Itoa(i) + arg := structs.ACLPolicySetRequest{ + Datacenter: "dc1", + Policy: structs.ACLPolicy{ + Name: namePrefix + "-policy-" + str, + Description: namePrefix + "-policy " + str, + Rules: testACLPolicyNew, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var out structs.ACLPolicy + if err := srv.RPC("ACL.PolicySet", &arg, &out); err != nil { + t.Fatalf("err: %v", err) + } + policyIDs = append(policyIDs, out.ID) + } + + // Create some roles. + for i := 0; i < numObjects; i++ { + str := strconv.Itoa(i) + arg := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Name: namePrefix + "-role-" + str, + Description: namePrefix + "-role " + str, + Policies: []structs.ACLRolePolicyLink{ + {ID: policyIDs[i]}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var out structs.ACLRole + if err := srv.RPC("ACL.RoleSet", &arg, &out); err != nil { + t.Fatalf("err: %v", err) + } + roleIDs = append(roleIDs, out.ID) + } + + // Create a bunch of new tokens. + for i := 0; i < numObjects; i++ { + str := strconv.Itoa(i) + arg := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: namePrefix + "-token " + str, + Policies: []structs.ACLTokenPolicyLink{ + {ID: policyIDs[i]}, + }, + Roles: []structs.ACLTokenRoleLink{ + {ID: roleIDs[i]}, + }, + Local: (i < numItemsThatAreLocal), + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var out structs.ACLToken + if err := srv.RPC("ACL.TokenSet", &arg, &out); err != nil { + t.Fatalf("err: %v", err) + } + tokenIDs = append(tokenIDs, out.AccessorID) + } + + return policyIDs, roleIDs, tokenIDs +} diff --git a/agent/consul/acl_replication_types.go b/agent/consul/acl_replication_types.go new file mode 100644 index 0000000000..7044442fdf --- /dev/null +++ b/agent/consul/acl_replication_types.go @@ -0,0 +1,370 @@ +package consul + +import ( + "context" + "fmt" + + "github.com/hashicorp/consul/agent/structs" +) + +type aclTokenReplicator struct { + local structs.ACLTokens + remote structs.ACLTokenListStubs + updated []*structs.ACLToken +} + +var _ aclTypeReplicator = (*aclTokenReplicator)(nil) + +func (r *aclTokenReplicator) Type() structs.ACLReplicationType { return structs.ACLReplicateTokens } +func (r *aclTokenReplicator) SingularNoun() string { return "token" } +func (r *aclTokenReplicator) PluralNoun() string { return "tokens" } + +func (r *aclTokenReplicator) FetchRemote(srv *Server, lastRemoteIndex uint64) (int, uint64, error) { + r.remote = nil + + remote, err := srv.fetchACLTokens(lastRemoteIndex) + if err != nil { + return 0, 0, err + } + + r.remote = remote.Tokens + return len(remote.Tokens), remote.QueryMeta.Index, nil +} + +func (r *aclTokenReplicator) FetchLocal(srv *Server) (int, uint64, error) { + r.local = nil + + idx, local, err := srv.fsm.State().ACLTokenList(nil, false, true, "", "") + if err != nil { + return 0, 0, err + } + + // Do not filter by expiration times. Wait until the tokens are explicitly + // deleted. + + r.local = local + return len(local), idx, nil +} + +func (r *aclTokenReplicator) SortState() (int, int) { + r.local.Sort() + r.remote.Sort() + + return len(r.local), len(r.remote) +} +func (r *aclTokenReplicator) LocalMeta(i int) (id string, modIndex uint64, hash []byte) { + v := r.local[i] + return v.AccessorID, v.ModifyIndex, v.Hash +} +func (r *aclTokenReplicator) RemoteMeta(i int) (id string, modIndex uint64, hash []byte) { + v := r.remote[i] + return v.AccessorID, v.ModifyIndex, v.Hash +} + +func (r *aclTokenReplicator) FetchUpdated(srv *Server, updates []string) (int, error) { + r.updated = nil + + if len(updates) > 0 { + tokens, err := srv.fetchACLTokensBatch(updates) + if err != nil { + return 0, err + } else if tokens.Redacted { + return 0, errContainsRedactedData + } + + // Do not filter by expiration times. Wait until the tokens are + // explicitly deleted. + + r.updated = tokens.Tokens + } + + return len(r.updated), nil +} + +func (r *aclTokenReplicator) DeleteLocalBatch(srv *Server, batch []string) error { + req := structs.ACLTokenBatchDeleteRequest{ + TokenIDs: batch, + } + + resp, err := srv.raftApply(structs.ACLTokenDeleteRequestType, &req) + if err != nil { + return err + } + if respErr, ok := resp.(error); ok && err != nil { + return respErr + } + return nil +} + +func (r *aclTokenReplicator) LenPendingUpdates() int { + return len(r.updated) +} + +func (r *aclTokenReplicator) PendingUpdateEstimatedSize(i int) int { + return r.updated[i].EstimateSize() +} + +func (r *aclTokenReplicator) PendingUpdateIsRedacted(i int) bool { + return r.updated[i].SecretID == redactedToken +} + +func (r *aclTokenReplicator) UpdateLocalBatch(ctx context.Context, srv *Server, start, end int) error { + req := structs.ACLTokenBatchSetRequest{ + Tokens: r.updated[start:end], + CAS: false, + } + + resp, err := srv.raftApply(structs.ACLTokenSetRequestType, &req) + if err != nil { + return err + } + if respErr, ok := resp.(error); ok && err != nil { + return respErr + } + + return nil +} + +/////////////////////// + +type aclPolicyReplicator struct { + local structs.ACLPolicies + remote structs.ACLPolicyListStubs + updated []*structs.ACLPolicy +} + +var _ aclTypeReplicator = (*aclPolicyReplicator)(nil) + +func (r *aclPolicyReplicator) Type() structs.ACLReplicationType { return structs.ACLReplicatePolicies } +func (r *aclPolicyReplicator) SingularNoun() string { return "policy" } +func (r *aclPolicyReplicator) PluralNoun() string { return "policies" } + +func (r *aclPolicyReplicator) FetchRemote(srv *Server, lastRemoteIndex uint64) (int, uint64, error) { + r.remote = nil + + remote, err := srv.fetchACLPolicies(lastRemoteIndex) + if err != nil { + return 0, 0, err + } + + r.remote = remote.Policies + return len(remote.Policies), remote.QueryMeta.Index, nil +} + +func (r *aclPolicyReplicator) FetchLocal(srv *Server) (int, uint64, error) { + r.local = nil + + idx, local, err := srv.fsm.State().ACLPolicyList(nil) + if err != nil { + return 0, 0, err + } + + r.local = local + return len(local), idx, nil +} + +func (r *aclPolicyReplicator) SortState() (int, int) { + r.local.Sort() + r.remote.Sort() + + return len(r.local), len(r.remote) +} +func (r *aclPolicyReplicator) LocalMeta(i int) (id string, modIndex uint64, hash []byte) { + v := r.local[i] + return v.ID, v.ModifyIndex, v.Hash +} +func (r *aclPolicyReplicator) RemoteMeta(i int) (id string, modIndex uint64, hash []byte) { + v := r.remote[i] + return v.ID, v.ModifyIndex, v.Hash +} + +func (r *aclPolicyReplicator) FetchUpdated(srv *Server, updates []string) (int, error) { + r.updated = nil + + if len(updates) > 0 { + policies, err := srv.fetchACLPoliciesBatch(updates) + if err != nil { + return 0, err + } + r.updated = policies.Policies + } + + return len(r.updated), nil +} + +func (r *aclPolicyReplicator) DeleteLocalBatch(srv *Server, batch []string) error { + req := structs.ACLPolicyBatchDeleteRequest{ + PolicyIDs: batch, + } + + resp, err := srv.raftApply(structs.ACLPolicyDeleteRequestType, &req) + if err != nil { + return err + } + if respErr, ok := resp.(error); ok && err != nil { + return respErr + } + return nil +} + +func (r *aclPolicyReplicator) LenPendingUpdates() int { + return len(r.updated) +} + +func (r *aclPolicyReplicator) PendingUpdateEstimatedSize(i int) int { + return r.updated[i].EstimateSize() +} + +func (r *aclPolicyReplicator) PendingUpdateIsRedacted(i int) bool { + return false +} + +func (r *aclPolicyReplicator) UpdateLocalBatch(ctx context.Context, srv *Server, start, end int) error { + req := structs.ACLPolicyBatchSetRequest{ + Policies: r.updated[start:end], + } + + resp, err := srv.raftApply(structs.ACLPolicySetRequestType, &req) + if err != nil { + return err + } + if respErr, ok := resp.(error); ok && err != nil { + return respErr + } + + return nil +} + +//////////////////////////////// + +type aclRoleReplicator struct { + local structs.ACLRoles + remote structs.ACLRoles + updated []*structs.ACLRole +} + +var _ aclTypeReplicator = (*aclRoleReplicator)(nil) + +func (r *aclRoleReplicator) Type() structs.ACLReplicationType { return structs.ACLReplicateRoles } +func (r *aclRoleReplicator) SingularNoun() string { return "role" } +func (r *aclRoleReplicator) PluralNoun() string { return "roles" } + +func (r *aclRoleReplicator) FetchRemote(srv *Server, lastRemoteIndex uint64) (int, uint64, error) { + r.remote = nil + + remote, err := srv.fetchACLRoles(lastRemoteIndex) + if err != nil { + return 0, 0, err + } + + r.remote = remote.Roles + return len(remote.Roles), remote.QueryMeta.Index, nil +} + +func (r *aclRoleReplicator) FetchLocal(srv *Server) (int, uint64, error) { + r.local = nil + + idx, local, err := srv.fsm.State().ACLRoleList(nil, "") + if err != nil { + return 0, 0, err + } + + r.local = local + return len(local), idx, nil +} + +func (r *aclRoleReplicator) SortState() (int, int) { + r.local.Sort() + r.remote.Sort() + + return len(r.local), len(r.remote) +} +func (r *aclRoleReplicator) LocalMeta(i int) (id string, modIndex uint64, hash []byte) { + v := r.local[i] + return v.ID, v.ModifyIndex, v.Hash +} +func (r *aclRoleReplicator) RemoteMeta(i int) (id string, modIndex uint64, hash []byte) { + v := r.remote[i] + return v.ID, v.ModifyIndex, v.Hash +} + +func (r *aclRoleReplicator) FetchUpdated(srv *Server, updates []string) (int, error) { + r.updated = nil + + if len(updates) > 0 { + // Since ACLRoles do not have a "list entry" variation, all of the data + // to replicate a role is already present in the "r.remote" list. + // + // We avoid a second query by just repurposing the data we already have + // access to in a way that is compatible with the generic ACL type + // replicator. + keep := make(map[string]struct{}) + for _, id := range updates { + keep[id] = struct{}{} + } + + subset := make([]*structs.ACLRole, 0, len(updates)) + for _, role := range r.remote { + if _, ok := keep[role.ID]; ok { + subset = append(subset, role) + } + } + + if len(subset) != len(keep) { // only possible via programming bug + for _, role := range subset { + delete(keep, role.ID) + } + missing := make([]string, 0, len(keep)) + for id, _ := range keep { + missing = append(missing, id) + } + return 0, fmt.Errorf("role replication trying to replicated uncached roles with IDs: %v", missing) + } + r.updated = subset + } + + return len(r.updated), nil +} + +func (r *aclRoleReplicator) DeleteLocalBatch(srv *Server, batch []string) error { + req := structs.ACLRoleBatchDeleteRequest{ + RoleIDs: batch, + } + + resp, err := srv.raftApply(structs.ACLRoleDeleteRequestType, &req) + if err != nil { + return err + } + if respErr, ok := resp.(error); ok && err != nil { + return respErr + } + return nil +} + +func (r *aclRoleReplicator) LenPendingUpdates() int { + return len(r.updated) +} + +func (r *aclRoleReplicator) PendingUpdateEstimatedSize(i int) int { + return r.updated[i].EstimateSize() +} + +func (r *aclRoleReplicator) PendingUpdateIsRedacted(i int) bool { + return false +} + +func (r *aclRoleReplicator) UpdateLocalBatch(ctx context.Context, srv *Server, start, end int) error { + req := structs.ACLRoleBatchSetRequest{ + Roles: r.updated[start:end], + } + + resp, err := srv.raftApply(structs.ACLRoleSetRequestType, &req) + if err != nil { + return err + } + if respErr, ok := resp.(error); ok && err != nil { + return respErr + } + + return nil +} diff --git a/agent/consul/acl_server.go b/agent/consul/acl_server.go index e8213af4fc..d895d922a2 100644 --- a/agent/consul/acl_server.go +++ b/agent/consul/acl_server.go @@ -13,7 +13,7 @@ var serverACLCacheConfig *structs.ACLCachesConfig = &structs.ACLCachesConfig{ // The server's ACL caching has a few underlying assumptions: // // 1 - All policies can be resolved locally. Hence we do not cache any - // unparsed policies as we have memdb for that. + // unparsed policies/roles as we have memdb for that. // 2 - While there could be many identities being used within a DC the // number of distinct policies and combined multi-policy authorizers // will be much less. @@ -26,6 +26,7 @@ var serverACLCacheConfig *structs.ACLCachesConfig = &structs.ACLCachesConfig{ Policies: 0, ParsedPolicies: 512, Authorizers: 1024, + Roles: 0, } func (s *Server) checkTokenUUID(id string) (bool, error) { @@ -61,6 +62,17 @@ func (s *Server) checkPolicyUUID(id string) (bool, error) { return !structs.ACLIDReserved(id), nil } +func (s *Server) checkRoleUUID(id string) (bool, error) { + state := s.fsm.State() + if _, role, err := state.ACLRoleGetByID(nil, id); err != nil { + return false, err + } else if role != nil { + return false, nil + } + + return !structs.ACLIDReserved(id), nil +} + func (s *Server) updateACLAdvertisement() { // One thing to note is that once in new ACL mode the server will // never transition to legacy ACL mode. This is not currently a @@ -172,6 +184,20 @@ func (s *Server) ResolvePolicyFromID(policyID string) (bool, *structs.ACLPolicy, return s.InACLDatacenter() || index > 0, policy, acl.ErrNotFound } +func (s *Server) ResolveRoleFromID(roleID string) (bool, *structs.ACLRole, error) { + index, role, err := s.fsm.State().ACLRoleGetByID(nil, roleID) + if err != nil { + return true, nil, err + } else if role != nil { + return true, role, nil + } + + // If the max index of the roles table is non-zero then we have acls, until then + // we may need to allow remote resolution. This is particularly useful to allow updating + // the replication token via the API in a non-primary dc. + return s.InACLDatacenter() || index > 0, role, acl.ErrNotFound +} + func (s *Server) ResolveToken(token string) (acl.Authorizer, error) { return s.acls.ResolveToken(token) } diff --git a/agent/consul/acl_test.go b/agent/consul/acl_test.go index 9ff773ca41..65f05d9875 100644 --- a/agent/consul/acl_test.go +++ b/agent/consul/acl_test.go @@ -60,6 +60,29 @@ func testIdentityForToken(token string) (bool, structs.ACLIdentity, error) { }, }, }, nil + case "missing-role": + return true, &structs.ACLToken{ + AccessorID: "435a75af-1763-4980-89f4-f0951dda53b4", + SecretID: "b1b6be70-ed2e-4c80-8495-bdb3db110b1e", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: "not-found", + }, + structs.ACLTokenRoleLink{ + ID: "acl-ro", + }, + }, + }, nil + case "missing-policy-on-role": + return true, &structs.ACLToken{ + AccessorID: "435a75af-1763-4980-89f4-f0951dda53b4", + SecretID: "b1b6be70-ed2e-4c80-8495-bdb3db110b1e", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: "missing-policy", + }, + }, + }, nil case "legacy-management": return true, &structs.ACLToken{ AccessorID: "d109a033-99d1-47e2-a711-d6593373a973", @@ -86,6 +109,36 @@ func testIdentityForToken(token string) (bool, structs.ACLIdentity, error) { }, }, }, nil + case "found-role": + // This should be permission-wise identical to "found", except it + // gets it's policies indirectly by way of a Role. + return true, &structs.ACLToken{ + AccessorID: "5f57c1f6-6a89-4186-9445-531b316e01df", + SecretID: "a1a54629-5050-4d17-8a4e-560d2423f835", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: "found", + }, + }, + }, nil + case "found-policy-and-role": + return true, &structs.ACLToken{ + AccessorID: "5f57c1f6-6a89-4186-9445-531b316e01df", + SecretID: "a1a54629-5050-4d17-8a4e-560d2423f835", + Policies: []structs.ACLTokenPolicyLink{ + structs.ACLTokenPolicyLink{ + ID: "node-wr", + }, + structs.ACLTokenPolicyLink{ + ID: "dc2-key-wr", + }, + }, + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: "service-ro", + }, + }, + }, nil case "acl-ro": return true, &structs.ACLToken{ AccessorID: "435a75af-1763-4980-89f4-f0951dda53b4", @@ -177,6 +230,24 @@ func testPolicyForID(policyID string) (bool, *structs.ACLPolicy, error) { Syntax: acl.SyntaxCurrent, RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, }, nil + case "service-ro": + return true, &structs.ACLPolicy{ + ID: "service-ro", + Name: "service-ro", + Description: "service-ro", + Rules: `service_prefix "" { policy = "read" }`, + Syntax: acl.SyntaxCurrent, + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, nil + case "service-wr": + return true, &structs.ACLPolicy{ + ID: "service-wr", + Name: "service-wr", + Description: "service-wr", + Rules: `service_prefix "" { policy = "write" }`, + Syntax: acl.SyntaxCurrent, + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, nil case "node-wr": return true, &structs.ACLPolicy{ ID: "node-wr", @@ -202,6 +273,141 @@ func testPolicyForID(policyID string) (bool, *structs.ACLPolicy, error) { } } +func testRoleForID(roleID string) (bool, *structs.ACLRole, error) { + switch roleID { + case "service-ro": + return true, &structs.ACLRole{ + ID: "service-ro", + Name: "service-ro", + Description: "service-ro", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "service-ro", + }, + }, + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, nil + case "service-wr": + return true, &structs.ACLRole{ + ID: "service-wr", + Name: "service-wr", + Description: "service-wr", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "service-wr", + }, + }, + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, nil + case "missing-policy": + return true, &structs.ACLRole{ + ID: "missing-policy", + Name: "missing-policy", + Description: "missing-policy", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "not-found", + }, + structs.ACLRolePolicyLink{ + ID: "acl-ro", + }, + }, + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, nil + case "found": + return true, &structs.ACLRole{ + ID: "found", + Name: "found", + Description: "found", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "node-wr", + }, + structs.ACLRolePolicyLink{ + ID: "dc2-key-wr", + }, + }, + }, nil + case "acl-ro": + return true, &structs.ACLRole{ + ID: "acl-ro", + Name: "acl-ro", + Description: "acl-ro", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "acl-ro", + }, + }, + }, nil + case "acl-wr": + return true, &structs.ACLRole{ + ID: "acl-rw", + Name: "acl-rw", + Description: "acl-rw", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "acl-wr", + }, + }, + }, nil + case "racey-unmodified": + return true, &structs.ACLRole{ + ID: "racey-unmodified", + Name: "racey-unmodified", + Description: "racey-unmodified", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "node-wr", + }, + structs.ACLRolePolicyLink{ + ID: "acl-wr", + }, + }, + }, nil + case "racey-modified": + return true, &structs.ACLRole{ + ID: "racey-modified", + Name: "racey-modified", + Description: "racey-modified", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "node-wr", + }, + }, + }, nil + case "concurrent-resolve-1": + return true, &structs.ACLRole{ + ID: "concurrent-resolve-1", + Name: "concurrent-resolve-1", + Description: "concurrent-resolve-1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "node-wr", + }, + structs.ACLRolePolicyLink{ + ID: "acl-wr", + }, + }, + }, nil + case "concurrent-resolve-2": + return true, &structs.ACLRole{ + ID: "concurrent-resolve-2", + Name: "concurrent-resolve-2", + Description: "concurrent-resolve-2", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "node-wr", + }, + structs.ACLRolePolicyLink{ + ID: "acl-wr", + }, + }, + }, nil + default: + return true, nil, acl.ErrNotFound + } +} + // ACLResolverTestDelegate is used to test // the ACLResolver without running Agents type ACLResolverTestDelegate struct { @@ -210,9 +416,69 @@ type ACLResolverTestDelegate struct { legacy bool localTokens bool localPolicies bool + localRoles bool getPolicyFn func(*structs.ACLPolicyResolveLegacyRequest, *structs.ACLPolicyResolveLegacyResponse) error tokenReadFn func(*structs.ACLTokenGetRequest, *structs.ACLTokenResponse) error policyResolveFn func(*structs.ACLPolicyBatchGetRequest, *structs.ACLPolicyBatchResponse) error + roleResolveFn func(*structs.ACLRoleBatchGetRequest, *structs.ACLRoleBatchResponse) error + + // state for the optional default resolver function defaultTokenReadFn + tokenCached bool + // state for the optional default resolver function defaultPolicyResolveFn + policyCached bool + // state for the optional default resolver function defaultRoleResolveFn + roleCached bool +} + +var errRPC = fmt.Errorf("Induced RPC Error") + +func (d *ACLResolverTestDelegate) defaultTokenReadFn(errAfterCached error) func(*structs.ACLTokenGetRequest, *structs.ACLTokenResponse) error { + return func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { + if !d.tokenCached { + _, token, _ := testIdentityForToken(args.TokenID) + reply.Token = token.(*structs.ACLToken) + + d.tokenCached = true + return nil + } + return errAfterCached + } +} + +func (d *ACLResolverTestDelegate) defaultPolicyResolveFn(errAfterCached error) func(*structs.ACLPolicyBatchGetRequest, *structs.ACLPolicyBatchResponse) error { + return func(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error { + if !d.policyCached { + for _, policyID := range args.PolicyIDs { + _, policy, _ := testPolicyForID(policyID) + if policy != nil { + reply.Policies = append(reply.Policies, policy) + } + } + + d.policyCached = true + return nil + } + + return errAfterCached + } +} + +func (d *ACLResolverTestDelegate) defaultRoleResolveFn(errAfterCached error) func(*structs.ACLRoleBatchGetRequest, *structs.ACLRoleBatchResponse) error { + return func(args *structs.ACLRoleBatchGetRequest, reply *structs.ACLRoleBatchResponse) error { + if !d.roleCached { + for _, roleID := range args.RoleIDs { + _, role, _ := testRoleForID(roleID) + if role != nil { + reply.Roles = append(reply.Roles, role) + } + } + + d.roleCached = true + return nil + } + + return errAfterCached + } } func (d *ACLResolverTestDelegate) ACLsEnabled() bool { @@ -243,23 +509,36 @@ func (d *ACLResolverTestDelegate) ResolvePolicyFromID(policyID string) (bool, *s return testPolicyForID(policyID) } +func (d *ACLResolverTestDelegate) ResolveRoleFromID(roleID string) (bool, *structs.ACLRole, error) { + if !d.localRoles { + return false, nil, nil + } + + return testRoleForID(roleID) +} + func (d *ACLResolverTestDelegate) RPC(method string, args interface{}, reply interface{}) error { switch method { case "ACL.GetPolicy": if d.getPolicyFn != nil { return d.getPolicyFn(args.(*structs.ACLPolicyResolveLegacyRequest), reply.(*structs.ACLPolicyResolveLegacyResponse)) } - panic("Bad Test Implmentation: should provide a getPolicyFn to the ACLResolverTestDelegate") + panic("Bad Test Implementation: should provide a getPolicyFn to the ACLResolverTestDelegate") case "ACL.TokenRead": if d.tokenReadFn != nil { return d.tokenReadFn(args.(*structs.ACLTokenGetRequest), reply.(*structs.ACLTokenResponse)) } - panic("Bad Test Implmentation: should provide a tokenReadFn to the ACLResolverTestDelegate") + panic("Bad Test Implementation: should provide a tokenReadFn to the ACLResolverTestDelegate") case "ACL.PolicyResolve": if d.policyResolveFn != nil { return d.policyResolveFn(args.(*structs.ACLPolicyBatchGetRequest), reply.(*structs.ACLPolicyBatchResponse)) } - panic("Bad Test Implmentation: should provide a policyResolveFn to the ACLResolverTestDelegate") + panic("Bad Test Implementation: should provide a policyResolveFn to the ACLResolverTestDelegate") + case "ACL.RoleResolve": + if d.roleResolveFn != nil { + return d.roleResolveFn(args.(*structs.ACLRoleBatchGetRequest), reply.(*structs.ACLRoleBatchResponse)) + } + panic("Bad Test Implementation: should provide a roleResolveFn to the ACLResolverTestDelegate") } panic("Bad Test Implementation: Was the ACLResolver updated to use new RPC methods") } @@ -276,6 +555,7 @@ func newTestACLResolver(t *testing.T, delegate ACLResolverDelegate, cb func(*ACL Policies: 4, ParsedPolicies: 4, Authorizers: 4, + Roles: 4, }, AutoDisable: true, Delegate: delegate, @@ -371,8 +651,9 @@ func TestACLResolver_DownPolicy(t *testing.T) { legacy: false, localTokens: false, localPolicies: true, + localRoles: true, tokenReadFn: func(*structs.ACLTokenGetRequest, *structs.ACLTokenResponse) error { - return fmt.Errorf("Induced RPC Error") + return errRPC }, } r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { @@ -395,8 +676,9 @@ func TestACLResolver_DownPolicy(t *testing.T) { legacy: false, localTokens: false, localPolicies: true, + localRoles: true, tokenReadFn: func(*structs.ACLTokenGetRequest, *structs.ACLTokenResponse) error { - return fmt.Errorf("Induced RPC Error") + return errRPC }, } r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { @@ -413,32 +695,20 @@ func TestACLResolver_DownPolicy(t *testing.T) { t.Run("Expired-Policy", func(t *testing.T) { t.Parallel() - policyCached := false delegate := &ACLResolverTestDelegate{ enabled: true, datacenter: "dc1", legacy: false, localTokens: true, localPolicies: false, - policyResolveFn: func(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error { - if !policyCached { - for _, policyID := range args.PolicyIDs { - _, policy, _ := testPolicyForID(policyID) - if policy != nil { - reply.Policies = append(reply.Policies, policy) - } - } - - policyCached = true - return nil - } - - return fmt.Errorf("Induced RPC Error") - }, + localRoles: false, } + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(errRPC) + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { config.Config.ACLDownPolicy = "deny" config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 }) authz, err := r.ResolveToken("found") @@ -460,75 +730,118 @@ func TestACLResolver_DownPolicy(t *testing.T) { requirePolicyCached(t, r, "dc2-key-wr", false, "expired") // from "found" token }) - t.Run("Extend-Cache", func(t *testing.T) { + t.Run("Expired-Role", func(t *testing.T) { t.Parallel() - cached := false - delegate := &ACLResolverTestDelegate{ - enabled: true, - datacenter: "dc1", - legacy: false, - localTokens: false, - localPolicies: true, - tokenReadFn: func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { - if !cached { - _, token, _ := testIdentityForToken("found") - reply.Token = token.(*structs.ACLToken) - cached = true - return nil - } - return fmt.Errorf("Induced RPC Error") - }, - } - r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { - config.Config.ACLDownPolicy = "extend-cache" - config.Config.ACLTokenTTL = 0 - }) - - authz, err := r.ResolveToken("foo") - require.NoError(t, err) - require.NotNil(t, authz) - require.True(t, authz.NodeWrite("foo", nil)) - - requireIdentityCached(t, r, "foo", true, "cached") - - authz2, err := r.ResolveToken("foo") - require.NoError(t, err) - require.NotNil(t, authz2) - // testing pointer equality - these will be the same object because it is cached. - require.True(t, authz == authz2) - require.True(t, authz.NodeWrite("foo", nil)) - - requireIdentityCached(t, r, "foo", true, "still cached") - }) - - t.Run("Extend-Cache-Expired-Policy", func(t *testing.T) { - t.Parallel() - policyCached := false delegate := &ACLResolverTestDelegate{ enabled: true, datacenter: "dc1", legacy: false, localTokens: true, localPolicies: false, - policyResolveFn: func(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error { - if !policyCached { - for _, policyID := range args.PolicyIDs { - _, policy, _ := testPolicyForID(policyID) - if policy != nil { - reply.Policies = append(reply.Policies, policy) - } - } - - policyCached = true - return nil - } - - return fmt.Errorf("Induced RPC Error") - }, + localRoles: false, } + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(errRPC) + delegate.roleResolveFn = delegate.defaultRoleResolveFn(errRPC) + + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { + config.Config.ACLDownPolicy = "deny" + config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 + }) + + authz, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz) + require.True(t, authz.NodeWrite("foo", nil)) + + // role cache expired - so we will fail to resolve that role and use the default policy only + authz2, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz2) + require.False(t, authz == authz2) + require.False(t, authz2.NodeWrite("foo", nil)) + }) + + t.Run("Extend-Cache-Policy", func(t *testing.T) { + t.Parallel() + delegate := &ACLResolverTestDelegate{ + enabled: true, + datacenter: "dc1", + legacy: false, + localTokens: false, + localPolicies: true, + localRoles: true, + } + delegate.tokenReadFn = delegate.defaultTokenReadFn(errRPC) + + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { + config.Config.ACLDownPolicy = "extend-cache" + config.Config.ACLTokenTTL = 0 + }) + + authz, err := r.ResolveToken("found") + require.NoError(t, err) + require.NotNil(t, authz) + require.True(t, authz.NodeWrite("foo", nil)) + + requireIdentityCached(t, r, "found", true, "cached") + + authz2, err := r.ResolveToken("found") + require.NoError(t, err) + require.NotNil(t, authz2) + // testing pointer equality - these will be the same object because it is cached. + require.True(t, authz == authz2) + require.True(t, authz2.NodeWrite("foo", nil)) + }) + + t.Run("Extend-Cache-Role", func(t *testing.T) { + t.Parallel() + delegate := &ACLResolverTestDelegate{ + enabled: true, + datacenter: "dc1", + legacy: false, + localTokens: false, + localPolicies: true, + localRoles: true, + } + delegate.tokenReadFn = delegate.defaultTokenReadFn(errRPC) + + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { + config.Config.ACLDownPolicy = "extend-cache" + config.Config.ACLTokenTTL = 0 + }) + + authz, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz) + require.True(t, authz.NodeWrite("foo", nil)) + + requireIdentityCached(t, r, "found-role", true, "still cached") + + authz2, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz2) + // testing pointer equality - these will be the same object because it is cached. + require.True(t, authz == authz2) + require.True(t, authz2.NodeWrite("foo", nil)) + }) + + t.Run("Extend-Cache-Expired-Policy", func(t *testing.T) { + t.Parallel() + delegate := &ACLResolverTestDelegate{ + enabled: true, + datacenter: "dc1", + legacy: false, + localTokens: true, + localPolicies: false, + localRoles: false, + } + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(errRPC) + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { config.Config.ACLDownPolicy = "extend-cache" config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 }) authz, err := r.ResolveToken("found") @@ -550,36 +863,56 @@ func TestACLResolver_DownPolicy(t *testing.T) { requirePolicyCached(t, r, "dc2-key-wr", true, "still cached") // from "found" token }) - t.Run("Async-Cache-Expired-Policy", func(t *testing.T) { + t.Run("Extend-Cache-Expired-Role", func(t *testing.T) { t.Parallel() - policyCached := false delegate := &ACLResolverTestDelegate{ enabled: true, datacenter: "dc1", legacy: false, localTokens: true, localPolicies: false, - policyResolveFn: func(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error { - if !policyCached { - for _, policyID := range args.PolicyIDs { - _, policy, _ := testPolicyForID(policyID) - if policy != nil { - reply.Policies = append(reply.Policies, policy) - } - } - - policyCached = true - return nil - } - - // We don't need to return acl.ErrNotFound here but we could. The ACLResolver will search for any - // policies not in the response and emit an ACL not found for any not-found within the result set. - return nil - }, + localRoles: false, } + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(errRPC) + delegate.roleResolveFn = delegate.defaultRoleResolveFn(errRPC) + + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { + config.Config.ACLDownPolicy = "extend-cache" + config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 + }) + + authz, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz) + require.True(t, authz.NodeWrite("foo", nil)) + + // Will just use the policy cache + authz2, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz2) + require.True(t, authz == authz2) + require.True(t, authz.NodeWrite("foo", nil)) + }) + + t.Run("Async-Cache-Expired-Policy", func(t *testing.T) { + t.Parallel() + delegate := &ACLResolverTestDelegate{ + enabled: true, + datacenter: "dc1", + legacy: false, + localTokens: true, + localPolicies: false, + localRoles: false, + } + // We don't need to return acl.ErrNotFound here but we could. The ACLResolver will search for any + // policies not in the response and emit an ACL not found for any not-found within the result set. + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(nil) + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { config.Config.ACLDownPolicy = "async-cache" config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 }) authz, err := r.ResolveToken("found") @@ -613,45 +946,67 @@ func TestACLResolver_DownPolicy(t *testing.T) { requirePolicyCached(t, r, "dc2-key-wr", false, "no longer cached") // from "found" token }) - t.Run("Extend-Cache-Client", func(t *testing.T) { + t.Run("Async-Cache-Expired-Role", func(t *testing.T) { + t.Parallel() + delegate := &ACLResolverTestDelegate{ + enabled: true, + datacenter: "dc1", + legacy: false, + localTokens: true, + localPolicies: false, + localRoles: false, + } + // We don't need to return acl.ErrNotFound here but we could. The ACLResolver will search for any + // policies not in the response and emit an ACL not found for any not-found within the result set. + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(nil) + delegate.roleResolveFn = delegate.defaultRoleResolveFn(nil) + + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { + config.Config.ACLDownPolicy = "async-cache" + config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 + }) + + authz, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz) + require.True(t, authz.NodeWrite("foo", nil)) + + // The identity should have been cached so this should still be valid + authz2, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz2) + // testing pointer equality - these will be the same object because it is cached. + require.True(t, authz == authz2) + require.True(t, authz.NodeWrite("foo", nil)) + + // the go routine spawned will eventually return with a authz that doesn't have the policy + retry.Run(t, func(t *retry.R) { + authz3, err := r.ResolveToken("found-role") + assert.NoError(t, err) + assert.NotNil(t, authz3) + assert.False(t, authz3.NodeWrite("foo", nil)) + }) + }) + + t.Run("Extend-Cache-Client-Policy", func(t *testing.T) { t.Parallel() - tokenCached := false - policyCached := false delegate := &ACLResolverTestDelegate{ enabled: true, datacenter: "dc1", legacy: false, localTokens: false, localPolicies: false, - tokenReadFn: func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { - if !tokenCached { - _, token, _ := testIdentityForToken("found") - reply.Token = token.(*structs.ACLToken) - tokenCached = true - return nil - } - return fmt.Errorf("Induced RPC Error") - }, - policyResolveFn: func(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error { - if !policyCached { - for _, policyID := range args.PolicyIDs { - _, policy, _ := testPolicyForID(policyID) - if policy != nil { - reply.Policies = append(reply.Policies, policy) - } - } - - policyCached = true - return nil - } - - return fmt.Errorf("Induced RPC Error") - }, + localRoles: false, } + delegate.tokenReadFn = delegate.defaultTokenReadFn(errRPC) + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(errRPC) + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { config.Config.ACLDownPolicy = "extend-cache" config.Config.ACLTokenTTL = 0 config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 }) authz, err := r.ResolveToken("found") @@ -667,62 +1022,89 @@ func TestACLResolver_DownPolicy(t *testing.T) { require.NotNil(t, authz2) // testing pointer equality - these will be the same object because it is cached. require.True(t, authz == authz2) + require.True(t, authz2.NodeWrite("foo", nil)) + }) + + t.Run("Extend-Cache-Client-Role", func(t *testing.T) { + t.Parallel() + delegate := &ACLResolverTestDelegate{ + enabled: true, + datacenter: "dc1", + legacy: false, + localTokens: false, + localPolicies: false, + localRoles: false, + } + delegate.tokenReadFn = delegate.defaultTokenReadFn(errRPC) + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(errRPC) + delegate.roleResolveFn = delegate.defaultRoleResolveFn(errRPC) + + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { + config.Config.ACLDownPolicy = "extend-cache" + config.Config.ACLTokenTTL = 0 + config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 + }) + + authz, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz) require.True(t, authz.NodeWrite("foo", nil)) requirePolicyCached(t, r, "node-wr", true, "still cached") // from "found" token requirePolicyCached(t, r, "dc2-key-wr", true, "still cached") // from "found" token + + authz2, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz2) + // testing pointer equality - these will be the same object because it is cached. + require.True(t, authz == authz2) + require.True(t, authz2.NodeWrite("foo", nil)) }) t.Run("Async-Cache", func(t *testing.T) { t.Parallel() - cached := false delegate := &ACLResolverTestDelegate{ enabled: true, datacenter: "dc1", legacy: false, localTokens: false, localPolicies: true, - tokenReadFn: func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { - if !cached { - _, token, _ := testIdentityForToken("found") - reply.Token = token.(*structs.ACLToken) - cached = true - return nil - } - return acl.ErrNotFound - }, + localRoles: true, } + delegate.tokenReadFn = delegate.defaultTokenReadFn(acl.ErrNotFound) + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { config.Config.ACLDownPolicy = "async-cache" config.Config.ACLTokenTTL = 0 }) - authz, err := r.ResolveToken("foo") + authz, err := r.ResolveToken("found") require.NoError(t, err) require.NotNil(t, authz) require.True(t, authz.NodeWrite("foo", nil)) - requireIdentityCached(t, r, "foo", true, "cached") + requireIdentityCached(t, r, "found", true, "cached") // The identity should have been cached so this should still be valid - authz2, err := r.ResolveToken("foo") + authz2, err := r.ResolveToken("found") require.NoError(t, err) require.NotNil(t, authz2) // testing pointer equality - these will be the same object because it is cached. require.True(t, authz == authz2) - require.True(t, authz.NodeWrite("foo", nil)) + require.True(t, authz2.NodeWrite("foo", nil)) - requireIdentityCached(t, r, "foo", true, "cached") + requireIdentityCached(t, r, "found", true, "cached") // the go routine spawned will eventually return and this will be a not found error retry.Run(t, func(t *retry.R) { - authz3, err := r.ResolveToken("foo") + authz3, err := r.ResolveToken("found") assert.Error(t, err) assert.True(t, acl.IsErrNotFound(err)) assert.Nil(t, authz3) }) - requireIdentityCached(t, r, "foo", false, "no longer cached") + requireIdentityCached(t, r, "found", false, "no longer cached") }) t.Run("PolicyResolve-TokenNotFound", func(t *testing.T) { @@ -864,6 +1246,7 @@ func TestACLResolver_DatacenterScoping(t *testing.T) { legacy: false, localTokens: true, localPolicies: true, + localRoles: true, // No need to provide any of the RPC callbacks } r := newTestACLResolver(t, delegate, nil) @@ -883,6 +1266,7 @@ func TestACLResolver_DatacenterScoping(t *testing.T) { legacy: false, localTokens: true, localPolicies: true, + localRoles: true, // No need to provide any of the RPC callbacks } r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { @@ -898,6 +1282,7 @@ func TestACLResolver_DatacenterScoping(t *testing.T) { }) } +// TODO(rb): replicate this sort of test but for roles func TestACLResolver_Client(t *testing.T) { t.Parallel() @@ -951,6 +1336,7 @@ func TestACLResolver_Client(t *testing.T) { r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { config.Config.ACLTokenTTL = 600 * time.Second config.Config.ACLPolicyTTL = 30 * time.Millisecond + config.Config.ACLRoleTTL = 30 * time.Millisecond config.Config.ACLDownPolicy = "extend-cache" }) @@ -1039,6 +1425,7 @@ func TestACLResolver_Client(t *testing.T) { // being resolved concurrently config.Config.ACLTokenTTL = 0 * time.Second config.Config.ACLPolicyTTL = 30 * time.Millisecond + config.Config.ACLRoleTTL = 30 * time.Millisecond config.Config.ACLDownPolicy = "extend-cache" }) @@ -1058,7 +1445,7 @@ func TestACLResolver_Client(t *testing.T) { }) } -func TestACLResolver_LocalTokensAndPolicies(t *testing.T) { +func TestACLResolver_LocalTokensPoliciesAndRoles(t *testing.T) { t.Parallel() delegate := &ACLResolverTestDelegate{ enabled: true, @@ -1066,66 +1453,23 @@ func TestACLResolver_LocalTokensAndPolicies(t *testing.T) { legacy: false, localTokens: true, localPolicies: true, + localRoles: true, // No need to provide any of the RPC callbacks } - r := newTestACLResolver(t, delegate, nil) - t.Run("Missing Identity", func(t *testing.T) { - authz, err := r.ResolveToken("doesn't exist") - require.Nil(t, authz) - require.Error(t, err) - require.True(t, acl.IsErrNotFound(err)) - }) - - t.Run("Missing Policy", func(t *testing.T) { - authz, err := r.ResolveToken("missing-policy") - require.NoError(t, err) - require.NotNil(t, authz) - require.True(t, authz.ACLRead()) - require.False(t, authz.NodeWrite("foo", nil)) - }) - - t.Run("Normal", func(t *testing.T) { - authz, err := r.ResolveToken("found") - require.NotNil(t, authz) - require.NoError(t, err) - require.False(t, authz.ACLRead()) - require.True(t, authz.NodeWrite("foo", nil)) - }) - - t.Run("Anonymous", func(t *testing.T) { - authz, err := r.ResolveToken("") - require.NotNil(t, authz) - require.NoError(t, err) - require.False(t, authz.ACLRead()) - require.True(t, authz.NodeWrite("foo", nil)) - }) - - t.Run("legacy-management", func(t *testing.T) { - authz, err := r.ResolveToken("legacy-management") - require.NotNil(t, authz) - require.NoError(t, err) - require.True(t, authz.ACLWrite()) - require.True(t, authz.KeyRead("foo")) - }) - - t.Run("legacy-client", func(t *testing.T) { - authz, err := r.ResolveToken("legacy-client") - require.NoError(t, err) - require.NotNil(t, authz) - require.False(t, authz.OperatorRead()) - require.True(t, authz.ServiceRead("foo")) - }) + testACLResolver_variousTokens(t, delegate) } -func TestACLResolver_LocalPolicies(t *testing.T) { +func TestACLResolver_LocalPoliciesAndRoles(t *testing.T) { t.Parallel() + delegate := &ACLResolverTestDelegate{ enabled: true, datacenter: "dc1", legacy: false, localTokens: false, localPolicies: true, + localRoles: true, tokenReadFn: func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { _, token, err := testIdentityForToken(args.TokenID) @@ -1135,6 +1479,12 @@ func TestACLResolver_LocalPolicies(t *testing.T) { return err }, } + + testACLResolver_variousTokens(t, delegate) +} + +func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelegate) { + t.Helper() r := newTestACLResolver(t, delegate, nil) t.Run("Missing Identity", func(t *testing.T) { @@ -1152,7 +1502,23 @@ func TestACLResolver_LocalPolicies(t *testing.T) { require.False(t, authz.NodeWrite("foo", nil)) }) - t.Run("Normal", func(t *testing.T) { + t.Run("Missing Role", func(t *testing.T) { + authz, err := r.ResolveToken("missing-role") + require.NoError(t, err) + require.NotNil(t, authz) + require.True(t, authz.ACLRead()) + require.False(t, authz.NodeWrite("foo", nil)) + }) + + t.Run("Missing Policy on Role", func(t *testing.T) { + authz, err := r.ResolveToken("missing-policy-on-role") + require.NoError(t, err) + require.NotNil(t, authz) + require.True(t, authz.ACLRead()) + require.False(t, authz.NodeWrite("foo", nil)) + }) + + t.Run("Normal with Policy", func(t *testing.T) { authz, err := r.ResolveToken("found") require.NotNil(t, authz) require.NoError(t, err) @@ -1160,6 +1526,23 @@ func TestACLResolver_LocalPolicies(t *testing.T) { require.True(t, authz.NodeWrite("foo", nil)) }) + t.Run("Normal with Role", func(t *testing.T) { + authz, err := r.ResolveToken("found-role") + require.NotNil(t, authz) + require.NoError(t, err) + require.False(t, authz.ACLRead()) + require.True(t, authz.NodeWrite("foo", nil)) + }) + + t.Run("Normal with Policy and Role", func(t *testing.T) { + authz, err := r.ResolveToken("found-policy-and-role") + require.NotNil(t, authz) + require.NoError(t, err) + require.False(t, authz.ACLRead()) + require.True(t, authz.NodeWrite("foo", nil)) + require.True(t, authz.ServiceRead("bar")) + }) + t.Run("Anonymous", func(t *testing.T) { authz, err := r.ResolveToken("") require.NotNil(t, authz) @@ -1214,7 +1597,7 @@ func TestACLResolver_Legacy(t *testing.T) { cached = true return nil } - return fmt.Errorf("Induced RPC Error") + return errRPC }, } r := newTestACLResolver(t, delegate, nil) @@ -1263,7 +1646,7 @@ func TestACLResolver_Legacy(t *testing.T) { cached = true return nil } - return fmt.Errorf("Induced RPC Error") + return errRPC }, } r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { @@ -1314,7 +1697,7 @@ func TestACLResolver_Legacy(t *testing.T) { cached = true return nil } - return fmt.Errorf("Induced RPC Error") + return errRPC }, } r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { @@ -1366,7 +1749,7 @@ func TestACLResolver_Legacy(t *testing.T) { cached = true return nil } - return fmt.Errorf("Induced RPC Error") + return errRPC }, } r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { diff --git a/agent/consul/config.go b/agent/consul/config.go index cce70f392b..d7e92943bb 100644 --- a/agent/consul/config.go +++ b/agent/consul/config.go @@ -243,6 +243,11 @@ type Config struct { // a substantial cost. ACLPolicyTTL time.Duration + // ACLRoleTTL controls the time-to-live of cached ACL roles. + // It can be set to zero to disable caching, but this adds + // a substantial cost. + ACLRoleTTL time.Duration + // ACLDisabledTTL is the time between checking if ACLs should be // enabled. This ACLDisabledTTL time.Duration @@ -470,6 +475,7 @@ func DefaultConfig() *Config { SerfFloodInterval: 60 * time.Second, ReconcileInterval: 60 * time.Second, ProtocolVersion: ProtocolVersion2Compatible, + ACLRoleTTL: 30 * time.Second, ACLPolicyTTL: 30 * time.Second, ACLTokenTTL: 30 * time.Second, ACLDefaultPolicy: "allow", diff --git a/agent/consul/fsm/commands_oss.go b/agent/consul/fsm/commands_oss.go index f9d75e83c8..36a09174df 100644 --- a/agent/consul/fsm/commands_oss.go +++ b/agent/consul/fsm/commands_oss.go @@ -30,6 +30,8 @@ func init() { registerCommand(structs.ACLPolicyDeleteRequestType, (*FSM).applyACLPolicyDeleteOperation) registerCommand(structs.ConnectCALeafRequestType, (*FSM).applyConnectCALeafOperation) registerCommand(structs.ConfigEntryRequestType, (*FSM).applyConfigEntryOperation) + registerCommand(structs.ACLRoleSetRequestType, (*FSM).applyACLRoleSetOperation) + registerCommand(structs.ACLRoleDeleteRequestType, (*FSM).applyACLRoleDeleteOperation) } func (c *FSM) applyRegister(buf []byte, index uint64) interface{} { @@ -452,3 +454,25 @@ func (c *FSM) applyConfigEntryOperation(buf []byte, index uint64) interface{} { return fmt.Errorf("invalid config entry operation type: %v", req.Op) } } + +func (c *FSM) applyACLRoleSetOperation(buf []byte, index uint64) interface{} { + var req structs.ACLRoleBatchSetRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "role"}, time.Now(), + []metrics.Label{{Name: "op", Value: "upsert"}}) + + return c.state.ACLRoleBatchSet(index, req.Roles) +} + +func (c *FSM) applyACLRoleDeleteOperation(buf []byte, index uint64) interface{} { + var req structs.ACLRoleBatchDeleteRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "role"}, time.Now(), + []metrics.Label{{Name: "op", Value: "delete"}}) + + return c.state.ACLRoleBatchDelete(index, req.RoleIDs) +} diff --git a/agent/consul/fsm/snapshot_oss.go b/agent/consul/fsm/snapshot_oss.go index 4195b8c422..3ad281434b 100644 --- a/agent/consul/fsm/snapshot_oss.go +++ b/agent/consul/fsm/snapshot_oss.go @@ -28,6 +28,7 @@ func init() { registerRestorer(structs.ACLTokenSetRequestType, restoreToken) registerRestorer(structs.ACLPolicySetRequestType, restorePolicy) registerRestorer(structs.ConfigEntryRequestType, restoreConfigEntry) + registerRestorer(structs.ACLRoleSetRequestType, restoreRole) } func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) error { @@ -203,6 +204,20 @@ func (s *snapshot) persistACLs(sink raft.SnapshotSink, } } + roles, err := s.state.ACLRoles() + if err != nil { + return err + } + + for role := roles.Next(); role != nil; role = roles.Next() { + if _, err := sink.Write([]byte{byte(structs.ACLRoleSetRequestType)}); err != nil { + return err + } + if err := encoder.Encode(role.(*structs.ACLRole)); err != nil { + return err + } + } + return nil } @@ -603,3 +618,11 @@ func restoreConfigEntry(header *snapshotHeader, restore *state.Restore, decoder } return restore.ConfigEntry(req.Entry) } + +func restoreRole(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.ACLRole + if err := decoder.Decode(&req); err != nil { + return err + } + return restore.ACLRole(&req) +} diff --git a/agent/consul/helper_test.go b/agent/consul/helper_test.go index bd84fce5c5..678ef033fd 100644 --- a/agent/consul/helper_test.go +++ b/agent/consul/helper_test.go @@ -11,7 +11,7 @@ import ( "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/types" - "github.com/hashicorp/net-rpc-msgpackrpc" + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/raft" "github.com/hashicorp/serf/serf" "github.com/stretchr/testify/require" @@ -166,6 +166,21 @@ func waitForNewACLs(t *testing.T, server *Server) { require.False(t, server.UseLegacyACLs(), "Server cannot use new ACLs") } +func waitForNewACLReplication(t *testing.T, server *Server, expectedReplicationType structs.ACLReplicationType) { + var ( + replTyp structs.ACLReplicationType + running bool + ) + retry.Run(t, func(r *retry.R) { + replTyp, running = server.getACLReplicationStatusRunningType() + require.Equal(r, expectedReplicationType, replTyp, "Server not running new replicator yet") + require.True(r, running, "Server not running new replicator yet") + }) + + require.Equal(t, expectedReplicationType, replTyp, "Server not running new replicator yet") + require.True(t, running, "Server not running new replicator yet") +} + func seeEachOther(a, b []serf.Member, addra, addrb string) bool { return serfMembersContains(a, addrb) && serfMembersContains(b, addra) } diff --git a/agent/consul/leader.go b/agent/consul/leader.go index fe58c3ad7e..cbef94da47 100644 --- a/agent/consul/leader.go +++ b/agent/consul/leader.go @@ -660,6 +660,7 @@ func (s *Server) startACLUpgrade() { // Assign the global-management policy to legacy management tokens if len(newToken.Policies) == 0 && len(newToken.ServiceIdentities) == 0 && + len(newToken.Roles) == 0 && newToken.Type == structs.ACLTokenTypeManagement { newToken.Policies = append(newToken.Policies, structs.ACLTokenPolicyLink{ID: structs.ACLPolicyGlobalManagementID}) } @@ -738,7 +739,7 @@ func (s *Server) startLegacyACLReplication() { s.logger.Printf("[WARN] consul: Legacy ACL replication error (will retry if still leader): %v", err) } else { lastRemoteIndex = index - s.updateACLReplicationStatusIndex(index) + s.updateACLReplicationStatusIndex(structs.ACLReplicateLegacy, index) s.logger.Printf("[DEBUG] consul: Legacy ACL replication completed through remote index %d", index) } } @@ -760,8 +761,22 @@ func (s *Server) startACLReplication() { ctx, cancel := context.WithCancel(context.Background()) s.aclReplicationCancel = cancel - replicationType := structs.ACLReplicatePolicies + s.startACLReplicator(ctx, structs.ACLReplicatePolicies, s.replicateACLPolicies) + s.startACLReplicator(ctx, structs.ACLReplicateRoles, s.replicateACLRoles) + if s.config.ACLTokenReplication { + s.startACLReplicator(ctx, structs.ACLReplicateTokens, s.replicateACLTokens) + s.updateACLReplicationStatusRunning(structs.ACLReplicateTokens) + } else { + s.updateACLReplicationStatusRunning(structs.ACLReplicatePolicies) + } + + s.aclReplicationEnabled = true +} + +type replicateFunc func(ctx context.Context, lastRemoteIndex uint64) (uint64, bool, error) + +func (s *Server) startACLReplicator(ctx context.Context, replicationType structs.ACLReplicationType, replicateFunc replicateFunc) { go func() { var failedAttempts uint limiter := rate.NewLimiter(rate.Limit(s.config.ACLReplicationRate), s.config.ACLReplicationBurst) @@ -776,7 +791,7 @@ func (s *Server) startACLReplication() { continue } - index, exit, err := s.replicateACLPolicies(lastRemoteIndex, ctx) + index, exit, err := replicateFunc(ctx, lastRemoteIndex) if exit { return } @@ -784,7 +799,7 @@ func (s *Server) startACLReplication() { if err != nil { lastRemoteIndex = 0 s.updateACLReplicationStatusError() - s.logger.Printf("[WARN] consul: ACL policy replication error (will retry if still leader): %v", err) + s.logger.Printf("[WARN] consul: ACL %s replication error (will retry if still leader): %v", replicationType.SingularNoun(), err) if (1 << failedAttempts) < aclReplicationMaxRetryBackoff { failedAttempts++ } @@ -797,65 +812,14 @@ func (s *Server) startACLReplication() { } } else { lastRemoteIndex = index - s.updateACLReplicationStatusIndex(index) - s.logger.Printf("[DEBUG] consul: ACL policy replication completed through remote index %d", index) + s.updateACLReplicationStatusIndex(replicationType, index) + s.logger.Printf("[DEBUG] consul: ACL %s replication completed through remote index %d", replicationType.SingularNoun(), index) failedAttempts = 0 } } }() - s.logger.Printf("[INFO] acl: started ACL Policy replication") - - if s.config.ACLTokenReplication { - replicationType = structs.ACLReplicateTokens - go func() { - var failedAttempts uint - limiter := rate.NewLimiter(rate.Limit(s.config.ACLReplicationRate), s.config.ACLReplicationBurst) - - var lastRemoteIndex uint64 - for { - if err := limiter.Wait(ctx); err != nil { - return - } - - if s.tokens.ReplicationToken() == "" { - continue - } - - index, exit, err := s.replicateACLTokens(lastRemoteIndex, ctx) - if exit { - return - } - - if err != nil { - lastRemoteIndex = 0 - s.updateACLReplicationStatusError() - s.logger.Printf("[WARN] consul: ACL token replication error (will retry if still leader): %v", err) - if (1 << failedAttempts) < aclReplicationMaxRetryBackoff { - failedAttempts++ - } - - select { - case <-ctx.Done(): - return - case <-time.After((1 << failedAttempts) * time.Second): - // do nothing - } - } else { - lastRemoteIndex = index - s.updateACLReplicationStatusTokenIndex(index) - s.logger.Printf("[DEBUG] consul: ACL token replication completed through remote index %d", index) - failedAttempts = 0 - } - } - }() - - s.logger.Printf("[INFO] acl: started ACL Token replication") - } - - s.updateACLReplicationStatusRunning(replicationType) - - s.aclReplicationEnabled = true + s.logger.Printf("[INFO] acl: started ACL %s replication", replicationType.SingularNoun()) } func (s *Server) stopACLReplication() { diff --git a/agent/consul/state/acl.go b/agent/consul/state/acl.go index 1533dcbef7..8ff538922a 100644 --- a/agent/consul/state/acl.go +++ b/agent/consul/state/acl.go @@ -60,6 +60,108 @@ func (s *TokenPoliciesIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) return val, nil } +type TokenRolesIndex struct { +} + +func (s *TokenRolesIndex) FromObject(obj interface{}) (bool, [][]byte, error) { + token, ok := obj.(*structs.ACLToken) + if !ok { + return false, nil, fmt.Errorf("object is not an ACLToken") + } + + links := token.Roles + + numLinks := len(links) + if numLinks == 0 { + return false, nil, nil + } + + vals := make([][]byte, 0, numLinks) + for _, link := range links { + vals = append(vals, []byte(link.ID+"\x00")) + } + + return true, vals, nil +} + +func (s *TokenRolesIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + arg, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", args[0]) + } + // Add the null character as a terminator + arg += "\x00" + return []byte(arg), nil +} + +func (s *TokenRolesIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { + val, err := s.FromArgs(args...) + if err != nil { + return nil, err + } + + // Strip the null terminator, the rest is a prefix + n := len(val) + if n > 0 { + return val[:n-1], nil + } + return val, nil +} + +type RolePoliciesIndex struct { +} + +func (s *RolePoliciesIndex) FromObject(obj interface{}) (bool, [][]byte, error) { + role, ok := obj.(*structs.ACLRole) + if !ok { + return false, nil, fmt.Errorf("object is not an ACLRole") + } + + links := role.Policies + + numLinks := len(links) + if numLinks == 0 { + return false, nil, nil + } + + vals := make([][]byte, 0, numLinks) + for _, link := range links { + vals = append(vals, []byte(link.ID+"\x00")) + } + + return true, vals, nil +} + +func (s *RolePoliciesIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + arg, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", args[0]) + } + // Add the null character as a terminator + arg += "\x00" + return []byte(arg), nil +} + +func (s *RolePoliciesIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { + val, err := s.FromArgs(args...) + if err != nil { + return nil, err + } + + // Strip the null terminator, the rest is a prefix + n := len(val) + if n > 0 { + return val[:n-1], nil + } + return val, nil +} + type TokenExpirationIndex struct { LocalFilter bool } @@ -137,6 +239,13 @@ func tokensTableSchema() *memdb.TableSchema { Unique: false, Indexer: &TokenPoliciesIndex{}, }, + "roles": &memdb.IndexSchema{ + Name: "roles", + // Need to allow missing for the anonymous token + AllowMissing: true, + Unique: false, + Indexer: &TokenRolesIndex{}, + }, "local": &memdb.IndexSchema{ Name: "local", AllowMissing: false, @@ -208,9 +317,42 @@ func policiesTableSchema() *memdb.TableSchema { } } +func rolesTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "acl-roles", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.UUIDFieldIndex{ + Field: "ID", + }, + }, + "name": &memdb.IndexSchema{ + Name: "name", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Name", + Lowercase: true, + }, + }, + "policies": &memdb.IndexSchema{ + Name: "policies", + // Need to allow missing for the anonymous token + AllowMissing: true, + Unique: false, + Indexer: &RolePoliciesIndex{}, + }, + }, + } +} + func init() { registerSchema(tokensTableSchema) registerSchema(policiesTableSchema) + registerSchema(rolesTableSchema) } // ACLTokens is used when saving a snapshot @@ -255,6 +397,26 @@ func (s *Restore) ACLPolicy(policy *structs.ACLPolicy) error { return nil } +// ACLRoles is used when saving a snapshot +func (s *Snapshot) ACLRoles() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("acl-roles", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +func (s *Restore) ACLRole(role *structs.ACLRole) error { + if err := s.tx.Insert("acl-roles", role); err != nil { + return fmt.Errorf("failed restoring acl role: %s", err) + } + + if err := indexUpdateMaxTxn(s.tx, role.ModifyIndex, "acl-roles"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + return nil +} + // ACLBootstrap is used to perform a one-time ACL bootstrap operation on a // cluster to get the first management token. func (s *Store) ACLBootstrap(idx, resetIndex uint64, token *structs.ACLToken, legacy bool) error { @@ -369,6 +531,7 @@ func (s *Store) fixupTokenPolicyLinks(tx *memdb.Txn, original *structs.ACLToken) // append the corrected policy token.Policies = append(token.Policies, structs.ACLTokenPolicyLink{ID: link.ID, Name: policy.Name}) + } else if owned { token.Policies = append(token.Policies, link) } @@ -377,6 +540,150 @@ func (s *Store) fixupTokenPolicyLinks(tx *memdb.Txn, original *structs.ACLToken) return token, nil } +func (s *Store) resolveTokenRoleLinks(tx *memdb.Txn, token *structs.ACLToken, allowMissing bool) error { + for linkIndex, link := range token.Roles { + if link.ID != "" { + role, err := s.getRoleWithTxn(tx, nil, link.ID, "id") + + if err != nil { + return err + } + + if role != nil { + // the name doesn't matter here + token.Roles[linkIndex].Name = role.Name + } else if !allowMissing { + return fmt.Errorf("No such role with ID: %s", link.ID) + } + } else { + return fmt.Errorf("Encountered a Token with roles linked by Name in the state store") + } + } + return nil +} + +// fixupTokenRoleLinks is to be used when retrieving tokens from memdb. The role links could have gotten +// stale when a linked role was deleted or renamed. This will correct them and generate a newly allocated +// token only when fixes are needed. If the role links are still accurate then we just return the original +// token. +func (s *Store) fixupTokenRoleLinks(tx *memdb.Txn, original *structs.ACLToken) (*structs.ACLToken, error) { + owned := false + token := original + + cloneToken := func(t *structs.ACLToken, copyNumLinks int) *structs.ACLToken { + clone := *t + clone.Roles = make([]structs.ACLTokenRoleLink, copyNumLinks) + copy(clone.Roles, t.Roles[:copyNumLinks]) + return &clone + } + + for linkIndex, link := range original.Roles { + if link.ID == "" { + return nil, fmt.Errorf("Detected corrupted token within the state store - missing role link ID") + } + + role, err := s.getRoleWithTxn(tx, nil, link.ID, "id") + + if err != nil { + return nil, err + } + + if role == nil { + if !owned { + // clone the token as we cannot touch the original + token = cloneToken(original, linkIndex) + owned = true + } + // if already owned then we just don't append it. + } else if role.Name != link.Name { + if !owned { + token = cloneToken(original, linkIndex) + owned = true + } + + // append the corrected policy + token.Roles = append(token.Roles, structs.ACLTokenRoleLink{ID: link.ID, Name: role.Name}) + + } else if owned { + token.Roles = append(token.Roles, link) + } + } + + return token, nil +} + +func (s *Store) resolveRolePolicyLinks(tx *memdb.Txn, role *structs.ACLRole, allowMissing bool) error { + for linkIndex, link := range role.Policies { + if link.ID != "" { + policy, err := s.getPolicyWithTxn(tx, nil, link.ID, "id") + + if err != nil { + return err + } + + if policy != nil { + // the name doesn't matter here + role.Policies[linkIndex].Name = policy.Name + } else if !allowMissing { + return fmt.Errorf("No such policy with ID: %s", link.ID) + } + } else { + return fmt.Errorf("Encountered a Role with policies linked by Name in the state store") + } + } + return nil +} + +// fixupRolePolicyLinks is to be used when retrieving roles from memdb. The policy links could have gotten +// stale when a linked policy was deleted or renamed. This will correct them and generate a newly allocated +// role only when fixes are needed. If the policy links are still accurate then we just return the original +// role. +func (s *Store) fixupRolePolicyLinks(tx *memdb.Txn, original *structs.ACLRole) (*structs.ACLRole, error) { + owned := false + role := original + + cloneRole := func(t *structs.ACLRole, copyNumLinks int) *structs.ACLRole { + clone := *t + clone.Policies = make([]structs.ACLRolePolicyLink, copyNumLinks) + copy(clone.Policies, t.Policies[:copyNumLinks]) + return &clone + } + + for linkIndex, link := range original.Policies { + if link.ID == "" { + return nil, fmt.Errorf("Detected corrupted role within the state store - missing policy link ID") + } + + policy, err := s.getPolicyWithTxn(tx, nil, link.ID, "id") + + if err != nil { + return nil, err + } + + if policy == nil { + if !owned { + // clone the token as we cannot touch the original + role = cloneRole(original, linkIndex) + owned = true + } + // if already owned then we just don't append it. + } else if policy.Name != link.Name { + if !owned { + role = cloneRole(original, linkIndex) + owned = true + } + + // append the corrected policy + role.Policies = append(role.Policies, structs.ACLRolePolicyLink{ID: link.ID, Name: policy.Name}) + + } else if owned { + role.Policies = append(role.Policies, link) + } + } + + return role, nil +} + // ACLTokenSet is used to insert an ACL rule into the state store. func (s *Store) ACLTokenSet(idx uint64, token *structs.ACLToken, legacy bool) error { tx := s.db.Txn(true) @@ -417,7 +724,7 @@ func (s *Store) ACLTokenBatchSet(idx uint64, tokens structs.ACLTokens, cas bool) // aclTokenSetTxn is the inner method used to insert an ACL token with the // proper indexes into the state store. -func (s *Store) aclTokenSetTxn(tx *memdb.Txn, idx uint64, token *structs.ACLToken, cas, allowMissingPolicyIDs, legacy bool) error { +func (s *Store) aclTokenSetTxn(tx *memdb.Txn, idx uint64, token *structs.ACLToken, cas, allowMissingPolicyAndRoleIDs, legacy bool) error { // Check that the ID is set if token.SecretID == "" { return ErrMissingACLTokenSecret @@ -474,7 +781,11 @@ func (s *Store) aclTokenSetTxn(tx *memdb.Txn, idx uint64, token *structs.ACLToke token.AccessorID = original.AccessorID } - if err := s.resolveTokenPolicyLinks(tx, token, allowMissingPolicyIDs); err != nil { + if err := s.resolveTokenPolicyLinks(tx, token, allowMissingPolicyAndRoleIDs); err != nil { + return err + } + + if err := s.resolveTokenRoleLinks(tx, token, allowMissingPolicyAndRoleIDs); err != nil { return err } @@ -563,7 +874,12 @@ func (s *Store) aclTokenGetTxn(tx *memdb.Txn, ws memdb.WatchSet, value, index st ws.Add(watchCh) if rawToken != nil { - token, err := s.fixupTokenPolicyLinks(tx, rawToken.(*structs.ACLToken)) + token := rawToken.(*structs.ACLToken) + token, err := s.fixupTokenPolicyLinks(tx, token) + if err != nil { + return nil, err + } + token, err = s.fixupTokenRoleLinks(tx, token) if err != nil { return nil, err } @@ -574,7 +890,7 @@ func (s *Store) aclTokenGetTxn(tx *memdb.Txn, ws memdb.WatchSet, value, index st } // ACLTokenList is used to list out all of the ACLs in the state store. -func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy string) (uint64, structs.ACLTokens, error) { +func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy, role string) (uint64, structs.ACLTokens, error) { tx := s.db.Txn(false) defer tx.Abort() @@ -585,6 +901,10 @@ func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy strin // to false but for defaulted structs (zero values for both) we want it to list out // all tokens so our checks just ensure that global == local + if policy != "" && role != "" { + return 0, nil, fmt.Errorf("cannot filter by role and policy at the same time") + } + if policy != "" { iter, err = tx.Get("acl-tokens", "policies", policy) if err == nil && global != local { @@ -600,6 +920,24 @@ func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy strin return false } + return true + }) + } + } else if role != "" { + iter, err = tx.Get("acl-tokens", "roles", role) + if err == nil && global != local { + iter = memdb.NewFilterIterator(iter, func(raw interface{}) bool { + token, ok := raw.(*structs.ACLToken) + if !ok { + return true + } + + if global && !token.Local { + return false + } else if local && token.Local { + return false + } + return true }) } @@ -618,8 +956,12 @@ func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy strin var result structs.ACLTokens for raw := iter.Next(); raw != nil; raw = iter.Next() { - token, err := s.fixupTokenPolicyLinks(tx, raw.(*structs.ACLToken)) - + token := raw.(*structs.ACLToken) + token, err := s.fixupTokenPolicyLinks(tx, token) + if err != nil { + return 0, nil, err + } + token, err = s.fixupTokenRoleLinks(tx, token) if err != nil { return 0, nil, err } @@ -1000,3 +1342,240 @@ func (s *Store) aclPolicyDeleteTxn(tx *memdb.Txn, idx uint64, value, index strin } return nil } + +func (s *Store) ACLRoleBatchSet(idx uint64, roles structs.ACLRoles) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, role := range roles { + if err := s.aclRoleSetTxn(tx, idx, role, true); err != nil { + return err + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-roles"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) ACLRoleSet(idx uint64, role *structs.ACLRole) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclRoleSetTxn(tx, idx, role, false); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-roles"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclRoleSetTxn(tx *memdb.Txn, idx uint64, role *structs.ACLRole, allowMissing bool) error { + // Check that the ID is set + if role.ID == "" { + return ErrMissingACLRoleID + } + + if role.Name == "" { + return ErrMissingACLRoleName + } + + existing, err := tx.First("acl-roles", "id", role.ID) + if err != nil { + return fmt.Errorf("failed acl role lookup: %v", err) + } + + // ensure the name is unique (cannot conflict with another role with a different ID) + nameMatch, err := tx.First("acl-roles", "name", role.Name) + if err != nil { + return fmt.Errorf("failed acl role lookup: %v", err) + } + if nameMatch != nil && role.ID != nameMatch.(*structs.ACLRole).ID { + return fmt.Errorf("A role with name %q already exists", role.Name) + } + + if err := s.resolveRolePolicyLinks(tx, role, allowMissing); err != nil { + return err + } + + for _, svcid := range role.ServiceIdentities { + if svcid.ServiceName == "" { + return fmt.Errorf("Encountered a Role with an empty service identity name in the state store") + } + } + + // Set the indexes + if existing != nil { + role.CreateIndex = existing.(*structs.ACLRole).CreateIndex + role.ModifyIndex = idx + } else { + role.CreateIndex = idx + role.ModifyIndex = idx + } + + if err := tx.Insert("acl-roles", role); err != nil { + return fmt.Errorf("failed inserting acl role: %v", err) + } + return nil +} + +func (s *Store) ACLRoleGetByID(ws memdb.WatchSet, id string) (uint64, *structs.ACLRole, error) { + return s.aclRoleGet(ws, id, "id") +} + +func (s *Store) ACLRoleGetByName(ws memdb.WatchSet, name string) (uint64, *structs.ACLRole, error) { + return s.aclRoleGet(ws, name, "name") +} + +func (s *Store) ACLRoleBatchGet(ws memdb.WatchSet, ids []string) (uint64, structs.ACLRoles, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + roles := make(structs.ACLRoles, 0) + for _, rid := range ids { + role, err := s.getRoleWithTxn(tx, ws, rid, "id") + if err != nil { + return 0, nil, err + } + + if role != nil { + roles = append(roles, role) + } + } + + idx := maxIndexTxn(tx, "acl-roles") + + return idx, roles, nil +} + +func (s *Store) getRoleWithTxn(tx *memdb.Txn, ws memdb.WatchSet, value, index string) (*structs.ACLRole, error) { + watchCh, rawRole, err := tx.FirstWatch("acl-roles", index, value) + if err != nil { + return nil, fmt.Errorf("failed acl role lookup: %v", err) + } + ws.Add(watchCh) + + if rawRole != nil { + role := rawRole.(*structs.ACLRole) + role, err := s.fixupRolePolicyLinks(tx, role) + if err != nil { + return nil, err + } + return role, nil + } + + return nil, nil +} + +func (s *Store) aclRoleGet(ws memdb.WatchSet, value, index string) (uint64, *structs.ACLRole, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + role, err := s.getRoleWithTxn(tx, ws, value, index) + if err != nil { + return 0, nil, err + } + + idx := maxIndexTxn(tx, "acl-roles") + + return idx, role, nil +} + +func (s *Store) ACLRoleList(ws memdb.WatchSet, policy string) (uint64, structs.ACLRoles, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + var iter memdb.ResultIterator + var err error + + if policy != "" { + iter, err = tx.Get("acl-roles", "policies", policy) + } else { + iter, err = tx.Get("acl-roles", "id") + } + + if err != nil { + return 0, nil, fmt.Errorf("failed acl role lookup: %v", err) + } + ws.Add(iter.WatchCh()) + + var result structs.ACLRoles + for raw := iter.Next(); raw != nil; raw = iter.Next() { + role := raw.(*structs.ACLRole) + role, err := s.fixupRolePolicyLinks(tx, role) + if err != nil { + return 0, nil, err + } + result = append(result, role) + } + + // Get the table index. + idx := maxIndexTxn(tx, "acl-roles") + + return idx, result, nil +} + +func (s *Store) ACLRoleDeleteByID(idx uint64, id string) error { + return s.aclRoleDelete(idx, id, "id") +} + +func (s *Store) ACLRoleDeleteByName(idx uint64, name string) error { + return s.aclRoleDelete(idx, name, "name") +} + +func (s *Store) ACLRoleBatchDelete(idx uint64, roleIDs []string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, roleID := range roleIDs { + if err := s.aclRoleDeleteTxn(tx, idx, roleID, "id"); err != nil { + return err + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-roles"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + tx.Commit() + return nil +} + +func (s *Store) aclRoleDelete(idx uint64, value, index string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclRoleDeleteTxn(tx, idx, value, index); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-roles"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclRoleDeleteTxn(tx *memdb.Txn, idx uint64, value, index string) error { + // Look up the existing role + rawRole, err := tx.First("acl-roles", index, value) + if err != nil { + return fmt.Errorf("failed acl role lookup: %v", err) + } + + if rawRole == nil { + return nil + } + + role := rawRole.(*structs.ACLRole) + + if err := tx.Delete("acl-roles", role); err != nil { + return fmt.Errorf("failed deleting acl role: %v", err) + } + return nil +} diff --git a/agent/consul/state/acl_test.go b/agent/consul/state/acl_test.go index d39c71d099..b561e1e54c 100644 --- a/agent/consul/state/acl_test.go +++ b/agent/consul/state/acl_test.go @@ -14,6 +14,16 @@ import ( "github.com/stretchr/testify/require" ) +const ( + testRoleID_A = "2c74a9b8-271c-4a21-b727-200db397c01c" // from:setupExtraPoliciesAndRoles + testRoleID_B = "aeab6b63-08d1-455a-b85b-3458b462b426" // from:setupExtraPoliciesAndRoles + testPolicyID_A = "a0625e95-9b3e-42de-a8d6-ceef5b6f3286" // from:setupExtraPolicies + testPolicyID_B = "9386ecae-6677-4686-bcd4-5ab9d86cca1d" // from:setupExtraPolicies + testPolicyID_C = "2bf7359d-cfde-4769-a9fa-54ff1bb2ae4c" // from:setupExtraPolicies + testPolicyID_D = "ff807410-2b82-48ae-9a63-6626a90789d0" // from:setupExtraPolicies + testPolicyID_E = "b4635d48-90aa-4a77-8e1b-9004f68bb3df" // from:setupExtraPolicies +) + func setupGlobalManagement(t *testing.T, s *Store) { policy := structs.ACLPolicy{ ID: structs.ACLPolicyGlobalManagementID, @@ -46,19 +56,40 @@ func testACLStateStore(t *testing.T) *Store { func setupExtraPolicies(t *testing.T, s *Store) { policies := structs.ACLPolicies{ &structs.ACLPolicy{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, Name: "node-read", Description: "Allows reading all node information", Rules: `node_prefix "" { policy = "read" }`, Syntax: acl.SyntaxCurrent, }, &structs.ACLPolicy{ - ID: "9386ecae-6677-4686-bcd4-5ab9d86cca1d", + ID: testPolicyID_B, Name: "agent-read", Description: "Allows reading all node information", Rules: `agent_prefix "" { policy = "read" }`, Syntax: acl.SyntaxCurrent, }, + &structs.ACLPolicy{ + ID: testPolicyID_C, + Name: "acl-read", + Description: "Allows acl read", + Rules: `acl = "read"`, + Syntax: acl.SyntaxCurrent, + }, + &structs.ACLPolicy{ + ID: testPolicyID_D, + Name: "acl-write", + Description: "Allows acl write", + Rules: `acl = "write"`, + Syntax: acl.SyntaxCurrent, + }, + &structs.ACLPolicy{ + ID: testPolicyID_E, + Name: "kv-read", + Description: "Allows kv read", + Rules: `key_prefix "" { policy = "read" }`, + Syntax: acl.SyntaxCurrent, + }, } for _, policy := range policies { @@ -68,7 +99,46 @@ func setupExtraPolicies(t *testing.T, s *Store) { require.NoError(t, s.ACLPolicyBatchSet(2, policies)) } +func setupExtraPoliciesAndRoles(t *testing.T, s *Store) { + setupExtraPolicies(t, s) + + roles := structs.ACLRoles{ + &structs.ACLRole{ + ID: testRoleID_A, + Name: "node-read-role", + Description: "Allows reading all node information", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + }, + &structs.ACLRole{ + ID: testRoleID_B, + Name: "agent-read-role", + Description: "Allows reading all agent information", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_B, + }, + }, + }, + } + + for _, role := range roles { + role.SetHash(true) + } + + require.NoError(t, s.ACLRoleBatchSet(3, roles)) +} + func testACLTokensStateStore(t *testing.T) *Store { + s := testACLStateStore(t) + setupExtraPoliciesAndRoles(t, s) + return s +} + +func testACLRolesStateStore(t *testing.T) *Store { s := testACLStateStore(t) setupExtraPolicies(t, s) return s @@ -135,7 +205,7 @@ func TestStateStore_ACLBootstrap(t *testing.T) { require.Equal(t, uint64(3), index) // Make sure the ACLs are in an expected state. - _, tokens, err := s.ACLTokenList(nil, true, true, "") + _, tokens, err := s.ACLTokenList(nil, true, true, "", "") require.NoError(t, err) require.Len(t, tokens, 1) compareTokens(t, token1, tokens[0]) @@ -149,7 +219,7 @@ func TestStateStore_ACLBootstrap(t *testing.T) { err = s.ACLBootstrap(32, index, token2.Clone(), false) require.NoError(t, err) - _, tokens, err = s.ACLTokenList(nil, true, true, "") + _, tokens, err = s.ACLTokenList(nil, true, true, "", "") require.NoError(t, err) require.Len(t, tokens, 2) } @@ -165,7 +235,7 @@ func TestStateStore_ACLToken_SetGet_Legacy(t *testing.T) { SecretID: "6d48ce91-2558-4098-bdab-8737e4e57d5f", Policies: []structs.ACLTokenPolicyLink{ structs.ACLTokenPolicyLink{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, }, }, } @@ -326,6 +396,23 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { require.Error(t, err) }) + t.Run("Missing Role ID", func(t *testing.T) { + t.Parallel() + s := testACLTokensStateStore(t) + token := &structs.ACLToken{ + AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", + SecretID: "39171632-6f34-4411-827f-9416403687f4", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + Name: "no-id", + }, + }, + } + + err := s.ACLTokenSet(2, token, false) + require.Error(t, err) + }) + t.Run("Unresolvable Policy ID", func(t *testing.T) { t.Parallel() s := testACLTokensStateStore(t) @@ -343,6 +430,23 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { require.Error(t, err) }) + t.Run("Unresolvable Role ID", func(t *testing.T) { + t.Parallel() + s := testACLTokensStateStore(t) + token := &structs.ACLToken{ + AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", + SecretID: "39171632-6f34-4411-827f-9416403687f4", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: "9b2349b6-55d3-4901-b287-347ae725af2f", + }, + }, + } + + err := s.ACLTokenSet(2, token, false) + require.Error(t, err) + }) + t.Run("New", func(t *testing.T) { t.Parallel() s := testACLTokensStateStore(t) @@ -351,7 +455,12 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { SecretID: "39171632-6f34-4411-827f-9416403687f4", Policies: []structs.ACLTokenPolicyLink{ structs.ACLTokenPolicyLink{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, + }, + }, + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: testRoleID_A, }, }, ServiceIdentities: []*structs.ACLServiceIdentity{ @@ -371,6 +480,8 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { require.Equal(t, uint64(2), rtoken.ModifyIndex) require.Len(t, rtoken.Policies, 1) require.Equal(t, "node-read", rtoken.Policies[0].Name) + require.Len(t, rtoken.Roles, 1) + require.Equal(t, "node-read-role", rtoken.Roles[0].Name) require.Len(t, rtoken.ServiceIdentities, 1) require.Equal(t, "web", rtoken.ServiceIdentities[0].ServiceName) }) @@ -383,7 +494,7 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { SecretID: "39171632-6f34-4411-827f-9416403687f4", Policies: []structs.ACLTokenPolicyLink{ structs.ACLTokenPolicyLink{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, }, }, ServiceIdentities: []*structs.ACLServiceIdentity{ @@ -403,6 +514,11 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { ID: structs.ACLPolicyGlobalManagementID, }, }, + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: testRoleID_A, + }, + }, ServiceIdentities: []*structs.ACLServiceIdentity{ &structs.ACLServiceIdentity{ ServiceName: "db", @@ -421,6 +537,9 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { require.Len(t, rtoken.Policies, 1) require.Equal(t, structs.ACLPolicyGlobalManagementID, rtoken.Policies[0].ID) require.Equal(t, "global-management", rtoken.Policies[0].Name) + require.Len(t, rtoken.Roles, 1) + require.Equal(t, testRoleID_A, rtoken.Roles[0].ID) + require.Equal(t, "node-read-role", rtoken.Roles[0].Name) require.Len(t, rtoken.ServiceIdentities, 1) require.Equal(t, "db", rtoken.ServiceIdentities[0].ServiceName) }) @@ -568,7 +687,7 @@ func TestStateStore_ACLTokens_UpsertBatchRead(t *testing.T) { Description: "first token", Policies: []structs.ACLTokenPolicyLink{ structs.ACLTokenPolicyLink{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, }, }, }, @@ -607,7 +726,7 @@ func TestStateStore_ACLTokens_UpsertBatchRead(t *testing.T) { require.Equal(t, "00ff4564-dd96-4d1b-8ad6-578a08279f79", rtokens[1].SecretID) require.Equal(t, "first token", rtokens[1].Description) require.Len(t, rtokens[1].Policies, 1) - require.Equal(t, "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", rtokens[1].Policies[0].ID) + require.Equal(t, testPolicyID_A, rtokens[1].Policies[0].ID) require.Equal(t, "node-read", rtokens[1].Policies[0].Name) require.Equal(t, uint64(2), rtokens[1].CreateIndex) require.Equal(t, uint64(3), rtokens[1].ModifyIndex) @@ -738,7 +857,7 @@ func TestStateStore_ACLToken_List(t *testing.T) { SecretID: "548bdb8e-c0d6-477b-bcc4-67fb836e9e61", Policies: []structs.ACLTokenPolicyLink{ structs.ACLTokenPolicyLink{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, }, }, }, @@ -748,7 +867,28 @@ func TestStateStore_ACLToken_List(t *testing.T) { SecretID: "f6998577-fd9b-4e6c-b202-cc3820513d32", Policies: []structs.ACLTokenPolicyLink{ structs.ACLTokenPolicyLink{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, + }, + }, + Local: true, + }, + // the role specific token + &structs.ACLToken{ + AccessorID: "a7715fde-8954-4c92-afbc-d84c6ecdc582", + SecretID: "77a2da3a-b479-4025-a83e-bd6b859f0cfe", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: testRoleID_A, + }, + }, + }, + // the role specific token and local + &structs.ACLToken{ + AccessorID: "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", + SecretID: "c432d12b-3c86-4628-b74f-94ddfc7fb3ba", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: testRoleID_A, }, }, Local: true, @@ -762,6 +902,7 @@ func TestStateStore_ACLToken_List(t *testing.T) { local bool global bool policy string + role string accessors []string } @@ -771,10 +912,12 @@ func TestStateStore_ACLToken_List(t *testing.T) { local: false, global: true, policy: "", + role: "", accessors: []string{ structs.ACLTokenAnonymousID, - "47eea4da-bda1-48a6-901c-3e36d2d9262f", - "54866514-3cf2-4fec-8a8a-710583831834", + "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global + "54866514-3cf2-4fec-8a8a-710583831834", // mgmt + global + "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global }, }, { @@ -782,37 +925,73 @@ func TestStateStore_ACLToken_List(t *testing.T) { local: true, global: false, policy: "", + role: "", accessors: []string{ - "4915fc9d-3726-4171-b588-6c271f45eecd", - "f1093997-b6c7-496d-bfb8-6b1b1895641b", + "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local + "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local + "f1093997-b6c7-496d-bfb8-6b1b1895641b", // mgmt + local }, }, { name: "Policy", local: true, global: true, - policy: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + policy: testPolicyID_A, + role: "", accessors: []string{ - "47eea4da-bda1-48a6-901c-3e36d2d9262f", - "4915fc9d-3726-4171-b588-6c271f45eecd", + "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global + "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local }, }, { name: "Policy - Local", local: true, global: false, - policy: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + policy: testPolicyID_A, + role: "", accessors: []string{ - "4915fc9d-3726-4171-b588-6c271f45eecd", + "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local }, }, { name: "Policy - Global", local: false, global: true, - policy: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + policy: testPolicyID_A, + role: "", accessors: []string{ - "47eea4da-bda1-48a6-901c-3e36d2d9262f", + "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global + }, + }, + { + name: "Role", + local: true, + global: true, + policy: "", + role: testRoleID_A, + accessors: []string{ + "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global + "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local + }, + }, + { + name: "Role - Local", + local: true, + global: false, + policy: "", + role: testRoleID_A, + accessors: []string{ + "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local + }, + }, + { + name: "Role - Global", + local: false, + global: true, + policy: "", + role: testRoleID_A, + accessors: []string{ + "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global }, }, { @@ -820,21 +999,29 @@ func TestStateStore_ACLToken_List(t *testing.T) { local: true, global: true, policy: "", + role: "", accessors: []string{ structs.ACLTokenAnonymousID, - "47eea4da-bda1-48a6-901c-3e36d2d9262f", - "4915fc9d-3726-4171-b588-6c271f45eecd", - "54866514-3cf2-4fec-8a8a-710583831834", - "f1093997-b6c7-496d-bfb8-6b1b1895641b", + "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global + "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local + "54866514-3cf2-4fec-8a8a-710583831834", // mgmt + global + "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global + "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local + "f1093997-b6c7-496d-bfb8-6b1b1895641b", // mgmt + local }, }, } + t.Run("can't filter on both", func(t *testing.T) { + _, _, err := s.ACLTokenList(nil, false, false, testPolicyID_A, testRoleID_A) + require.Error(t, err) + }) + for _, tc := range cases { tc := tc // capture range variable t.Run(tc.name, func(t *testing.T) { t.Parallel() - _, tokens, err := s.ACLTokenList(nil, tc.local, tc.global, tc.policy) + _, tokens, err := s.ACLTokenList(nil, tc.local, tc.global, tc.policy, tc.role) require.NoError(t, err) require.Len(t, tokens, len(tc.accessors)) tokens.Sort() @@ -861,7 +1048,7 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { SecretID: "548bdb8e-c0d6-477b-bcc4-67fb836e9e61", Policies: []structs.ACLTokenPolicyLink{ structs.ACLTokenPolicyLink{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, }, }, } @@ -877,7 +1064,7 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { // rename the policy renamed := &structs.ACLPolicy{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, Name: "node-read-renamed", Description: "Allows reading all node information", Rules: `node_prefix "" { policy = "read" }`, @@ -895,7 +1082,7 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { require.Equal(t, "node-read-renamed", retrieved.Policies[0].Name) // list tokens without stale links - _, tokens, err := s.ACLTokenList(nil, true, true, "") + _, tokens, err := s.ACLTokenList(nil, true, true, "", "") require.NoError(t, err) found := false @@ -929,7 +1116,7 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { require.True(t, found) // delete the policy - require.NoError(t, s.ACLPolicyDeleteByID(4, "a0625e95-9b3e-42de-a8d6-ceef5b6f3286")) + require.NoError(t, s.ACLPolicyDeleteByID(4, testPolicyID_A)) // retrieve the token again _, retrieved, err = s.ACLTokenGetByAccessor(nil, token.AccessorID) @@ -939,7 +1126,7 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { require.Len(t, retrieved.Policies, 0) // list tokens without stale links - _, tokens, err = s.ACLTokenList(nil, true, true, "") + _, tokens, err = s.ACLTokenList(nil, true, true, "", "") require.NoError(t, err) found = false @@ -971,6 +1158,135 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { require.True(t, found) } +func TestStateStore_ACLToken_FixupRoleLinks(t *testing.T) { + // This test wants to ensure a couple of things. + // + // 1. Doing a token list/get should never modify the data + // tracked by memdb + // 2. Token list/get operations should return an accurate set + // of role links + t.Parallel() + s := testACLTokensStateStore(t) + + // the role specific token + token := &structs.ACLToken{ + AccessorID: "47eea4da-bda1-48a6-901c-3e36d2d9262f", + SecretID: "548bdb8e-c0d6-477b-bcc4-67fb836e9e61", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: testRoleID_A, + }, + }, + } + + require.NoError(t, s.ACLTokenSet(2, token, false)) + + _, retrieved, err := s.ACLTokenGetByAccessor(nil, token.AccessorID) + require.NoError(t, err) + // pointer equality check these should be identical + require.True(t, token == retrieved) + require.Len(t, retrieved.Roles, 1) + require.Equal(t, "node-read-role", retrieved.Roles[0].Name) + + // rename the role + renamed := &structs.ACLRole{ + ID: testRoleID_A, + Name: "node-read-role-renamed", + Description: "Allows reading all node information", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + } + renamed.SetHash(true) + require.NoError(t, s.ACLRoleSet(3, renamed)) + + // retrieve the token again + _, retrieved, err = s.ACLTokenGetByAccessor(nil, token.AccessorID) + require.NoError(t, err) + // pointer equality check these should be different if we cloned things appropriately + require.True(t, token != retrieved) + require.Len(t, retrieved.Roles, 1) + require.Equal(t, "node-read-role-renamed", retrieved.Roles[0].Name) + + // list tokens without stale links + _, tokens, err := s.ACLTokenList(nil, true, true, "", "") + require.NoError(t, err) + + found := false + for _, tok := range tokens { + if tok.AccessorID == token.AccessorID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, tok != token) + require.Len(t, tok.Roles, 1) + require.Equal(t, "node-read-role-renamed", tok.Roles[0].Name) + found = true + break + } + } + require.True(t, found) + + // batch get without stale links + _, tokens, err = s.ACLTokenBatchGet(nil, []string{token.AccessorID}) + require.NoError(t, err) + + found = false + for _, tok := range tokens { + if tok.AccessorID == token.AccessorID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, tok != token) + require.Len(t, tok.Roles, 1) + require.Equal(t, "node-read-role-renamed", tok.Roles[0].Name) + found = true + break + } + } + require.True(t, found) + + // delete the role + require.NoError(t, s.ACLRoleDeleteByID(4, testRoleID_A)) + + // retrieve the token again + _, retrieved, err = s.ACLTokenGetByAccessor(nil, token.AccessorID) + require.NoError(t, err) + // pointer equality check these should be different if we cloned things appropriately + require.True(t, token != retrieved) + require.Len(t, retrieved.Roles, 0) + + // list tokens without stale links + _, tokens, err = s.ACLTokenList(nil, true, true, "", "") + require.NoError(t, err) + + found = false + for _, tok := range tokens { + if tok.AccessorID == token.AccessorID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, tok != token) + require.Len(t, tok.Roles, 0) + found = true + break + } + } + require.True(t, found) + + // batch get without stale links + _, tokens, err = s.ACLTokenBatchGet(nil, []string{token.AccessorID}) + require.NoError(t, err) + + found = false + for _, tok := range tokens { + if tok.AccessorID == token.AccessorID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, tok != token) + require.Len(t, tok.Roles, 0) + found = true + break + } + } + require.True(t, found) +} + func TestStateStore_ACLToken_Delete(t *testing.T) { t.Parallel() @@ -1117,7 +1433,7 @@ func TestStateStore_ACLPolicy_SetGet(t *testing.T) { s := testACLStateStore(t) policy := structs.ACLPolicy{ - ID: "2c74a9b8-271c-4a21-b727-200db397c01c", + ID: testRoleID_A, Description: "test", Rules: `keyring = "write"`, } @@ -1185,7 +1501,7 @@ func TestStateStore_ACLPolicy_SetGet(t *testing.T) { s := testACLStateStore(t) policy := structs.ACLPolicy{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, Name: "node-read", Description: "Allows reading all node information", Rules: `node_prefix "" { policy = "read" }`, @@ -1195,7 +1511,7 @@ func TestStateStore_ACLPolicy_SetGet(t *testing.T) { require.NoError(t, s.ACLPolicySet(3, &policy)) - idx, rpolicy, err := s.ACLPolicyGetByID(nil, "a0625e95-9b3e-42de-a8d6-ceef5b6f3286") + idx, rpolicy, err := s.ACLPolicyGetByID(nil, testPolicyID_A) require.Equal(t, uint64(3), idx) require.NoError(t, err) require.NotNil(t, rpolicy) @@ -1228,7 +1544,7 @@ func TestStateStore_ACLPolicy_SetGet(t *testing.T) { s := testACLTokensStateStore(t) update := &structs.ACLPolicy{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, Name: "node-read-modified", Description: "Modified", Rules: `node_prefix "" { policy = "read" } node "secret" { policy = "deny" }`, @@ -1243,7 +1559,7 @@ func TestStateStore_ACLPolicy_SetGet(t *testing.T) { expect.ModifyIndex = 3 // policy found via id - idx, rpolicy, err := s.ACLPolicyGetByID(nil, "a0625e95-9b3e-42de-a8d6-ceef5b6f3286") + idx, rpolicy, err := s.ACLPolicyGetByID(nil, testPolicyID_A) require.NoError(t, err) require.Equal(t, uint64(3), idx) require.Equal(t, expect, rpolicy) @@ -1521,6 +1837,673 @@ func TestStateStore_ACLPolicy_Delete(t *testing.T) { }) } +func TestStateStore_ACLRole_SetGet(t *testing.T) { + t.Parallel() + + t.Run("Missing ID", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := structs.ACLRole{ + Name: "test-role", + Description: "test", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: structs.ACLPolicyGlobalManagementID, + }, + }, + } + + require.Error(t, s.ACLRoleSet(3, &role)) + }) + + t.Run("Missing Name", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := structs.ACLRole{ + ID: testRoleID_A, + Description: "test", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: structs.ACLPolicyGlobalManagementID, + }, + }, + } + + require.Error(t, s.ACLRoleSet(3, &role)) + }) + + t.Run("Missing Service Identity Fields", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := structs.ACLRole{ + ID: testRoleID_A, + Description: "test", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{}, + }, + } + + require.Error(t, s.ACLRoleSet(3, &role)) + }) + + t.Run("Missing Service Identity Name", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := structs.ACLRole{ + ID: testRoleID_A, + Description: "test", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + Datacenters: []string{"dc1"}, + }, + }, + } + + require.Error(t, s.ACLRoleSet(3, &role)) + }) + + t.Run("Missing Policy ID", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := structs.ACLRole{ + ID: testRoleID_A, + Description: "test", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + Name: "no-id", + }, + }, + } + + require.Error(t, s.ACLRoleSet(3, &role)) + }) + + t.Run("Unresolvable Policy ID", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := structs.ACLRole{ + ID: testRoleID_A, + Description: "test", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "4f20e379-b496-4b99-9599-19a197126490", + }, + }, + } + + require.Error(t, s.ACLRoleSet(3, &role)) + }) + + t.Run("New", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := structs.ACLRole{ + ID: testRoleID_A, + Name: "my-new-role", + Description: "test", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + } + + require.NoError(t, s.ACLRoleSet(3, &role)) + + verify := func(idx uint64, rrole *structs.ACLRole, err error) { + require.Equal(t, uint64(3), idx) + require.NoError(t, err) + require.NotNil(t, rrole) + require.Equal(t, "my-new-role", rrole.Name) + require.Equal(t, "test", rrole.Description) + require.Equal(t, uint64(3), rrole.CreateIndex) + require.Equal(t, uint64(3), rrole.ModifyIndex) + require.Len(t, rrole.ServiceIdentities, 0) + // require.ElementsMatch(t, role.Policies, rrole.Policies) + require.Len(t, rrole.Policies, 1) + require.Equal(t, "node-read", rrole.Policies[0].Name) + } + + idx, rpolicy, err := s.ACLRoleGetByID(nil, testRoleID_A) + verify(idx, rpolicy, err) + + idx, rpolicy, err = s.ACLRoleGetByName(nil, "my-new-role") + verify(idx, rpolicy, err) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + // Create the initial role + role := &structs.ACLRole{ + ID: testRoleID_A, + Name: "node-read-role", + Description: "Allows reading all node information", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + } + role.SetHash(true) + + require.NoError(t, s.ACLRoleSet(2, role)) + + // Now make sure we can update it + update := &structs.ACLRole{ + ID: testRoleID_A, + Name: "node-read-role-modified", + Description: "Modified", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: structs.ACLPolicyGlobalManagementID, + }, + }, + } + update.SetHash(true) + + require.NoError(t, s.ACLRoleSet(3, update)) + + verify := func(idx uint64, rrole *structs.ACLRole, err error) { + require.Equal(t, uint64(3), idx) + require.NoError(t, err) + require.NotNil(t, rrole) + require.Equal(t, "node-read-role-modified", rrole.Name) + require.Equal(t, "Modified", rrole.Description) + require.Equal(t, uint64(2), rrole.CreateIndex) + require.Equal(t, uint64(3), rrole.ModifyIndex) + require.Len(t, rrole.ServiceIdentities, 0) + require.Len(t, rrole.Policies, 1) + require.Equal(t, structs.ACLPolicyGlobalManagementID, rrole.Policies[0].ID) + require.Equal(t, "global-management", rrole.Policies[0].Name) + } + + // role found via id + idx, rrole, err := s.ACLRoleGetByID(nil, testRoleID_A) + verify(idx, rrole, err) + + // role no longer found via old name + idx, rrole, err = s.ACLRoleGetByName(nil, "node-read-role") + require.Equal(t, uint64(3), idx) + require.NoError(t, err) + require.Nil(t, rrole) + + // role is found via new name + idx, rrole, err = s.ACLRoleGetByName(nil, "node-read-role-modified") + verify(idx, rrole, err) + }) +} + +func TestStateStore_ACLRoles_UpsertBatchRead(t *testing.T) { + t.Parallel() + + t.Run("Normal", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + roles := structs.ACLRoles{ + &structs.ACLRole{ + ID: testRoleID_A, + Name: "role1", + Description: "test-role1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + }, + &structs.ACLRole{ + ID: testRoleID_B, + Name: "role2", + Description: "test-role2", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_B, + }, + }, + }, + } + + require.NoError(t, s.ACLRoleBatchSet(2, roles)) + + idx, rroles, err := s.ACLRoleBatchGet(nil, []string{testRoleID_A, testRoleID_B}) + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.Len(t, rroles, 2) + rroles.Sort() + require.ElementsMatch(t, roles, rroles) + require.Equal(t, uint64(2), rroles[0].CreateIndex) + require.Equal(t, uint64(2), rroles[0].ModifyIndex) + require.Equal(t, uint64(2), rroles[1].CreateIndex) + require.Equal(t, uint64(2), rroles[1].ModifyIndex) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + // Seed initial data. + roles := structs.ACLRoles{ + &structs.ACLRole{ + ID: testRoleID_A, + Name: "role1", + Description: "test-role1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + }, + &structs.ACLRole{ + ID: testRoleID_B, + Name: "role2", + Description: "test-role2", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_B, + }, + }, + }, + } + + require.NoError(t, s.ACLRoleBatchSet(2, roles)) + + // Update two roles at the same time. + updates := structs.ACLRoles{ + &structs.ACLRole{ + ID: testRoleID_A, + Name: "role1-modified", + Description: "test-role1-modified", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_C, + }, + }, + }, + &structs.ACLRole{ + ID: testRoleID_B, + Name: "role2-modified", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_D, + }, + structs.ACLRolePolicyLink{ + ID: testPolicyID_E, + }, + }, + }, + } + + require.NoError(t, s.ACLRoleBatchSet(3, updates)) + + idx, rroles, err := s.ACLRoleBatchGet(nil, []string{testRoleID_A, testRoleID_B}) + + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.Len(t, rroles, 2) + rroles.Sort() + require.Equal(t, testRoleID_A, rroles[0].ID) + require.Equal(t, "role1-modified", rroles[0].Name) + require.Equal(t, "test-role1-modified", rroles[0].Description) + require.ElementsMatch(t, updates[0].Policies, rroles[0].Policies) + require.Equal(t, uint64(2), rroles[0].CreateIndex) + require.Equal(t, uint64(3), rroles[0].ModifyIndex) + + require.Equal(t, testRoleID_B, rroles[1].ID) + require.Equal(t, "role2-modified", rroles[1].Name) + require.Equal(t, "", rroles[1].Description) + require.ElementsMatch(t, updates[1].Policies, rroles[1].Policies) + require.Equal(t, uint64(2), rroles[1].CreateIndex) + require.Equal(t, uint64(3), rroles[1].ModifyIndex) + }) +} + +func TestStateStore_ACLRole_List(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + roles := structs.ACLRoles{ + &structs.ACLRole{ + ID: testRoleID_A, + Name: "role1", + Description: "test-role1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + }, + &structs.ACLRole{ + ID: testRoleID_B, + Name: "role2", + Description: "test-role2", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_B, + }, + }, + }, + } + + require.NoError(t, s.ACLRoleBatchSet(2, roles)) + + type testCase struct { + name string + policy string + ids []string + } + + cases := []testCase{ + { + name: "Any", + policy: "", + ids: []string{ + testRoleID_A, + testRoleID_B, + }, + }, + { + name: "Policy A", + policy: testPolicyID_A, + ids: []string{ + testRoleID_A, + }, + }, + { + name: "Policy B", + policy: testPolicyID_B, + ids: []string{ + testRoleID_B, + }, + }, + } + + for _, tc := range cases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + // t.Parallel() + _, rroles, err := s.ACLRoleList(nil, tc.policy) + require.NoError(t, err) + + require.Len(t, rroles, len(tc.ids)) + rroles.Sort() + for i, rrole := range rroles { + expectID := tc.ids[i] + require.Equal(t, expectID, rrole.ID) + switch expectID { + case testRoleID_A: + require.Equal(t, testRoleID_A, rrole.ID) + require.Equal(t, "role1", rrole.Name) + require.Equal(t, "test-role1", rrole.Description) + require.ElementsMatch(t, roles[0].Policies, rrole.Policies) + require.Nil(t, rrole.Hash) + require.Equal(t, uint64(2), rrole.CreateIndex) + require.Equal(t, uint64(2), rrole.ModifyIndex) + case testRoleID_B: + require.Equal(t, testRoleID_B, rrole.ID) + require.Equal(t, "role2", rrole.Name) + require.Equal(t, "test-role2", rrole.Description) + require.ElementsMatch(t, roles[1].Policies, rrole.Policies) + require.Nil(t, rrole.Hash) + require.Equal(t, uint64(2), rrole.CreateIndex) + require.Equal(t, uint64(2), rrole.ModifyIndex) + } + } + }) + } +} + +func TestStateStore_ACLRole_FixupPolicyLinks(t *testing.T) { + // This test wants to ensure a couple of things. + // + // 1. Doing a role list/get should never modify the data + // tracked by memdb + // 2. Role list/get operations should return an accurate set + // of policy links + t.Parallel() + s := testACLRolesStateStore(t) + + // the policy specific role + role := &structs.ACLRole{ + ID: "672537b1-35cb-48fc-a2cd-a1863c301b70", + Name: "test-role", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + } + + require.NoError(t, s.ACLRoleSet(2, role)) + + _, retrieved, err := s.ACLRoleGetByID(nil, role.ID) + require.NoError(t, err) + // pointer equality check these should be identical + require.True(t, role == retrieved) + require.Len(t, retrieved.Policies, 1) + require.Equal(t, "node-read", retrieved.Policies[0].Name) + + // rename the policy + renamed := &structs.ACLPolicy{ + ID: testPolicyID_A, + Name: "node-read-renamed", + Description: "Allows reading all node information", + Rules: `node_prefix "" { policy = "read" }`, + Syntax: acl.SyntaxCurrent, + } + renamed.SetHash(true) + require.NoError(t, s.ACLPolicySet(3, renamed)) + + // retrieve the role again + _, retrieved, err = s.ACLRoleGetByID(nil, role.ID) + require.NoError(t, err) + // pointer equality check these should be different if we cloned things appropriately + require.True(t, role != retrieved) + require.Len(t, retrieved.Policies, 1) + require.Equal(t, "node-read-renamed", retrieved.Policies[0].Name) + + // list roles without stale links + _, roles, err := s.ACLRoleList(nil, "") + require.NoError(t, err) + + found := false + for _, r := range roles { + if r.ID == role.ID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, r != role) + require.Len(t, r.Policies, 1) + require.Equal(t, "node-read-renamed", r.Policies[0].Name) + found = true + break + } + } + require.True(t, found) + + // batch get without stale links + _, roles, err = s.ACLRoleBatchGet(nil, []string{role.ID}) + require.NoError(t, err) + + found = false + for _, r := range roles { + if r.ID == role.ID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, r != role) + require.Len(t, r.Policies, 1) + require.Equal(t, "node-read-renamed", r.Policies[0].Name) + found = true + break + } + } + require.True(t, found) + + // delete the policy + require.NoError(t, s.ACLPolicyDeleteByID(4, testPolicyID_A)) + + // retrieve the role again + _, retrieved, err = s.ACLRoleGetByID(nil, role.ID) + require.NoError(t, err) + // pointer equality check these should be different if we cloned things appropriately + require.True(t, role != retrieved) + require.Len(t, retrieved.Policies, 0) + + // list roles without stale links + _, roles, err = s.ACLRoleList(nil, "") + require.NoError(t, err) + + found = false + for _, r := range roles { + if r.ID == role.ID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, r != role) + require.Len(t, r.Policies, 0) + found = true + break + } + } + require.True(t, found) + + // batch get without stale links + _, roles, err = s.ACLRoleBatchGet(nil, []string{role.ID}) + require.NoError(t, err) + + found = false + for _, r := range roles { + if r.ID == role.ID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, r != role) + require.Len(t, r.Policies, 0) + found = true + break + } + } + require.True(t, found) +} + +func TestStateStore_ACLRole_Delete(t *testing.T) { + t.Parallel() + + t.Run("ID", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := &structs.ACLRole{ + ID: testRoleID_A, + Name: "role1", + Description: "test-role1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: structs.ACLPolicyGlobalManagementID, + }, + }, + } + + require.NoError(t, s.ACLRoleSet(2, role)) + + _, rrole, err := s.ACLRoleGetByID(nil, testRoleID_A) + require.NoError(t, err) + require.NotNil(t, rrole) + + require.NoError(t, s.ACLRoleDeleteByID(3, testRoleID_A)) + require.NoError(t, err) + + _, rrole, err = s.ACLRoleGetByID(nil, testRoleID_A) + require.NoError(t, err) + require.Nil(t, rrole) + }) + + t.Run("Name", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := &structs.ACLRole{ + ID: testRoleID_A, + Name: "role1", + Description: "test-role1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: structs.ACLPolicyGlobalManagementID, + }, + }, + } + + require.NoError(t, s.ACLRoleSet(2, role)) + + _, rrole, err := s.ACLRoleGetByName(nil, "role1") + require.NoError(t, err) + require.NotNil(t, rrole) + + require.NoError(t, s.ACLRoleDeleteByName(3, "role1")) + require.NoError(t, err) + + _, rrole, err = s.ACLRoleGetByName(nil, "role1") + require.NoError(t, err) + require.Nil(t, rrole) + }) + + t.Run("Multiple", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + roles := structs.ACLRoles{ + &structs.ACLRole{ + ID: testRoleID_A, + Name: "role1", + Description: "test-role1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: structs.ACLPolicyGlobalManagementID, + }, + }, + }, + &structs.ACLRole{ + ID: testRoleID_B, + Name: "role2", + Description: "test-role2", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: structs.ACLPolicyGlobalManagementID, + }, + }, + }, + } + + require.NoError(t, s.ACLRoleBatchSet(2, roles)) + + _, rrole, err := s.ACLRoleGetByID(nil, testRoleID_A) + require.NoError(t, err) + require.NotNil(t, rrole) + _, rrole, err = s.ACLRoleGetByID(nil, testRoleID_B) + require.NoError(t, err) + require.NotNil(t, rrole) + + require.NoError(t, s.ACLRoleBatchDelete(3, []string{testRoleID_A, testRoleID_B})) + + _, rrole, err = s.ACLRoleGetByID(nil, testRoleID_A) + require.NoError(t, err) + require.Nil(t, rrole) + _, rrole, err = s.ACLRoleGetByID(nil, testRoleID_B) + require.NoError(t, err) + require.Nil(t, rrole) + }) + + t.Run("Not Found", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + // deletion of non-existant roles is not an error + require.NoError(t, s.ACLRoleDeleteByName(3, "not-found")) + require.NoError(t, s.ACLRoleDeleteByID(3, testRoleID_A)) + }) +} + func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { s := testStateStore(t) @@ -1547,6 +2530,35 @@ func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { require.NoError(t, s.ACLPolicyBatchSet(2, policies)) + roles := structs.ACLRoles{ + &structs.ACLRole{ + ID: "1a3a9af9-9cdc-473a-8016-010067b7e424", + Name: "role1", + Description: "role1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "ca1fc52c-3676-4050-82ed-ca223e38b2c9", + }, + }, + }, + &structs.ACLRole{ + ID: "4dccc2c7-10f3-4eba-b367-9c09be9a9d67", + Name: "role2", + Description: "role2", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "7b70fa0f-58cd-412d-93c3-a0f17bb19a3e", + }, + }, + }, + } + + for _, role := range roles { + role.SetHash(true) + } + + require.NoError(t, s.ACLRoleBatchSet(3, roles)) + tokens := structs.ACLTokens{ &structs.ACLToken{ AccessorID: "68016c3d-835b-450c-a6f9-75db9ba740be", @@ -1562,8 +2574,16 @@ func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { Name: "policy2", }, }, - Hash: []byte{1, 2, 3, 4}, - RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: "1a3a9af9-9cdc-473a-8016-010067b7e424", + Name: "role1", + }, + structs.ACLTokenRoleLink{ + ID: "4dccc2c7-10f3-4eba-b367-9c09be9a9d67", + Name: "role2", + }, + }, }, &structs.ACLToken{ AccessorID: "b2125a1b-2a52-41d4-88f3-c58761998a46", @@ -1579,12 +2599,22 @@ func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { Name: "policy2", }, }, + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: "1a3a9af9-9cdc-473a-8016-010067b7e424", + Name: "role1", + }, + structs.ACLTokenRoleLink{ + ID: "4dccc2c7-10f3-4eba-b367-9c09be9a9d67", + Name: "role2", + }, + }, Hash: []byte{1, 2, 3, 4}, RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, }, } - require.NoError(t, s.ACLTokenBatchSet(2, tokens, false)) + require.NoError(t, s.ACLTokenBatchSet(4, tokens, false)) // Snapshot the ACLs. snap := s.Snapshot() @@ -1594,7 +2624,7 @@ func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { require.NoError(t, s.ACLTokenDeleteByAccessor(3, tokens[0].AccessorID)) // Verify the snapshot. - require.Equal(t, uint64(2), snap.LastIndex()) + require.Equal(t, uint64(4), snap.LastIndex()) iter, err := snap.ACLTokens() require.NoError(t, err) @@ -1617,12 +2647,15 @@ func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { // need to ensure we have the policies or else the links will be removed require.NoError(t, s.ACLPolicyBatchSet(2, policies)) + // need to ensure we have the roles or else the links will be removed + require.NoError(t, s.ACLRoleBatchSet(2, roles)) + // Read the restored ACLs back out and verify that they match. - idx, res, err := s.ACLTokenList(nil, true, true, "") + idx, res, err := s.ACLTokenList(nil, true, true, "", "") require.NoError(t, err) - require.Equal(t, uint64(2), idx) + require.Equal(t, uint64(4), idx) require.ElementsMatch(t, tokens, res) - require.Equal(t, uint64(2), s.maxIndex("acl-tokens")) + require.Equal(t, uint64(4), s.maxIndex("acl-tokens")) }() } @@ -1840,6 +2873,10 @@ func stripIrrelevantTokenFields(token *structs.ACLToken) *structs.ACLToken { for i, _ := range tokenCopy.Policies { tokenCopy.Policies[i].Name = "" } + // Also do the same for Role links. + for i, _ := range tokenCopy.Roles { + tokenCopy.Roles[i].Name = "" + } // The raft indexes won't match either because the requester will not // have access to that. tokenCopy.RaftIndex = structs.RaftIndex{} @@ -1849,3 +2886,108 @@ func stripIrrelevantTokenFields(token *structs.ACLToken) *structs.ACLToken { func compareTokens(t *testing.T, expected, actual *structs.ACLToken) { require.Equal(t, stripIrrelevantTokenFields(expected), stripIrrelevantTokenFields(actual)) } + +func TestStateStore_ACLRoles_Snapshot_Restore(t *testing.T) { + s := testStateStore(t) + + policies := structs.ACLPolicies{ + &structs.ACLPolicy{ + ID: "ca1fc52c-3676-4050-82ed-ca223e38b2c9", + Name: "policy1", + Description: "policy1", + Rules: `node_prefix "" { policy = "read" }`, + Syntax: acl.SyntaxCurrent, + }, + &structs.ACLPolicy{ + ID: "7b70fa0f-58cd-412d-93c3-a0f17bb19a3e", + Name: "policy2", + Description: "policy2", + Rules: `acl = "read"`, + Syntax: acl.SyntaxCurrent, + }, + } + + for _, policy := range policies { + policy.SetHash(true) + } + + require.NoError(t, s.ACLPolicyBatchSet(2, policies)) + + roles := structs.ACLRoles{ + &structs.ACLRole{ + ID: "68016c3d-835b-450c-a6f9-75db9ba740be", + Name: "838f72b5-5c15-4a9e-aa6d-31734c3a0286", + Description: "policy1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "ca1fc52c-3676-4050-82ed-ca223e38b2c9", + Name: "policy1", + }, + structs.ACLRolePolicyLink{ + ID: "7b70fa0f-58cd-412d-93c3-a0f17bb19a3e", + Name: "policy2", + }, + }, + Hash: []byte{1, 2, 3, 4}, + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + &structs.ACLRole{ + ID: "b2125a1b-2a52-41d4-88f3-c58761998a46", + Name: "ba5d9239-a4ab-49b9-ae09-1f19eed92204", + Description: "policy2", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "ca1fc52c-3676-4050-82ed-ca223e38b2c9", + Name: "policy1", + }, + structs.ACLRolePolicyLink{ + ID: "7b70fa0f-58cd-412d-93c3-a0f17bb19a3e", + Name: "policy2", + }, + }, + Hash: []byte{1, 2, 3, 4}, + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + } + + require.NoError(t, s.ACLRoleBatchSet(2, roles)) + + // Snapshot the ACLs. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + require.NoError(t, s.ACLRoleDeleteByID(3, roles[0].ID)) + + // Verify the snapshot. + require.Equal(t, uint64(2), snap.LastIndex()) + + iter, err := snap.ACLRoles() + require.NoError(t, err) + + var dump structs.ACLRoles + for role := iter.Next(); role != nil; role = iter.Next() { + dump = append(dump, role.(*structs.ACLRole)) + } + require.ElementsMatch(t, dump, roles) + + // Restore the values into a new state store. + func() { + s := testStateStore(t) + restore := s.Restore() + for _, role := range dump { + require.NoError(t, restore.ACLRole(role)) + } + restore.Commit() + + // need to ensure we have the policies or else the links will be removed + require.NoError(t, s.ACLPolicyBatchSet(2, policies)) + + // Read the restored ACLs back out and verify that they match. + idx, res, err := s.ACLRoleList(nil, "") + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.ElementsMatch(t, roles, res) + require.Equal(t, uint64(2), s.maxIndex("acl-roles")) + }() +} diff --git a/agent/consul/state/state_store.go b/agent/consul/state/state_store.go index 7b42c7ad4f..4dcc74ddde 100644 --- a/agent/consul/state/state_store.go +++ b/agent/consul/state/state_store.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/hashicorp/consul/types" - "github.com/hashicorp/go-memdb" + memdb "github.com/hashicorp/go-memdb" ) var ( @@ -37,6 +37,18 @@ var ( // policy with an empty Name. ErrMissingACLPolicyName = errors.New("Missing ACL Policy Name") + // ErrMissingACLRoleID is returned when an role set is called on + // a role with an empty ID. + ErrMissingACLRoleID = errors.New("Missing ACL Role ID") + + // ErrMissingACLRoleName is returned when an role set is called on + // a role with an empty Name. + ErrMissingACLRoleName = errors.New("Missing ACL Role Name") + + // ErrInvalidACLRoleName is returned when an role set is called on + // a role with an invalid Name. + ErrInvalidACLRoleName = errors.New("Invalid ACL Role Name") + // ErrMissingQueryID is returned when a Query set is called on // a Query with an empty ID. ErrMissingQueryID = errors.New("Missing Query ID") diff --git a/agent/http_oss.go b/agent/http_oss.go index e524e450a9..6a0d5917bc 100644 --- a/agent/http_oss.go +++ b/agent/http_oss.go @@ -14,6 +14,10 @@ func init() { registerEndpoint("/v1/acl/policies", []string{"GET"}, (*HTTPServer).ACLPolicyList) registerEndpoint("/v1/acl/policy", []string{"PUT"}, (*HTTPServer).ACLPolicyCreate) registerEndpoint("/v1/acl/policy/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).ACLPolicyCRUD) + registerEndpoint("/v1/acl/roles", []string{"GET"}, (*HTTPServer).ACLRoleList) + registerEndpoint("/v1/acl/role", []string{"PUT"}, (*HTTPServer).ACLRoleCreate) + registerEndpoint("/v1/acl/role/name/", []string{"GET"}, (*HTTPServer).ACLRoleReadByName) + registerEndpoint("/v1/acl/role/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).ACLRoleCRUD) registerEndpoint("/v1/acl/rules/translate", []string{"POST"}, (*HTTPServer).ACLRulesTranslate) registerEndpoint("/v1/acl/rules/translate/", []string{"GET"}, (*HTTPServer).ACLRulesTranslateLegacyToken) registerEndpoint("/v1/acl/tokens", []string{"GET"}, (*HTTPServer).ACLTokenList) diff --git a/agent/structs/acl.go b/agent/structs/acl.go index 50d4df0b52..0b46b21e31 100644 --- a/agent/structs/acl.go +++ b/agent/structs/acl.go @@ -129,6 +129,7 @@ type ACLIdentity interface { ID() string SecretToken() string PolicyIDs() []string + RoleIDs() []string EmbeddedPolicy() *ACLPolicy ServiceIdentityList() []*ACLServiceIdentity IsExpired(asOf time.Time) bool @@ -139,6 +140,11 @@ type ACLTokenPolicyLink struct { Name string `hash:"ignore"` } +type ACLTokenRoleLink struct { + ID string + Name string `hash:"ignore"` +} + // ACLServiceIdentity represents a high-level grant of all necessary privileges // to assume the identity of the named Service in the Catalog and within // Connect. @@ -166,6 +172,14 @@ func (s *ACLServiceIdentity) AddToHash(h hash.Hash) { } } +func (s *ACLServiceIdentity) EstimateSize() int { + size := len(s.ServiceName) + for _, dc := range s.Datacenters { + size += len(dc) + } + return size +} + func (s *ACLServiceIdentity) SyntheticPolicy() *ACLPolicy { // Given that we validate this string name before persisting, we do not // have to escape it before doing the following interpolation. @@ -197,6 +211,11 @@ type ACLToken struct { // the list of policy names gets validated and the policy IDs get stored herein Policies []ACLTokenPolicyLink `json:",omitempty"` + // List of role links. Note this is the list of IDs and not the names. + // Prior to token creation the list of role names gets validated and the + // role IDs get stored herein + Roles []ACLTokenRoleLink `json:",omitempty"` + // List of services to generate synthetic policies for. ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` @@ -249,12 +268,17 @@ type ACLToken struct { func (t *ACLToken) Clone() *ACLToken { t2 := *t t2.Policies = nil + t2.Roles = nil t2.ServiceIdentities = nil if len(t.Policies) > 0 { t2.Policies = make([]ACLTokenPolicyLink, len(t.Policies)) copy(t2.Policies, t.Policies) } + if len(t.Roles) > 0 { + t2.Roles = make([]ACLTokenRoleLink, len(t.Roles)) + copy(t2.Roles, t.Roles) + } if len(t.ServiceIdentities) > 0 { t2.ServiceIdentities = make([]*ACLServiceIdentity, len(t.ServiceIdentities)) for i, s := range t.ServiceIdentities { @@ -284,6 +308,14 @@ func (t *ACLToken) PolicyIDs() []string { return ids } +func (t *ACLToken) RoleIDs() []string { + var ids []string + for _, link := range t.Roles { + ids = append(ids, link.ID) + } + return ids +} + func (t *ACLToken) ServiceIdentityList() []*ACLServiceIdentity { if len(t.ServiceIdentities) == 0 { return nil @@ -310,6 +342,7 @@ func (t *ACLToken) HasExpirationTime() bool { func (t *ACLToken) UsesNonLegacyFields() bool { return len(t.Policies) > 0 || len(t.ServiceIdentities) > 0 || + len(t.Roles) > 0 || t.Type == "" || t.HasExpirationTime() || t.ExpirationTTL != 0 @@ -376,6 +409,10 @@ func (t *ACLToken) SetHash(force bool) []byte { hash.Write([]byte(link.ID)) } + for _, link := range t.Roles { + hash.Write([]byte(link.ID)) + } + for _, srvid := range t.ServiceIdentities { srvid.AddToHash(hash) } @@ -395,11 +432,11 @@ func (t *ACLToken) EstimateSize() int { for _, link := range t.Policies { size += len(link.ID) + len(link.Name) } + for _, link := range t.Roles { + size += len(link.ID) + len(link.Name) + } for _, srvid := range t.ServiceIdentities { - size += len(srvid.ServiceName) - for _, dc := range srvid.Datacenters { - size += len(dc) - } + size += srvid.EstimateSize() } return size } @@ -411,6 +448,7 @@ type ACLTokenListStub struct { AccessorID string Description string Policies []ACLTokenPolicyLink `json:",omitempty"` + Roles []ACLTokenRoleLink `json:",omitempty"` ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` Local bool ExpirationTime *time.Time `json:",omitempty"` @@ -428,6 +466,7 @@ func (token *ACLToken) Stub() *ACLTokenListStub { AccessorID: token.AccessorID, Description: token.Description, Policies: token.Policies, + Roles: token.Roles, ServiceIdentities: token.ServiceIdentities, Local: token.Local, ExpirationTime: token.ExpirationTime, @@ -650,14 +689,160 @@ func (policies ACLPolicies) Merge(cache *ACLCaches, sentinel sentinel.Evaluator) return acl.MergePolicies(parsed), nil } +type ACLRoles []*ACLRole + +// HashKey returns a consistent hash for a set of roles. +func (roles ACLRoles) HashKey() string { + cacheKeyHash, err := blake2b.New256(nil) + if err != nil { + panic(err) + } + for _, role := range roles { + cacheKeyHash.Write([]byte(role.ID)) + // including the modify index prevents a role set from being + // cached if one of the roles has changed + binary.Write(cacheKeyHash, binary.BigEndian, role.ModifyIndex) + } + return fmt.Sprintf("%x", cacheKeyHash.Sum(nil)) +} + +func (roles ACLRoles) Sort() { + sort.Slice(roles, func(i, j int) bool { + return roles[i].ID < roles[j].ID + }) +} + +type ACLRolePolicyLink struct { + ID string + Name string `hash:"ignore"` +} + +type ACLRole struct { + // ID is the internal UUID associated with the role + ID string + + // Name is the unique name to reference the role by. + // + // Validated with structs.isValidRoleName() + Name string + + // Description is a human readable description (Optional) + Description string + + // List of policy links. + // Note this is the list of IDs and not the names. Prior to role creation + // the list of policy names gets validated and the policy IDs get stored herein + Policies []ACLRolePolicyLink `json:",omitempty"` + + // List of services to generate synthetic policies for. + ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` + + // Hash of the contents of the role + // This does not take into account the ID (which is immutable) + // nor the raft metadata. + // + // This is needed mainly for replication purposes. When replicating from + // one DC to another keeping the content Hash will allow us to avoid + // unnecessary calls to the authoritative DC + Hash []byte + + // Embedded Raft Metadata + RaftIndex `hash:"ignore"` +} + +func (r *ACLRole) Clone() *ACLRole { + r2 := *r + r2.Policies = nil + r2.ServiceIdentities = nil + + if len(r.Policies) > 0 { + r2.Policies = make([]ACLRolePolicyLink, len(r.Policies)) + copy(r2.Policies, r.Policies) + } + if len(r.ServiceIdentities) > 0 { + r2.ServiceIdentities = make([]*ACLServiceIdentity, len(r.ServiceIdentities)) + for i, s := range r.ServiceIdentities { + r2.ServiceIdentities[i] = s.Clone() + } + } + return &r2 +} + +func (r *ACLRole) SetHash(force bool) []byte { + if force || r.Hash == nil { + // Initialize a 256bit Blake2 hash (32 bytes) + hash, err := blake2b.New256(nil) + if err != nil { + panic(err) + } + + // Any non-immutable "content" fields should be involved with the + // overall hash. The ID is immutable which is why it isn't here. The + // raft indices are metadata similar to the hash which is why they + // aren't incorporated. CreateTime is similarly immutable + // + // The Hash is really only used for replication to determine if a role + // has changed and should be updated locally. + + // Write all the user set fields + hash.Write([]byte(r.Name)) + hash.Write([]byte(r.Description)) + for _, link := range r.Policies { + hash.Write([]byte(link.ID)) + } + for _, srvid := range r.ServiceIdentities { + srvid.AddToHash(hash) + } + + // Finalize the hash + hashVal := hash.Sum(nil) + + // Set and return the hash + r.Hash = hashVal + } + return r.Hash +} + +func (r *ACLRole) EstimateSize() int { + // This is just an estimate. There is other data structure overhead + // pointers etc that this does not account for. + + // 60 = 36 (uuid) + 16 (RaftIndex) + 8 (Hash) + size := 60 + len(r.Name) + len(r.Description) + for _, link := range r.Policies { + size += len(link.ID) + len(link.Name) + } + for _, srvid := range r.ServiceIdentities { + size += srvid.EstimateSize() + } + + return size +} + type ACLReplicationType string const ( ACLReplicateLegacy ACLReplicationType = "legacy" ACLReplicatePolicies ACLReplicationType = "policies" + ACLReplicateRoles ACLReplicationType = "roles" ACLReplicateTokens ACLReplicationType = "tokens" ) +func (t ACLReplicationType) SingularNoun() string { + switch t { + case ACLReplicateLegacy: + return "legacy" + case ACLReplicatePolicies: + return "policy" + case ACLReplicateRoles: + return "role" + case ACLReplicateTokens: + return "token" + default: + return "" + } +} + // ACLReplicationStatus provides information about the health of the ACL // replication system. type ACLReplicationStatus struct { @@ -666,6 +851,7 @@ type ACLReplicationStatus struct { SourceDatacenter string ReplicationType ACLReplicationType ReplicatedIndex uint64 + ReplicatedRoleIndex uint64 ReplicatedTokenIndex uint64 LastSuccess time.Time LastError time.Time @@ -711,6 +897,7 @@ type ACLTokenListRequest struct { IncludeLocal bool // Whether local tokens should be included IncludeGlobal bool // Whether global tokens should be included Policy string // Policy filter + Role string // Role filter Datacenter string // The datacenter to perform the request within QueryOptions } @@ -878,3 +1065,92 @@ func cloneStringSlice(s []string) []string { copy(out, s) return out } + +// ACLRoleSetRequest is used at the RPC layer for creation and update requests +type ACLRoleSetRequest struct { + Role ACLRole // The policy to upsert + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLRoleSetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLRoleDeleteRequest is used at the RPC layer deletion requests +type ACLRoleDeleteRequest struct { + RoleID string // id of the role to delete + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLRoleDeleteRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLRoleGetRequest is used at the RPC layer to perform role read operations +type ACLRoleGetRequest struct { + RoleID string // id used for the role lookup (one of RoleID or RoleName is allowed) + RoleName string // name used for the role lookup (one of RoleID or RoleName is allowed) + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLRoleGetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLRoleListRequest is used at the RPC layer to request a listing of roles +type ACLRoleListRequest struct { + Policy string // Policy filter + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLRoleListRequest) RequestDatacenter() string { + return r.Datacenter +} + +type ACLRoleListResponse struct { + Roles ACLRoles + QueryMeta +} + +// ACLRoleBatchGetRequest is used at the RPC layer to request a subset of +// the roles associated with the token used for retrieval +type ACLRoleBatchGetRequest struct { + RoleIDs []string // List of role ids to fetch + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLRoleBatchGetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLRoleResponse returns a single role + metadata +type ACLRoleResponse struct { + Role *ACLRole + QueryMeta +} + +type ACLRoleBatchResponse struct { + Roles []*ACLRole + QueryMeta +} + +// ACLRoleBatchSetRequest is used at the Raft layer for batching +// multiple role creations and updates +// +// This is particularly useful during replication +type ACLRoleBatchSetRequest struct { + Roles ACLRoles +} + +// ACLRoleBatchDeleteRequest is used at the Raft layer for batching +// multiple role deletions +// +// This is particularly useful during replication +type ACLRoleBatchDeleteRequest struct { + RoleIDs []string +} diff --git a/agent/structs/acl_cache.go b/agent/structs/acl_cache.go index 9e7df64053..8a4f494194 100644 --- a/agent/structs/acl_cache.go +++ b/agent/structs/acl_cache.go @@ -4,7 +4,7 @@ import ( "time" "github.com/hashicorp/consul/acl" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" ) type ACLCachesConfig struct { @@ -12,6 +12,7 @@ type ACLCachesConfig struct { Policies int ParsedPolicies int Authorizers int + Roles int } type ACLCaches struct { @@ -19,6 +20,7 @@ type ACLCaches struct { parsedPolicies *lru.TwoQueueCache // policy content hash -> acl.Policy policies *lru.TwoQueueCache // policy ID -> ACLPolicy authorizers *lru.TwoQueueCache // token secret -> acl.Authorizer + roles *lru.TwoQueueCache // role ID -> ACLRole } type IdentityCacheEntry struct { @@ -58,6 +60,16 @@ func (e *AuthorizerCacheEntry) Age() time.Duration { return time.Since(e.CacheTime) } +// RoleCacheEntry is the payload for by by-id and by-name caches. +type RoleCacheEntry struct { + Role *ACLRole + CacheTime time.Time +} + +func (e *RoleCacheEntry) Age() time.Duration { + return time.Since(e.CacheTime) +} + func NewACLCaches(config *ACLCachesConfig) (*ACLCaches, error) { cache := &ACLCaches{} @@ -97,6 +109,15 @@ func NewACLCaches(config *ACLCachesConfig) (*ACLCaches, error) { cache.authorizers = authCache } + if config != nil && config.Roles > 0 { + roleCache, err := lru.New2Q(config.Roles) + if err != nil { + return nil, err + } + + cache.roles = roleCache + } + return cache, nil } @@ -152,6 +173,19 @@ func (c *ACLCaches) GetAuthorizer(id string) *AuthorizerCacheEntry { return nil } +// GetRoleByID fetches a role from the cache by id and returns it +func (c *ACLCaches) GetRole(roleID string) *RoleCacheEntry { + if c == nil || c.roles == nil { + return nil + } + + if raw, ok := c.roles.Get(roleID); ok { + return raw.(*RoleCacheEntry) + } + + return nil +} + // PutIdentity adds a new identity to the cache func (c *ACLCaches) PutIdentity(id string, ident ACLIdentity) { if c == nil || c.identities == nil { @@ -193,6 +227,12 @@ func (c *ACLCaches) PutAuthorizerWithTTL(id string, authorizer acl.Authorizer, t c.authorizers.Add(id, &AuthorizerCacheEntry{Authorizer: authorizer, CacheTime: time.Now(), TTL: ttl}) } +func (c *ACLCaches) PutRole(roleID string, role *ACLRole) { + if c != nil && c.roles != nil { + c.roles.Add(roleID, &RoleCacheEntry{Role: role, CacheTime: time.Now()}) + } +} + func (c *ACLCaches) RemoveIdentity(id string) { if c != nil && c.identities != nil { c.identities.Remove(id) @@ -205,6 +245,12 @@ func (c *ACLCaches) RemovePolicy(policyID string) { } } +func (c *ACLCaches) RemoveRole(roleID string) { + if c != nil && c.roles != nil && roleID != "" { + c.roles.Remove(roleID) + } +} + func (c *ACLCaches) Purge() { if c != nil { if c.identities != nil { @@ -219,5 +265,8 @@ func (c *ACLCaches) Purge() { if c.authorizers != nil { c.authorizers.Purge() } + if c.roles != nil { + c.roles.Purge() + } } } diff --git a/agent/structs/acl_cache_test.go b/agent/structs/acl_cache_test.go index 471dc408a2..dbbf717c88 100644 --- a/agent/structs/acl_cache_test.go +++ b/agent/structs/acl_cache_test.go @@ -16,7 +16,7 @@ func TestStructs_ACLCaches(t *testing.T) { t.Run("Valid Sizes", func(t *testing.T) { t.Parallel() // 1 isn't valid due to a bug in golang-lru library - config := ACLCachesConfig{2, 2, 2, 2} + config := ACLCachesConfig{2, 2, 2, 2, 2} cache, err := NewACLCaches(&config) require.NoError(t, err) @@ -30,7 +30,7 @@ func TestStructs_ACLCaches(t *testing.T) { t.Run("Zero Sizes", func(t *testing.T) { t.Parallel() // 1 isn't valid due to a bug in golang-lru library - config := ACLCachesConfig{0, 0, 0, 0} + config := ACLCachesConfig{0, 0, 0, 0, 0} cache, err := NewACLCaches(&config) require.NoError(t, err) @@ -102,4 +102,19 @@ func TestStructs_ACLCaches(t *testing.T) { require.NotNil(t, entry.Authorizer) require.True(t, entry.Authorizer == acl.DenyAll()) }) + + t.Run("Roles", func(t *testing.T) { + t.Parallel() + // 1 isn't valid due to a bug in golang-lru library + config := ACLCachesConfig{Roles: 4} + + cache, err := NewACLCaches(&config) + require.NoError(t, err) + require.NotNil(t, cache) + + cache.PutRole("foo", &ACLRole{}) + entry := cache.GetRole("foo") + require.NotNil(t, entry) + require.NotNil(t, entry.Role) + }) } diff --git a/agent/structs/acl_test.go b/agent/structs/acl_test.go index bfc585a55d..0d69c9886d 100644 --- a/agent/structs/acl_test.go +++ b/agent/structs/acl_test.go @@ -514,6 +514,7 @@ func TestStructs_ACLPolicies_resolveWithCache(t *testing.T) { Policies: 0, ParsedPolicies: 4, Authorizers: 0, + Roles: 0, } cache, err := NewACLCaches(&config) require.NoError(t, err) @@ -606,6 +607,7 @@ func TestStructs_ACLPolicies_Compile(t *testing.T) { Policies: 0, ParsedPolicies: 4, Authorizers: 2, + Roles: 0, } cache, err := NewACLCaches(&config) require.NoError(t, err) diff --git a/agent/structs/structs.go b/agent/structs/structs.go index c1acf7a5b4..2ad0a4e97c 100644 --- a/agent/structs/structs.go +++ b/agent/structs/structs.go @@ -56,6 +56,8 @@ const ( ACLPolicyDeleteRequestType = 20 ConnectCALeafRequestType = 21 ConfigEntryRequestType = 22 + ACLRoleSetRequestType = 23 + ACLRoleDeleteRequestType = 24 ) const ( diff --git a/api/acl.go b/api/acl.go index e920c46d6e..2713d0ddc9 100644 --- a/api/acl.go +++ b/api/acl.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "io/ioutil" + "net/url" "time" ) @@ -19,6 +20,10 @@ type ACLTokenPolicyLink struct { ID string Name string } +type ACLTokenRoleLink struct { + ID string + Name string +} // ACLToken represents an ACL Token type ACLToken struct { @@ -28,6 +33,7 @@ type ACLToken struct { SecretID string Description string Policies []*ACLTokenPolicyLink `json:",omitempty"` + Roles []*ACLTokenRoleLink `json:",omitempty"` ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` Local bool ExpirationTTL time.Duration `json:",omitempty"` @@ -46,6 +52,7 @@ type ACLTokenListEntry struct { AccessorID string Description string Policies []*ACLTokenPolicyLink `json:",omitempty"` + Roles []*ACLTokenRoleLink `json:",omitempty"` ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` Local bool ExpirationTime *time.Time `json:",omitempty"` @@ -72,6 +79,7 @@ type ACLReplicationStatus struct { SourceDatacenter string ReplicationType string ReplicatedIndex uint64 + ReplicatedRoleIndex uint64 ReplicatedTokenIndex uint64 LastSuccess time.Time LastError time.Time @@ -107,6 +115,23 @@ type ACLPolicyListEntry struct { ModifyIndex uint64 } +type ACLRolePolicyLink struct { + ID string + Name string +} + +// ACLRole represents an ACL Role. +type ACLRole struct { + ID string + Name string + Description string + Policies []*ACLRolePolicyLink `json:",omitempty"` + ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` + Hash []byte + CreateIndex uint64 + ModifyIndex uint64 +} + // ACL can be used to query the ACL endpoints type ACL struct { c *Client @@ -599,3 +624,142 @@ func (a *ACL) RulesTranslateToken(tokenID string) (string, error) { return string(ruleBytes), nil } + +// RoleCreate will create a new role. It is not allowed for the role parameters +// ID field to be set as this will be generated by Consul while processing the request. +func (a *ACL) RoleCreate(role *ACLRole, q *WriteOptions) (*ACLRole, *WriteMeta, error) { + if role.ID != "" { + return nil, nil, fmt.Errorf("Cannot specify an ID in Role Creation") + } + + r := a.c.newRequest("PUT", "/v1/acl/role") + r.setWriteOptions(q) + r.obj = role + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLRole + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// RoleUpdate updates a role. The ID field of the role parameter must be set to an +// existing role ID +func (a *ACL) RoleUpdate(role *ACLRole, q *WriteOptions) (*ACLRole, *WriteMeta, error) { + if role.ID == "" { + return nil, nil, fmt.Errorf("Must specify an ID in Role Creation") + } + + r := a.c.newRequest("PUT", "/v1/acl/role/"+role.ID) + r.setWriteOptions(q) + r.obj = role + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLRole + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// RoleDelete deletes a role given its ID. +func (a *ACL) RoleDelete(roleID string, q *WriteOptions) (*WriteMeta, error) { + r := a.c.newRequest("DELETE", "/v1/acl/role/"+roleID) + r.setWriteOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, err + } + resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + return wm, nil +} + +// RoleRead retrieves the role details (by ID). Returns nil if not found. +func (a *ACL) RoleRead(roleID string, q *QueryOptions) (*ACLRole, *QueryMeta, error) { + r := a.c.newRequest("GET", "/v1/acl/role/"+roleID) + r.setQueryOptions(q) + found, rtt, resp, err := requireNotFoundOrOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + if !found { + return nil, qm, nil + } + + var out ACLRole + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, qm, nil +} + +// RoleReadByName retrieves the role details (by name). Returns nil if not found. +func (a *ACL) RoleReadByName(roleName string, q *QueryOptions) (*ACLRole, *QueryMeta, error) { + r := a.c.newRequest("GET", "/v1/acl/role/name/"+url.QueryEscape(roleName)) + r.setQueryOptions(q) + found, rtt, resp, err := requireNotFoundOrOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + if !found { + return nil, qm, nil + } + + var out ACLRole + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, qm, nil +} + +// RoleList retrieves a listing of all roles. The listing does not include some +// metadata for the role as those should be retrieved by subsequent calls to +// RoleRead. +func (a *ACL) RoleList(q *QueryOptions) ([]*ACLRole, *QueryMeta, error) { + r := a.c.newRequest("GET", "/v1/acl/roles") + r.setQueryOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + var entries []*ACLRole + if err := decodeBody(resp, &entries); err != nil { + return nil, nil, err + } + return entries, qm, nil +} diff --git a/api/api.go b/api/api.go index ffa2ce24df..e8370d0441 100644 --- a/api/api.go +++ b/api/api.go @@ -897,10 +897,7 @@ func requireOK(d time.Duration, resp *http.Response, e error) (time.Duration, *h return d, nil, e } if resp.StatusCode != 200 { - var buf bytes.Buffer - io.Copy(&buf, resp.Body) - resp.Body.Close() - return d, nil, fmt.Errorf("Unexpected response code: %d (%s)", resp.StatusCode, buf.Bytes()) + return d, nil, generateUnexpectedResponseCodeError(resp) } return d, resp, nil } @@ -912,3 +909,30 @@ func (req *request) filterQuery(filter string) { req.params.Set("filter", filter) } + +// generateUnexpectedResponseCodeError consumes the rest of the body, closes +// the body stream and generates an error indicating the status code was +// unexpected. +func generateUnexpectedResponseCodeError(resp *http.Response) error { + var buf bytes.Buffer + io.Copy(&buf, resp.Body) + resp.Body.Close() + return fmt.Errorf("Unexpected response code: %d (%s)", resp.StatusCode, buf.Bytes()) +} + +func requireNotFoundOrOK(d time.Duration, resp *http.Response, e error) (bool, time.Duration, *http.Response, error) { + if e != nil { + if resp != nil { + resp.Body.Close() + } + return false, d, nil, e + } + switch resp.StatusCode { + case 200: + return true, d, resp, nil + case 404: + return false, d, resp, nil + default: + return false, d, nil, generateUnexpectedResponseCodeError(resp) + } +} diff --git a/command/acl/acl_helpers.go b/command/acl/acl_helpers.go index 843a91487a..928f8beb55 100644 --- a/command/acl/acl_helpers.go +++ b/command/acl/acl_helpers.go @@ -27,6 +27,10 @@ func PrintToken(token *api.ACLToken, ui cli.Ui, showMeta bool) { for _, policy := range token.Policies { ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) } + ui.Info(fmt.Sprintf("Roles:")) + for _, role := range token.Roles { + ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name)) + } ui.Info(fmt.Sprintf("Service Identities:")) for _, svcid := range token.ServiceIdentities { if len(svcid.Datacenters) > 0 { @@ -59,6 +63,10 @@ func PrintTokenListEntry(token *api.ACLTokenListEntry, ui cli.Ui, showMeta bool) for _, policy := range token.Policies { ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) } + ui.Info(fmt.Sprintf("Roles:")) + for _, role := range token.Roles { + ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name)) + } ui.Info(fmt.Sprintf("Service Identities:")) for _, svcid := range token.ServiceIdentities { if len(svcid.Datacenters) > 0 { @@ -95,6 +103,52 @@ func PrintPolicyListEntry(policy *api.ACLPolicyListEntry, ui cli.Ui, showMeta bo } } +func PrintRole(role *api.ACLRole, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("ID: %s", role.ID)) + ui.Info(fmt.Sprintf("Name: %s", role.Name)) + ui.Info(fmt.Sprintf("Description: %s", role.Description)) + if showMeta { + ui.Info(fmt.Sprintf("Hash: %x", role.Hash)) + ui.Info(fmt.Sprintf("Create Index: %d", role.CreateIndex)) + ui.Info(fmt.Sprintf("Modify Index: %d", role.ModifyIndex)) + } + ui.Info(fmt.Sprintf("Policies:")) + for _, policy := range role.Policies { + ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + } + ui.Info(fmt.Sprintf("Service Identities:")) + for _, svcid := range role.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } + } +} + +func PrintRoleListEntry(role *api.ACLRole, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("%s:", role.Name)) + ui.Info(fmt.Sprintf(" ID: %s", role.ID)) + ui.Info(fmt.Sprintf(" Description: %s", role.Description)) + if showMeta { + ui.Info(fmt.Sprintf(" Hash: %x", role.Hash)) + ui.Info(fmt.Sprintf(" Create Index: %d", role.CreateIndex)) + ui.Info(fmt.Sprintf(" Modify Index: %d", role.ModifyIndex)) + } + ui.Info(fmt.Sprintf(" Policies:")) + for _, policy := range role.Policies { + ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + } + ui.Info(fmt.Sprintf(" Service Identities:")) + for _, svcid := range role.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } + } +} + func GetTokenIDFromPartial(client *api.Client, partialID string) (string, error) { if partialID == "anonymous" { return structs.ACLTokenAnonymousID, nil @@ -208,6 +262,53 @@ func GetRulesFromLegacyToken(client *api.Client, tokenID string, isSecret bool) return token.Rules, nil } +func GetRoleIDFromPartial(client *api.Client, partialID string) (string, error) { + // the full UUID string was given + if len(partialID) == 36 { + return partialID, nil + } + + roles, _, err := client.ACL().RoleList(nil) + if err != nil { + return "", err + } + + roleID := "" + for _, role := range roles { + if strings.HasPrefix(role.ID, partialID) { + if roleID != "" { + return "", fmt.Errorf("Partial role ID is not unique") + } + roleID = role.ID + } + } + + if roleID == "" { + return "", fmt.Errorf("No such role ID with prefix: %s", partialID) + } + + return roleID, nil +} + +func GetRoleIDByName(client *api.Client, name string) (string, error) { + if name == "" { + return "", fmt.Errorf("No name specified") + } + + roles, _, err := client.ACL().RoleList(nil) + if err != nil { + return "", err + } + + for _, role := range roles { + if role.Name == name { + return role.ID, nil + } + } + + return "", fmt.Errorf("No such role with name %s", name) +} + func ExtractServiceIdentities(serviceIdents []string) ([]*api.ACLServiceIdentity, error) { var out []*api.ACLServiceIdentity for _, svcidRaw := range serviceIdents { diff --git a/command/acl/role/create/role_create.go b/command/acl/role/create/role_create.go new file mode 100644 index 0000000000..8ab869c00c --- /dev/null +++ b/command/acl/role/create/role_create.go @@ -0,0 +1,134 @@ +package rolecreate + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + aclhelpers "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + name string + description string + policyIDs []string + policyNames []string + serviceIdents []string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.flags.BoolVar(&c.showMeta, "meta", false, "Indicates that role metadata such "+ + "as the content hash and raft indices should be shown for each entry") + c.flags.StringVar(&c.name, "name", "", "The new role's name. This flag is required.") + c.flags.StringVar(&c.description, "description", "", "A description of the role") + c.flags.Var((*flags.AppendSliceValue)(&c.policyIDs), "policy-id", "ID of a "+ + "policy to use for this role. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.policyNames), "policy-name", "Name of a "+ + "policy to use for this role. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.serviceIdents), "service-identity", "Name of a "+ + "service identity to use for this role. May be specified multiple times. Format is "+ + "the SERVICENAME or SERVICENAME:DATACENTER1,DATACENTER2,...") + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.name == "" { + c.UI.Error(fmt.Sprintf("Missing require '-name' flag")) + c.UI.Error(c.Help()) + return 1 + } + + if len(c.policyNames) == 0 && len(c.policyIDs) == 0 && len(c.serviceIdents) == 0 { + c.UI.Error(fmt.Sprintf("Cannot create a role without specifying -policy-name, -policy-id, or -service-identity at least once")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + newRole := &api.ACLRole{ + Name: c.name, + Description: c.description, + } + + for _, policyName := range c.policyNames { + // We could resolve names to IDs here but there isn't any reason why its would be better + // than allowing the agent to do it. + newRole.Policies = append(newRole.Policies, &api.ACLRolePolicyLink{Name: policyName}) + } + + for _, policyID := range c.policyIDs { + policyID, err := acl.GetPolicyIDFromPartial(client, policyID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error resolving policy ID %s: %v", policyID, err)) + return 1 + } + newRole.Policies = append(newRole.Policies, &api.ACLRolePolicyLink{ID: policyID}) + } + + parsedServiceIdents, err := acl.ExtractServiceIdentities(c.serviceIdents) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + newRole.ServiceIdentities = parsedServiceIdents + + role, _, err := client.ACL().RoleCreate(newRole, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to create new role: %v", err)) + return 1 + } + + aclhelpers.PrintRole(role, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Create an ACL Role" + +const help = ` +Usage: consul acl role create -name NAME [options] + + Create a new role: + + $ consul acl role create -name "new-role" \ + -description "This is an example role" \ + -policy-id b52fc3de-5 \ + -policy-name "acl-replication" \ + -service-identity "web" \ + -service-identity "db:east,west" +` diff --git a/command/acl/role/create/role_create_test.go b/command/acl/role/create/role_create_test.go new file mode 100644 index 0000000000..d592aba99c --- /dev/null +++ b/command/acl/role/create/role_create_test.go @@ -0,0 +1,116 @@ +package rolecreate + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" +) + +func TestRoleCreateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestRoleCreateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + ui := cli.NewMockUi() + cmd := New(ui) + + // Create a policy + client := a.Client() + + policy, _, err := client.ACL().PolicyCreate( + &api.ACLPolicy{Name: "test-policy"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + // create with policy by name + { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=role-with-policy-by-name", + "-description=test-role", + "-policy-name=" + policy.Name, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + } + + // create with policy by id + { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=role-with-policy-by-id", + "-description=test-role", + "-policy-id=" + policy.ID, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + } + + // create with service identity + { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=role-with-service-identity", + "-description=test-role", + "-service-identity=web", + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + } + + // create with service identity scoped to 2 DCs + { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=role-with-service-identity-in-2-dcs", + "-description=test-role", + "-service-identity=db:abc,xyz", + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + } +} diff --git a/command/acl/role/delete/role_delete.go b/command/acl/role/delete/role_delete.go new file mode 100644 index 0000000000..543133ae1b --- /dev/null +++ b/command/acl/role/delete/role_delete.go @@ -0,0 +1,91 @@ +package roledelete + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + roleID string +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.flags.StringVar(&c.roleID, "id", "", "The ID of the role to delete. "+ + "It may be specified as a unique ID prefix but will error if the prefix "+ + "matches multiple role IDs") + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.roleID == "" { + c.UI.Error(fmt.Sprintf("Must specify the -id parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + roleID, err := acl.GetRoleIDFromPartial(client, c.roleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error determining role ID: %v", err)) + return 1 + } + + if _, err := client.ACL().RoleDelete(roleID, nil); err != nil { + c.UI.Error(fmt.Sprintf("Error deleting role %q: %v", roleID, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Role %q deleted successfully", roleID)) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Delete an ACL Role" +const help = ` +Usage: consul acl role delete [options] -id ROLE + + Deletes an ACL role by providing the ID or a unique ID prefix. + + Delete by prefix: + + $ consul acl role delete -id b6b85 + + Delete by full ID: + + $ consul acl role delete -id b6b856da-5193-4e78-845a-7d61ca8371ba + +` diff --git a/command/acl/role/delete/role_delete_test.go b/command/acl/role/delete/role_delete_test.go new file mode 100644 index 0000000000..e6523941f6 --- /dev/null +++ b/command/acl/role/delete/role_delete_test.go @@ -0,0 +1,141 @@ +package roledelete + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" +) + +func TestRoleDeleteCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestRoleDeleteCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("id required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id parameter") + }) + + t.Run("delete works", func(t *testing.T) { + // Create a role + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{ + Name: "test-role-for-id-delete", + ServiceIdentities: []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ + ServiceName: "fake", + }, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + role.ID, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, role.ID) + + role, _, err = client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.Nil(t, role) + }) + + t.Run("delete works via prefixes", func(t *testing.T) { + // Create a role + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{ + Name: "test-role-for-id-prefix-delete", + ServiceIdentities: []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ + ServiceName: "fake", + }, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + role.ID[0:5], + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, role.ID) + + role, _, err = client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.Nil(t, role) + }) +} diff --git a/command/acl/role/list/role_list.go b/command/acl/role/list/role_list.go new file mode 100644 index 0000000000..95a3741890 --- /dev/null +++ b/command/acl/role/list/role_list.go @@ -0,0 +1,79 @@ +package rolelist + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.flags.BoolVar(&c.showMeta, "meta", false, "Indicates that policy metadata such "+ + "as the content hash and raft indices should be shown for each entry") + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + roles, _, err := client.ACL().RoleList(nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to retrieve the role list: %v", err)) + return 1 + } + + for _, role := range roles { + acl.PrintRoleListEntry(role, c.UI, c.showMeta) + } + + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Lists ACL Roles" +const help = ` +Usage: consul acl role list [options] + + Lists all the ACL roles. + + Example: + + $ consul acl role list +` diff --git a/command/acl/role/list/role_list_test.go b/command/acl/role/list/role_list_test.go new file mode 100644 index 0000000000..5da280f3d3 --- /dev/null +++ b/command/acl/role/list/role_list_test.go @@ -0,0 +1,83 @@ +package rolelist + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" +) + +func TestRoleListCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestRoleListCommand(t *testing.T) { + t.Parallel() + require := require.New(t) + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + ui := cli.NewMockUi() + cmd := New(ui) + + var roleIDs []string + + // Create a couple roles to list + client := a.Client() + svcids := []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ServiceName: "fake"}, + } + for i := 0; i < 5; i++ { + name := fmt.Sprintf("test-role-%d", i) + + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{Name: name, ServiceIdentities: svcids}, + &api.WriteOptions{Token: "root"}, + ) + roleIDs = append(roleIDs, role.ID) + + require.NoError(err) + } + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(code, 0) + require.Empty(ui.ErrorWriter.String()) + output := ui.OutputWriter.String() + + for i, v := range roleIDs { + require.Contains(output, fmt.Sprintf("test-role-%d", i)) + require.Contains(output, v) + } +} diff --git a/command/acl/role/read/role_read.go b/command/acl/role/read/role_read.go new file mode 100644 index 0000000000..fb51b8d099 --- /dev/null +++ b/command/acl/role/read/role_read.go @@ -0,0 +1,115 @@ +package roleread + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + roleID string + roleName string + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.flags.BoolVar(&c.showMeta, "meta", false, "Indicates that role metadata such "+ + "as the content hash and raft indices should be shown for each entry") + c.flags.StringVar(&c.roleID, "id", "", "The ID of the role to read. "+ + "It may be specified as a unique ID prefix but will error if the prefix "+ + "matches multiple policy IDs") + c.flags.StringVar(&c.roleName, "name", "", "The name of the role to read.") + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.roleID == "" && c.roleName == "" { + c.UI.Error(fmt.Sprintf("Must specify either the -id or -name parameters")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + var role *api.ACLRole + + if c.roleID != "" { + roleID, err := acl.GetRoleIDFromPartial(client, c.roleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error determining role ID: %v", err)) + return 1 + } + role, _, err = client.ACL().RoleRead(roleID, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading role %q: %v", roleID, err)) + return 1 + } else if role == nil { + c.UI.Error(fmt.Sprintf("Role not found with ID %q", roleID)) + return 1 + } + + } else { + role, _, err = client.ACL().RoleReadByName(c.roleName, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading role %q: %v", c.roleName, err)) + return 1 + } else if role == nil { + c.UI.Error(fmt.Sprintf("Role not found with name %q", c.roleName)) + return 1 + } + } + + acl.PrintRole(role, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Read an ACL Role" +const help = ` +Usage: consul acl role read [options] ROLE + + This command will retrieve and print out the details + of a single role. + + Read: + + $ consul acl role read -id fdabbcb5-9de5-4b1a-961f-77214ae88cba + + Read by name: + + $ consul acl role read -name my-policy + +` diff --git a/command/acl/role/read/role_read_test.go b/command/acl/role/read/role_read_test.go new file mode 100644 index 0000000000..f0f7b45dd7 --- /dev/null +++ b/command/acl/role/read/role_read_test.go @@ -0,0 +1,194 @@ +package roleread + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" +) + +func TestRoleReadCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestRoleReadCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("id or name required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify either the -id or -name parameters") + }) + + t.Run("read by id not found", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + fakeID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Role not found with ID") + }) + + t.Run("read by name not found", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=blah", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Role not found with name") + }) + + t.Run("read by id", func(t *testing.T) { + // create a role + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{ + Name: "test-role-by-id", + ServiceIdentities: []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ + ServiceName: "fake", + }, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + role.ID, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("test-role")) + require.Contains(t, output, role.ID) + }) + + t.Run("read by id prefix", func(t *testing.T) { + // create a role + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{ + Name: "test-role-by-id-prefix", + ServiceIdentities: []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ + ServiceName: "fake", + }, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + role.ID[0:5], + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("test-role")) + require.Contains(t, output, role.ID) + }) + + t.Run("read by name", func(t *testing.T) { + // create a role + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{ + Name: "test-role-by-name", + ServiceIdentities: []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ + ServiceName: "fake", + }, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + role.Name, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("test-role")) + require.Contains(t, output, role.ID) + }) +} diff --git a/command/acl/role/role.go b/command/acl/role/role.go new file mode 100644 index 0000000000..87bf01b3d7 --- /dev/null +++ b/command/acl/role/role.go @@ -0,0 +1,56 @@ +package role + +import ( + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New() *cmd { + return &cmd{} +} + +type cmd struct{} + +func (c *cmd) Run(args []string) int { + return cli.RunResultHelp +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(help, nil) +} + +const synopsis = "Manage Consul's ACL Roles" +const help = ` +Usage: consul acl role [options] [args] + + This command has subcommands for managing Consul's ACL Roles. + Here are some simple examples, and more detailed examples are available + in the subcommands or the documentation. + + Create a new ACL Role: + + $ consul acl role create -name "new-role" \ + -description "This is an example role" \ + -policy-id 06acc965 + List all roles: + + $ consul acl role list + + Update a role: + + $ consul acl role update -name "other-role" -datacenter "dc1" + + Read a role: + + $ consul acl role read -id 0479e93e-091c-4475-9b06-79a004765c24 + + Delete a role + + $ consul acl role delete -name "my-role" + + For more examples, ask for subcommand help or view the documentation. +` diff --git a/command/acl/role/update/role_update.go b/command/acl/role/update/role_update.go new file mode 100644 index 0000000000..6327755ccb --- /dev/null +++ b/command/acl/role/update/role_update.go @@ -0,0 +1,225 @@ +package roleupdate + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + roleID string + name string + description string + policyIDs []string + policyNames []string + serviceIdents []string + + noMerge bool + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.flags.BoolVar(&c.showMeta, "meta", false, "Indicates that role metadata such "+ + "as the content hash and raft indices should be shown for each entry") + c.flags.StringVar(&c.roleID, "id", "", "The ID of the role to update. "+ + "It may be specified as a unique ID prefix but will error if the prefix "+ + "matches multiple role IDs") + c.flags.StringVar(&c.name, "name", "", "The role name.") + c.flags.StringVar(&c.description, "description", "", "A description of the role") + c.flags.Var((*flags.AppendSliceValue)(&c.policyIDs), "policy-id", "ID of a "+ + "policy to use for this role. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.policyNames), "policy-name", "Name of a "+ + "policy to use for this role. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.serviceIdents), "service-identity", "Name of a "+ + "service identity to use for this role. May be specified multiple times. Format is "+ + "the SERVICENAME or SERVICENAME:DATACENTER1,DATACENTER2,...") + c.flags.BoolVar(&c.noMerge, "no-merge", false, "Do not merge the current role "+ + "information with what is provided to the command. Instead overwrite all fields "+ + "with the exception of the role ID which is immutable.") + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.roleID == "" { + c.UI.Error(fmt.Sprintf("Cannot update a role without specifying the -id parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + roleID, err := acl.GetRoleIDFromPartial(client, c.roleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error determining role ID: %v", err)) + return 1 + } + + parsedServiceIdents, err := acl.ExtractServiceIdentities(c.serviceIdents) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + + // Read the current role in both cases so we can fail better if not found. + currentRole, _, err := client.ACL().RoleRead(roleID, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error when retrieving current role: %v", err)) + return 1 + } else if currentRole == nil { + c.UI.Error(fmt.Sprintf("Role not found with ID %q", roleID)) + return 1 + } + + var role *api.ACLRole + if c.noMerge { + role = &api.ACLRole{ + ID: c.roleID, + Name: c.name, + Description: c.description, + ServiceIdentities: parsedServiceIdents, + } + + for _, policyName := range c.policyNames { + // We could resolve names to IDs here but there isn't any reason + // why its would be better than allowing the agent to do it. + role.Policies = append(role.Policies, &api.ACLRolePolicyLink{Name: policyName}) + } + + for _, policyID := range c.policyIDs { + policyID, err := acl.GetPolicyIDFromPartial(client, policyID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error resolving policy ID %s: %v", policyID, err)) + return 1 + } + role.Policies = append(role.Policies, &api.ACLRolePolicyLink{ID: policyID}) + } + } else { + role = currentRole + + if c.name != "" { + role.Name = c.name + } + if c.description != "" { + role.Description = c.description + } + + for _, policyName := range c.policyNames { + found := false + for _, link := range role.Policies { + if link.Name == policyName { + found = true + break + } + } + + if !found { + // We could resolve names to IDs here but there isn't any + // reason why its would be better than allowing the agent to do + // it. + role.Policies = append(role.Policies, &api.ACLRolePolicyLink{Name: policyName}) + } + } + + for _, policyID := range c.policyIDs { + policyID, err := acl.GetPolicyIDFromPartial(client, policyID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error resolving policy ID %s: %v", policyID, err)) + return 1 + } + found := false + + for _, link := range role.Policies { + if link.ID == policyID { + found = true + break + } + } + + if !found { + role.Policies = append(role.Policies, &api.ACLRolePolicyLink{ID: policyID}) + } + } + + for _, svcid := range parsedServiceIdents { + found := -1 + for i, link := range role.ServiceIdentities { + if link.ServiceName == svcid.ServiceName { + found = i + break + } + } + + if found != -1 { + role.ServiceIdentities[found] = svcid + } else { + role.ServiceIdentities = append(role.ServiceIdentities, svcid) + } + } + } + + role, _, err = client.ACL().RoleUpdate(role, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error updating role %q: %v", roleID, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Role updated successfully")) + acl.PrintRole(role, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Update an ACL Role" +const help = ` +Usage: consul acl role update [options] + + Updates a role. By default it will merge the role information with its + current state so that you do not have to provide all parameters. This + behavior can be disabled by passing -no-merge. + + Rename the Role: + + $ consul acl role update -id abcd -name "better-name" + + Update all editable fields of the role: + + $ consul acl role update -id abcd \ + -name "better-name" \ + -description "replication" \ + -policy-name "token-replication" \ + -service-identity "web" +` diff --git a/command/acl/role/update/role_update_test.go b/command/acl/role/update/role_update_test.go new file mode 100644 index 0000000000..c9094e7286 --- /dev/null +++ b/command/acl/role/update/role_update_test.go @@ -0,0 +1,398 @@ +package roleupdate + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + uuid "github.com/hashicorp/go-uuid" +) + +func TestRoleUpdateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestRoleUpdateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // Create 2 policies + policy1, _, err := client.ACL().PolicyCreate( + &api.ACLPolicy{Name: "test-policy1"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + policy2, _, err := client.ACL().PolicyCreate( + &api.ACLPolicy{Name: "test-policy2"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + // create a role + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{ + Name: "test-role", + ServiceIdentities: []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ + ServiceName: "fake", + }, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + t.Run("update a role that does not exist", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID, + "-token=root", + "-policy-name=" + policy1.Name, + "-description=test role edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Role not found with ID") + }) + + t.Run("update with policy by name", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-token=root", + "-policy-name=" + policy1.Name, + "-description=test role edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "test role edited", role.Description) + require.Len(t, role.Policies, 1) + require.Len(t, role.ServiceIdentities, 1) + }) + + t.Run("update with policy by id", func(t *testing.T) { + // also update with no description shouldn't delete the current + // description + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-token=root", + "-policy-id=" + policy2.ID, + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "test role edited", role.Description) + require.Len(t, role.Policies, 2) + require.Len(t, role.ServiceIdentities, 1) + }) + + t.Run("update with service identity", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-token=root", + "-service-identity=web", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "test role edited", role.Description) + require.Len(t, role.Policies, 2) + require.Len(t, role.ServiceIdentities, 2) + }) + + t.Run("update with service identity scoped to 2 DCs", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-token=root", + "-service-identity=db:abc,xyz", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "test role edited", role.Description) + require.Len(t, role.Policies, 2) + require.Len(t, role.ServiceIdentities, 3) + }) +} + +func TestRoleUpdateCommand_noMerge(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // Create 3 policies + policy1, _, err := client.ACL().PolicyCreate( + &api.ACLPolicy{Name: "test-policy1"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + policy2, _, err := client.ACL().PolicyCreate( + &api.ACLPolicy{Name: "test-policy2"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + policy3, _, err := client.ACL().PolicyCreate( + &api.ACLPolicy{Name: "test-policy3"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + // create a role + createRole := func(t *testing.T) *api.ACLRole { + roleUnq, err := uuid.GenerateUUID() + require.NoError(t, err) + + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{ + Name: "test-role-" + roleUnq, + Description: "original description", + ServiceIdentities: []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ + ServiceName: "fake", + }, + }, + Policies: []*api.ACLRolePolicyLink{ + &api.ACLRolePolicyLink{ + ID: policy3.ID, + }, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return role + } + + t.Run("update a role that does not exist", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID, + "-token=root", + "-policy-name=" + policy1.Name, + "-no-merge", + "-description=test role edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Role not found with ID") + }) + + t.Run("update with policy by name", func(t *testing.T) { + role := createRole(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-name=" + role.Name, + "-token=root", + "-no-merge", + "-policy-name=" + policy1.Name, + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "", role.Description) + require.Len(t, role.Policies, 1) + require.Len(t, role.ServiceIdentities, 0) + }) + + t.Run("update with policy by id", func(t *testing.T) { + role := createRole(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-name=" + role.Name, + "-token=root", + "-no-merge", + "-policy-id=" + policy2.ID, + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "", role.Description) + require.Len(t, role.Policies, 1) + require.Len(t, role.ServiceIdentities, 0) + }) + + t.Run("update with service identity", func(t *testing.T) { + role := createRole(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-name=" + role.Name, + "-token=root", + "-no-merge", + "-service-identity=web", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "", role.Description) + require.Len(t, role.Policies, 0) + require.Len(t, role.ServiceIdentities, 1) + }) + + t.Run("update with service identity scoped to 2 DCs", func(t *testing.T) { + role := createRole(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-name=" + role.Name, + "-token=root", + "-no-merge", + "-service-identity=db:abc,xyz", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "", role.Description) + require.Len(t, role.Policies, 0) + require.Len(t, role.ServiceIdentities, 1) + }) +} diff --git a/command/acl/token/create/token_create.go b/command/acl/token/create/token_create.go index 0760c18f07..5d55563919 100644 --- a/command/acl/token/create/token_create.go +++ b/command/acl/token/create/token_create.go @@ -25,6 +25,8 @@ type cmd struct { policyIDs []string policyNames []string + roleIDs []string + roleNames []string serviceIdents []string expirationTTL time.Duration description string @@ -42,6 +44,10 @@ func (c *cmd) init() { "policy to use for this token. May be specified multiple times") c.flags.Var((*flags.AppendSliceValue)(&c.policyNames), "policy-name", "Name of a "+ "policy to use for this token. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.roleIDs), "role-id", "ID of a "+ + "role to use for this token. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.roleNames), "role-name", "Name of a "+ + "role to use for this token. May be specified multiple times") c.flags.Var((*flags.AppendSliceValue)(&c.serviceIdents), "service-identity", "Name of a "+ "service identity to use for this token. May be specified multiple times. Format is "+ "the SERVICENAME or SERVICENAME:DATACENTER1,DATACENTER2,...") @@ -59,8 +65,9 @@ func (c *cmd) Run(args []string) int { } if len(c.policyNames) == 0 && len(c.policyIDs) == 0 && + len(c.roleNames) == 0 && len(c.roleIDs) == 0 && len(c.serviceIdents) == 0 { - c.UI.Error(fmt.Sprintf("Cannot create a token without specifying -policy-name, -policy-id, or -service-identity at least once")) + c.UI.Error(fmt.Sprintf("Cannot create a token without specifying -policy-name, -policy-id, -role-name, -role-id, or -service-identity at least once")) return 1 } @@ -100,6 +107,21 @@ func (c *cmd) Run(args []string) int { newToken.Policies = append(newToken.Policies, &api.ACLTokenPolicyLink{ID: policyID}) } + for _, roleName := range c.roleNames { + // We could resolve names to IDs here but there isn't any reason why its would be better + // than allowing the agent to do it. + newToken.Roles = append(newToken.Roles, &api.ACLTokenRoleLink{Name: roleName}) + } + + for _, roleID := range c.roleIDs { + roleID, err := acl.GetRoleIDFromPartial(client, roleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error resolving role ID %s: %v", roleID, err)) + return 1 + } + newToken.Roles = append(newToken.Roles, &api.ACLTokenRoleLink{ID: roleID}) + } + token, _, err := client.ACL().TokenCreate(newToken, nil) if err != nil { c.UI.Error(fmt.Sprintf("Failed to create new token: %v", err)) @@ -130,8 +152,9 @@ Usage: consul acl token create [options] $ consul acl token create -description "Replication token" \ -policy-id b52fc3de-5 \ - -policy-name "acl-replication" -policy-name "acl-replication" \ + -role-id c630d4ef-6 \ + -role-name "db-updater" \ -service-identity "web" \ -service-identity "db:east,west" ` diff --git a/command/acl/token/update/token_update.go b/command/acl/token/update/token_update.go index c2ad084688..72663f8193 100644 --- a/command/acl/token/update/token_update.go +++ b/command/acl/token/update/token_update.go @@ -25,9 +25,12 @@ type cmd struct { tokenID string policyIDs []string policyNames []string + roleIDs []string + roleNames []string serviceIdents []string description string mergePolicies bool + mergeRoles bool mergeServiceIdents bool showMeta bool upgradeLegacy bool @@ -39,6 +42,8 @@ func (c *cmd) init() { "as the content hash and raft indices should be shown for each entry") c.flags.BoolVar(&c.mergePolicies, "merge-policies", false, "Merge the new policies "+ "with the existing policies") + c.flags.BoolVar(&c.mergeRoles, "merge-roles", false, "Merge the new roles "+ + "with the existing roles") c.flags.BoolVar(&c.mergeServiceIdents, "merge-service-identities", false, "Merge the new service identities "+ "with the existing service identities") c.flags.StringVar(&c.tokenID, "id", "", "The Accessor ID of the token to read. "+ @@ -49,6 +54,10 @@ func (c *cmd) init() { "policy to use for this token. May be specified multiple times") c.flags.Var((*flags.AppendSliceValue)(&c.policyNames), "policy-name", "Name of a "+ "policy to use for this token. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.roleIDs), "role-id", "ID of a "+ + "role to use for this token. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.roleNames), "role-name", "Name of a "+ + "role to use for this token. May be specified multiple times") c.flags.Var((*flags.AppendSliceValue)(&c.serviceIdents), "service-identity", "Name of a "+ "service identity to use for this token. May be specified multiple times. Format is "+ "the SERVICENAME or SERVICENAME:DATACENTER1,DATACENTER2,...") @@ -175,6 +184,61 @@ func (c *cmd) Run(args []string) int { } } + if c.mergeRoles { + for _, roleName := range c.roleNames { + found := false + for _, link := range token.Roles { + if link.Name == roleName { + found = true + break + } + } + + if !found { + // We could resolve names to IDs here but there isn't any reason why its would be better + // than allowing the agent to do it. + token.Roles = append(token.Roles, &api.ACLTokenRoleLink{Name: roleName}) + } + } + + for _, roleID := range c.roleIDs { + roleID, err := acl.GetRoleIDFromPartial(client, roleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error resolving role ID %s: %v", roleID, err)) + return 1 + } + found := false + + for _, link := range token.Roles { + if link.ID == roleID { + found = true + break + } + } + + if !found { + token.Roles = append(token.Roles, &api.ACLTokenRoleLink{Name: roleID}) + } + } + } else { + token.Roles = nil + + for _, roleName := range c.roleNames { + // We could resolve names to IDs here but there isn't any reason why its would be better + // than allowing the agent to do it. + token.Roles = append(token.Roles, &api.ACLTokenRoleLink{Name: roleName}) + } + + for _, roleID := range c.roleIDs { + roleID, err := acl.GetRoleIDFromPartial(client, roleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error resolving role ID %s: %v", roleID, err)) + return 1 + } + token.Roles = append(token.Roles, &api.ACLTokenRoleLink{ID: roleID}) + } + } + if c.mergeServiceIdents { for _, svcid := range parsedServiceIdents { found := -1 @@ -229,5 +293,6 @@ Usage: consul acl token update [options] $ consul acl token update -id abcd \ -description "replication" \ - -policy-name "token-replication" + -policy-name "token-replication" \ + -role-name "db-updater" ` diff --git a/command/commands_oss.go b/command/commands_oss.go index 91fee14397..f92fcf8bba 100644 --- a/command/commands_oss.go +++ b/command/commands_oss.go @@ -10,6 +10,12 @@ import ( aclplist "github.com/hashicorp/consul/command/acl/policy/list" aclpread "github.com/hashicorp/consul/command/acl/policy/read" aclpupdate "github.com/hashicorp/consul/command/acl/policy/update" + aclrole "github.com/hashicorp/consul/command/acl/role" + aclrcreate "github.com/hashicorp/consul/command/acl/role/create" + aclrdelete "github.com/hashicorp/consul/command/acl/role/delete" + aclrlist "github.com/hashicorp/consul/command/acl/role/list" + aclrread "github.com/hashicorp/consul/command/acl/role/read" + aclrupdate "github.com/hashicorp/consul/command/acl/role/update" aclrules "github.com/hashicorp/consul/command/acl/rules" acltoken "github.com/hashicorp/consul/command/acl/token" acltclone "github.com/hashicorp/consul/command/acl/token/clone" @@ -106,6 +112,12 @@ func init() { Register("acl token read", func(ui cli.Ui) (cli.Command, error) { return acltread.New(ui), nil }) Register("acl token update", func(ui cli.Ui) (cli.Command, error) { return acltupdate.New(ui), nil }) Register("acl token delete", func(ui cli.Ui) (cli.Command, error) { return acltdelete.New(ui), nil }) + Register("acl role", func(cli.Ui) (cli.Command, error) { return aclrole.New(), nil }) + Register("acl role create", func(ui cli.Ui) (cli.Command, error) { return aclrcreate.New(ui), nil }) + Register("acl role list", func(ui cli.Ui) (cli.Command, error) { return aclrlist.New(ui), nil }) + Register("acl role read", func(ui cli.Ui) (cli.Command, error) { return aclrread.New(ui), nil }) + Register("acl role update", func(ui cli.Ui) (cli.Command, error) { return aclrupdate.New(ui), nil }) + Register("acl role delete", func(ui cli.Ui) (cli.Command, error) { return aclrdelete.New(ui), nil }) Register("agent", func(ui cli.Ui) (cli.Command, error) { return agent.New(ui, rev, ver, verPre, verHuman, make(chan struct{})), nil }) From e47d7eeddb93bac09b7e5e94231ef5fa17cc112b Mon Sep 17 00:00:00 2001 From: "R.B. Boyer" Date: Fri, 26 Apr 2019 12:49:28 -0500 Subject: [PATCH 5/5] acl: adding support for kubernetes auth provider login (#5600) * auth providers * binding rules * auth provider for kubernetes * login/logout --- agent/acl_endpoint.go | 327 +++ agent/acl_endpoint_test.go | 733 ++++- agent/consul/acl.go | 69 +- agent/consul/acl_authmethod.go | 169 ++ agent/consul/acl_authmethod_test.go | 48 + agent/consul/acl_endpoint.go | 668 ++++- agent/consul/acl_endpoint_legacy.go | 2 +- agent/consul/acl_endpoint_test.go | 2499 ++++++++++++++++- agent/consul/acl_replication_legacy.go | 2 +- agent/consul/acl_replication_legacy_test.go | 4 +- agent/consul/acl_replication_test.go | 10 +- agent/consul/acl_replication_types.go | 2 +- agent/consul/acl_server.go | 11 + agent/consul/acl_test.go | 197 +- agent/consul/acl_token_exp_test.go | 2 +- agent/consul/authmethod/authmethods.go | 112 + agent/consul/authmethod/kubeauth/k8s.go | 202 ++ agent/consul/authmethod/kubeauth/k8s_test.go | 144 + agent/consul/authmethod/kubeauth/testing.go | 532 ++++ agent/consul/authmethod/testauth/testing.go | 166 ++ agent/consul/fsm/commands_oss.go | 48 + agent/consul/fsm/snapshot_oss.go | 46 + agent/consul/fsm/snapshot_oss_test.go | 66 +- agent/consul/leader.go | 4 + agent/consul/server.go | 3 + agent/consul/state/acl.go | 593 +++- agent/consul/state/acl_test.go | 1080 ++++++- agent/consul/state/state_store.go | 22 +- agent/consul/util.go | 42 + agent/consul/util_test.go | 131 + agent/http_oss.go | 8 + agent/structs/acl.go | 316 ++- agent/structs/acl_cache.go | 11 +- agent/structs/acl_cache_test.go | 1 + agent/structs/acl_test.go | 3 +- agent/structs/structs.go | 54 +- api/acl.go | 364 ++- api/api.go | 28 + api/api_test.go | 16 +- command/acl/acl_helpers.go | 211 +- command/acl/authmethod/authmethod.go | 64 + .../authmethod/create/authmethod_create.go | 186 ++ .../create/authmethod_create_test.go | 226 ++ .../authmethod/delete/authmethod_delete.go | 82 + .../delete/authmethod_delete_test.go | 131 + .../acl/authmethod/list/authmethod_list.go | 83 + .../authmethod/list/authmethod_list_test.go | 109 + .../acl/authmethod/read/authmethod_read.go | 96 + .../authmethod/read/authmethod_read_test.go | 118 + .../authmethod/update/authmethod_update.go | 220 ++ .../update/authmethod_update_test.go | 647 +++++ command/acl/bindingrule/bindingrule.go | 60 + .../bindingrule/create/bindingrule_create.go | 148 + .../create/bindingrule_create_test.go | 178 ++ .../bindingrule/delete/bindingrule_delete.go | 97 + .../delete/bindingrule_delete_test.go | 187 ++ .../acl/bindingrule/list/bindingrule_list.go | 98 + .../bindingrule/list/bindingrule_list_test.go | 167 ++ .../acl/bindingrule/read/bindingrule_read.go | 108 + .../bindingrule/read/bindingrule_read_test.go | 152 + .../bindingrule/update/bindingrule_update.go | 212 ++ .../update/bindingrule_update_test.go | 768 +++++ command/acl/role/delete/role_delete.go | 15 +- command/acl/role/delete/role_delete_test.go | 4 +- command/commands_oss.go | 28 + command/connect/envoy/envoy.go | 2 +- command/connect/proxy/proxy.go | 2 +- command/flags/http.go | 30 + command/login/login.go | 148 + command/login/login_test.go | 321 +++ command/logout/logout.go | 70 + command/logout/logout_test.go | 299 ++ command/watch/watch.go | 12 +- go.mod | 6 +- go.sum | 10 +- vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go | 77 + .../square/go-jose.v2/.gitcookies.sh.enc | 1 + vendor/gopkg.in/square/go-jose.v2/.gitignore | 7 + vendor/gopkg.in/square/go-jose.v2/.travis.yml | 46 + .../gopkg.in/square/go-jose.v2/BUG-BOUNTY.md | 10 + .../square/go-jose.v2/CONTRIBUTING.md | 14 + vendor/gopkg.in/square/go-jose.v2/LICENSE | 202 ++ vendor/gopkg.in/square/go-jose.v2/README.md | 118 + .../gopkg.in/square/go-jose.v2/asymmetric.go | 592 ++++ .../square/go-jose.v2/cipher/cbc_hmac.go | 196 ++ .../square/go-jose.v2/cipher/concat_kdf.go | 75 + .../square/go-jose.v2/cipher/ecdh_es.go | 62 + .../square/go-jose.v2/cipher/key_wrap.go | 109 + vendor/gopkg.in/square/go-jose.v2/crypter.go | 535 ++++ vendor/gopkg.in/square/go-jose.v2/doc.go | 27 + vendor/gopkg.in/square/go-jose.v2/encoding.go | 179 ++ .../gopkg.in/square/go-jose.v2/json/LICENSE | 27 + .../gopkg.in/square/go-jose.v2/json/README.md | 13 + .../gopkg.in/square/go-jose.v2/json/decode.go | 1183 ++++++++ .../gopkg.in/square/go-jose.v2/json/encode.go | 1197 ++++++++ .../gopkg.in/square/go-jose.v2/json/indent.go | 141 + .../square/go-jose.v2/json/scanner.go | 623 ++++ .../gopkg.in/square/go-jose.v2/json/stream.go | 480 ++++ .../gopkg.in/square/go-jose.v2/json/tags.go | 44 + vendor/gopkg.in/square/go-jose.v2/jwe.go | 294 ++ vendor/gopkg.in/square/go-jose.v2/jwk.go | 608 ++++ vendor/gopkg.in/square/go-jose.v2/jws.go | 321 +++ .../gopkg.in/square/go-jose.v2/jwt/builder.go | 334 +++ .../gopkg.in/square/go-jose.v2/jwt/claims.go | 120 + vendor/gopkg.in/square/go-jose.v2/jwt/doc.go | 22 + .../gopkg.in/square/go-jose.v2/jwt/errors.go | 53 + vendor/gopkg.in/square/go-jose.v2/jwt/jwt.go | 163 ++ .../square/go-jose.v2/jwt/validation.go | 114 + vendor/gopkg.in/square/go-jose.v2/opaque.go | 83 + vendor/gopkg.in/square/go-jose.v2/shared.go | 499 ++++ vendor/gopkg.in/square/go-jose.v2/signing.go | 389 +++ .../gopkg.in/square/go-jose.v2/symmetric.go | 482 ++++ .../apimachinery/pkg/api/errors/errors.go | 24 + .../apimachinery/pkg/apis/meta/v1/types.go | 4 + .../k8s.io/apimachinery/pkg/runtime/codec.go | 20 + .../runtime/serializer/streaming/streaming.go | 2 +- .../apimachinery/pkg/util/runtime/runtime.go | 6 +- vendor/modules.txt | 54 +- 118 files changed, 23163 insertions(+), 417 deletions(-) create mode 100644 agent/consul/acl_authmethod.go create mode 100644 agent/consul/acl_authmethod_test.go create mode 100644 agent/consul/authmethod/authmethods.go create mode 100644 agent/consul/authmethod/kubeauth/k8s.go create mode 100644 agent/consul/authmethod/kubeauth/k8s_test.go create mode 100644 agent/consul/authmethod/kubeauth/testing.go create mode 100644 agent/consul/authmethod/testauth/testing.go create mode 100644 command/acl/authmethod/authmethod.go create mode 100644 command/acl/authmethod/create/authmethod_create.go create mode 100644 command/acl/authmethod/create/authmethod_create_test.go create mode 100644 command/acl/authmethod/delete/authmethod_delete.go create mode 100644 command/acl/authmethod/delete/authmethod_delete_test.go create mode 100644 command/acl/authmethod/list/authmethod_list.go create mode 100644 command/acl/authmethod/list/authmethod_list_test.go create mode 100644 command/acl/authmethod/read/authmethod_read.go create mode 100644 command/acl/authmethod/read/authmethod_read_test.go create mode 100644 command/acl/authmethod/update/authmethod_update.go create mode 100644 command/acl/authmethod/update/authmethod_update_test.go create mode 100644 command/acl/bindingrule/bindingrule.go create mode 100644 command/acl/bindingrule/create/bindingrule_create.go create mode 100644 command/acl/bindingrule/create/bindingrule_create_test.go create mode 100644 command/acl/bindingrule/delete/bindingrule_delete.go create mode 100644 command/acl/bindingrule/delete/bindingrule_delete_test.go create mode 100644 command/acl/bindingrule/list/bindingrule_list.go create mode 100644 command/acl/bindingrule/list/bindingrule_list_test.go create mode 100644 command/acl/bindingrule/read/bindingrule_read.go create mode 100644 command/acl/bindingrule/read/bindingrule_read_test.go create mode 100644 command/acl/bindingrule/update/bindingrule_update.go create mode 100644 command/acl/bindingrule/update/bindingrule_update_test.go create mode 100644 command/login/login.go create mode 100644 command/login/login_test.go create mode 100644 command/logout/logout.go create mode 100644 command/logout/logout_test.go create mode 100644 vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/.gitcookies.sh.enc create mode 100644 vendor/gopkg.in/square/go-jose.v2/.gitignore create mode 100644 vendor/gopkg.in/square/go-jose.v2/.travis.yml create mode 100644 vendor/gopkg.in/square/go-jose.v2/BUG-BOUNTY.md create mode 100644 vendor/gopkg.in/square/go-jose.v2/CONTRIBUTING.md create mode 100644 vendor/gopkg.in/square/go-jose.v2/LICENSE create mode 100644 vendor/gopkg.in/square/go-jose.v2/README.md create mode 100644 vendor/gopkg.in/square/go-jose.v2/asymmetric.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/cipher/cbc_hmac.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/cipher/concat_kdf.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/cipher/ecdh_es.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/cipher/key_wrap.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/crypter.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/doc.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/encoding.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/LICENSE create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/README.md create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/decode.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/encode.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/indent.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/scanner.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/stream.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/tags.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwe.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwk.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jws.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwt/builder.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwt/claims.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwt/doc.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwt/errors.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwt/jwt.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwt/validation.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/opaque.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/shared.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/signing.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/symmetric.go diff --git a/agent/acl_endpoint.go b/agent/acl_endpoint.go index cafe6e11c3..12c6b313a2 100644 --- a/agent/acl_endpoint.go +++ b/agent/acl_endpoint.go @@ -376,6 +376,7 @@ func (s *HTTPServer) ACLTokenList(resp http.ResponseWriter, req *http.Request) ( args.Policy = req.URL.Query().Get("policy") args.Role = req.URL.Query().Get("role") + args.AuthMethod = req.URL.Query().Get("authmethod") var out structs.ACLTokenListResponse defer setMeta(resp, &out.QueryMeta) @@ -701,3 +702,329 @@ func (s *HTTPServer) ACLRoleDelete(resp http.ResponseWriter, req *http.Request, return true, nil } + +func (s *HTTPServer) ACLBindingRuleList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var args structs.ACLBindingRuleListRequest + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + args.AuthMethod = req.URL.Query().Get("authmethod") + + var out structs.ACLBindingRuleListResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.BindingRuleList", &args, &out); err != nil { + return nil, err + } + + // make sure we return an array and not nil + if out.BindingRules == nil { + out.BindingRules = make(structs.ACLBindingRules, 0) + } + + return out.BindingRules, nil +} + +func (s *HTTPServer) ACLBindingRuleCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var fn func(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) + + switch req.Method { + case "GET": + fn = s.ACLBindingRuleRead + + case "PUT": + fn = s.ACLBindingRuleWrite + + case "DELETE": + fn = s.ACLBindingRuleDelete + + default: + return nil, MethodNotAllowedError{req.Method, []string{"GET", "PUT", "DELETE"}} + } + + bindingRuleID := strings.TrimPrefix(req.URL.Path, "/v1/acl/binding-rule/") + if bindingRuleID == "" && req.Method != "PUT" { + return nil, BadRequestError{Reason: "Missing binding rule ID"} + } + + return fn(resp, req, bindingRuleID) +} + +func (s *HTTPServer) ACLBindingRuleRead(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) { + args := structs.ACLBindingRuleGetRequest{ + Datacenter: s.agent.config.Datacenter, + BindingRuleID: bindingRuleID, + } + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + var out structs.ACLBindingRuleResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.BindingRuleRead", &args, &out); err != nil { + return nil, err + } + + if out.BindingRule == nil { + resp.WriteHeader(http.StatusNotFound) + return nil, nil + } + + return out.BindingRule, nil +} + +func (s *HTTPServer) ACLBindingRuleCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + return s.ACLBindingRuleWrite(resp, req, "") +} + +func (s *HTTPServer) ACLBindingRuleWrite(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) { + args := structs.ACLBindingRuleSetRequest{ + Datacenter: s.agent.config.Datacenter, + } + s.parseToken(req, &args.Token) + + if err := decodeBody(req, &args.BindingRule, fixTimeAndHashFields); err != nil { + return nil, BadRequestError{Reason: fmt.Sprintf("BindingRule decoding failed: %v", err)} + } + + if args.BindingRule.ID != "" && args.BindingRule.ID != bindingRuleID { + return nil, BadRequestError{Reason: "BindingRule ID in URL and payload do not match"} + } else if args.BindingRule.ID == "" { + args.BindingRule.ID = bindingRuleID + } + + var out structs.ACLBindingRule + if err := s.agent.RPC("ACL.BindingRuleSet", args, &out); err != nil { + return nil, err + } + + return &out, nil +} + +func (s *HTTPServer) ACLBindingRuleDelete(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) { + args := structs.ACLBindingRuleDeleteRequest{ + Datacenter: s.agent.config.Datacenter, + BindingRuleID: bindingRuleID, + } + s.parseToken(req, &args.Token) + + var ignored bool + if err := s.agent.RPC("ACL.BindingRuleDelete", args, &ignored); err != nil { + return nil, err + } + + return true, nil +} + +func (s *HTTPServer) ACLAuthMethodList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var args structs.ACLAuthMethodListRequest + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + var out structs.ACLAuthMethodListResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.AuthMethodList", &args, &out); err != nil { + return nil, err + } + + // make sure we return an array and not nil + if out.AuthMethods == nil { + out.AuthMethods = make(structs.ACLAuthMethodListStubs, 0) + } + + return out.AuthMethods, nil +} + +func (s *HTTPServer) ACLAuthMethodCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var fn func(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) + + switch req.Method { + case "GET": + fn = s.ACLAuthMethodRead + + case "PUT": + fn = s.ACLAuthMethodWrite + + case "DELETE": + fn = s.ACLAuthMethodDelete + + default: + return nil, MethodNotAllowedError{req.Method, []string{"GET", "PUT", "DELETE"}} + } + + methodName := strings.TrimPrefix(req.URL.Path, "/v1/acl/auth-method/") + if methodName == "" && req.Method != "PUT" { + return nil, BadRequestError{Reason: "Missing auth method name"} + } + + return fn(resp, req, methodName) +} + +func (s *HTTPServer) ACLAuthMethodRead(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) { + args := structs.ACLAuthMethodGetRequest{ + Datacenter: s.agent.config.Datacenter, + AuthMethodName: methodName, + } + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + var out structs.ACLAuthMethodResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.AuthMethodRead", &args, &out); err != nil { + return nil, err + } + + if out.AuthMethod == nil { + resp.WriteHeader(http.StatusNotFound) + return nil, nil + } + + fixupAuthMethodConfig(out.AuthMethod) + return out.AuthMethod, nil +} + +func (s *HTTPServer) ACLAuthMethodCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + return s.ACLAuthMethodWrite(resp, req, "") +} + +func (s *HTTPServer) ACLAuthMethodWrite(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) { + args := structs.ACLAuthMethodSetRequest{ + Datacenter: s.agent.config.Datacenter, + } + s.parseToken(req, &args.Token) + + if err := decodeBody(req, &args.AuthMethod, fixTimeAndHashFields); err != nil { + return nil, BadRequestError{Reason: fmt.Sprintf("AuthMethod decoding failed: %v", err)} + } + + if methodName != "" { + if args.AuthMethod.Name != "" && args.AuthMethod.Name != methodName { + return nil, BadRequestError{Reason: "AuthMethod Name in URL and payload do not match"} + } else if args.AuthMethod.Name == "" { + args.AuthMethod.Name = methodName + } + } + + var out structs.ACLAuthMethod + if err := s.agent.RPC("ACL.AuthMethodSet", args, &out); err != nil { + return nil, err + } + + fixupAuthMethodConfig(&out) + return &out, nil +} + +func (s *HTTPServer) ACLAuthMethodDelete(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) { + args := structs.ACLAuthMethodDeleteRequest{ + Datacenter: s.agent.config.Datacenter, + AuthMethodName: methodName, + } + s.parseToken(req, &args.Token) + + var ignored bool + if err := s.agent.RPC("ACL.AuthMethodDelete", args, &ignored); err != nil { + return nil, err + } + + return true, nil +} + +func (s *HTTPServer) ACLLogin(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + args := &structs.ACLLoginRequest{ + Datacenter: s.agent.config.Datacenter, + } + s.parseDC(req, &args.Datacenter) + + if err := decodeBody(req, &args.Auth, nil); err != nil { + return nil, BadRequestError{Reason: fmt.Sprintf("Failed to decode request body:: %v", err)} + } + + var out structs.ACLToken + if err := s.agent.RPC("ACL.Login", args, &out); err != nil { + return nil, err + } + + return &out, nil +} + +func (s *HTTPServer) ACLLogout(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + args := structs.ACLLogoutRequest{ + Datacenter: s.agent.config.Datacenter, + } + s.parseDC(req, &args.Datacenter) + s.parseToken(req, &args.Token) + + if args.Token == "" { + return nil, acl.ErrNotFound + } + + var ignored bool + if err := s.agent.RPC("ACL.Logout", &args, &ignored); err != nil { + return nil, err + } + + return true, nil +} + +// A hack to fix up the config types inside of the map[string]interface{} +// so that they get formatted correctly during json.Marshal. Without this, +// string values that get converted to []uint8 end up getting output back +// to the user in base64-encoded form. +func fixupAuthMethodConfig(method *structs.ACLAuthMethod) { + for k, v := range method.Config { + if raw, ok := v.([]uint8); ok { + strVal := structs.Uint8ToString(raw) + method.Config[k] = strVal + } + } +} diff --git a/agent/acl_endpoint_test.go b/agent/acl_endpoint_test.go index 754b931cc6..ab92c1457a 100644 --- a/agent/acl_endpoint_test.go +++ b/agent/acl_endpoint_test.go @@ -8,6 +8,8 @@ import ( "net/http/httptest" "testing" + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/authmethod/testauth" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/testrpc" "github.com/stretchr/testify/require" @@ -40,6 +42,17 @@ func TestACL_Disabled_Response(t *testing.T) { {"ACLTokenCreate", a.srv.ACLTokenCreate}, {"ACLTokenSelf", a.srv.ACLTokenSelf}, {"ACLTokenCRUD", a.srv.ACLTokenCRUD}, + {"ACLRoleList", a.srv.ACLRoleList}, + {"ACLRoleCreate", a.srv.ACLRoleCreate}, + {"ACLRoleCRUD", a.srv.ACLRoleCRUD}, + {"ACLBindingRuleList", a.srv.ACLBindingRuleList}, + {"ACLBindingRuleCreate", a.srv.ACLBindingRuleCreate}, + {"ACLBindingRuleCRUD", a.srv.ACLBindingRuleCRUD}, + {"ACLAuthMethodList", a.srv.ACLAuthMethodList}, + {"ACLAuthMethodCreate", a.srv.ACLAuthMethodCreate}, + {"ACLAuthMethodCRUD", a.srv.ACLAuthMethodCRUD}, + {"ACLLogin", a.srv.ACLLogin}, + {"ACLLogout", a.srv.ACLLogout}, } testrpc.WaitForLeader(t, a.RPC, "dc1") for _, tt := range tests { @@ -119,6 +132,7 @@ func TestACL_HTTP(t *testing.T) { idMap := make(map[string]string) policyMap := make(map[string]*structs.ACLPolicy) + roleMap := make(map[string]*structs.ACLRole) tokenMap := make(map[string]*structs.ACLToken) // This is all done as a subtest for a couple reasons @@ -220,7 +234,7 @@ func TestACL_HTTP(t *testing.T) { policyMap[policy.ID] = policy }) - t.Run("Update Name ID Mistmatch", func(t *testing.T) { + t.Run("Update Name ID Mismatch", func(t *testing.T) { policyInput := &structs.ACLPolicy{ ID: "ac7560be-7f11-4d6d-bfcf-15633c2090fd", Name: "read-all-nodes", @@ -355,6 +369,222 @@ func TestACL_HTTP(t *testing.T) { }) }) + t.Run("Role", func(t *testing.T) { + t.Run("Create", func(t *testing.T) { + roleInput := &structs.ACLRole{ + Name: "test", + Description: "test", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: idMap["policy-test"], + Name: policyMap[idMap["policy-test"]].Name, + }, + structs.ACLRolePolicyLink{ + ID: idMap["policy-read-all-nodes"], + Name: policyMap[idMap["policy-read-all-nodes"]].Name, + }, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", jsonBody(roleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLRoleCreate(resp, req) + require.NoError(t, err) + + role, ok := obj.(*structs.ACLRole) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, role.ID, 36) + require.Equal(t, roleInput.Name, role.Name) + require.Equal(t, roleInput.Description, role.Description) + require.Equal(t, roleInput.Policies, role.Policies) + require.True(t, role.CreateIndex > 0) + require.Equal(t, role.CreateIndex, role.ModifyIndex) + require.NotNil(t, role.Hash) + require.NotEqual(t, role.Hash, []byte{}) + + idMap["role-test"] = role.ID + roleMap[role.ID] = role + }) + + t.Run("Name Chars", func(t *testing.T) { + roleInput := &structs.ACLRole{ + Name: "service-id-web", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "web", + }, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", jsonBody(roleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLRoleCreate(resp, req) + require.NoError(t, err) + + role, ok := obj.(*structs.ACLRole) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, role.ID, 36) + require.Equal(t, roleInput.Name, role.Name) + require.Equal(t, roleInput.Description, role.Description) + require.Equal(t, roleInput.ServiceIdentities, role.ServiceIdentities) + require.True(t, role.CreateIndex > 0) + require.Equal(t, role.CreateIndex, role.ModifyIndex) + require.NotNil(t, role.Hash) + require.NotEqual(t, role.Hash, []byte{}) + + idMap["role-service-id-web"] = role.ID + roleMap[role.ID] = role + }) + + t.Run("Update Name ID Mismatch", func(t *testing.T) { + roleInput := &structs.ACLRole{ + ID: "ac7560be-7f11-4d6d-bfcf-15633c2090fd", + Name: "test", + Description: "test", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "db", + }, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/role/"+idMap["role-test"]+"?token=root", jsonBody(roleInput)) + resp := httptest.NewRecorder() + _, err := a.srv.ACLRoleCRUD(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Role CRUD Missing ID in URL", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/role/?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLRoleCRUD(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Update", func(t *testing.T) { + roleInput := &structs.ACLRole{ + Name: "test", + Description: "test", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "web-indexer", + }, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/role/"+idMap["role-test"]+"?token=root", jsonBody(roleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLRoleCRUD(resp, req) + require.NoError(t, err) + + role, ok := obj.(*structs.ACLRole) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, role.ID, 36) + require.Equal(t, roleInput.Name, role.Name) + require.Equal(t, roleInput.Description, role.Description) + require.Equal(t, roleInput.Policies, role.Policies) + require.Equal(t, roleInput.ServiceIdentities, role.ServiceIdentities) + require.True(t, role.CreateIndex > 0) + require.True(t, role.CreateIndex < role.ModifyIndex) + require.NotNil(t, role.Hash) + require.NotEqual(t, role.Hash, []byte{}) + + idMap["role-test"] = role.ID + roleMap[role.ID] = role + }) + + t.Run("ID Supplied", func(t *testing.T) { + roleInput := &structs.ACLRole{ + ID: "12123d01-37f1-47e6-b55b-32328652bd38", + Name: "with-id", + Description: "test", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "foobar", + }, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", jsonBody(roleInput)) + resp := httptest.NewRecorder() + _, err := a.srv.ACLRoleCreate(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Invalid payload", func(t *testing.T) { + body := bytes.NewBuffer(nil) + body.Write([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + + req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", body) + resp := httptest.NewRecorder() + _, err := a.srv.ACLRoleCreate(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Delete", func(t *testing.T) { + req, _ := http.NewRequest("DELETE", "/v1/acl/role/"+idMap["role-service-id-web"]+"?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLRoleCRUD(resp, req) + require.NoError(t, err) + delete(roleMap, idMap["role-service-id-web"]) + delete(idMap, "role-service-id-web") + }) + + t.Run("List", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/roles?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLRoleList(resp, req) + require.NoError(t, err) + roles, ok := raw.(structs.ACLRoles) + require.True(t, ok) + + // 1 we just created + require.Len(t, roles, 1) + + for roleID, expected := range roleMap { + found := false + for _, actual := range roles { + if actual.ID == roleID { + require.Equal(t, expected.Name, actual.Name) + require.Equal(t, expected.Policies, actual.Policies) + require.Equal(t, expected.ServiceIdentities, actual.ServiceIdentities) + require.Equal(t, expected.Hash, actual.Hash) + require.Equal(t, expected.CreateIndex, actual.CreateIndex) + require.Equal(t, expected.ModifyIndex, actual.ModifyIndex) + found = true + break + } + } + + require.True(t, found) + } + }) + + t.Run("Read", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/role/"+idMap["role-test"]+"?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLRoleCRUD(resp, req) + require.NoError(t, err) + role, ok := raw.(*structs.ACLRole) + require.True(t, ok) + require.Equal(t, roleMap[idMap["role-test"]], role) + }) + }) + t.Run("Token", func(t *testing.T) { t.Run("Create", func(t *testing.T) { tokenInput := &structs.ACLToken{ @@ -594,3 +824,504 @@ func TestACL_HTTP(t *testing.T) { }) }) } + +func TestACL_LoginProcedure_HTTP(t *testing.T) { + // This tests AuthMethods, BindingRules, Login, and Logout. + t.Parallel() + a := NewTestAgent(t, t.Name(), TestACLConfig()) + defer a.Shutdown() + + testrpc.WaitForLeader(t, a.RPC, "dc1") + + idMap := make(map[string]string) + methodMap := make(map[string]*structs.ACLAuthMethod) + ruleMap := make(map[string]*structs.ACLBindingRule) + tokenMap := make(map[string]*structs.ACLToken) + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + + // This is all done as a subtest for a couple reasons + // 1. It uses only 1 test agent and these are + // somewhat expensive to bring up and tear down often + // 2. Instead of having to bring up a new agent and prime + // the ACL system with some data before running the test + // we can intelligently order these tests so we can still + // test everything with less actual operations and do + // so in a manner that is less prone to being flaky + // 3. While this test will be large it should + t.Run("AuthMethod", func(t *testing.T) { + t.Run("Create", func(t *testing.T) { + methodInput := &structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/auth-method?token=root", jsonBody(methodInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLAuthMethodCreate(resp, req) + require.NoError(t, err) + + method, ok := obj.(*structs.ACLAuthMethod) + require.True(t, ok) + + require.Equal(t, methodInput.Name, method.Name) + require.Equal(t, methodInput.Type, method.Type) + require.Equal(t, methodInput.Description, method.Description) + require.Equal(t, methodInput.Config, method.Config) + require.True(t, method.CreateIndex > 0) + require.Equal(t, method.CreateIndex, method.ModifyIndex) + + methodMap[method.Name] = method + }) + + t.Run("Create other", func(t *testing.T) { + methodInput := &structs.ACLAuthMethod{ + Name: "other", + Type: "testing", + Description: "test", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/auth-method?token=root", jsonBody(methodInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLAuthMethodCreate(resp, req) + require.NoError(t, err) + + method, ok := obj.(*structs.ACLAuthMethod) + require.True(t, ok) + + require.Equal(t, methodInput.Name, method.Name) + require.Equal(t, methodInput.Type, method.Type) + require.Equal(t, methodInput.Description, method.Description) + require.Equal(t, methodInput.Config, method.Config) + require.True(t, method.CreateIndex > 0) + require.Equal(t, method.CreateIndex, method.ModifyIndex) + + methodMap[method.Name] = method + }) + + t.Run("Update Name URL Mismatch", func(t *testing.T) { + methodInput := &structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/auth-method/not-test?token=root", jsonBody(methodInput)) + resp := httptest.NewRecorder() + _, err := a.srv.ACLAuthMethodCRUD(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Update", func(t *testing.T) { + methodInput := &structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "updated description", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/auth-method/test?token=root", jsonBody(methodInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLAuthMethodCRUD(resp, req) + require.NoError(t, err) + + method, ok := obj.(*structs.ACLAuthMethod) + require.True(t, ok) + + require.Equal(t, methodInput.Name, method.Name) + require.Equal(t, methodInput.Type, method.Type) + require.Equal(t, methodInput.Description, method.Description) + require.Equal(t, methodInput.Config, method.Config) + require.True(t, method.CreateIndex > 0) + require.True(t, method.CreateIndex < method.ModifyIndex) + + methodMap[method.Name] = method + }) + + t.Run("Invalid payload", func(t *testing.T) { + body := bytes.NewBuffer(nil) + body.Write([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + + req, _ := http.NewRequest("PUT", "/v1/acl/auth-method?token=root", body) + resp := httptest.NewRecorder() + _, err := a.srv.ACLAuthMethodCreate(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("List", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/auth-methods?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLAuthMethodList(resp, req) + require.NoError(t, err) + methods, ok := raw.(structs.ACLAuthMethodListStubs) + require.True(t, ok) + + // 2 we just created + require.Len(t, methods, 2) + + for methodName, expected := range methodMap { + found := false + for _, actual := range methods { + if actual.Name == methodName { + require.Equal(t, expected.Name, actual.Name) + require.Equal(t, expected.Type, actual.Type) + require.Equal(t, expected.Description, actual.Description) + require.Equal(t, expected.CreateIndex, actual.CreateIndex) + require.Equal(t, expected.ModifyIndex, actual.ModifyIndex) + found = true + break + } + } + + require.True(t, found) + } + }) + + t.Run("Delete", func(t *testing.T) { + req, _ := http.NewRequest("DELETE", "/v1/acl/auth-method/other?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLAuthMethodCRUD(resp, req) + require.NoError(t, err) + delete(methodMap, "other") + }) + + t.Run("Read", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/auth-method/test?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLAuthMethodCRUD(resp, req) + require.NoError(t, err) + method, ok := raw.(*structs.ACLAuthMethod) + require.True(t, ok) + require.Equal(t, methodMap["test"], method) + }) + }) + + t.Run("BindingRule", func(t *testing.T) { + t.Run("Create", func(t *testing.T) { + ruleInput := &structs.ACLBindingRule{ + Description: "test", + AuthMethod: "test", + Selector: "serviceaccount.namespace==default", + BindType: structs.BindingRuleBindTypeService, + BindName: "web", + } + + req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", jsonBody(ruleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLBindingRuleCreate(resp, req) + require.NoError(t, err) + + rule, ok := obj.(*structs.ACLBindingRule) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, rule.ID, 36) + require.Equal(t, ruleInput.Description, rule.Description) + require.Equal(t, ruleInput.AuthMethod, rule.AuthMethod) + require.Equal(t, ruleInput.Selector, rule.Selector) + require.Equal(t, ruleInput.BindType, rule.BindType) + require.Equal(t, ruleInput.BindName, rule.BindName) + require.True(t, rule.CreateIndex > 0) + require.Equal(t, rule.CreateIndex, rule.ModifyIndex) + + idMap["rule-test"] = rule.ID + ruleMap[rule.ID] = rule + }) + + t.Run("Create other", func(t *testing.T) { + ruleInput := &structs.ACLBindingRule{ + Description: "other", + AuthMethod: "test", + Selector: "serviceaccount.namespace==default", + BindType: structs.BindingRuleBindTypeRole, + BindName: "fancy-role", + } + + req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", jsonBody(ruleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLBindingRuleCreate(resp, req) + require.NoError(t, err) + + rule, ok := obj.(*structs.ACLBindingRule) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, rule.ID, 36) + require.Equal(t, ruleInput.Description, rule.Description) + require.Equal(t, ruleInput.AuthMethod, rule.AuthMethod) + require.Equal(t, ruleInput.Selector, rule.Selector) + require.Equal(t, ruleInput.BindType, rule.BindType) + require.Equal(t, ruleInput.BindName, rule.BindName) + require.True(t, rule.CreateIndex > 0) + require.Equal(t, rule.CreateIndex, rule.ModifyIndex) + + idMap["rule-other"] = rule.ID + ruleMap[rule.ID] = rule + }) + + t.Run("BindingRule CRUD Missing ID in URL", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/binding-rule/?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLBindingRuleCRUD(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Update", func(t *testing.T) { + ruleInput := &structs.ACLBindingRule{ + Description: "updated", + AuthMethod: "test", + Selector: "serviceaccount.namespace==default", + BindType: structs.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + } + + req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule/"+idMap["rule-test"]+"?token=root", jsonBody(ruleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLBindingRuleCRUD(resp, req) + require.NoError(t, err) + + rule, ok := obj.(*structs.ACLBindingRule) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, rule.ID, 36) + require.Equal(t, ruleInput.Description, rule.Description) + require.Equal(t, ruleInput.AuthMethod, rule.AuthMethod) + require.Equal(t, ruleInput.Selector, rule.Selector) + require.Equal(t, ruleInput.BindType, rule.BindType) + require.Equal(t, ruleInput.BindName, rule.BindName) + require.True(t, rule.CreateIndex > 0) + require.True(t, rule.CreateIndex < rule.ModifyIndex) + + idMap["rule-test"] = rule.ID + ruleMap[rule.ID] = rule + }) + + t.Run("ID Supplied", func(t *testing.T) { + ruleInput := &structs.ACLBindingRule{ + ID: "12123d01-37f1-47e6-b55b-32328652bd38", + Description: "with-id", + AuthMethod: "test", + Selector: "serviceaccount.namespace==default", + BindType: structs.BindingRuleBindTypeService, + BindName: "vault", + } + + req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", jsonBody(ruleInput)) + resp := httptest.NewRecorder() + _, err := a.srv.ACLBindingRuleCreate(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Invalid payload", func(t *testing.T) { + body := bytes.NewBuffer(nil) + body.Write([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + + req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", body) + resp := httptest.NewRecorder() + _, err := a.srv.ACLBindingRuleCreate(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("List", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/binding-rules?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLBindingRuleList(resp, req) + require.NoError(t, err) + rules, ok := raw.(structs.ACLBindingRules) + require.True(t, ok) + + // 2 we just created + require.Len(t, rules, 2) + + for ruleID, expected := range ruleMap { + found := false + for _, actual := range rules { + if actual.ID == ruleID { + require.Equal(t, expected.Description, actual.Description) + require.Equal(t, expected.AuthMethod, actual.AuthMethod) + require.Equal(t, expected.Selector, actual.Selector) + require.Equal(t, expected.BindType, actual.BindType) + require.Equal(t, expected.BindName, actual.BindName) + require.Equal(t, expected.CreateIndex, actual.CreateIndex) + require.Equal(t, expected.ModifyIndex, actual.ModifyIndex) + found = true + break + } + } + + require.True(t, found) + } + }) + + t.Run("Delete", func(t *testing.T) { + req, _ := http.NewRequest("DELETE", "/v1/acl/binding-rule/"+idMap["rule-other"]+"?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLBindingRuleCRUD(resp, req) + require.NoError(t, err) + delete(ruleMap, idMap["rule-other"]) + delete(idMap, "rule-other") + }) + + t.Run("Read", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/binding-rule/"+idMap["rule-test"]+"?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLBindingRuleCRUD(resp, req) + require.NoError(t, err) + rule, ok := raw.(*structs.ACLBindingRule) + require.True(t, ok) + require.Equal(t, ruleMap[idMap["rule-test"]], rule) + }) + }) + + testauth.InstallSessionToken(testSessionID, "token1", "default", "demo1", "abc123") + testauth.InstallSessionToken(testSessionID, "token2", "default", "demo2", "def456") + + t.Run("Login", func(t *testing.T) { + t.Run("Create Token 1", func(t *testing.T) { + loginInput := &structs.ACLLoginParams{ + AuthMethod: "test", + BearerToken: "token1", + Meta: map[string]string{"foo": "bar"}, + } + + req, _ := http.NewRequest("POST", "/v1/acl/login?token=root", jsonBody(loginInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLLogin(resp, req) + require.NoError(t, err) + + token, ok := obj.(*structs.ACLToken) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, token.AccessorID, 36) + require.Len(t, token.SecretID, 36) + require.Equal(t, `token created via login: {"foo":"bar"}`, token.Description) + require.True(t, token.Local) + require.Len(t, token.Policies, 0) + require.Len(t, token.Roles, 0) + require.Len(t, token.ServiceIdentities, 1) + require.Equal(t, "demo1", token.ServiceIdentities[0].ServiceName) + require.Len(t, token.ServiceIdentities[0].Datacenters, 0) + require.True(t, token.CreateIndex > 0) + require.Equal(t, token.CreateIndex, token.ModifyIndex) + require.NotNil(t, token.Hash) + require.NotEqual(t, token.Hash, []byte{}) + + idMap["token-test-1"] = token.AccessorID + tokenMap[token.AccessorID] = token + }) + t.Run("Create Token 2", func(t *testing.T) { + loginInput := &structs.ACLLoginParams{ + AuthMethod: "test", + BearerToken: "token2", + Meta: map[string]string{"blah": "woot"}, + } + + req, _ := http.NewRequest("POST", "/v1/acl/login?token=root", jsonBody(loginInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLLogin(resp, req) + require.NoError(t, err) + + token, ok := obj.(*structs.ACLToken) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, token.AccessorID, 36) + require.Len(t, token.SecretID, 36) + require.Equal(t, `token created via login: {"blah":"woot"}`, token.Description) + require.True(t, token.Local) + require.Len(t, token.Policies, 0) + require.Len(t, token.Roles, 0) + require.Len(t, token.ServiceIdentities, 1) + require.Equal(t, "demo2", token.ServiceIdentities[0].ServiceName) + require.Len(t, token.ServiceIdentities[0].Datacenters, 0) + require.True(t, token.CreateIndex > 0) + require.Equal(t, token.CreateIndex, token.ModifyIndex) + require.NotNil(t, token.Hash) + require.NotEqual(t, token.Hash, []byte{}) + + idMap["token-test-2"] = token.AccessorID + tokenMap[token.AccessorID] = token + }) + + t.Run("List Tokens by (incorrect) Method", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/tokens?token=root&authmethod=other", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLTokenList(resp, req) + require.NoError(t, err) + tokens, ok := raw.(structs.ACLTokenListStubs) + require.True(t, ok) + require.Len(t, tokens, 0) + }) + + t.Run("List Tokens by (correct) Method", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/tokens?token=root&authmethod=test", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLTokenList(resp, req) + require.NoError(t, err) + tokens, ok := raw.(structs.ACLTokenListStubs) + require.True(t, ok) + require.Len(t, tokens, 2) + + for tokenID, expected := range tokenMap { + found := false + for _, actual := range tokens { + if actual.AccessorID == tokenID { + require.Equal(t, expected.Description, actual.Description) + require.Equal(t, expected.Policies, actual.Policies) + require.Equal(t, expected.Roles, actual.Roles) + require.Equal(t, expected.ServiceIdentities, actual.ServiceIdentities) + require.Equal(t, expected.Local, actual.Local) + require.Equal(t, expected.CreateTime, actual.CreateTime) + require.Equal(t, expected.Hash, actual.Hash) + require.Equal(t, expected.CreateIndex, actual.CreateIndex) + require.Equal(t, expected.ModifyIndex, actual.ModifyIndex) + found = true + break + } + } + require.True(t, found) + } + }) + + t.Run("Logout", func(t *testing.T) { + tok := tokenMap[idMap["token-test-1"]] + req, _ := http.NewRequest("POST", "/v1/acl/logout?token="+tok.SecretID, nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLLogout(resp, req) + require.NoError(t, err) + }) + + t.Run("Token is gone after Logout", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/token/"+idMap["token-test-1"]+"?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLTokenCRUD(resp, req) + require.Error(t, err) + require.True(t, acl.IsErrNotFound(err), err.Error()) + }) + }) +} diff --git a/agent/consul/acl.go b/agent/consul/acl.go index 74ebb90385..6e8130af27 100644 --- a/agent/consul/acl.go +++ b/agent/consul/acl.go @@ -447,25 +447,8 @@ func (r *ACLResolver) fetchAndCachePoliciesForIdentity(identity structs.ACLIdent return out, nil } - if acl.IsErrNotFound(err) { - // make sure to indicate that this identity is no longer valid within - // the cache - r.cache.PutIdentity(identity.SecretToken(), nil) - - // Do not touch the policy cache. Getting a top level ACL not found error - // only indicates that the secret token used in the request - // no longer exists - return nil, &policyOrRoleTokenError{acl.ErrNotFound, identity.SecretToken()} - } - - if acl.IsErrPermissionDenied(err) { - // invalidate our ID cache so that identity resolution will take place - // again in the future - r.cache.RemoveIdentity(identity.SecretToken()) - - // Do not remove from the policy cache for permission denied - // what this does indicate is that our view of the token is out of date - return nil, &policyOrRoleTokenError{acl.ErrPermissionDenied, identity.SecretToken()} + if handledErr := r.maybeHandleIdentityErrorDuringFetch(identity, err); handledErr != nil { + return nil, handledErr } // other RPC error - use cache if available @@ -519,25 +502,8 @@ func (r *ACLResolver) fetchAndCacheRolesForIdentity(identity structs.ACLIdentity return out, nil } - if acl.IsErrNotFound(err) { - // make sure to indicate that this identity is no longer valid within - // the cache - r.cache.PutIdentity(identity.SecretToken(), nil) - - // Do not touch the cache. Getting a top level ACL not found error - // only indicates that the secret token used in the request - // no longer exists - return nil, &policyOrRoleTokenError{acl.ErrNotFound, identity.SecretToken()} - } - - if acl.IsErrPermissionDenied(err) { - // invalidate our ID cache so that identity resolution will take place - // again in the future - r.cache.RemoveIdentity(identity.SecretToken()) - - // Do not remove from the cache for permission denied - // what this does indicate is that our view of the token is out of date - return nil, &policyOrRoleTokenError{acl.ErrPermissionDenied, identity.SecretToken()} + if handledErr := r.maybeHandleIdentityErrorDuringFetch(identity, err); handledErr != nil { + return nil, handledErr } // other RPC error - use cache if available @@ -557,12 +523,39 @@ func (r *ACLResolver) fetchAndCacheRolesForIdentity(identity structs.ACLIdentity insufficientCache = true } } + if insufficientCache { return nil, ACLRemoteError{Err: err} } + return out, nil } +func (r *ACLResolver) maybeHandleIdentityErrorDuringFetch(identity structs.ACLIdentity, err error) error { + if acl.IsErrNotFound(err) { + // make sure to indicate that this identity is no longer valid within + // the cache + r.cache.PutIdentity(identity.SecretToken(), nil) + + // Do not touch the cache. Getting a top level ACL not found error + // only indicates that the secret token used in the request + // no longer exists + return &policyOrRoleTokenError{acl.ErrNotFound, identity.SecretToken()} + } + + if acl.IsErrPermissionDenied(err) { + // invalidate our ID cache so that identity resolution will take place + // again in the future + r.cache.RemoveIdentity(identity.SecretToken()) + + // Do not remove from the cache for permission denied + // what this does indicate is that our view of the token is out of date + return &policyOrRoleTokenError{acl.ErrPermissionDenied, identity.SecretToken()} + } + + return nil +} + func (r *ACLResolver) filterPoliciesByScope(policies structs.ACLPolicies) structs.ACLPolicies { var out structs.ACLPolicies for _, policy := range policies { diff --git a/agent/consul/acl_authmethod.go b/agent/consul/acl_authmethod.go new file mode 100644 index 0000000000..ba3b3772da --- /dev/null +++ b/agent/consul/acl_authmethod.go @@ -0,0 +1,169 @@ +package consul + +import ( + "fmt" + + "github.com/hashicorp/consul/agent/consul/authmethod" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-bexpr" + + // register this as a builtin auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth" +) + +type authMethodValidatorEntry struct { + Validator authmethod.Validator + ModifyIndex uint64 // the raft index when this last changed +} + +// loadAuthMethodValidator returns an authmethod.Validator for the given auth +// method configuration. If the cache is up to date as-of the provided index +// then the cached version is returned, otherwise a new validator is created +// and cached. +func (s *Server) loadAuthMethodValidator(idx uint64, method *structs.ACLAuthMethod) (authmethod.Validator, error) { + if prevIdx, v, ok := s.getCachedAuthMethodValidator(method.Name); ok && idx <= prevIdx { + return v, nil + } + + v, err := authmethod.NewValidator(method) + if err != nil { + return nil, fmt.Errorf("auth method validator for %q could not be initialized: %v", method.Name, err) + } + + v = s.getOrReplaceAuthMethodValidator(method.Name, idx, v) + + return v, nil +} + +// getCachedAuthMethodValidator returns an AuthMethodValidator for +// the given name exclusively from the cache. If one is not found in the cache +// nil is returned. +func (s *Server) getCachedAuthMethodValidator(name string) (uint64, authmethod.Validator, bool) { + s.aclAuthMethodValidatorLock.RLock() + defer s.aclAuthMethodValidatorLock.RUnlock() + + if s.aclAuthMethodValidators != nil { + v, ok := s.aclAuthMethodValidators[name] + if ok { + return v.ModifyIndex, v.Validator, true + } + } + return 0, nil, false +} + +// getOrReplaceAuthMethodValidator updates the cached validator with the +// provided one UNLESS it has been updated by another goroutine in which case +// the updated one is returned. +func (s *Server) getOrReplaceAuthMethodValidator(name string, idx uint64, v authmethod.Validator) authmethod.Validator { + s.aclAuthMethodValidatorLock.Lock() + defer s.aclAuthMethodValidatorLock.Unlock() + + if s.aclAuthMethodValidators == nil { + s.aclAuthMethodValidators = make(map[string]*authMethodValidatorEntry) + } + + prev, ok := s.aclAuthMethodValidators[name] + if ok { + if prev.ModifyIndex >= idx { + return prev.Validator + } + } + + s.logger.Printf("[DEBUG] acl: updating cached auth method validator for %q", name) + + s.aclAuthMethodValidators[name] = &authMethodValidatorEntry{ + Validator: v, + ModifyIndex: idx, + } + return v +} + +// purgeAuthMethodValidators resets the cache of validators. +func (s *Server) purgeAuthMethodValidators() { + s.aclAuthMethodValidatorLock.Lock() + s.aclAuthMethodValidators = make(map[string]*authMethodValidatorEntry) + s.aclAuthMethodValidatorLock.Unlock() +} + +// evaluateRoleBindings evaluates all current binding rules associated with the +// given auth method against the verified data returned from the authentication +// process. +// +// A list of role links and service identities are returned. +func (s *Server) evaluateRoleBindings( + validator authmethod.Validator, + verifiedFields map[string]string, +) ([]*structs.ACLServiceIdentity, []structs.ACLTokenRoleLink, error) { + // Only fetch rules that are relevant for this method. + _, rules, err := s.fsm.State().ACLBindingRuleList(nil, validator.Name()) + if err != nil { + return nil, nil, err + } else if len(rules) == 0 { + return nil, nil, nil + } + + // Convert the fields into something suitable for go-bexpr. + selectableVars := validator.MakeFieldMapSelectable(verifiedFields) + + // Find all binding rules that match the provided fields. + var matchingRules []*structs.ACLBindingRule + for _, rule := range rules { + if doesBindingRuleMatch(rule, selectableVars) { + matchingRules = append(matchingRules, rule) + } + } + if len(matchingRules) == 0 { + return nil, nil, nil + } + + // For all matching rules compute the attributes of a token. + var ( + roleLinks []structs.ACLTokenRoleLink + serviceIdentities []*structs.ACLServiceIdentity + ) + for _, rule := range matchingRules { + bindName, valid, err := computeBindingRuleBindName(rule.BindType, rule.BindName, verifiedFields) + if err != nil { + return nil, nil, fmt.Errorf("cannot compute %q bind name for bind target: %v", rule.BindType, err) + } else if !valid { + return nil, nil, fmt.Errorf("computed %q bind name for bind target is invalid: %q", rule.BindType, bindName) + } + + switch rule.BindType { + case structs.BindingRuleBindTypeService: + serviceIdentities = append(serviceIdentities, &structs.ACLServiceIdentity{ + ServiceName: bindName, + }) + + case structs.BindingRuleBindTypeRole: + roleLinks = append(roleLinks, structs.ACLTokenRoleLink{ + Name: bindName, + }) + + default: + // skip unknown bind type; don't grant privileges + } + } + + return serviceIdentities, roleLinks, nil +} + +// doesBindingRuleMatch checks that a single binding rule matches the provided +// vars. +func doesBindingRuleMatch(rule *structs.ACLBindingRule, selectableVars interface{}) bool { + if rule.Selector == "" { + return true // catch-all + } + + eval, err := bexpr.CreateEvaluatorForType(rule.Selector, nil, selectableVars) + if err != nil { + return false // fails to match if selector is invalid + } + + result, err := eval.Evaluate(selectableVars) + if err != nil { + return false // fails to match if evaluation fails + } + + return result +} diff --git a/agent/consul/acl_authmethod_test.go b/agent/consul/acl_authmethod_test.go new file mode 100644 index 0000000000..45e3021e44 --- /dev/null +++ b/agent/consul/acl_authmethod_test.go @@ -0,0 +1,48 @@ +package consul + +import ( + "testing" + + "github.com/hashicorp/consul/agent/structs" + "github.com/stretchr/testify/require" +) + +func TestDoesBindingRuleMatch(t *testing.T) { + type matchable struct { + A string `bexpr:"a"` + C string `bexpr:"c"` + } + + for _, test := range []struct { + name string + selector string + details interface{} + ok bool + }{ + {"no fields", + "a==b", nil, false}, + {"1 term ok", + "a==b", &matchable{A: "b"}, true}, + {"1 term no field", + "a==b", &matchable{C: "d"}, false}, + {"1 term wrong value", + "a==b", &matchable{A: "z"}, false}, + {"2 terms ok", + "a==b and c==d", &matchable{A: "b", C: "d"}, true}, + {"2 terms one missing field", + "a==b and c==d", &matchable{A: "b"}, false}, + {"2 terms one wrong value", + "a==b and c==d", &matchable{A: "z", C: "d"}, false}, + /////////////////////////////// + {"no fields (no selectors)", + "", nil, true}, + {"1 term ok (no selectors)", + "", &matchable{A: "b"}, true}, + } { + t.Run(test.name, func(t *testing.T) { + rule := structs.ACLBindingRule{Selector: test.selector} + ok := doesBindingRuleMatch(&rule, test.details) + require.Equal(t, test.ok, ok) + }) + } +} diff --git a/agent/consul/acl_endpoint.go b/agent/consul/acl_endpoint.go index 62c17496f9..cc3b5da2e7 100644 --- a/agent/consul/acl_endpoint.go +++ b/agent/consul/acl_endpoint.go @@ -1,6 +1,8 @@ package consul import ( + "encoding/json" + "errors" "fmt" "io/ioutil" "os" @@ -10,9 +12,11 @@ import ( metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/authmethod" "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/lib" + "github.com/hashicorp/go-bexpr" memdb "github.com/hashicorp/go-memdb" uuid "github.com/hashicorp/go-uuid" ) @@ -29,6 +33,7 @@ var ( validServiceIdentityName = regexp.MustCompile(`^[a-z0-9]([a-z0-9\-_]*[a-z0-9])?$`) serviceIdentityNameMaxLength = 256 validRoleName = regexp.MustCompile(`^[A-Za-z0-9\-_]{1,256}$`) + validAuthMethod = regexp.MustCompile(`^[A-Za-z0-9\-_]{1,128}$`) ) // ACL endpoint is used to manipulate ACLs @@ -273,6 +278,10 @@ func (a *ACL) TokenClone(args *structs.ACLTokenSetRequest, reply *structs.ACLTok return a.srv.forwardDC("ACL.TokenClone", a.srv.config.ACLDatacenter, args, reply) } + if token.AuthMethod != "" { + return fmt.Errorf("Cannot clone a token created from an auth method") + } + if token.Rules != "" { return fmt.Errorf("Cannot clone a legacy ACL with this endpoint") } @@ -324,7 +333,7 @@ func (a *ACL) TokenSet(args *structs.ACLTokenSetRequest, reply *structs.ACLToken return a.tokenSetInternal(args, reply, false) } -func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs.ACLToken, upgrade bool) error { +func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs.ACLToken, fromLogin bool) error { token := &args.ACLToken if !a.srv.LocalTokensEnabled() { @@ -354,6 +363,19 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. token.CreateTime = time.Now() + if fromLogin { + if token.AuthMethod == "" { + return fmt.Errorf("AuthMethod field is required during Login") + } + if !token.Local { + return fmt.Errorf("Cannot create Global token via Login") + } + } else { + if token.AuthMethod != "" { + return fmt.Errorf("AuthMethod field is disallowed outside of Login") + } + } + // Ensure an ExpirationTTL is valid if provided. if token.ExpirationTTL != 0 { if token.ExpirationTTL < 0 { @@ -418,6 +440,12 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. return fmt.Errorf("cannot toggle local mode of %s", token.AccessorID) } + if token.AuthMethod == "" { + token.AuthMethod = existing.AuthMethod + } else if token.AuthMethod != existing.AuthMethod { + return fmt.Errorf("Cannot change AuthMethod of %s", token.AccessorID) + } + if token.ExpirationTTL != 0 { return fmt.Errorf("Cannot change expiration time of %s", token.AccessorID) } @@ -430,11 +458,7 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. return fmt.Errorf("Cannot change expiration time of %s", token.AccessorID) } - if upgrade { - token.CreateTime = time.Now() - } else { - token.CreateTime = existing.CreateTime - } + token.CreateTime = existing.CreateTime } policyIDs := make(map[string]struct{}) @@ -467,7 +491,7 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. roleIDs := make(map[string]struct{}) var roles []structs.ACLTokenRoleLink - // Validate all the role names and convert them to role IDs + // Validate all the role names and convert them to role IDs. for _, link := range token.Roles { if link.ID == "" { _, role, err := state.ACLRoleGetByName(nil, link.Name) @@ -502,6 +526,7 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. return fmt.Errorf("Service identity %q has an invalid name. Only alphanumeric characters, '-' and '_' are allowed", svcid.ServiceName) } } + token.ServiceIdentities = dedupeServiceIdentities(token.ServiceIdentities) if token.Rules != "" { return fmt.Errorf("Rules cannot be specified for this token") @@ -540,6 +565,51 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. return nil } +func validateBindingRuleBindName(bindType, bindName string, availableFields []string) (bool, error) { + if bindType == "" || bindName == "" { + return false, nil + } + + fakeVarMap := make(map[string]string) + for _, v := range availableFields { + fakeVarMap[v] = "fake" + } + + _, valid, err := computeBindingRuleBindName(bindType, bindName, fakeVarMap) + if err != nil { + return false, err + } + return valid, nil +} + +// computeBindingRuleBindName processes the HIL for the provided bind type+name +// using the verified fields. +// +// - If the HIL is invalid ("", false, AN_ERROR) is returned. +// - If the computed name is not valid for the type ("INVALID_NAME", false, nil) is returned. +// - If the computed name is valid for the type ("VALID_NAME", true, nil) is returned. +func computeBindingRuleBindName(bindType, bindName string, verifiedFields map[string]string) (string, bool, error) { + bindName, err := InterpolateHIL(bindName, verifiedFields) + if err != nil { + return "", false, err + } + + valid := false + + switch bindType { + case structs.BindingRuleBindTypeService: + valid = isValidServiceIdentityName(bindName) + + case structs.BindingRuleBindTypeRole: + valid = validRoleName.MatchString(bindName) + + default: + return "", false, fmt.Errorf("unknown binding rule bind type: %s", bindType) + } + + return bindName, valid, nil +} + // isValidServiceIdentityName returns true if the provided name can be used as // an ACLServiceIdentity ServiceName. This is more restrictive than standard // catalog registration, which basically takes the view that "everything is @@ -652,7 +722,7 @@ func (a *ACL) TokenList(args *structs.ACLTokenListRequest, reply *structs.ACLTok return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, tokens, err := state.ACLTokenList(ws, args.IncludeLocal, args.IncludeGlobal, args.Policy, args.Role) + index, tokens, err := state.ACLTokenList(ws, args.IncludeLocal, args.IncludeGlobal, args.Policy, args.Role, args.AuthMethod) if err != nil { return err } @@ -1252,17 +1322,11 @@ func (a *ACL) RoleSet(args *structs.ACLRoleSetRequest, reply *structs.ACLRole) e if svcid.ServiceName == "" { return fmt.Errorf("Service identity is missing the service name field on this role") } - // TODO(rb): ugh if a local token gets a role that has a service - // identity that has datacenters set, we won't be anble to enforce this - // next blob here. This makes me lean more towards nuking ServiceIdentity.Datacenters again - // - // if token.Local && len(svcid.Datacenters) > 0 { - // return fmt.Errorf("Service identity %q cannot specify a list of datacenters on a local token", svcid.ServiceName) - // } if !isValidServiceIdentityName(svcid.ServiceName) { return fmt.Errorf("Service identity %q has an invalid name. Only alphanumeric characters, '-' and '_' are allowed", svcid.ServiceName) } } + role.ServiceIdentities = dedupeServiceIdentities(role.ServiceIdentities) // calculate the hash for this role role.SetHash(true) @@ -1412,3 +1476,577 @@ func (a *ACL) RoleResolve(args *structs.ACLRoleBatchGetRequest, reply *structs.A return nil } + +var errAuthMethodsRequireTokenReplication = errors.New("Token replication is required for auth methods to function") + +func (a *ACL) BindingRuleRead(args *structs.ACLBindingRuleGetRequest, reply *structs.ACLBindingRuleResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.BindingRuleRead", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, rule, err := state.ACLBindingRuleGetByID(ws, args.BindingRuleID) + + if err != nil { + return err + } + + reply.Index, reply.BindingRule = index, rule + return nil + }) +} + +func (a *ACL) BindingRuleSet(args *structs.ACLBindingRuleSetRequest, reply *structs.ACLBindingRule) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.BindingRuleSet", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "bindingrule", "upsert"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + rule := &args.BindingRule + state := a.srv.fsm.State() + + if rule.ID == "" { + // with no binding rule ID one will be generated + var err error + + rule.ID, err = lib.GenerateUUID(a.srv.checkBindingRuleUUID) + if err != nil { + return err + } + } else { + if _, err := uuid.ParseUUID(rule.ID); err != nil { + return fmt.Errorf("Binding Rule ID invalid UUID") + } + + // Verify the role exists + _, existing, err := state.ACLBindingRuleGetByID(nil, rule.ID) + if err != nil { + return fmt.Errorf("acl binding rule lookup failed: %v", err) + } else if existing == nil { + return fmt.Errorf("cannot find binding rule %s", rule.ID) + } + + if rule.AuthMethod == "" { + rule.AuthMethod = existing.AuthMethod + } else if existing.AuthMethod != rule.AuthMethod { + return fmt.Errorf("the AuthMethod field of an Binding Rule is immutable") + } + } + + if rule.AuthMethod == "" { + return fmt.Errorf("Invalid Binding Rule: no AuthMethod is set") + } + + methodIdx, method, err := state.ACLAuthMethodGetByName(nil, rule.AuthMethod) + if err != nil { + return fmt.Errorf("acl auth method lookup failed: %v", err) + } else if method == nil { + return fmt.Errorf("cannot find auth method with name %q", rule.AuthMethod) + } + validator, err := a.srv.loadAuthMethodValidator(methodIdx, method) + if err != nil { + return err + } + + if rule.Selector != "" { + selectableVars := validator.MakeFieldMapSelectable(map[string]string{}) + _, err := bexpr.CreateEvaluatorForType(rule.Selector, nil, selectableVars) + if err != nil { + return fmt.Errorf("invalid Binding Rule: Selector is invalid: %v", err) + } + } + + if rule.BindType == "" { + return fmt.Errorf("Invalid Binding Rule: no BindType is set") + } + + if rule.BindName == "" { + return fmt.Errorf("Invalid Binding Rule: no BindName is set") + } + + switch rule.BindType { + case structs.BindingRuleBindTypeService: + case structs.BindingRuleBindTypeRole: + default: + return fmt.Errorf("Invalid Binding Rule: unknown BindType %q", rule.BindType) + } + + if valid, err := validateBindingRuleBindName(rule.BindType, rule.BindName, validator.AvailableFields()); err != nil { + return fmt.Errorf("Invalid Binding Rule: invalid BindName: %v", err) + } else if !valid { + return fmt.Errorf("Invalid Binding Rule: invalid BindName") + } + + req := &structs.ACLBindingRuleBatchSetRequest{ + BindingRules: structs.ACLBindingRules{rule}, + } + + resp, err := a.srv.raftApply(structs.ACLBindingRuleSetRequestType, req) + if err != nil { + return fmt.Errorf("Failed to apply binding rule upsert request: %v", err) + } + + if respErr, ok := resp.(error); ok { + return respErr + } + + if _, rule, err := a.srv.fsm.State().ACLBindingRuleGetByID(nil, rule.ID); err == nil && rule != nil { + *reply = *rule + } + + return nil +} + +func (a *ACL) BindingRuleDelete(args *structs.ACLBindingRuleDeleteRequest, reply *bool) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.BindingRuleDelete", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "bindingrule", "delete"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + _, rule, err := a.srv.fsm.State().ACLBindingRuleGetByID(nil, args.BindingRuleID) + if err != nil { + return err + } + + if rule == nil { + return nil + } + + req := structs.ACLBindingRuleBatchDeleteRequest{ + BindingRuleIDs: []string{args.BindingRuleID}, + } + + resp, err := a.srv.raftApply(structs.ACLBindingRuleDeleteRequestType, &req) + if err != nil { + return fmt.Errorf("Failed to apply binding rule delete request: %v", err) + } + + if respErr, ok := resp.(error); ok { + return respErr + } + + *reply = true + + return nil +} + +func (a *ACL) BindingRuleList(args *structs.ACLBindingRuleListRequest, reply *structs.ACLBindingRuleListResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.BindingRuleList", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, rules, err := state.ACLBindingRuleList(ws, args.AuthMethod) + if err != nil { + return err + } + + reply.Index, reply.BindingRules = index, rules + return nil + }) +} + +func (a *ACL) AuthMethodRead(args *structs.ACLAuthMethodGetRequest, reply *structs.ACLAuthMethodResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.AuthMethodRead", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, method, err := state.ACLAuthMethodGetByName(ws, args.AuthMethodName) + + if err != nil { + return err + } + + reply.Index, reply.AuthMethod = index, method + return nil + }) +} + +func (a *ACL) AuthMethodSet(args *structs.ACLAuthMethodSetRequest, reply *structs.ACLAuthMethod) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.AuthMethodSet", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "authmethod", "upsert"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + method := &args.AuthMethod + state := a.srv.fsm.State() + + // ensure a name is set + if method.Name == "" { + return fmt.Errorf("Invalid Auth Method: no Name is set") + } + if !validAuthMethod.MatchString(method.Name) { + return fmt.Errorf("Invalid Auth Method: invalid Name. Only alphanumeric characters, '-' and '_' are allowed") + } + + // Check to see if the method exists first. + _, existing, err := state.ACLAuthMethodGetByName(nil, method.Name) + if err != nil { + return fmt.Errorf("acl auth method lookup failed: %v", err) + } + + if existing != nil { + if method.Type == "" { + method.Type = existing.Type + } else if existing.Type != method.Type { + return fmt.Errorf("the Type field of an Auth Method is immutable") + } + } + + if !authmethod.IsRegisteredType(method.Type) { + return fmt.Errorf("Invalid Auth Method: Type should be one of: %v", authmethod.Types()) + } + + // Instantiate a validator but do not cache it yet. This will validate the + // configuration. + if _, err := authmethod.NewValidator(method); err != nil { + return fmt.Errorf("Invalid Auth Method: %v", err) + } + + req := &structs.ACLAuthMethodBatchSetRequest{ + AuthMethods: structs.ACLAuthMethods{method}, + } + + resp, err := a.srv.raftApply(structs.ACLAuthMethodSetRequestType, req) + if err != nil { + return fmt.Errorf("Failed to apply auth method upsert request: %v", err) + } + + if respErr, ok := resp.(error); ok { + return respErr + } + + if _, method, err := a.srv.fsm.State().ACLAuthMethodGetByName(nil, method.Name); err == nil && method != nil { + *reply = *method + } + + return nil +} + +func (a *ACL) AuthMethodDelete(args *structs.ACLAuthMethodDeleteRequest, reply *bool) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.AuthMethodDelete", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "authmethod", "delete"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + _, method, err := a.srv.fsm.State().ACLAuthMethodGetByName(nil, args.AuthMethodName) + if err != nil { + return err + } + + if method == nil { + return nil + } + + req := structs.ACLAuthMethodBatchDeleteRequest{ + AuthMethodNames: []string{args.AuthMethodName}, + } + + resp, err := a.srv.raftApply(structs.ACLAuthMethodDeleteRequestType, &req) + if err != nil { + return fmt.Errorf("Failed to apply auth method delete request: %v", err) + } + + if respErr, ok := resp.(error); ok { + return respErr + } + + *reply = true + + return nil +} + +func (a *ACL) AuthMethodList(args *structs.ACLAuthMethodListRequest, reply *structs.ACLAuthMethodListResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.AuthMethodList", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, methods, err := state.ACLAuthMethodList(ws) + if err != nil { + return err + } + + var stubs structs.ACLAuthMethodListStubs + for _, method := range methods { + stubs = append(stubs, method.Stub()) + } + + reply.Index, reply.AuthMethods = index, stubs + return nil + }) +} + +func (a *ACL) Login(args *structs.ACLLoginRequest, reply *structs.ACLToken) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if args.Token != "" { // This shouldn't happen. + return errors.New("do not provide a token when logging in") + } + + if done, err := a.srv.forward("ACL.Login", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "login"}, time.Now()) + + auth := args.Auth + + // 1. take args.Data.AuthMethod to get an AuthMethod Validator + idx, method, err := a.srv.fsm.State().ACLAuthMethodGetByName(nil, auth.AuthMethod) + if err != nil { + return err + } else if method == nil { + return acl.ErrNotFound + } + + validator, err := a.srv.loadAuthMethodValidator(idx, method) + if err != nil { + return err + } + + // 2. Send args.Data.BearerToken to method validator and get back a fields map + verifiedFields, err := validator.ValidateLogin(auth.BearerToken) + if err != nil { + return err + } + + // 3. send map through role bindings + serviceIdentities, roleLinks, err := a.srv.evaluateRoleBindings(validator, verifiedFields) + if err != nil { + return err + } + + if len(serviceIdentities) == 0 && len(roleLinks) == 0 { + return acl.ErrPermissionDenied + } + + description := "token created via login" + loginMeta, err := encodeLoginMeta(auth.Meta) + if err != nil { + return err + } + if loginMeta != "" { + description += ": " + loginMeta + } + + // 4. create token + createReq := structs.ACLTokenSetRequest{ + Datacenter: args.Datacenter, + ACLToken: structs.ACLToken{ + Description: description, + Local: true, + AuthMethod: auth.AuthMethod, + ServiceIdentities: serviceIdentities, + Roles: roleLinks, + }, + WriteRequest: args.WriteRequest, + } + + // 5. return token information like a TokenCreate would + return a.tokenSetInternal(&createReq, reply, true) +} + +func encodeLoginMeta(meta map[string]string) (string, error) { + if len(meta) == 0 { + return "", nil + } + + d, err := json.Marshal(meta) + if err != nil { + return "", err + } + return string(d), nil +} + +func (a *ACL) Logout(args *structs.ACLLogoutRequest, reply *bool) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if args.Token == "" { + return acl.ErrNotFound + } + + if done, err := a.srv.forward("ACL.Logout", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "logout"}, time.Now()) + + _, token, err := a.srv.fsm.State().ACLTokenGetBySecret(nil, args.Token) + if err != nil { + return err + + } else if token == nil { + return acl.ErrNotFound + + } else if token.AuthMethod == "" { + // Can't "logout" of a token that wasn't a result of login. + return acl.ErrPermissionDenied + + } else if !a.srv.InACLDatacenter() && !token.Local { + // global token writes must be forwarded to the primary DC + args.Datacenter = a.srv.config.ACLDatacenter + return a.srv.forwardDC("ACL.Logout", a.srv.config.ACLDatacenter, args, reply) + } + + // No need to check expiration time because it's being deleted. + + req := &structs.ACLTokenBatchDeleteRequest{ + TokenIDs: []string{token.AccessorID}, + } + + resp, err := a.srv.raftApply(structs.ACLTokenDeleteRequestType, req) + if err != nil { + return fmt.Errorf("Failed to apply token delete request: %v", err) + } + + // Purge the identity from the cache to prevent using the previous definition of the identity + if token != nil { + a.srv.acls.cache.RemoveIdentity(token.SecretID) + } + + if respErr, ok := resp.(error); ok { + return respErr + } + + *reply = true + + return nil +} diff --git a/agent/consul/acl_endpoint_legacy.go b/agent/consul/acl_endpoint_legacy.go index 3b5ee22c6e..16379faa2e 100644 --- a/agent/consul/acl_endpoint_legacy.go +++ b/agent/consul/acl_endpoint_legacy.go @@ -255,7 +255,7 @@ func (a *ACL) List(args *structs.DCSpecificRequest, return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, tokens, err := state.ACLTokenList(ws, false, true, "", "") + index, tokens, err := state.ACLTokenList(ws, false, true, "", "", "") if err != nil { return err } diff --git a/agent/consul/acl_endpoint_test.go b/agent/consul/acl_endpoint_test.go index 99de10be32..bfdecafea9 100644 --- a/agent/consul/acl_endpoint_test.go +++ b/agent/consul/acl_endpoint_test.go @@ -12,6 +12,8 @@ import ( "time" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth" + "github.com/hashicorp/consul/agent/consul/authmethod/testauth" "github.com/hashicorp/consul/agent/structs" tokenStore "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/lib" @@ -964,6 +966,117 @@ func TestACLEndpoint_TokenSet(t *testing.T) { require.Len(t, token.Roles, 0) }) + t.Run("Create it with AuthMethod set outside of login", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + AuthMethod: "fakemethod", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "AuthMethod field is disallowed outside of Login") + }) + + t.Run("Update auth method linked token and try to change auth method", func(t *testing.T) { + acl := ACL{srv: s1} + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + testauth.InstallSessionToken(testSessionID, "fake-token", "default", "demo", "abc123") + + method1, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID) + require.NoError(t, err) + + _, err = upsertTestBindingRule(codec, "root", "dc1", method1.Name, "", structs.BindingRuleBindTypeService, "demo") + require.NoError(t, err) + + // create a token in one method + methodToken := structs.ACLToken{} + require.NoError(t, acl.Login(&structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method1.Name, + BearerToken: "fake-token", + }, + Datacenter: "dc1", + }, &methodToken)) + + method2, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + // try to update the token and change the method + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + AccessorID: methodToken.AccessorID, + SecretID: methodToken.SecretID, + AuthMethod: method2.Name, + Description: "updated token", + Local: true, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err = acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "Cannot change AuthMethod") + }) + + t.Run("Update auth method linked token and let the SecretID and AuthMethod be defaulted", func(t *testing.T) { + acl := ACL{srv: s1} + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + testauth.InstallSessionToken(testSessionID, "fake-token", "default", "demo", "abc123") + + method, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID) + require.NoError(t, err) + + _, err = upsertTestBindingRule(codec, "root", "dc1", method.Name, "", structs.BindingRuleBindTypeService, "demo") + require.NoError(t, err) + + methodToken := structs.ACLToken{} + require.NoError(t, acl.Login(&structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-token", + }, + Datacenter: "dc1", + }, &methodToken)) + + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + AccessorID: methodToken.AccessorID, + // SecretID: methodToken.SecretID, + // AuthMethod: method.Name, + Description: "updated token", + Local: true, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + require.NoError(t, acl.TokenSet(&req, &resp)) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + require.Len(t, token.Roles, 0) + require.Equal(t, "updated token", token.Description) + require.True(t, token.Local) + require.Equal(t, methodToken.SecretID, token.SecretID) + require.Equal(t, methodToken.AuthMethod, token.AuthMethod) + }) + t.Run("Create it with invalid service identity (empty)", func(t *testing.T) { req := structs.ACLTokenSetRequest{ Datacenter: "dc1", @@ -1062,6 +1175,69 @@ func TestACLEndpoint_TokenSet(t *testing.T) { }) } + t.Run("Create it with two of the same service identities", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: "example"}, + &structs.ACLServiceIdentity{ServiceName: "example"}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + require.Len(t, token.ServiceIdentities, 1) + }) + + t.Run("Create it with two of the same service identities and different DCs", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "example", + Datacenters: []string{"dc2", "dc3"}, + }, + &structs.ACLServiceIdentity{ + ServiceName: "example", + Datacenters: []string{"dc1", "dc2"}, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + require.Len(t, token.ServiceIdentities, 1) + svcid := token.ServiceIdentities[0] + require.Equal(t, "example", svcid.ServiceName) + require.ElementsMatch(t, []string{"dc1", "dc2", "dc3"}, svcid.Datacenters) + }) + t.Run("Create it with invalid service identity (datacenters set on local token)", func(t *testing.T) { req := structs.ACLTokenSetRequest{ Datacenter: "dc1", @@ -1241,11 +1417,37 @@ func TestACLEndpoint_TokenSet(t *testing.T) { // do not insert another test at this point: these tests need to be serial + t.Run("Update anything except expiration time is ok - omit expiration time and let it default", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "new-description-1", + AccessorID: tokenID, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + require.NotNil(t, token.AccessorID) + require.Equal(t, token.Description, "new-description-1") + require.Equal(t, token.AccessorID, resp.AccessorID) + requireTimeEquals(t, &expTime, resp.ExpirationTime) + }) + t.Run("Update anything except expiration time is ok", func(t *testing.T) { req := structs.ACLTokenSetRequest{ Datacenter: "dc1", ACLToken: structs.ACLToken{ - Description: "new-description", + Description: "new-description-2", AccessorID: tokenID, ExpirationTime: &expTime, }, @@ -1263,7 +1465,7 @@ func TestACLEndpoint_TokenSet(t *testing.T) { token := tokenResp.Token require.NotNil(t, token.AccessorID) - require.Equal(t, token.Description, "new-description") + require.Equal(t, token.Description, "new-description-2") require.Equal(t, token.AccessorID, resp.AccessorID) requireTimeEquals(t, &expTime, resp.ExpirationTime) }) @@ -1615,12 +1817,7 @@ func TestACLEndpoint_TokenList(t *testing.T) { t2.AccessorID, t3.AccessorID, } - - var retrievedTokens []string - for _, v := range resp.Tokens { - retrievedTokens = append(retrievedTokens, v.AccessorID) - } - require.ElementsMatch(t, retrievedTokens, tokens) + require.ElementsMatch(t, gatherIDs(t, resp.Tokens), tokens) }) time.Sleep(20 * time.Millisecond) // now 't3' is expired @@ -1642,12 +1839,7 @@ func TestACLEndpoint_TokenList(t *testing.T) { t1.AccessorID, t2.AccessorID, } - - var retrievedTokens []string - for _, v := range resp.Tokens { - retrievedTokens = append(retrievedTokens, v.AccessorID) - } - require.ElementsMatch(t, retrievedTokens, tokens) + require.ElementsMatch(t, gatherIDs(t, resp.Tokens), tokens) }) } @@ -1694,13 +1886,7 @@ func TestACLEndpoint_TokenBatchRead(t *testing.T) { err = acl.TokenBatchRead(&req, &resp) require.NoError(t, err) - - var retrievedTokens []string - - for _, v := range resp.Tokens { - retrievedTokens = append(retrievedTokens, v.AccessorID) - } - require.EqualValues(t, retrievedTokens, tokens) + require.ElementsMatch(t, gatherIDs(t, resp.Tokens), tokens) }) time.Sleep(20 * time.Millisecond) // now 't3' is expired @@ -1718,13 +1904,7 @@ func TestACLEndpoint_TokenBatchRead(t *testing.T) { err = acl.TokenBatchRead(&req, &resp) require.NoError(t, err) - - var retrievedTokens []string - - for _, v := range resp.Tokens { - retrievedTokens = append(retrievedTokens, v.AccessorID) - } - require.EqualValues(t, retrievedTokens, tokens) + require.ElementsMatch(t, gatherIDs(t, resp.Tokens), tokens) }) } @@ -1801,13 +1981,7 @@ func TestACLEndpoint_PolicyBatchRead(t *testing.T) { err = acl.PolicyBatchRead(&req, &resp) require.NoError(t, err) - - var retrievedPolicies []string - - for _, v := range resp.Policies { - retrievedPolicies = append(retrievedPolicies, v.ID) - } - require.EqualValues(t, retrievedPolicies, policies) + require.ElementsMatch(t, gatherIDs(t, resp.Policies), []string{p1.ID, p2.ID}) } func TestACLEndpoint_PolicySet(t *testing.T) { @@ -2053,12 +2227,7 @@ func TestACLEndpoint_PolicyList(t *testing.T) { p1.ID, p2.ID, } - var retrievedPolicies []string - - for _, v := range resp.Policies { - retrievedPolicies = append(retrievedPolicies, v.ID) - } - require.ElementsMatch(t, retrievedPolicies, policies) + require.ElementsMatch(t, gatherIDs(t, resp.Policies), policies) } func TestACLEndpoint_PolicyResolve(t *testing.T) { @@ -2114,13 +2283,7 @@ func TestACLEndpoint_PolicyResolve(t *testing.T) { } err = acl.PolicyResolve(&req, &resp) require.NoError(t, err) - - var retrievedPolicies []string - - for _, v := range resp.Policies { - retrievedPolicies = append(retrievedPolicies, v.ID) - } - require.EqualValues(t, retrievedPolicies, policies) + require.ElementsMatch(t, gatherIDs(t, resp.Policies), policies) } func TestACLEndpoint_RoleRead(t *testing.T) { @@ -2189,13 +2352,7 @@ func TestACLEndpoint_RoleBatchRead(t *testing.T) { err = acl.RoleBatchRead(&req, &resp) require.NoError(t, err) - - var retrievedRoles []string - - for _, v := range resp.Roles { - retrievedRoles = append(retrievedRoles, v.ID) - } - require.EqualValues(t, retrievedRoles, roles) + require.ElementsMatch(t, gatherIDs(t, resp.Roles), roles) } func TestACLEndpoint_RoleSet(t *testing.T) { @@ -2432,6 +2589,67 @@ func TestACLEndpoint_RoleSet(t *testing.T) { } }) } + + t.Run("Create it with two of the same service identities", func(t *testing.T) { + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: roleNameGen(t), + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: "example"}, + &structs.ACLServiceIdentity{ServiceName: "example"}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + require.NoError(t, err) + + // Get the role directly to validate that it exists + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + require.Len(t, role.ServiceIdentities, 1) + }) + + t.Run("Create it with two of the same service identities and different DCs", func(t *testing.T) { + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: roleNameGen(t), + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "example", + Datacenters: []string{"dc2", "dc3"}, + }, + &structs.ACLServiceIdentity{ + ServiceName: "example", + Datacenters: []string{"dc1", "dc2"}, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + require.NoError(t, err) + + // Get the role directly to validate that it exists + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + require.Len(t, role.ServiceIdentities, 1) + svcid := role.ServiceIdentities[0] + require.Equal(t, "example", svcid.ServiceName) + require.ElementsMatch(t, []string{"dc1", "dc2", "dc3"}, svcid.Datacenters) + }) } func TestACLEndpoint_RoleSet_names(t *testing.T) { @@ -2589,14 +2807,2009 @@ func TestACLEndpoint_RoleList(t *testing.T) { err = acl.RoleList(&req, &resp) require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.Roles), []string{r1.ID, r2.ID}) +} - roles := []string{r1.ID, r2.ID} - var retrievedRoles []string +func TestACLEndpoint_RoleResolve(t *testing.T) { + t.Parallel() - for _, v := range resp.Roles { - retrievedRoles = append(retrievedRoles, v.ID) + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + t.Run("Normal", func(t *testing.T) { + r1, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + r2, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + acl := ACL{srv: s1} + + // Assign the roles to a token + tokenUpsertReq := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: r1.ID, + }, + structs.ACLTokenRoleLink{ + ID: r2.ID, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + token := structs.ACLToken{} + err = acl.TokenSet(&tokenUpsertReq, &token) + require.NoError(t, err) + require.NotEmpty(t, token.SecretID) + + resp := structs.ACLRoleBatchResponse{} + req := structs.ACLRoleBatchGetRequest{ + Datacenter: "dc1", + RoleIDs: []string{r1.ID, r2.ID}, + QueryOptions: structs.QueryOptions{Token: token.SecretID}, + } + err = acl.RoleResolve(&req, &resp) + require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.Roles), []string{r1.ID, r2.ID}) + }) +} + +func TestACLEndpoint_AuthMethodSet(t *testing.T) { + t.Parallel() + + tempDir, err := ioutil.TempDir("", "consul") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + + newAuthMethod := func(name string) structs.ACLAuthMethod { + return structs.ACLAuthMethod{ + Name: name, + Description: "test", + Type: "testing", + } + } + + t.Run("Create", func(t *testing.T) { + reqMethod := newAuthMethod("test") + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: reqMethod, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.NoError(t, err) + + // Get the method directly to validate that it exists + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", resp.Name) + require.NoError(t, err) + method := methodResp.AuthMethod + + require.Equal(t, method.Name, "test") + require.Equal(t, method.Description, "test") + require.Equal(t, method.Type, "testing") + }) + + t.Run("Update fails; not allowed to change types", func(t *testing.T) { + reqMethod := newAuthMethod("test") + reqMethod.Type = "invalid" + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: reqMethod, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.Error(t, err) + }) + + t.Run("Update - allow type to default", func(t *testing.T) { + reqMethod := newAuthMethod("test") + reqMethod.Description = "test modified 1" + reqMethod.Type = "" // unset + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: reqMethod, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.NoError(t, err) + + // Get the method directly to validate that it exists + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", resp.Name) + require.NoError(t, err) + method := methodResp.AuthMethod + + require.Equal(t, method.Name, "test") + require.Equal(t, method.Description, "test modified 1") + require.Equal(t, method.Type, "testing") + }) + + t.Run("Update - specify type", func(t *testing.T) { + reqMethod := newAuthMethod("test") + reqMethod.Description = "test modified 2" + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: reqMethod, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.NoError(t, err) + + // Get the method directly to validate that it exists + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", resp.Name) + require.NoError(t, err) + method := methodResp.AuthMethod + + require.Equal(t, method.Name, "test") + require.Equal(t, method.Description, "test modified 2") + require.Equal(t, method.Type, "testing") + }) + + t.Run("Create with no name", func(t *testing.T) { + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: newAuthMethod(""), + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.Error(t, err) + }) + + t.Run("Create with invalid type", func(t *testing.T) { + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: structs.ACLAuthMethod{ + Name: "invalid", + Description: "invalid test", + Type: "invalid", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.Error(t, err) + }) + + for _, test := range []struct { + name string + ok bool + }{ + {strings.Repeat("x", 129), false}, + {strings.Repeat("x", 128), true}, + {"-abc", true}, + {"abc-", true}, + {"a-bc", true}, + {"_abc", true}, + {"abc_", true}, + {"a_bc", true}, + {":abc", false}, + {"abc:", false}, + {"a:bc", false}, + {"Abc", true}, + {"aBc", true}, + {"abC", true}, + {"0abc", true}, + {"abc0", true}, + {"a0bc", true}, + } { + var testName string + if test.ok { + testName = "Create with valid name (by regex): " + test.name + } else { + testName = "Create with invalid name (by regex): " + test.name + } + t.Run(testName, func(t *testing.T) { + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: newAuthMethod(test.name), + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + + if test.ok { + require.NoError(t, err) + + // Get the method directly to validate that it exists + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", resp.Name) + require.NoError(t, err) + method := methodResp.AuthMethod + + require.Equal(t, method.Name, test.name) + require.Equal(t, method.Type, "testing") + } else { + require.Error(t, err) + } + }) + } +} + +func TestACLEndpoint_AuthMethodDelete(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + + existingMethod, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID) + require.NoError(t, err) + + acl := ACL{srv: s1} + + t.Run("normal", func(t *testing.T) { + req := structs.ACLAuthMethodDeleteRequest{ + Datacenter: "dc1", + AuthMethodName: existingMethod.Name, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + err = acl.AuthMethodDelete(&req, &ignored) + require.NoError(t, err) + + // Make sure the method is gone + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", existingMethod.Name) + require.NoError(t, err) + require.Nil(t, methodResp.AuthMethod) + }) + + t.Run("delete something that doesn't exist", func(t *testing.T) { + req := structs.ACLAuthMethodDeleteRequest{ + Datacenter: "dc1", + AuthMethodName: "missing", + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + err = acl.AuthMethodDelete(&req, &ignored) + require.NoError(t, err) + }) +} + +// Deleting an auth method atomically deletes all rules and tokens as well. +func TestACLEndpoint_AuthMethodDelete_RuleAndTokenCascade(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + testSessionID1 := testauth.StartSession() + defer testauth.ResetSession(testSessionID1) + testauth.InstallSessionToken(testSessionID1, "fake-token1", "default", "abc", "abc123") + + testSessionID2 := testauth.StartSession() + defer testauth.ResetSession(testSessionID2) + testauth.InstallSessionToken(testSessionID2, "fake-token2", "default", "abc", "abc123") + + createToken := func(methodName, bearerToken string) *structs.ACLToken { + acl := ACL{srv: s1} + + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: methodName, + BearerToken: bearerToken, + }, + Datacenter: "dc1", + }, &resp)) + + return &resp + } + + method1, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID1) + require.NoError(t, err) + i1_r1, err := upsertTestBindingRule( + codec, "root", "dc1", + method1.Name, + "serviceaccount.name==abc", + structs.BindingRuleBindTypeService, + "abc", + ) + require.NoError(t, err) + i1_r2, err := upsertTestBindingRule( + codec, "root", "dc1", + method1.Name, + "serviceaccount.name==def", + structs.BindingRuleBindTypeService, + "def", + ) + require.NoError(t, err) + i1_t1 := createToken(method1.Name, "fake-token1") + i1_t2 := createToken(method1.Name, "fake-token1") + + method2, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID2) + require.NoError(t, err) + i2_r1, err := upsertTestBindingRule( + codec, "root", "dc1", + method2.Name, + "serviceaccount.name==abc", + structs.BindingRuleBindTypeService, + "abc", + ) + require.NoError(t, err) + i2_r2, err := upsertTestBindingRule( + codec, "root", "dc1", + method2.Name, + "serviceaccount.name==def", + structs.BindingRuleBindTypeService, + "def", + ) + require.NoError(t, err) + i2_t1 := createToken(method2.Name, "fake-token2") + i2_t2 := createToken(method2.Name, "fake-token2") + + acl := ACL{srv: s1} + + req := structs.ACLAuthMethodDeleteRequest{ + Datacenter: "dc1", + AuthMethodName: method1.Name, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + err = acl.AuthMethodDelete(&req, &ignored) + require.NoError(t, err) + + // Make sure the method is gone. + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", method1.Name) + require.NoError(t, err) + require.Nil(t, methodResp.AuthMethod) + + // Make sure the rules and tokens are gone. + for _, id := range []string{i1_r1.ID, i1_r2.ID} { + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", id) + require.NoError(t, err) + require.Nil(t, ruleResp.BindingRule) + } + for _, id := range []string{i1_t1.AccessorID, i1_t2.AccessorID} { + tokResp, err := retrieveTestToken(codec, "root", "dc1", id) + require.NoError(t, err) + require.Nil(t, tokResp.Token) + } + + // Make sure the rules and tokens for the untouched auth method are still there. + for _, id := range []string{i2_r1.ID, i2_r2.ID} { + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", id) + require.NoError(t, err) + require.NotNil(t, ruleResp.BindingRule) + } + for _, id := range []string{i2_t1.AccessorID, i2_t2.AccessorID} { + tokResp, err := retrieveTestToken(codec, "root", "dc1", id) + require.NoError(t, err) + require.NotNil(t, tokResp.Token) + } +} + +func TestACLEndpoint_AuthMethodList(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + i1, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + i2, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + acl := ACL{srv: s1} + + req := structs.ACLAuthMethodListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLAuthMethodListResponse{} + + err = acl.AuthMethodList(&req, &resp) + require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.AuthMethods), []string{i1.Name, i2.Name}) +} + +func TestACLEndpoint_BindingRuleSet(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + var ruleID string + + testAuthMethod, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + otherTestAuthMethod, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + newRule := func() structs.ACLBindingRule { + return structs.ACLBindingRule{ + Description: "foobar", + AuthMethod: testAuthMethod.Name, + Selector: "serviceaccount.name==abc", + BindType: structs.BindingRuleBindTypeService, + BindName: "abc", + } + } + + requireSetErrors := func(t *testing.T, reqRule structs.ACLBindingRule) { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: reqRule, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLBindingRule{} + + err := acl.BindingRuleSet(&req, &resp) + require.Error(t, err) + } + + requireOK := func(t *testing.T, reqRule structs.ACLBindingRule) *structs.ACLBindingRule { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: reqRule, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLBindingRule{} + + err := acl.BindingRuleSet(&req, &resp) + require.NoError(t, err) + require.NotEmpty(t, resp.ID) + return &resp + } + + t.Run("Create it", func(t *testing.T) { + reqRule := newRule() + + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: reqRule, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLBindingRule{} + + err := acl.BindingRuleSet(&req, &resp) + require.NoError(t, err) + require.NotNil(t, resp.ID) + + // Get the rule directly to validate that it exists + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + rule := ruleResp.BindingRule + + require.NotEmpty(t, rule.ID) + require.Equal(t, rule.Description, "foobar") + require.Equal(t, rule.AuthMethod, testAuthMethod.Name) + require.Equal(t, "serviceaccount.name==abc", rule.Selector) + require.Equal(t, structs.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "abc", rule.BindName) + + ruleID = rule.ID + }) + + t.Run("Update fails; cannot change method name", func(t *testing.T) { + reqRule := newRule() + reqRule.ID = ruleID + reqRule.AuthMethod = otherTestAuthMethod.Name + requireSetErrors(t, reqRule) + }) + + t.Run("Update it - omit method name", func(t *testing.T) { + reqRule := newRule() + reqRule.ID = ruleID + reqRule.Description = "foobar modified 1" + reqRule.Selector = "serviceaccount.namespace==def" + reqRule.BindType = structs.BindingRuleBindTypeRole + reqRule.BindName = "def" + reqRule.AuthMethod = "" // clear + + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: reqRule, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLBindingRule{} + + err := acl.BindingRuleSet(&req, &resp) + require.NoError(t, err) + require.NotNil(t, resp.ID) + + // Get the rule directly to validate that it exists + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + rule := ruleResp.BindingRule + + require.NotEmpty(t, rule.ID) + require.Equal(t, rule.Description, "foobar modified 1") + require.Equal(t, rule.AuthMethod, testAuthMethod.Name) + require.Equal(t, "serviceaccount.namespace==def", rule.Selector) + require.Equal(t, structs.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "def", rule.BindName) + }) + + t.Run("Update it - specify method name", func(t *testing.T) { + reqRule := newRule() + reqRule.ID = ruleID + reqRule.Description = "foobar modified 2" + reqRule.Selector = "serviceaccount.namespace==def" + reqRule.BindType = structs.BindingRuleBindTypeRole + reqRule.BindName = "def" + + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: reqRule, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLBindingRule{} + + err := acl.BindingRuleSet(&req, &resp) + require.NoError(t, err) + require.NotNil(t, resp.ID) + + // Get the rule directly to validate that it exists + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + rule := ruleResp.BindingRule + + require.NotEmpty(t, rule.ID) + require.Equal(t, rule.Description, "foobar modified 2") + require.Equal(t, rule.AuthMethod, testAuthMethod.Name) + require.Equal(t, "serviceaccount.namespace==def", rule.Selector) + require.Equal(t, structs.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "def", rule.BindName) + }) + + t.Run("Create fails; empty method name", func(t *testing.T) { + reqRule := newRule() + reqRule.AuthMethod = "" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; unknown method name", func(t *testing.T) { + reqRule := newRule() + reqRule.AuthMethod = "unknown" + requireSetErrors(t, reqRule) + }) + + t.Run("Create with no explicit selector", func(t *testing.T) { + reqRule := newRule() + reqRule.Selector = "" + + rule := requireOK(t, reqRule) + require.Empty(t, rule.Selector, 0) + }) + + t.Run("Create fails; match selector with unknown vars", func(t *testing.T) { + reqRule := newRule() + reqRule.Selector = "serviceaccount.name==a and serviceaccount.bizarroname==b" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; match selector invalid", func(t *testing.T) { + reqRule := newRule() + reqRule.Selector = "serviceaccount.name" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; empty bind type", func(t *testing.T) { + reqRule := newRule() + reqRule.BindType = "" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; empty bind name", func(t *testing.T) { + reqRule := newRule() + reqRule.BindName = "" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; invalid bind type", func(t *testing.T) { + reqRule := newRule() + reqRule.BindType = "invalid" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; bind name with unknown vars", func(t *testing.T) { + reqRule := newRule() + reqRule.BindName = "method-${serviceaccount.bizarroname}" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; invalid bind name no template", func(t *testing.T) { + reqRule := newRule() + reqRule.BindName = "-abc:" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; invalid bind name with template", func(t *testing.T) { + reqRule := newRule() + reqRule.BindName = "method-${serviceaccount.name" + requireSetErrors(t, reqRule) + }) + t.Run("Create fails; invalid bind name after template computed", func(t *testing.T) { + reqRule := newRule() + reqRule.BindName = "method-${serviceaccount.name}:blah-" + requireSetErrors(t, reqRule) + }) +} + +func TestACLEndpoint_BindingRuleDelete(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + testAuthMethod, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + existingRule, err := upsertTestBindingRule( + codec, "root", "dc1", + testAuthMethod.Name, + "serviceaccount.name==abc", + structs.BindingRuleBindTypeService, + "abc", + ) + require.NoError(t, err) + + acl := ACL{srv: s1} + + t.Run("normal", func(t *testing.T) { + req := structs.ACLBindingRuleDeleteRequest{ + Datacenter: "dc1", + BindingRuleID: existingRule.ID, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + err = acl.BindingRuleDelete(&req, &ignored) + require.NoError(t, err) + + // Make sure the rule is gone + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", existingRule.ID) + require.NoError(t, err) + require.Nil(t, ruleResp.BindingRule) + }) + + t.Run("delete something that doesn't exist", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + req := structs.ACLBindingRuleDeleteRequest{ + Datacenter: "dc1", + BindingRuleID: fakeID, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + err = acl.BindingRuleDelete(&req, &ignored) + require.NoError(t, err) + }) +} + +func TestACLEndpoint_BindingRuleList(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + testAuthMethod, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + r1, err := upsertTestBindingRule( + codec, "root", "dc1", + testAuthMethod.Name, + "serviceaccount.name==abc", + structs.BindingRuleBindTypeService, + "abc", + ) + require.NoError(t, err) + + r2, err := upsertTestBindingRule( + codec, "root", "dc1", + testAuthMethod.Name, + "serviceaccount.name==def", + structs.BindingRuleBindTypeService, + "def", + ) + require.NoError(t, err) + + acl := ACL{srv: s1} + + req := structs.ACLBindingRuleListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLBindingRuleListResponse{} + + err = acl.BindingRuleList(&req, &resp) + require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.BindingRules), []string{r1.ID, r2.ID}) +} + +func TestACLEndpoint_SecureIntroEndpoints_LocalTokensDisabled(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + dir2, s2 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.Datacenter = "dc2" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second + // disable local tokens + c.ACLTokenReplication = false + }) + defer os.RemoveAll(dir2) + defer s2.Shutdown() + codec2 := rpcClient(t, s2) + defer codec2.Close() + + s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig) + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForLeader(t, s2.RPC, "dc2") + + // Try to join + joinWAN(t, s2, s1) + + waitForNewACLs(t, s1) + waitForNewACLs(t, s2) + + acl2 := ACL{srv: s2} + var ignored bool + + errString := errAuthMethodsRequireTokenReplication.Error() + + t.Run("AuthMethodRead", func(t *testing.T) { + requireErrorContains(t, + acl2.AuthMethodRead(&structs.ACLAuthMethodGetRequest{Datacenter: "dc2"}, + &structs.ACLAuthMethodResponse{}), + errString, + ) + }) + t.Run("AuthMethodSet", func(t *testing.T) { + requireErrorContains(t, + acl2.AuthMethodSet(&structs.ACLAuthMethodSetRequest{Datacenter: "dc2"}, + &structs.ACLAuthMethod{}), + errString, + ) + }) + t.Run("AuthMethodDelete", func(t *testing.T) { + requireErrorContains(t, + acl2.AuthMethodDelete(&structs.ACLAuthMethodDeleteRequest{Datacenter: "dc2"}, &ignored), + errString, + ) + }) + t.Run("AuthMethodList", func(t *testing.T) { + requireErrorContains(t, + acl2.AuthMethodList(&structs.ACLAuthMethodListRequest{Datacenter: "dc2"}, + &structs.ACLAuthMethodListResponse{}), + errString, + ) + }) + + t.Run("BindingRuleRead", func(t *testing.T) { + requireErrorContains(t, + acl2.BindingRuleRead(&structs.ACLBindingRuleGetRequest{Datacenter: "dc2"}, + &structs.ACLBindingRuleResponse{}), + errString, + ) + }) + t.Run("BindingRuleSet", func(t *testing.T) { + requireErrorContains(t, + acl2.BindingRuleSet(&structs.ACLBindingRuleSetRequest{Datacenter: "dc2"}, + &structs.ACLBindingRule{}), + errString, + ) + }) + t.Run("BindingRuleDelete", func(t *testing.T) { + requireErrorContains(t, + acl2.BindingRuleDelete(&structs.ACLBindingRuleDeleteRequest{Datacenter: "dc2"}, &ignored), + errString, + ) + }) + t.Run("BindingRuleList", func(t *testing.T) { + requireErrorContains(t, + acl2.BindingRuleList(&structs.ACLBindingRuleListRequest{Datacenter: "dc2"}, + &structs.ACLBindingRuleListResponse{}), + errString, + ) + }) + + t.Run("Login", func(t *testing.T) { + requireErrorContains(t, + acl2.Login(&structs.ACLLoginRequest{Datacenter: "dc2"}, + &structs.ACLToken{}), + errString, + ) + }) + t.Run("Logout", func(t *testing.T) { + requireErrorContains(t, + acl2.Logout(&structs.ACLLogoutRequest{Datacenter: "dc2"}, &ignored), + errString, + ) + }) +} + +func TestACLEndpoint_SecureIntroEndpoints_OnlyCreateLocalData(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec1 := rpcClient(t, s1) + defer codec1.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + dir2, s2 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.Datacenter = "dc2" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second + // enable token replication so secure intro works + c.ACLTokenReplication = true + }) + defer os.RemoveAll(dir2) + defer s2.Shutdown() + codec2 := rpcClient(t, s2) + defer codec2.Close() + + s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig) + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForLeader(t, s2.RPC, "dc2") + + // Try to join + joinWAN(t, s2, s1) + + waitForNewACLs(t, s1) + waitForNewACLs(t, s2) + + acl := ACL{srv: s1} + acl2 := ACL{srv: s2} + + // + // this order is specific so that we can do it in one pass + // + + testSessionID_1 := testauth.StartSession() + defer testauth.ResetSession(testSessionID_1) + + testSessionID_2 := testauth.StartSession() + defer testauth.ResetSession(testSessionID_2) + + testauth.InstallSessionToken( + testSessionID_1, + "fake-web1-token", + "default", "web1", "abc123", + ) + testauth.InstallSessionToken( + testSessionID_2, + "fake-web2-token", + "default", "web2", "def456", + ) + + t.Run("create auth method", func(t *testing.T) { + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc2", + AuthMethod: structs.ACLAuthMethod{ + Name: "testmethod", + Description: "test original", + Type: "testing", + Config: map[string]interface{}{ + "SessionID": testSessionID_2, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + require.NoError(t, acl2.AuthMethodSet(&req, &resp)) + + // present in dc2 + resp2, err := retrieveTestAuthMethod(codec2, "root", "dc2", "testmethod") + require.NoError(t, err) + require.NotNil(t, resp2.AuthMethod) + require.Equal(t, "test original", resp2.AuthMethod.Description) + // absent in dc1 + resp2, err = retrieveTestAuthMethod(codec1, "root", "dc1", "testmethod") + require.NoError(t, err) + require.Nil(t, resp2.AuthMethod) + }) + + t.Run("update auth method", func(t *testing.T) { + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc2", + AuthMethod: structs.ACLAuthMethod{ + Name: "testmethod", + Description: "test updated", + Config: map[string]interface{}{ + "SessionID": testSessionID_2, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + require.NoError(t, acl2.AuthMethodSet(&req, &resp)) + + // present in dc2 + resp2, err := retrieveTestAuthMethod(codec2, "root", "dc2", "testmethod") + require.NoError(t, err) + require.NotNil(t, resp2.AuthMethod) + require.Equal(t, "test updated", resp2.AuthMethod.Description) + // absent in dc1 + resp2, err = retrieveTestAuthMethod(codec1, "root", "dc1", "testmethod") + require.NoError(t, err) + require.Nil(t, resp2.AuthMethod) + }) + + t.Run("read auth method", func(t *testing.T) { + // present in dc2 + req := structs.ACLAuthMethodGetRequest{ + Datacenter: "dc2", + AuthMethodName: "testmethod", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp := structs.ACLAuthMethodResponse{} + require.NoError(t, acl2.AuthMethodRead(&req, &resp)) + require.NotNil(t, resp.AuthMethod) + require.Equal(t, "test updated", resp.AuthMethod.Description) + + // absent in dc1 + req = structs.ACLAuthMethodGetRequest{ + Datacenter: "dc1", + AuthMethodName: "testmethod", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp = structs.ACLAuthMethodResponse{} + require.NoError(t, acl.AuthMethodRead(&req, &resp)) + require.Nil(t, resp.AuthMethod) + }) + + t.Run("list auth method", func(t *testing.T) { + // present in dc2 + req := structs.ACLAuthMethodListRequest{ + Datacenter: "dc2", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp := structs.ACLAuthMethodListResponse{} + require.NoError(t, acl2.AuthMethodList(&req, &resp)) + require.Len(t, resp.AuthMethods, 1) + + // absent in dc1 + req = structs.ACLAuthMethodListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp = structs.ACLAuthMethodListResponse{} + require.NoError(t, acl.AuthMethodList(&req, &resp)) + require.Len(t, resp.AuthMethods, 0) + }) + + var ruleID string + t.Run("create binding rule", func(t *testing.T) { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc2", + BindingRule: structs.ACLBindingRule{ + Description: "test original", + AuthMethod: "testmethod", + BindType: structs.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLBindingRule{} + + require.NoError(t, acl2.BindingRuleSet(&req, &resp)) + ruleID = resp.ID + + // present in dc2 + resp2, err := retrieveTestBindingRule(codec2, "root", "dc2", ruleID) + require.NoError(t, err) + require.NotNil(t, resp2.BindingRule) + require.Equal(t, "test original", resp2.BindingRule.Description) + // absent in dc1 + resp2, err = retrieveTestBindingRule(codec1, "root", "dc1", ruleID) + require.NoError(t, err) + require.Nil(t, resp2.BindingRule) + }) + + t.Run("update binding rule", func(t *testing.T) { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc2", + BindingRule: structs.ACLBindingRule{ + ID: ruleID, + Description: "test updated", + AuthMethod: "testmethod", + BindType: structs.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLBindingRule{} + + require.NoError(t, acl2.BindingRuleSet(&req, &resp)) + ruleID = resp.ID + + // present in dc2 + resp2, err := retrieveTestBindingRule(codec2, "root", "dc2", ruleID) + require.NoError(t, err) + require.NotNil(t, resp2.BindingRule) + require.Equal(t, "test updated", resp2.BindingRule.Description) + // absent in dc1 + resp2, err = retrieveTestBindingRule(codec1, "root", "dc1", ruleID) + require.NoError(t, err) + require.Nil(t, resp2.BindingRule) + }) + + t.Run("read binding rule", func(t *testing.T) { + // present in dc2 + req := structs.ACLBindingRuleGetRequest{ + Datacenter: "dc2", + BindingRuleID: ruleID, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp := structs.ACLBindingRuleResponse{} + require.NoError(t, acl2.BindingRuleRead(&req, &resp)) + require.NotNil(t, resp.BindingRule) + require.Equal(t, "test updated", resp.BindingRule.Description) + + // absent in dc1 + req = structs.ACLBindingRuleGetRequest{ + Datacenter: "dc1", + BindingRuleID: ruleID, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp = structs.ACLBindingRuleResponse{} + require.NoError(t, acl.BindingRuleRead(&req, &resp)) + require.Nil(t, resp.BindingRule) + }) + + t.Run("list binding rule", func(t *testing.T) { + // present in dc2 + req := structs.ACLBindingRuleListRequest{ + Datacenter: "dc2", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp := structs.ACLBindingRuleListResponse{} + require.NoError(t, acl2.BindingRuleList(&req, &resp)) + require.Len(t, resp.BindingRules, 1) + + // absent in dc1 + req = structs.ACLBindingRuleListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp = structs.ACLBindingRuleListResponse{} + require.NoError(t, acl.BindingRuleList(&req, &resp)) + require.Len(t, resp.BindingRules, 0) + }) + + var remoteToken *structs.ACLToken + t.Run("login in remote", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Datacenter: "dc2", + Auth: &structs.ACLLoginParams{ + AuthMethod: "testmethod", + BearerToken: "fake-web2-token", + }, + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + remoteToken = &resp + + // present in dc2 + resp2, err := retrieveTestToken(codec2, "root", "dc2", remoteToken.AccessorID) + require.NoError(t, err) + require.NotNil(t, resp2.Token) + require.Len(t, resp2.Token.ServiceIdentities, 1) + require.Equal(t, "web2", resp2.Token.ServiceIdentities[0].ServiceName) + // absent in dc1 + resp2, err = retrieveTestToken(codec1, "root", "dc1", remoteToken.AccessorID) + require.NoError(t, err) + require.Nil(t, resp2.Token) + }) + + // We delay until now to setup an auth method and binding rule in the + // primary so our earlier listing tests were sane. We need to be able to + // use auth methods in both datacenters in order to verify Logout is + // properly scoped. + t.Run("initialize primary so we can test logout", func(t *testing.T) { + reqAM := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: structs.ACLAuthMethod{ + Name: "primarymethod", + Type: "testing", + Config: map[string]interface{}{ + "SessionID": testSessionID_1, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + respAM := structs.ACLAuthMethod{} + require.NoError(t, acl.AuthMethodSet(&reqAM, &respAM)) + + reqBR := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: structs.ACLBindingRule{ + AuthMethod: "primarymethod", + BindType: structs.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + respBR := structs.ACLBindingRule{} + require.NoError(t, acl.BindingRuleSet(&reqBR, &respBR)) + }) + + var primaryToken *structs.ACLToken + t.Run("login in primary", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Datacenter: "dc1", + Auth: &structs.ACLLoginParams{ + AuthMethod: "primarymethod", + BearerToken: "fake-web1-token", + }, + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + primaryToken = &resp + + // present in dc1 + resp2, err := retrieveTestToken(codec1, "root", "dc1", primaryToken.AccessorID) + require.NoError(t, err) + require.NotNil(t, resp2.Token) + require.Len(t, resp2.Token.ServiceIdentities, 1) + require.Equal(t, "web1", resp2.Token.ServiceIdentities[0].ServiceName) + // absent in dc2 + resp2, err = retrieveTestToken(codec2, "root", "dc2", primaryToken.AccessorID) + require.NoError(t, err) + require.Nil(t, resp2.Token) + }) + + t.Run("logout of remote token in remote dc", func(t *testing.T) { + req := structs.ACLLogoutRequest{ + Datacenter: "dc2", + WriteRequest: structs.WriteRequest{Token: remoteToken.SecretID}, + } + + var ignored bool + require.NoError(t, acl.Logout(&req, &ignored)) + + // absent in dc2 + resp2, err := retrieveTestToken(codec2, "root", "dc2", remoteToken.AccessorID) + require.NoError(t, err) + require.Nil(t, resp2.Token) + // absent in dc1 + resp2, err = retrieveTestToken(codec1, "root", "dc1", remoteToken.AccessorID) + require.NoError(t, err) + require.Nil(t, resp2.Token) + }) + + t.Run("logout of primary token in remote dc should not work", func(t *testing.T) { + req := structs.ACLLogoutRequest{ + Datacenter: "dc2", + WriteRequest: structs.WriteRequest{Token: primaryToken.SecretID}, + } + + var ignored bool + requireErrorContains(t, acl.Logout(&req, &ignored), "ACL not found") + + // present in dc1 + resp2, err := retrieveTestToken(codec1, "root", "dc1", primaryToken.AccessorID) + require.NoError(t, err) + require.NotNil(t, resp2.Token) + require.Len(t, resp2.Token.ServiceIdentities, 1) + require.Equal(t, "web1", resp2.Token.ServiceIdentities[0].ServiceName) + // absent in dc2 + resp2, err = retrieveTestToken(codec2, "root", "dc2", primaryToken.AccessorID) + require.NoError(t, err) + require.Nil(t, resp2.Token) + }) + + // Don't trigger the auth method delete cascade so we know the individual + // endpoints follow the rules. + + t.Run("delete binding rule", func(t *testing.T) { + req := structs.ACLBindingRuleDeleteRequest{ + Datacenter: "dc2", + BindingRuleID: ruleID, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + require.NoError(t, acl2.BindingRuleDelete(&req, &ignored)) + + // absent in dc2 + resp2, err := retrieveTestBindingRule(codec2, "root", "dc2", ruleID) + require.NoError(t, err) + require.Nil(t, resp2.BindingRule) + // absent in dc1 + resp2, err = retrieveTestBindingRule(codec1, "root", "dc1", ruleID) + require.NoError(t, err) + require.Nil(t, resp2.BindingRule) + }) + + t.Run("delete auth method", func(t *testing.T) { + req := structs.ACLAuthMethodDeleteRequest{ + Datacenter: "dc2", + AuthMethodName: "testmethod", + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + require.NoError(t, acl2.AuthMethodDelete(&req, &ignored)) + + // absent in dc2 + resp2, err := retrieveTestAuthMethod(codec2, "root", "dc2", "testmethod") + require.NoError(t, err) + require.Nil(t, resp2.AuthMethod) + // absent in dc1 + resp2, err = retrieveTestAuthMethod(codec1, "root", "dc1", "testmethod") + require.NoError(t, err) + require.Nil(t, resp2.AuthMethod) + }) +} + +func TestACLEndpoint_Login(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + + testauth.InstallSessionToken( + testSessionID, + "fake-web", // no rules + "default", "web", "abc123", + ) + testauth.InstallSessionToken( + testSessionID, + "fake-db", // 1 rule + "default", "db", "def456", + ) + testauth.InstallSessionToken( + testSessionID, + "fake-monolith", // 1 rule, must exist + "default", "monolith", "ghi789", + ) + + method, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID) + require.NoError(t, err) + + ruleDB, err := upsertTestBindingRule( + codec, "root", "dc1", method.Name, + "serviceaccount.namespace==default and serviceaccount.name==db", + structs.BindingRuleBindTypeService, + "method-${serviceaccount.name}", + ) + _, err = upsertTestBindingRule( + codec, "root", "dc1", method.Name, + "serviceaccount.namespace==default and serviceaccount.name==monolith", + structs.BindingRuleBindTypeRole, + "method-${serviceaccount.name}", + ) + require.NoError(t, err) + + t.Run("do not provide a token", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-web", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + req.Token = "nope" + resp := structs.ACLToken{} + + requireErrorContains(t, acl.Login(&req, &resp), "do not provide a token") + }) + + t.Run("unknown method", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name + "-notexist", + BearerToken: "fake-web", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + requireErrorContains(t, acl.Login(&req, &resp), "ACL not found") + }) + + t.Run("invalid method token", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "invalid", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.Error(t, acl.Login(&req, &resp)) + }) + + t.Run("valid method token no bindings", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-web", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + requireErrorContains(t, acl.Login(&req, &resp), "Permission denied") + }) + + t.Run("valid method token 1 role binding must exist and does not exist", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-monolith", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.Error(t, acl.Login(&req, &resp)) + }) + + // create the role so that the bindtype=existing login works + var monolithRoleID string + { + arg := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Name: "method-monolith", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var out structs.ACLRole + require.NoError(t, acl.RoleSet(&arg, &out)) + + monolithRoleID = out.ID + } + s1.purgeAuthMethodValidators() + + t.Run("valid bearer token 1 role binding must exist and now exists", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-monolith", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + + require.Equal(t, method.Name, resp.AuthMethod) + require.Equal(t, `token created via login: {"pod":"pod1"}`, resp.Description) + require.True(t, resp.Local) + require.Len(t, resp.ServiceIdentities, 0) + require.Len(t, resp.Roles, 1) + role := resp.Roles[0] + require.Equal(t, monolithRoleID, role.ID) + require.Equal(t, "method-monolith", role.Name) + }) + + t.Run("valid bearer token 1 service binding", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-db", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + + require.Equal(t, method.Name, resp.AuthMethod) + require.Equal(t, `token created via login: {"pod":"pod1"}`, resp.Description) + require.True(t, resp.Local) + require.Len(t, resp.Roles, 0) + require.Len(t, resp.ServiceIdentities, 1) + svcid := resp.ServiceIdentities[0] + require.Len(t, svcid.Datacenters, 0) + require.Equal(t, "method-db", svcid.ServiceName) + }) + + { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: structs.ACLBindingRule{ + AuthMethod: ruleDB.AuthMethod, + BindType: structs.BindingRuleBindTypeService, + BindName: ruleDB.BindName, + Selector: "", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var out structs.ACLBindingRule + require.NoError(t, acl.BindingRuleSet(&req, &out)) + } + + t.Run("valid bearer token 1 binding (no selectors this time)", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-db", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + + require.Equal(t, method.Name, resp.AuthMethod) + require.Equal(t, `token created via login: {"pod":"pod1"}`, resp.Description) + require.True(t, resp.Local) + require.Len(t, resp.Roles, 0) + require.Len(t, resp.ServiceIdentities, 1) + svcid := resp.ServiceIdentities[0] + require.Len(t, svcid.Datacenters, 0) + require.Equal(t, "method-db", svcid.ServiceName) + }) + + testSessionID_2 := testauth.StartSession() + defer testauth.ResetSession(testSessionID_2) + { + // Update the method to force the cache to invalidate for the next + // subtest. + updated := *method + updated.Description = "updated for the test" + updated.Config = map[string]interface{}{ + "SessionID": testSessionID_2, + } + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: updated, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored structs.ACLAuthMethod + require.NoError(t, acl.AuthMethodSet(&req, &ignored)) + } + + t.Run("updating the method invalidates the cache", func(t *testing.T) { + // We'll try to login with the 'fake-db' cred which DOES exist in the + // old fake validator, but no longer exists in the new fake validator. + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-db", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + requireErrorContains(t, acl.Login(&req, &resp), "ACL not found") + }) +} + +func TestACLEndpoint_Login_k8s(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + + // spin up a fake api server + testSrv := kubeauth.StartTestAPIServer(t) + defer testSrv.Stop() + + testSrv.AuthorizeJWT(goodJWT_A) + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "", + goodJWT_B, + ) + + method, err := upsertTestKubernetesAuthMethod( + codec, "root", "dc1", + testSrv.CACert(), + testSrv.Addr(), + goodJWT_A, + ) + require.NoError(t, err) + + t.Run("invalid bearer token", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "invalid", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.Error(t, acl.Login(&req, &resp)) + }) + + t.Run("valid bearer token no bindings", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: goodJWT_B, + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + requireErrorContains(t, acl.Login(&req, &resp), "Permission denied") + }) + + _, err = upsertTestBindingRule( + codec, "root", "dc1", method.Name, + "serviceaccount.namespace==default", + structs.BindingRuleBindTypeService, + "${serviceaccount.name}", + ) + require.NoError(t, err) + + t.Run("valid bearer token 1 service binding", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: goodJWT_B, + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + + require.Equal(t, method.Name, resp.AuthMethod) + require.Equal(t, `token created via login: {"pod":"pod1"}`, resp.Description) + require.True(t, resp.Local) + require.Len(t, resp.Roles, 0) + require.Len(t, resp.ServiceIdentities, 1) + svcid := resp.ServiceIdentities[0] + require.Len(t, svcid.Datacenters, 0) + require.Equal(t, "demo", svcid.ServiceName) + }) + + // annotate the account + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "alternate-name", + goodJWT_B, + ) + + t.Run("valid bearer token 1 service binding - with annotation", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: goodJWT_B, + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + + require.Equal(t, method.Name, resp.AuthMethod) + require.Equal(t, `token created via login: {"pod":"pod1"}`, resp.Description) + require.True(t, resp.Local) + require.Len(t, resp.Roles, 0) + require.Len(t, resp.ServiceIdentities, 1) + svcid := resp.ServiceIdentities[0] + require.Len(t, svcid.Datacenters, 0) + require.Equal(t, "alternate-name", svcid.ServiceName) + }) +} + +func TestACLEndpoint_Logout(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + testauth.InstallSessionToken( + testSessionID, + "fake-db", // 1 rule + "default", "db", "def456", + ) + + method, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID) + require.NoError(t, err) + + _, err = upsertTestBindingRule( + codec, "root", "dc1", method.Name, + "", + structs.BindingRuleBindTypeService, + "method-${serviceaccount.name}", + ) + require.NoError(t, err) + + t.Run("you must provide a token", func(t *testing.T) { + req := structs.ACLLogoutRequest{ + Datacenter: "dc1", + // WriteRequest: structs.WriteRequest{Token: "root"}, + } + req.Token = "" + var ignored bool + + requireErrorContains(t, acl.Logout(&req, &ignored), "ACL not found") + }) + + t.Run("logout from deleted token", func(t *testing.T) { + req := structs.ACLLogoutRequest{ + Datacenter: "dc1", + WriteRequest: structs.WriteRequest{Token: "not-found"}, + } + var ignored bool + requireErrorContains(t, acl.Logout(&req, &ignored), "ACL not found") + }) + + t.Run("logout from non-auth method-linked token should fail", func(t *testing.T) { + req := structs.ACLLogoutRequest{ + Datacenter: "dc1", + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var ignored bool + requireErrorContains(t, acl.Logout(&req, &ignored), "Permission denied") + }) + + t.Run("login then logout", func(t *testing.T) { + // Create a totally legit Login token. + loginReq := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-db", + }, + Datacenter: "dc1", + } + loginToken := structs.ACLToken{} + + require.NoError(t, acl.Login(&loginReq, &loginToken)) + require.NotEmpty(t, loginToken.SecretID) + + // Now turn around and nuke it. + req := structs.ACLLogoutRequest{ + Datacenter: "dc1", + WriteRequest: structs.WriteRequest{Token: loginToken.SecretID}, + } + + var ignored bool + require.NoError(t, acl.Logout(&req, &ignored)) + }) +} + +func gatherIDs(t *testing.T, v interface{}) []string { + t.Helper() + + var out []string + switch x := v.(type) { + case []*structs.ACLRole: + for _, r := range x { + out = append(out, r.ID) + } + case structs.ACLRoles: + for _, r := range x { + out = append(out, r.ID) + } + case []*structs.ACLPolicy: + for _, p := range x { + out = append(out, p.ID) + } + case structs.ACLPolicyListStubs: + for _, p := range x { + out = append(out, p.ID) + } + case []*structs.ACLToken: + for _, p := range x { + out = append(out, p.AccessorID) + } + case structs.ACLTokenListStubs: + for _, p := range x { + out = append(out, p.AccessorID) + } + case []*structs.ACLAuthMethod: + for _, p := range x { + out = append(out, p.Name) + } + case structs.ACLAuthMethodListStubs: + for _, p := range x { + out = append(out, p.Name) + } + case []*structs.ACLBindingRule: + for _, p := range x { + out = append(out, p.ID) + } + case structs.ACLBindingRules: + for _, p := range x { + out = append(out, p.ID) + } + default: + t.Fatalf("unknown type: %T", x) + } + return out +} + +func TestValidateBindingRuleBindName(t *testing.T) { + t.Parallel() + + type testcase struct { + name string + bindType string + bindName string + fields string + valid bool // valid HIL, invalid contents + err bool // invalid HIL + } + + for _, test := range []testcase{ + {"no bind type", + "", "", "", false, false}, + {"bad bind type", + "invalid", "blah", "", false, true}, + // valid HIL, invalid name + {"empty", + "both", "", "", false, false}, + {"just end", + "both", "}", "", false, false}, + {"var without start", + "both", " item }", "item", false, false}, + {"two vars missing second start", + "both", "before-${ item }after--more }", "item,more", false, false}, + // names for the two types are validated differently + {"@ is disallowed", + "both", "bad@name", "", false, false}, + {"leading dash", + "role", "-name", "", true, false}, + {"leading dash", + "service", "-name", "", false, false}, + {"trailing dash", + "role", "name-", "", true, false}, + {"trailing dash", + "service", "name-", "", false, false}, + {"inner dash", + "both", "name-end", "", true, false}, + {"upper case", + "role", "NAME", "", true, false}, + {"upper case", + "service", "NAME", "", false, false}, + // valid HIL, valid name + {"no vars", + "both", "nothing", "", true, false}, + {"just var", + "both", "${item}", "item", true, false}, + {"var in middle", + "both", "before-${item}after", "item", true, false}, + {"two vars", + "both", "before-${item}after-${more}", "item,more", true, false}, + // bad + {"no bind name", + "both", "", "", false, false}, + {"just start", + "both", "${", "", false, true}, + {"backwards", + "both", "}${", "", false, true}, + {"no varname", + "both", "${}", "", false, true}, + {"missing map key", + "both", "${item}", "", false, true}, + {"var without end", + "both", "${ item ", "item", false, true}, + {"two vars missing first end", + "both", "before-${ item after-${ more }", "item,more", false, true}, + } { + var cases []testcase + if test.bindType == "both" { + test1 := test + test1.bindType = "role" + test2 := test + test2.bindType = "service" + cases = []testcase{test1, test2} + } else { + cases = []testcase{test} + } + + for _, test := range cases { + test := test + t.Run(test.bindType+"--"+test.name, func(t *testing.T) { + t.Parallel() + valid, err := validateBindingRuleBindName( + test.bindType, + test.bindName, + strings.Split(test.fields, ","), + ) + if test.err { + require.NotNil(t, err) + require.False(t, valid) + } else { + require.NoError(t, err) + require.Equal(t, test.valid, valid) + } + }) + } } - require.ElementsMatch(t, retrievedRoles, roles) } // upsertTestToken creates a token for testing purposes @@ -2836,6 +5049,166 @@ func retrieveTestRoleByName(codec rpc.ClientCodec, masterToken string, datacente return &out, nil } +func deleteTestAuthMethod(codec rpc.ClientCodec, masterToken string, datacenter string, methodName string) error { + arg := structs.ACLAuthMethodDeleteRequest{ + Datacenter: datacenter, + AuthMethodName: methodName, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var ignored string + err := msgpackrpc.CallWithCodec(codec, "ACL.AuthMethodDelete", &arg, &ignored) + return err +} +func upsertTestAuthMethod( + codec rpc.ClientCodec, masterToken string, datacenter string, + sessionID string, +) (*structs.ACLAuthMethod, error) { + name, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: datacenter, + AuthMethod: structs.ACLAuthMethod{ + Name: "test-method-" + name, + Type: "testing", + Config: map[string]interface{}{ + "SessionID": sessionID, + }, + }, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var out structs.ACLAuthMethod + + err = msgpackrpc.CallWithCodec(codec, "ACL.AuthMethodSet", &req, &out) + if err != nil { + return nil, err + } + + return &out, nil +} + +func upsertTestKubernetesAuthMethod( + codec rpc.ClientCodec, masterToken string, datacenter string, + caCert, kubeHost, kubeJWT string, +) (*structs.ACLAuthMethod, error) { + name, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + + if kubeHost == "" { + kubeHost = "https://abc:8443" + } + if kubeJWT == "" { + kubeJWT = goodJWT_A + } + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: datacenter, + AuthMethod: structs.ACLAuthMethod{ + Name: "test-method-" + name, + Type: "kubernetes", + Config: map[string]interface{}{ + "Host": kubeHost, + "CACert": caCert, + "ServiceAccountJWT": kubeJWT, + }, + }, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var out structs.ACLAuthMethod + + err = msgpackrpc.CallWithCodec(codec, "ACL.AuthMethodSet", &req, &out) + if err != nil { + return nil, err + } + + return &out, nil +} + +func retrieveTestAuthMethod(codec rpc.ClientCodec, masterToken string, datacenter string, name string) (*structs.ACLAuthMethodResponse, error) { + arg := structs.ACLAuthMethodGetRequest{ + Datacenter: datacenter, + AuthMethodName: name, + QueryOptions: structs.QueryOptions{Token: masterToken}, + } + + var out structs.ACLAuthMethodResponse + + err := msgpackrpc.CallWithCodec(codec, "ACL.AuthMethodRead", &arg, &out) + + if err != nil { + return nil, err + } + + return &out, nil +} + +func deleteTestBindingRule(codec rpc.ClientCodec, masterToken string, datacenter string, ruleID string) error { + arg := structs.ACLBindingRuleDeleteRequest{ + Datacenter: datacenter, + BindingRuleID: ruleID, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var ignored string + err := msgpackrpc.CallWithCodec(codec, "ACL.BindingRuleDelete", &arg, &ignored) + return err +} + +func upsertTestBindingRule( + codec rpc.ClientCodec, + masterToken string, + datacenter string, + methodName string, + selector string, + bindType string, + bindName string, +) (*structs.ACLBindingRule, error) { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: datacenter, + BindingRule: structs.ACLBindingRule{ + AuthMethod: methodName, + BindType: bindType, + BindName: bindName, + Selector: selector, + }, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var out structs.ACLBindingRule + + err := msgpackrpc.CallWithCodec(codec, "ACL.BindingRuleSet", &req, &out) + if err != nil { + return nil, err + } + + return &out, nil +} + +func retrieveTestBindingRule(codec rpc.ClientCodec, masterToken string, datacenter string, ruleID string) (*structs.ACLBindingRuleResponse, error) { + arg := structs.ACLBindingRuleGetRequest{ + Datacenter: datacenter, + BindingRuleID: ruleID, + QueryOptions: structs.QueryOptions{Token: masterToken}, + } + + var out structs.ACLBindingRuleResponse + + err := msgpackrpc.CallWithCodec(codec, "ACL.BindingRuleRead", &arg, &out) + + if err != nil { + return nil, err + } + + return &out, nil +} + func requireTimeEquals(t *testing.T, expect, got *time.Time) { t.Helper() if expect == nil && got == nil { @@ -2858,3 +5231,9 @@ func requireErrorContains(t *testing.T, err error, expectedErrorMessage string) t.Fatalf("unexpected error: %v", err) } } + +// 'default/admin' +const goodJWT_A = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImFkbWluLXRva2VuLXFsejQyIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQubmFtZSI6ImFkbWluIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQudWlkIjoiNzM4YmMyNTEtNjUzMi0xMWU5LWI2N2YtNDhlNmM4YjhlY2I1Iiwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6YWRtaW4ifQ.ixMlnWrAG7NVuTTKu8cdcYfM7gweS3jlKaEsIBNGOVEjPE7rtXtgMkAwjQTdYR08_0QBjkgzy5fQC5ZNyglSwONJ-bPaXGvhoH1cTnRi1dz9H_63CfqOCvQP1sbdkMeRxNTGVAyWZT76rXoCUIfHP4LY2I8aab0KN9FTIcgZRF0XPTtT70UwGIrSmRpxW38zjiy2ymWL01cc5VWGhJqVysmWmYk3wNp0h5N57H_MOrz4apQR4pKaamzskzjLxO55gpbmZFC76qWuUdexAR7DT2fpbHLOw90atN_NlLMY-VrXyW3-Ei5EhYaVreMB9PSpKwkrA4jULITohV-sxpa1LA" + +// 'default/demo' +const goodJWT_B = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImRlbW8tdG9rZW4ta21iOW4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZGVtbyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6Ijc2MDkxYWY0LTRiNTYtMTFlOS1hYzRiLTcwOGIxMTgwMWNiZSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmRlbW8ifQ.ZiAHjijBAOsKdum0Aix6lgtkLkGo9_Tu87dWQ5Zfwnn3r2FejEWDAnftTft1MqqnMzivZ9Wyyki5ZjQRmTAtnMPJuHC-iivqY4Wh4S6QWCJ1SivBv5tMZR79t5t8mE7R1-OHwst46spru1pps9wt9jsA04d3LpV0eeKYgdPTVaQKklxTm397kIMUugA6yINIBQ3Rh8eQqBgNwEmL4iqyYubzHLVkGkoP9MJikFI05vfRiHtYr-piXz6JFDzXMQj9rW6xtMmrBSn79ChbyvC5nz-Nj2rJPnHsb_0rDUbmXY5PpnMhBpdSH-CbZ4j8jsiib6DtaGJhVZeEQ1GjsFAZwQ" diff --git a/agent/consul/acl_replication_legacy.go b/agent/consul/acl_replication_legacy.go index 010c220a96..b933f714e9 100644 --- a/agent/consul/acl_replication_legacy.go +++ b/agent/consul/acl_replication_legacy.go @@ -138,7 +138,7 @@ func reconcileLegacyACLs(local, remote structs.ACLs, lastRemoteIndex uint64) str // FetchLocalACLs returns the ACLs in the local state store. func (s *Server) fetchLocalLegacyACLs() (structs.ACLs, error) { - _, local, err := s.fsm.State().ACLTokenList(nil, false, true, "", "") + _, local, err := s.fsm.State().ACLTokenList(nil, false, true, "", "", "") if err != nil { return nil, err } diff --git a/agent/consul/acl_replication_legacy_test.go b/agent/consul/acl_replication_legacy_test.go index f5a2601d54..171f71c359 100644 --- a/agent/consul/acl_replication_legacy_test.go +++ b/agent/consul/acl_replication_legacy_test.go @@ -396,11 +396,11 @@ func TestACLReplication_LegacyTokens(t *testing.T) { } checkSame := func() error { - index, remote, err := s1.fsm.State().ACLTokenList(nil, true, true, "", "") + index, remote, err := s1.fsm.State().ACLTokenList(nil, true, true, "", "", "") if err != nil { return err } - _, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "") + _, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "", "") if err != nil { return err } diff --git a/agent/consul/acl_replication_test.go b/agent/consul/acl_replication_test.go index 730527eedc..e8a6a7d693 100644 --- a/agent/consul/acl_replication_test.go +++ b/agent/consul/acl_replication_test.go @@ -351,9 +351,9 @@ func TestACLReplication_Tokens(t *testing.T) { checkSame := func(t *retry.R) { // only account for global tokens - local tokens shouldn't be replicated - index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "") + index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "", "") require.NoError(t, err) - _, local, err := s2.fsm.State().ACLTokenList(nil, false, true, "", "") + _, local, err := s2.fsm.State().ACLTokenList(nil, false, true, "", "", "") require.NoError(t, err) require.Len(t, local, len(remote)) @@ -444,7 +444,7 @@ func TestACLReplication_Tokens(t *testing.T) { }) // verify dc2 local tokens didn't get blown away - _, local, err := s2.fsm.State().ACLTokenList(nil, true, false, "", "") + _, local, err := s2.fsm.State().ACLTokenList(nil, true, false, "", "", "") require.NoError(t, err) require.Len(t, local, 50) @@ -779,10 +779,10 @@ func TestACLReplication_AllTypes(t *testing.T) { checkSameTokens := func(t *retry.R) { // only account for global tokens - local tokens shouldn't be replicated - index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "") + index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "", "") require.NoError(t, err) // Query for all of them, so that we can prove that no globals snuck in. - _, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "") + _, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) require.Len(t, remote, len(local)) diff --git a/agent/consul/acl_replication_types.go b/agent/consul/acl_replication_types.go index 7044442fdf..8efc229632 100644 --- a/agent/consul/acl_replication_types.go +++ b/agent/consul/acl_replication_types.go @@ -34,7 +34,7 @@ func (r *aclTokenReplicator) FetchRemote(srv *Server, lastRemoteIndex uint64) (i func (r *aclTokenReplicator) FetchLocal(srv *Server) (int, uint64, error) { r.local = nil - idx, local, err := srv.fsm.State().ACLTokenList(nil, false, true, "", "") + idx, local, err := srv.fsm.State().ACLTokenList(nil, false, true, "", "", "") if err != nil { return 0, 0, err } diff --git a/agent/consul/acl_server.go b/agent/consul/acl_server.go index d895d922a2..34ca09584b 100644 --- a/agent/consul/acl_server.go +++ b/agent/consul/acl_server.go @@ -73,6 +73,17 @@ func (s *Server) checkRoleUUID(id string) (bool, error) { return !structs.ACLIDReserved(id), nil } +func (s *Server) checkBindingRuleUUID(id string) (bool, error) { + state := s.fsm.State() + if _, rule, err := state.ACLBindingRuleGetByID(nil, id); err != nil { + return false, err + } else if rule != nil { + return false, nil + } + + return !structs.ACLIDReserved(id), nil +} + func (s *Server) updateACLAdvertisement() { // One thing to note is that once in new ACL mode the server will // never transition to legacy ACL mode. This is not currently a diff --git a/agent/consul/acl_test.go b/agent/consul/acl_test.go index 65f05d9875..e2c84afe21 100644 --- a/agent/consul/acl_test.go +++ b/agent/consul/acl_test.go @@ -2,7 +2,6 @@ package consul import ( "fmt" - "log" "os" "reflect" "strings" @@ -12,6 +11,7 @@ import ( "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -139,6 +139,26 @@ func testIdentityForToken(token string) (bool, structs.ACLIdentity, error) { }, }, }, nil + case "found-synthetic-policy-1": + return true, &structs.ACLToken{ + AccessorID: "f6c5a5fb-4da4-422b-9abf-2c942813fc71", + SecretID: "55cb7d69-2bea-42c3-a68f-2a1443d2abbc", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "service1", + }, + }, + }, nil + case "found-synthetic-policy-2": + return true, &structs.ACLToken{ + AccessorID: "7c87dfad-be37-446e-8305-299585677cb5", + SecretID: "dfca9676-ac80-453a-837b-4c0cf923473c", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "service2", + }, + }, + }, nil case "acl-ro": return true, &structs.ACLToken{ AccessorID: "435a75af-1763-4980-89f4-f0951dda53b4", @@ -430,57 +450,87 @@ type ACLResolverTestDelegate struct { roleCached bool } +func (d *ACLResolverTestDelegate) Reset() { + d.tokenCached = false + d.policyCached = false + d.roleCached = false +} + var errRPC = fmt.Errorf("Induced RPC Error") func (d *ACLResolverTestDelegate) defaultTokenReadFn(errAfterCached error) func(*structs.ACLTokenGetRequest, *structs.ACLTokenResponse) error { return func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { if !d.tokenCached { - _, token, _ := testIdentityForToken(args.TokenID) - reply.Token = token.(*structs.ACLToken) - + err := d.plainTokenReadFn(args, reply) d.tokenCached = true - return nil + return err } return errAfterCached } } +func (d *ACLResolverTestDelegate) plainTokenReadFn(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { + _, token, err := testIdentityForToken(args.TokenID) + if token != nil { + reply.Token = token.(*structs.ACLToken) + } + return err +} + func (d *ACLResolverTestDelegate) defaultPolicyResolveFn(errAfterCached error) func(*structs.ACLPolicyBatchGetRequest, *structs.ACLPolicyBatchResponse) error { return func(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error { if !d.policyCached { - for _, policyID := range args.PolicyIDs { - _, policy, _ := testPolicyForID(policyID) - if policy != nil { - reply.Policies = append(reply.Policies, policy) - } - } - + err := d.plainPolicyResolveFn(args, reply) d.policyCached = true - return nil + return err } return errAfterCached } } +func (d *ACLResolverTestDelegate) plainPolicyResolveFn(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error { + // TODO: if we were being super correct about it, we'd verify the token first + // TODO: and possibly return a not-found or permission-denied here + + for _, policyID := range args.PolicyIDs { + _, policy, _ := testPolicyForID(policyID) + if policy != nil { + reply.Policies = append(reply.Policies, policy) + } + } + + return nil +} + func (d *ACLResolverTestDelegate) defaultRoleResolveFn(errAfterCached error) func(*structs.ACLRoleBatchGetRequest, *structs.ACLRoleBatchResponse) error { return func(args *structs.ACLRoleBatchGetRequest, reply *structs.ACLRoleBatchResponse) error { if !d.roleCached { - for _, roleID := range args.RoleIDs { - _, role, _ := testRoleForID(roleID) - if role != nil { - reply.Roles = append(reply.Roles, role) - } - } - + err := d.plainRoleResolveFn(args, reply) d.roleCached = true - return nil + return err } return errAfterCached } } +// plainRoleResolveFn tries to follow the normal logic of ACL.RoleResolve using +// the test fixtures. +func (d *ACLResolverTestDelegate) plainRoleResolveFn(args *structs.ACLRoleBatchGetRequest, reply *structs.ACLRoleBatchResponse) error { + // TODO: if we were being super correct about it, we'd verify the token first + // TODO: and possibly return a not-found or permission-denied here + + for _, roleID := range args.RoleIDs { + _, role, _ := testRoleForID(roleID) + if role != nil { + reply.Roles = append(reply.Roles, role) + } + } + + return nil +} + func (d *ACLResolverTestDelegate) ACLsEnabled() bool { return d.enabled } @@ -549,7 +599,7 @@ func newTestACLResolver(t *testing.T, delegate ACLResolverDelegate, cb func(*ACL config.ACLDownPolicy = "extend-cache" rconf := &ACLResolverConfig{ Config: config, - Logger: log.New(os.Stdout, t.Name()+" - ", log.LstdFlags|log.Lmicroseconds), + Logger: testutil.TestLoggerWithName(t, t.Name()), CacheConfig: &structs.ACLCachesConfig{ Identities: 4, Policies: 4, @@ -1058,7 +1108,7 @@ func TestACLResolver_DownPolicy(t *testing.T) { require.NoError(t, err) require.NotNil(t, authz2) // testing pointer equality - these will be the same object because it is cached. - require.True(t, authz == authz2) + require.True(t, authz == authz2, "\n[1]={%+v} != \n[2]={%+v}", authz, authz2) require.True(t, authz2.NodeWrite("foo", nil)) }) @@ -1445,6 +1495,23 @@ func TestACLResolver_Client(t *testing.T) { }) } +func TestACLResolver_Client_TokensPoliciesAndRoles(t *testing.T) { + t.Parallel() + delegate := &ACLResolverTestDelegate{ + enabled: true, + datacenter: "dc1", + legacy: false, + localTokens: false, + localPolicies: false, + localRoles: false, + } + delegate.tokenReadFn = delegate.plainTokenReadFn + delegate.policyResolveFn = delegate.plainPolicyResolveFn + delegate.roleResolveFn = delegate.plainRoleResolveFn + + testACLResolver_variousTokens(t, delegate) +} + func TestACLResolver_LocalTokensPoliciesAndRoles(t *testing.T) { t.Parallel() delegate := &ACLResolverTestDelegate{ @@ -1470,31 +1537,43 @@ func TestACLResolver_LocalPoliciesAndRoles(t *testing.T) { localTokens: false, localPolicies: true, localRoles: true, - tokenReadFn: func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { - _, token, err := testIdentityForToken(args.TokenID) - - if token != nil { - reply.Token = token.(*structs.ACLToken) - } - return err - }, } + delegate.tokenReadFn = delegate.plainTokenReadFn testACLResolver_variousTokens(t, delegate) } func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelegate) { t.Helper() - r := newTestACLResolver(t, delegate, nil) + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { + config.Config.ACLTokenTTL = 600 * time.Second + config.Config.ACLPolicyTTL = 30 * time.Millisecond + config.Config.ACLRoleTTL = 30 * time.Millisecond + config.Config.ACLDownPolicy = "extend-cache" + }) + reset := func() { + // prevent subtest bleedover + r.cache.Purge() + delegate.Reset() + } - t.Run("Missing Identity", func(t *testing.T) { + runTwiceAndReset := func(name string, f func(t *testing.T)) { + t.Helper() + defer reset() // reset the stateful resolve AND blow away the cache + + t.Run(name+" (no-cache)", f) + delegate.Reset() // allow the stateful resolve functions to reset + t.Run(name+" (cached)", f) + } + + runTwiceAndReset("Missing Identity", func(t *testing.T) { authz, err := r.ResolveToken("doesn't exist") require.Nil(t, authz) require.Error(t, err) require.True(t, acl.IsErrNotFound(err)) }) - t.Run("Missing Policy", func(t *testing.T) { + runTwiceAndReset("Missing Policy", func(t *testing.T) { authz, err := r.ResolveToken("missing-policy") require.NoError(t, err) require.NotNil(t, authz) @@ -1502,7 +1581,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.False(t, authz.NodeWrite("foo", nil)) }) - t.Run("Missing Role", func(t *testing.T) { + runTwiceAndReset("Missing Role", func(t *testing.T) { authz, err := r.ResolveToken("missing-role") require.NoError(t, err) require.NotNil(t, authz) @@ -1510,7 +1589,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.False(t, authz.NodeWrite("foo", nil)) }) - t.Run("Missing Policy on Role", func(t *testing.T) { + runTwiceAndReset("Missing Policy on Role", func(t *testing.T) { authz, err := r.ResolveToken("missing-policy-on-role") require.NoError(t, err) require.NotNil(t, authz) @@ -1518,7 +1597,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.False(t, authz.NodeWrite("foo", nil)) }) - t.Run("Normal with Policy", func(t *testing.T) { + runTwiceAndReset("Normal with Policy", func(t *testing.T) { authz, err := r.ResolveToken("found") require.NotNil(t, authz) require.NoError(t, err) @@ -1526,7 +1605,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.True(t, authz.NodeWrite("foo", nil)) }) - t.Run("Normal with Role", func(t *testing.T) { + runTwiceAndReset("Normal with Role", func(t *testing.T) { authz, err := r.ResolveToken("found-role") require.NotNil(t, authz) require.NoError(t, err) @@ -1534,7 +1613,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.True(t, authz.NodeWrite("foo", nil)) }) - t.Run("Normal with Policy and Role", func(t *testing.T) { + runTwiceAndReset("Normal with Policy and Role", func(t *testing.T) { authz, err := r.ResolveToken("found-policy-and-role") require.NotNil(t, authz) require.NoError(t, err) @@ -1543,7 +1622,41 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.True(t, authz.ServiceRead("bar")) }) - t.Run("Anonymous", func(t *testing.T) { + runTwiceAndReset("Synthetic Policies Independently Cache", func(t *testing.T) { + // We resolve both of these tokens in the same cache session + // to verify that the keys for caching synthetic policies don't bleed + // over between each other. + { + authz, err := r.ResolveToken("found-synthetic-policy-1") + require.NotNil(t, authz) + require.NoError(t, err) + // spot check some random perms + require.False(t, authz.ACLRead()) + require.False(t, authz.NodeWrite("foo", nil)) + // ensure we didn't bleed over to the other synthetic policy + require.False(t, authz.ServiceWrite("service2", nil)) + // check our own synthetic policy + require.True(t, authz.ServiceWrite("service1", nil)) + require.True(t, authz.ServiceRead("literally-anything")) + require.True(t, authz.NodeRead("any-node")) + } + { + authz, err := r.ResolveToken("found-synthetic-policy-2") + require.NotNil(t, authz) + require.NoError(t, err) + // spot check some random perms + require.False(t, authz.ACLRead()) + require.False(t, authz.NodeWrite("foo", nil)) + // ensure we didn't bleed over to the other synthetic policy + require.False(t, authz.ServiceWrite("service1", nil)) + // check our own synthetic policy + require.True(t, authz.ServiceWrite("service2", nil)) + require.True(t, authz.ServiceRead("literally-anything")) + require.True(t, authz.NodeRead("any-node")) + } + }) + + runTwiceAndReset("Anonymous", func(t *testing.T) { authz, err := r.ResolveToken("") require.NotNil(t, authz) require.NoError(t, err) @@ -1551,7 +1664,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.True(t, authz.NodeWrite("foo", nil)) }) - t.Run("legacy-management", func(t *testing.T) { + runTwiceAndReset("legacy-management", func(t *testing.T) { authz, err := r.ResolveToken("legacy-management") require.NotNil(t, authz) require.NoError(t, err) @@ -1559,7 +1672,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.True(t, authz.KeyRead("foo")) }) - t.Run("legacy-client", func(t *testing.T) { + runTwiceAndReset("legacy-client", func(t *testing.T) { authz, err := r.ResolveToken("legacy-client") require.NoError(t, err) require.NotNil(t, authz) diff --git a/agent/consul/acl_token_exp_test.go b/agent/consul/acl_token_exp_test.go index 20ae878afc..a851b4dc33 100644 --- a/agent/consul/acl_token_exp_test.go +++ b/agent/consul/acl_token_exp_test.go @@ -51,7 +51,7 @@ func testACLTokenReap_Primary(t *testing.T, local, global bool) { codec := rpcClient(t, s1) defer codec.Close() - acl := ACL{s1} + acl := ACL{srv: s1} masterTokenAccessorID, err := retrieveTestTokenAccessorForSecret(codec, "root", "dc1", "root") require.NoError(t, err) diff --git a/agent/consul/authmethod/authmethods.go b/agent/consul/authmethod/authmethods.go new file mode 100644 index 0000000000..8fd477d0f2 --- /dev/null +++ b/agent/consul/authmethod/authmethods.go @@ -0,0 +1,112 @@ +package authmethod + +import ( + "fmt" + "sort" + "sync" + + "github.com/hashicorp/consul/agent/structs" + "github.com/mitchellh/mapstructure" +) + +type ValidatorFactory func(method *structs.ACLAuthMethod) (Validator, error) + +type Validator interface { + // Name returns the name of the auth method backing this validator. + Name() string + + // ValidateLogin takes raw user-provided auth method metadata and ensures + // it is sane, provably correct, and currently valid. Relevant identifying + // data is extracted and returned for immediate use by the role binding + // process. + // + // Depending upon the method, it may make sense to use these calls to + // continue to extend the life of the underlying token. + // + // Returns auth method specific metadata suitable for the Role Binding + // process. + ValidateLogin(loginToken string) (map[string]string, error) + + // AvailableFields returns a slice of all fields that are returned as a + // result of ValidateLogin. These are valid fields for use in any + // BindingRule tied to this auth method. + AvailableFields() []string + + // MakeFieldMapSelectable converts a field map as returned by ValidateLogin + // into a structure suitable for selection with a binding rule. + MakeFieldMapSelectable(fieldMap map[string]string) interface{} +} + +var ( + typesMu sync.RWMutex + types = make(map[string]ValidatorFactory) +) + +// Register makes an auth method with the given type available for use. If +// Register is called twice with the same name or if validator is nil, it +// panics. +func Register(name string, factory ValidatorFactory) { + typesMu.Lock() + defer typesMu.Unlock() + if factory == nil { + panic("authmethod: Register factory is nil for type " + name) + } + if _, dup := types[name]; dup { + panic("authmethod: Register called twice for type " + name) + } + types[name] = factory +} + +func IsRegisteredType(typeName string) bool { + typesMu.RLock() + _, ok := types[typeName] + typesMu.RUnlock() + return ok +} + +// NewValidator instantiates a new Validator for the given auth method +// configuration. If no auth method is registered with the provided type an +// error is returned. +func NewValidator(method *structs.ACLAuthMethod) (Validator, error) { + typesMu.RLock() + factory, ok := types[method.Type] + typesMu.RUnlock() + + if !ok { + return nil, fmt.Errorf("no auth method registered with type: %s", method.Type) + } + + return factory(method) +} + +// Types returns a sorted list of the names of the registered types. +func Types() []string { + typesMu.RLock() + defer typesMu.RUnlock() + var list []string + for name := range types { + list = append(list, name) + } + sort.Strings(list) + return list +} + +// ParseConfig parses the config block for a auth method. +func ParseConfig(rawConfig map[string]interface{}, out interface{}) error { + decodeConf := &mapstructure.DecoderConfig{ + Result: out, + WeaklyTypedInput: true, + ErrorUnused: true, + } + + decoder, err := mapstructure.NewDecoder(decodeConf) + if err != nil { + return err + } + + if err := decoder.Decode(rawConfig); err != nil { + return fmt.Errorf("error decoding config: %s", err) + } + + return nil +} diff --git a/agent/consul/authmethod/kubeauth/k8s.go b/agent/consul/authmethod/kubeauth/k8s.go new file mode 100644 index 0000000000..88c4b32e3d --- /dev/null +++ b/agent/consul/authmethod/kubeauth/k8s.go @@ -0,0 +1,202 @@ +package kubeauth + +import ( + "errors" + "fmt" + "strings" + + "github.com/hashicorp/consul/agent/consul/authmethod" + "github.com/hashicorp/consul/agent/structs" + cleanhttp "github.com/hashicorp/go-cleanhttp" + "gopkg.in/square/go-jose.v2/jwt" + authv1 "k8s.io/api/authentication/v1" + client_metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + k8s "k8s.io/client-go/kubernetes" + client_authv1 "k8s.io/client-go/kubernetes/typed/authentication/v1" + client_corev1 "k8s.io/client-go/kubernetes/typed/core/v1" + client_rest "k8s.io/client-go/rest" + cert "k8s.io/client-go/util/cert" +) + +func init() { + // register this as an available auth method type + authmethod.Register("kubernetes", func(method *structs.ACLAuthMethod) (authmethod.Validator, error) { + v, err := NewValidator(method) + if err != nil { + return nil, err + } + return v, nil + }) +} + +const ( + serviceAccountNamespaceField = "serviceaccount.namespace" + serviceAccountNameField = "serviceaccount.name" + serviceAccountUIDField = "serviceaccount.uid" + + serviceAccountServiceNameAnnotation = "consul.hashicorp.com/service-name" +) + +type Config struct { + // Host must be a host string, a host:port pair, or a URL to the base of + // the Kubernetes API server. + Host string `json:",omitempty"` + + // PEM encoded CA cert for use by the TLS client used to talk with the + // Kubernetes API. Every line must end with a newline: \n + CACert string `json:",omitempty"` + + // A service account JWT used to access the TokenReview API to validate + // other JWTs during login. It also must be able to read ServiceAccount + // annotations. + ServiceAccountJWT string `json:",omitempty"` +} + +// Validator is the wrapper around the relevant portions of the Kubernetes API +// that also conforms to the authmethod.Validator interface. +type Validator struct { + name string + config *Config + saGetter client_corev1.ServiceAccountsGetter + trGetter client_authv1.TokenReviewsGetter +} + +func NewValidator(method *structs.ACLAuthMethod) (*Validator, error) { + if method.Type != "kubernetes" { + return nil, fmt.Errorf("%q is not a kubernetes auth method", method.Name) + } + + var config Config + if err := authmethod.ParseConfig(method.Config, &config); err != nil { + return nil, err + } + + if config.Host == "" { + return nil, fmt.Errorf("Config.Host is required") + } + + if config.CACert == "" { + return nil, fmt.Errorf("Config.CACert is required") + } + if _, err := cert.ParseCertsPEM([]byte(config.CACert)); err != nil { + return nil, fmt.Errorf("error parsing kubernetes ca cert: %v", err) + } + + // This is the bearer token we give the apiserver to use the API. + if config.ServiceAccountJWT == "" { + return nil, fmt.Errorf("Config.ServiceAccountJWT is required") + } + if _, err := jwt.ParseSigned(config.ServiceAccountJWT); err != nil { + return nil, fmt.Errorf("Config.ServiceAccountJWT is not a valid JWT: %v", err) + } + + transport := cleanhttp.DefaultTransport() + client, err := k8s.NewForConfig(&client_rest.Config{ + Host: config.Host, + BearerToken: config.ServiceAccountJWT, + Dial: transport.DialContext, + TLSClientConfig: client_rest.TLSClientConfig{ + CAData: []byte(config.CACert), + }, + ContentConfig: client_rest.ContentConfig{ + ContentType: "application/json", + }, + }) + if err != nil { + return nil, err + } + + return &Validator{ + name: method.Name, + config: &config, + saGetter: client.CoreV1(), + trGetter: client.AuthenticationV1(), + }, nil +} + +func (v *Validator) Name() string { return v.name } + +func (v *Validator) ValidateLogin(loginToken string) (map[string]string, error) { + if _, err := jwt.ParseSigned(loginToken); err != nil { + return nil, fmt.Errorf("failed to parse and validate JWT: %v", err) + } + + // Check TokenReview for the bulk of the work. + trResp, err := v.trGetter.TokenReviews().Create(&authv1.TokenReview{ + Spec: authv1.TokenReviewSpec{ + Token: loginToken, + }, + }) + + if err != nil { + return nil, err + } else if trResp.Status.Error != "" { + return nil, fmt.Errorf("lookup failed: %s", trResp.Status.Error) + } + + if !trResp.Status.Authenticated { + return nil, errors.New("lookup failed: service account jwt not valid") + } + + // The username is of format: system:serviceaccount:(NAMESPACE):(SERVICEACCOUNT) + parts := strings.Split(trResp.Status.User.Username, ":") + if len(parts) != 4 { + return nil, errors.New("lookup failed: unexpected username format") + } + + // Validate the user that comes back from token review is a service account + if parts[0] != "system" || parts[1] != "serviceaccount" { + return nil, errors.New("lookup failed: username returned is not a service account") + } + + var ( + saNamespace = parts[2] + saName = parts[3] + saUID = string(trResp.Status.User.UID) + ) + + // Check to see if there is an override name on the ServiceAccount object. + sa, err := v.saGetter.ServiceAccounts(saNamespace).Get(saName, client_metav1.GetOptions{}) + if err != nil { + return nil, fmt.Errorf("annotation lookup failed: %v", err) + } + + annotations := sa.GetObjectMeta().GetAnnotations() + if serviceNameOverride, ok := annotations[serviceAccountServiceNameAnnotation]; ok { + saName = serviceNameOverride + } + + return map[string]string{ + serviceAccountNamespaceField: saNamespace, + serviceAccountNameField: saName, + serviceAccountUIDField: saUID, + }, nil +} + +func (p *Validator) AvailableFields() []string { + return []string{ + serviceAccountNamespaceField, + serviceAccountNameField, + serviceAccountUIDField, + } +} + +func (v *Validator) MakeFieldMapSelectable(fieldMap map[string]string) interface{} { + return &k8sFieldDetails{ + ServiceAccount: k8sFieldDetailsServiceAccount{ + Namespace: fieldMap[serviceAccountNamespaceField], + Name: fieldMap[serviceAccountNameField], + UID: fieldMap[serviceAccountUIDField], + }, + } +} + +type k8sFieldDetails struct { + ServiceAccount k8sFieldDetailsServiceAccount `bexpr:"serviceaccount"` +} + +type k8sFieldDetailsServiceAccount struct { + Namespace string `bexpr:"namespace"` + Name string `bexpr:"name"` + UID string `bexpr:"uid"` +} diff --git a/agent/consul/authmethod/kubeauth/k8s_test.go b/agent/consul/authmethod/kubeauth/k8s_test.go new file mode 100644 index 0000000000..614538c40e --- /dev/null +++ b/agent/consul/authmethod/kubeauth/k8s_test.go @@ -0,0 +1,144 @@ +package kubeauth + +import ( + "testing" + + "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/agent/structs" + "github.com/stretchr/testify/require" +) + +func TestValidateLogin(t *testing.T) { + testSrv := StartTestAPIServer(t) + defer testSrv.Stop() + + testSrv.AuthorizeJWT(goodJWT_A) + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "", + goodJWT_B, + ) + + method := &structs.ACLAuthMethod{ + Name: "test-k8s", + Description: "k8s test", + Type: "kubernetes", + Config: map[string]interface{}{ + "Host": testSrv.Addr(), + "CACert": testSrv.CACert(), + "ServiceAccountJWT": goodJWT_A, + }, + } + validator, err := NewValidator(method) + require.NoError(t, err) + + t.Run("invalid bearer token", func(t *testing.T) { + _, err := validator.ValidateLogin("invalid") + require.Error(t, err) + }) + + t.Run("valid bearer token", func(t *testing.T) { + fields, err := validator.ValidateLogin(goodJWT_B) + require.NoError(t, err) + require.Equal(t, map[string]string{ + "serviceaccount.namespace": "default", + "serviceaccount.name": "demo", + "serviceaccount.uid": "76091af4-4b56-11e9-ac4b-708b11801cbe", + }, fields) + }) + + // annotate the account + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "alternate-name", + goodJWT_B, + ) + + t.Run("valid bearer token with annotation", func(t *testing.T) { + fields, err := validator.ValidateLogin(goodJWT_B) + require.NoError(t, err) + require.Equal(t, map[string]string{ + "serviceaccount.namespace": "default", + "serviceaccount.name": "alternate-name", + "serviceaccount.uid": "76091af4-4b56-11e9-ac4b-708b11801cbe", + }, fields) + }) +} + +func TestNewValidator(t *testing.T) { + ca := connect.TestCA(t, nil) + + type AM = *structs.ACLAuthMethod + + makeAuthMethod := func(f func(method AM)) *structs.ACLAuthMethod { + method := &structs.ACLAuthMethod{ + Name: "test-k8s", + Description: "k8s test", + Type: "kubernetes", + Config: map[string]interface{}{ + "Host": "https://abc:8443", + "CACert": ca.RootCert, + "ServiceAccountJWT": goodJWT_A, + }, + } + if f != nil { + f(method) + } + return method + } + + for _, test := range []struct { + name string + method *structs.ACLAuthMethod + ok bool + }{ + // bad + {"wrong type", makeAuthMethod(func(method AM) { + method.Type = "invalid" + }), false}, + {"extra config", makeAuthMethod(func(method AM) { + method.Config["extra"] = "config" + }), false}, + {"wrong type of config", makeAuthMethod(func(method AM) { + method.Config["Host"] = []int{12345} + }), false}, + {"missing host", makeAuthMethod(func(method AM) { + delete(method.Config, "Host") + }), false}, + {"missing ca cert", makeAuthMethod(func(method AM) { + delete(method.Config, "CACert") + }), false}, + {"invalid ca cert", makeAuthMethod(func(method AM) { + method.Config["CACert"] = "invalid" + }), false}, + {"invalid jwt", makeAuthMethod(func(method AM) { + method.Config["ServiceAccountJWT"] = "invalid" + }), false}, + {"garbage host", makeAuthMethod(func(method AM) { + method.Config["Host"] = "://:12345" + }), false}, + // good + {"normal", makeAuthMethod(nil), true}, + } { + t.Run(test.name, func(t *testing.T) { + v, err := NewValidator(test.method) + if test.ok { + require.NoError(t, err) + require.NotNil(t, v) + } else { + require.NotNil(t, err) + require.Nil(t, v) + } + }) + } +} + +// 'default/admin' +const goodJWT_A = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImFkbWluLXRva2VuLXFsejQyIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQubmFtZSI6ImFkbWluIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQudWlkIjoiNzM4YmMyNTEtNjUzMi0xMWU5LWI2N2YtNDhlNmM4YjhlY2I1Iiwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6YWRtaW4ifQ.ixMlnWrAG7NVuTTKu8cdcYfM7gweS3jlKaEsIBNGOVEjPE7rtXtgMkAwjQTdYR08_0QBjkgzy5fQC5ZNyglSwONJ-bPaXGvhoH1cTnRi1dz9H_63CfqOCvQP1sbdkMeRxNTGVAyWZT76rXoCUIfHP4LY2I8aab0KN9FTIcgZRF0XPTtT70UwGIrSmRpxW38zjiy2ymWL01cc5VWGhJqVysmWmYk3wNp0h5N57H_MOrz4apQR4pKaamzskzjLxO55gpbmZFC76qWuUdexAR7DT2fpbHLOw90atN_NlLMY-VrXyW3-Ei5EhYaVreMB9PSpKwkrA4jULITohV-sxpa1LA" + +// 'default/demo' +const goodJWT_B = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImRlbW8tdG9rZW4ta21iOW4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZGVtbyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6Ijc2MDkxYWY0LTRiNTYtMTFlOS1hYzRiLTcwOGIxMTgwMWNiZSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmRlbW8ifQ.ZiAHjijBAOsKdum0Aix6lgtkLkGo9_Tu87dWQ5Zfwnn3r2FejEWDAnftTft1MqqnMzivZ9Wyyki5ZjQRmTAtnMPJuHC-iivqY4Wh4S6QWCJ1SivBv5tMZR79t5t8mE7R1-OHwst46spru1pps9wt9jsA04d3LpV0eeKYgdPTVaQKklxTm397kIMUugA6yINIBQ3Rh8eQqBgNwEmL4iqyYubzHLVkGkoP9MJikFI05vfRiHtYr-piXz6JFDzXMQj9rW6xtMmrBSn79ChbyvC5nz-Nj2rJPnHsb_0rDUbmXY5PpnMhBpdSH-CbZ4j8jsiib6DtaGJhVZeEQ1GjsFAZwQ" diff --git a/agent/consul/authmethod/kubeauth/testing.go b/agent/consul/authmethod/kubeauth/testing.go new file mode 100644 index 0000000000..7e6340dd9d --- /dev/null +++ b/agent/consul/authmethod/kubeauth/testing.go @@ -0,0 +1,532 @@ +package kubeauth + +import ( + "bytes" + "encoding/json" + "encoding/pem" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "regexp" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + authv1 "k8s.io/api/authentication/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" +) + +// TestAPIServer is a way to mock the Kubernetes API server as it is used by +// the consul kubernetes auth method. +// +// - POST /apis/authentication.k8s.io/v1/tokenreviews +// - GET /api/v1/namespaces//serviceaccounts/ +// +type TestAPIServer struct { + t *testing.T + srv *httptest.Server + caCert string + + mu sync.Mutex + authorizedJWT string // token review and sa read + allowedServiceAccountJWT string // general service account + replyStatus *authv1.TokenReview // general service account + replyRead *corev1.ServiceAccount // general service account +} + +// StartTestAPIServer creates a disposable TestAPIServer and binds it to a +// random free port. +func StartTestAPIServer(t *testing.T) *TestAPIServer { + s := &TestAPIServer{t: t} + + s.srv = httptest.NewTLSServer(s) + + bs := s.srv.TLS.Certificates[0].Certificate[0] + + var buf bytes.Buffer + require.NoError(t, pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})) + s.caCert = buf.String() + + return s +} + +// AuthorizeJWT whitelists the given JWT as able to use the API server. +func (s *TestAPIServer) AuthorizeJWT(jwt string) { + s.mu.Lock() + defer s.mu.Unlock() + + s.authorizedJWT = jwt +} + +// SetAllowedServiceAccount configures the singular known Service Account +// installed in this API server. If any of namespace/name/uid/jwt are empty +// it removes anything previously configured. +// +// It is up to the caller to ensure that the provided JWT matches the other +// data. +func (s *TestAPIServer) SetAllowedServiceAccount( + namespace, name, uid, overrideAnnotation, jwt string, +) { + s.mu.Lock() + defer s.mu.Unlock() + + if namespace == "" || name == "" || uid == "" || jwt == "" { + s.allowedServiceAccountJWT = "" + s.replyStatus = nil + s.replyRead = nil + return + } + + s.allowedServiceAccountJWT = jwt + s.replyRead = createReadServiceAccountFound(namespace, name, uid, overrideAnnotation, jwt) + s.replyStatus = createTokenReviewFound(namespace, name, uid, jwt) +} + +// Stop stops the running TestAPIServer. +func (s *TestAPIServer) Stop() { + s.srv.Close() +} + +// Addr returns the current base URL for the running webserver. +func (s *TestAPIServer) Addr() string { return s.srv.URL } + +// CACert returns the pem-encoded CA certificate used by the HTTPS server. +func (s *TestAPIServer) CACert() string { return s.caCert } + +var readServiceAccountPathRE = regexp.MustCompile("^/api/v1/namespaces/([^/]+)/serviceaccounts/([^/]+)$") + +func (s *TestAPIServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + s.mu.Lock() + defer s.mu.Unlock() + + w.Header().Set("content-type", "application/json") + + if req.URL.Path == "/apis/authentication.k8s.io/v1/tokenreviews" { + s.handleTokenReview(w, req) + return + } + + if m := readServiceAccountPathRE.FindStringSubmatch(req.URL.Path); m != nil { + namespace, err := url.QueryUnescape(m[1]) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + name, err := url.QueryUnescape(m[2]) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + s.handleReadServiceAccount(namespace, name, w, req) + return + } + + w.WriteHeader(http.StatusNotFound) +} + +func writeJSON(w http.ResponseWriter, out interface{}) error { + enc := json.NewEncoder(w) + return enc.Encode(out) +} + +func (s *TestAPIServer) handleTokenReview(w http.ResponseWriter, req *http.Request) { + if req.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + if auth, anon := s.isAuthenticated(req); !auth { + var out interface{} + if anon { + out = createTokenReviewForbidden_NoAuthz() + } else { + out = createTokenReviewForbidden("default", "fake-account") + } + + w.WriteHeader(http.StatusForbidden) + if err := writeJSON(w, out); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + return + } + + if req.Body == nil { + w.WriteHeader(http.StatusBadRequest) + return + } + defer req.Body.Close() + + b, err := ioutil.ReadAll(req.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + var trReq authv1.TokenReview + if err := json.Unmarshal(b, &trReq); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + reviewingJWT := trReq.Spec.Token + + var out interface{} + if s.replyStatus == nil || reviewingJWT != s.allowedServiceAccountJWT { + out = createTokenReviewNotFound(reviewingJWT) + } else { + out = s.replyStatus + } + w.WriteHeader(http.StatusCreated) + + if err := writeJSON(w, out); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +func (s *TestAPIServer) handleReadServiceAccount( + namespace, name string, + w http.ResponseWriter, + req *http.Request, +) { + if req.Method != "GET" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + var out interface{} + if auth, anon := s.isAuthenticated(req); !auth { + if anon { + out = createReadServiceAccountForbidden_NoAuthz() + } else { + out = createReadServiceAccountForbidden(namespace, name) + } + w.WriteHeader(http.StatusForbidden) + } else if s.replyRead == nil { + out = createReadServiceAccountNotFound(namespace, name) + w.WriteHeader(http.StatusNotFound) + } else if s.replyRead.Namespace != namespace || s.replyRead.Name != name { + out = createReadServiceAccountNotFound(namespace, name) + w.WriteHeader(http.StatusNotFound) + } else { + out = s.replyRead + w.WriteHeader(http.StatusOK) + } + + if err := writeJSON(w, out); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +func (s *TestAPIServer) isAuthenticated(req *http.Request) (auth, anonymous bool) { + authz := req.Header.Get("Authorization") + if !strings.HasPrefix(authz, "Bearer ") { + return false, true + } + jwt := strings.TrimPrefix(authz, "Bearer ") + + return s.authorizedJWT == jwt, false +} + +func createTokenReviewForbidden_NoAuthz() *metav1.Status { + /* + STATUS: 403 + { + "kind": "Status", + "apiVersion": "v1", + "metadata": {}, + "status": "Failure", + "message": "tokenreviews.authentication.k8s.io is forbidden: User \"system:anonymous\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" at the cluster scope", + "reason": "Forbidden", + "details": { + "group": "authentication.k8s.io", + "kind": "tokenreviews" + }, + "code": 403 + } + */ + return createStatus( + metav1.StatusFailure, + "tokenreviews.authentication.k8s.io is forbidden: User \"system:anonymous\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" in the cluster scope", + metav1.StatusReasonForbidden, + &metav1.StatusDetails{ + Group: "authentication.k8s.io", + Kind: "tokenreviews", + }, + 403, + ) +} + +func createTokenReviewForbidden(namespace, name string) *metav1.Status { + /* + STATUS: 403 + { + "kind": "Status", + "apiVersion": "v1", + "metadata": {}, + "status": "Failure", + "message": "tokenreviews.authentication.k8s.io is forbidden: User \"system:serviceaccount:default:admin\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" at the cluster scope", + "reason": "Forbidden", + "details": { + "group": "authentication.k8s.io", + "kind": "tokenreviews" + }, + "code": 403 + } + */ + return createStatus( + metav1.StatusFailure, + "tokenreviews.authentication.k8s.io is forbidden: User \"system:serviceaccount:"+namespace+":"+name+"\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" in the cluster scope", + metav1.StatusReasonForbidden, + &metav1.StatusDetails{ + Group: "authentication.k8s.io", + Kind: "tokenreviews", + }, + 403, + ) +} + +func createTokenReviewNotFound(jwt string) *authv1.TokenReview { + /* + STATUS: 201 + { + "kind": "TokenReview", + "apiVersion": "authentication.k8s.io/v1", + "metadata": { + "creationTimestamp": null + }, + "spec": { + "token": "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImZha2UtdG9rZW4tano2YnYiLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZmFrZSIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6IjgxYTY1Mjg2LTU3YzEtMTFlOS1iYzJhLTQ4ZTZjOGI4ZWNiNSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmZha2UifQ.DqjUXe34SzCP4NCwbhqV9EuksfzmTSLhJzkE_URyufeGJDn-Gw0_JS-_KmxZSdAO0XXNzB1tJNM1NCVW-V6YbThnPUw5WY4V2J6U1W72c2dzNBx_ipBxGBZ632ZnpViIRu6tL2guT36lWa8YnMDF_OY8sHhl_3kJ6MRxNxY41vAuf45mohi3gri46Kpzc3pf1g6PJ-0oogvUsZ2nBFv1mIdciGBV0zejMKc5Bnxur1L-hEQ9EgZrJ7o0yQRCWYgam_yo_M38EsB8b-suTzQJMA-pRgApOb9dHIV6YAE_b3g_pGkJjrPYzV4IJC1CiPfdz1SAjm7e0ARXtZmaoPltjQ" + }, + "status": { + "user": {}, + "error": "[invalid bearer token, Token has been invalidated]" + } + } + */ + return &authv1.TokenReview{ + TypeMeta: metav1.TypeMeta{ + Kind: "TokenReview", + APIVersion: "authentication.k8s.io/v1", + }, + ObjectMeta: metav1.ObjectMeta{}, + Spec: authv1.TokenReviewSpec{ + Token: jwt, + }, + Status: authv1.TokenReviewStatus{ + User: authv1.UserInfo{}, + Error: "[invalid bearer token, Token has been invalidated]", + }, + } +} + +func createTokenReviewFound(namespace, name, uid, jwt string) *authv1.TokenReview { + /* + STATUS: 201 + { + "kind": "TokenReview", + "apiVersion": "authentication.k8s.io/v1", + "metadata": { + "creationTimestamp": null + }, + "spec": { + "token": "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImRlbW8tdG9rZW4tbTljdm4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZGVtbyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6IjlmZjUxZmY0LTU1N2UtMTFlOS05Njg3LTQ4ZTZjOGI4ZWNiNSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmRlbW8ifQ.UJEphtrN261gy9WCl4ZKjm2PRDLDkc3Xg9VcDGfzyroOqFQ6sog5dVAb9voc5Nc0-H5b1yGwxDViEMucwKvZpA5pi7VEx_OskK-KTWXSmafM0Xg_AvzpU9Ed5TSRno-OhXaAraxdjXoC4myh1ay2DMeHUusJg_ibqcYJrWx-6MO1bH_ObORtAKhoST_8fzkqNAlZmsQ87FinQvYN5mzDXYukl-eeRdBgQUBkWvEb-Ju6cc0-QE4sUQ4IH_fs0fUyX_xc0om0SZGWLP909FTz4V8LxV8kr6L7irxROiS1jn3Fvyc9ur1PamVf3JOPPrOyfmKbaGRiWJM32b3buQw7cg" + }, + "status": { + "authenticated": true, + "user": { + "username": "system:serviceaccount:default:demo", + "uid": "9ff51ff4-557e-11e9-9687-48e6c8b8ecb5", + "groups": [ + "system:serviceaccounts", + "system:serviceaccounts:default", + "system:authenticated" + ] + } + } + } + */ + return &authv1.TokenReview{ + TypeMeta: metav1.TypeMeta{ + Kind: "TokenReview", + APIVersion: "authentication.k8s.io/v1", + }, + ObjectMeta: metav1.ObjectMeta{}, + Spec: authv1.TokenReviewSpec{ + Token: jwt, + }, + Status: authv1.TokenReviewStatus{ + Authenticated: true, + User: authv1.UserInfo{ + Username: "system:serviceaccount:" + namespace + ":" + name, + UID: uid, + Groups: []string{ + "system:serviceaccounts", + "system:serviceaccounts:default", + "system:authenticated", + }, + }, + }, + } +} + +func createReadServiceAccountForbidden(namespace, name string) *metav1.Status { + /* + STATUS: 403 + { + "kind": "Status", + "apiVersion": "v1", + "metadata": {}, + "status": "Failure", + "message": "serviceaccounts \"demo\" is forbidden: User \"system:serviceaccount:default:admin\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \"default\"", + "reason": "Forbidden", + "details": { + "name": "demo", + "kind": "serviceaccounts" + }, + "code": 403 + } + */ + return createStatus( + metav1.StatusFailure, + "serviceaccounts \""+name+"\" is forbidden: User \"system:serviceaccount:"+namespace+":"+name+"\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \""+namespace+"\"", + metav1.StatusReasonForbidden, + &metav1.StatusDetails{ + Kind: "serviceaccounts", + Name: name, + }, + 403, + ) +} + +func createReadServiceAccountForbidden_NoAuthz() *metav1.Status { + // missing bearer token header 403 + /* + { + "kind": "Status", + "apiVersion": "v1", + "metadata": {}, + "status": "Failure", + "message": "serviceaccounts \"demo\" is forbidden: User \"system:anonymous\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \"default\"", + "reason": "Forbidden", + "details": { + "name": "demo", + "kind": "serviceaccounts" + }, + "code": 403 + } + */ + return createStatus( + metav1.StatusFailure, + "serviceaccounts \"PLACEHOLDER\" is forbidden: User \"system:anonymous\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \"default\"", + metav1.StatusReasonForbidden, + &metav1.StatusDetails{ + Kind: "serviceaccounts", + Name: "PLACEHOLDER", + }, + 403, + ) +} + +func createReadServiceAccountNotFound(namespace, name string) *metav1.Status { + /* + STATUS: 404 + { + "kind": "Status", + "apiVersion": "v1", + "metadata": {}, + "status": "Failure", + "message": "serviceaccounts \"demo\" not found", + "reason": "NotFound", + "details": { + "name": "demo", + "kind": "serviceaccounts" + }, + "code": 404 + } + */ + return createStatus( + metav1.StatusFailure, + "serviceaccounts \""+name+"\" not found", + metav1.StatusReasonNotFound, + &metav1.StatusDetails{ + Kind: "serviceaccounts", + Name: name, + }, + 404, + ) +} + +func createReadServiceAccountFound(namespace, name, uid, overrideAnnotation, jwt string) *corev1.ServiceAccount { + /* + STATUS: 200 + { + "kind": "ServiceAccount", + "apiVersion": "v1", + "metadata": { + "name": "demo", + "namespace": "default", + "selfLink": "/api/v1/namespaces/default/serviceaccounts/demo", + "uid": "9ff51ff4-557e-11e9-9687-48e6c8b8ecb5", + "resourceVersion": "2101", + "creationTimestamp": "2019-04-02T19:36:34Z", + "annotations": { + "consul.hashicorp.com/service-name": "actual", + "kubectl.kubernetes.io/last-applied-configuration": "{\"apiVersion\":\"v1\",\"kind\":\"ServiceAccount\",\"metadata\":{\"annotations\":{\"consul.hashicorp.com/service-name\":\"actual\"},\"name\":\"demo\",\"namespace\":\"default\"}}\n" + } + }, + "secrets": [ + { + "name": "demo-token-m9cvn" + } + ] + } + */ + sa := &corev1.ServiceAccount{ + TypeMeta: metav1.TypeMeta{ + Kind: "ServiceAccount", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + SelfLink: "/api/v1/namespaces/" + namespace + "/serviceaccounts/" + name, + UID: types.UID(uid), + ResourceVersion: "123", + CreationTimestamp: metav1.Time{Time: time.Now()}, + }, + Secrets: []corev1.ObjectReference{ + corev1.ObjectReference{ + Name: name + "-token-m9cvn", + }, + }, + } + if overrideAnnotation != "" { + sa.ObjectMeta.Annotations = map[string]string{ + "consul.hashicorp.com/service-name": overrideAnnotation, + } + } + + return sa +} + +func createStatus(status, message string, reason metav1.StatusReason, details *metav1.StatusDetails, code int32) *metav1.Status { + return &metav1.Status{ + TypeMeta: metav1.TypeMeta{ + Kind: "Status", + APIVersion: "v1", + }, + ListMeta: metav1.ListMeta{}, + Status: status, + Message: message, + Reason: reason, + Details: details, + Code: code, + } +} diff --git a/agent/consul/authmethod/testauth/testing.go b/agent/consul/authmethod/testauth/testing.go new file mode 100644 index 0000000000..638450d94d --- /dev/null +++ b/agent/consul/authmethod/testauth/testing.go @@ -0,0 +1,166 @@ +package testauth + +import ( + "fmt" + "sync" + + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/authmethod" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-uuid" +) + +func init() { + authmethod.Register("testing", newValidator) +} + +var ( + tokenDatabaseMu sync.Mutex + tokenDatabase map[string]map[string]map[string]string // session => token => fieldmap +) + +func StartSession() string { + sessionID, err := uuid.GenerateUUID() + if err != nil { + panic(err) + } + return sessionID +} + +func ResetSession(sessionID string) { + tokenDatabaseMu.Lock() + defer tokenDatabaseMu.Unlock() + if tokenDatabase != nil { + delete(tokenDatabase, sessionID) + } +} + +func InstallSessionToken(sessionID string, token string, namespace, name, uid string) { + fields := map[string]string{ + serviceAccountNamespaceField: namespace, + serviceAccountNameField: name, + serviceAccountUIDField: uid, + } + + tokenDatabaseMu.Lock() + defer tokenDatabaseMu.Unlock() + if tokenDatabase == nil { + tokenDatabase = make(map[string]map[string]map[string]string) + } + sdb, ok := tokenDatabase[sessionID] + if !ok { + sdb = make(map[string]map[string]string) + tokenDatabase[sessionID] = sdb + } + sdb[token] = fields +} + +func GetSessionToken(sessionID string, token string) (map[string]string, bool) { + tokenDatabaseMu.Lock() + defer tokenDatabaseMu.Unlock() + if tokenDatabase == nil { + return nil, false + } + sdb, ok := tokenDatabase[sessionID] + if !ok { + return nil, false + } + fields, ok := sdb[token] + if !ok { + return nil, false + } + + fmCopy := make(map[string]string) + for k, v := range fields { + fmCopy[k] = v + } + + return fmCopy, true +} + +type Config struct { + SessionID string // unique identifier for this set of tokens in the database +} + +func newValidator(method *structs.ACLAuthMethod) (authmethod.Validator, error) { + if method.Type != "testing" { + return nil, fmt.Errorf("%q is not a testing auth method", method.Name) + } + + var config Config + if err := authmethod.ParseConfig(method.Config, &config); err != nil { + return nil, err + } + + if config.SessionID == "" { + // If you don't explicitly create one, we create a random one but you + // won't have access to it. Useful if you are testing everything EXCEPT + // ValidateToken(). + config.SessionID = StartSession() + } + + return &Validator{ + name: method.Name, + config: &config, + }, nil +} + +type Validator struct { + name string + config *Config +} + +func (v *Validator) Name() string { return v.name } + +// ValidateLogin takes raw user-provided auth method metadata and ensures it is +// sane, provably correct, and currently valid. Relevant identifying data is +// extracted and returned for immediate use by the role binding process. +// +// Depending upon the method, it may make sense to use these calls to continue +// to extend the life of the underlying token. +// +// Returns auth method specific metadata suitable for the Role Binding process. +func (v *Validator) ValidateLogin(loginToken string) (map[string]string, error) { + fields, valid := GetSessionToken(v.config.SessionID, loginToken) + if !valid { + return nil, acl.ErrNotFound + } + + return fields, nil +} + +func (v *Validator) AvailableFields() []string { return availableFields } + +const ( + serviceAccountNamespaceField = "serviceaccount.namespace" + serviceAccountNameField = "serviceaccount.name" + serviceAccountUIDField = "serviceaccount.uid" +) + +var availableFields = []string{ + serviceAccountNamespaceField, + serviceAccountNameField, + serviceAccountUIDField, +} + +// MakeFieldMapSelectable converts a field map as returned by ValidateLogin +// into a structure suitable for selection with a binding rule. +func (v *Validator) MakeFieldMapSelectable(fieldMap map[string]string) interface{} { + return &selectableVars{ + ServiceAccount: selectableServiceAccount{ + Namespace: fieldMap[serviceAccountNamespaceField], + Name: fieldMap[serviceAccountNameField], + UID: fieldMap[serviceAccountUIDField], + }, + } +} + +type selectableVars struct { + ServiceAccount selectableServiceAccount `bexpr:"serviceaccount"` +} + +type selectableServiceAccount struct { + Namespace string `bexpr:"namespace"` + Name string `bexpr:"name"` + UID string `bexpr:"uid"` +} diff --git a/agent/consul/fsm/commands_oss.go b/agent/consul/fsm/commands_oss.go index 36a09174df..f093aa0abd 100644 --- a/agent/consul/fsm/commands_oss.go +++ b/agent/consul/fsm/commands_oss.go @@ -32,6 +32,10 @@ func init() { registerCommand(structs.ConfigEntryRequestType, (*FSM).applyConfigEntryOperation) registerCommand(structs.ACLRoleSetRequestType, (*FSM).applyACLRoleSetOperation) registerCommand(structs.ACLRoleDeleteRequestType, (*FSM).applyACLRoleDeleteOperation) + registerCommand(structs.ACLBindingRuleSetRequestType, (*FSM).applyACLBindingRuleSetOperation) + registerCommand(structs.ACLBindingRuleDeleteRequestType, (*FSM).applyACLBindingRuleDeleteOperation) + registerCommand(structs.ACLAuthMethodSetRequestType, (*FSM).applyACLAuthMethodSetOperation) + registerCommand(structs.ACLAuthMethodDeleteRequestType, (*FSM).applyACLAuthMethodDeleteOperation) } func (c *FSM) applyRegister(buf []byte, index uint64) interface{} { @@ -476,3 +480,47 @@ func (c *FSM) applyACLRoleDeleteOperation(buf []byte, index uint64) interface{} return c.state.ACLRoleBatchDelete(index, req.RoleIDs) } + +func (c *FSM) applyACLBindingRuleSetOperation(buf []byte, index uint64) interface{} { + var req structs.ACLBindingRuleBatchSetRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "bindingrule"}, time.Now(), + []metrics.Label{{Name: "op", Value: "upsert"}}) + + return c.state.ACLBindingRuleBatchSet(index, req.BindingRules) +} + +func (c *FSM) applyACLBindingRuleDeleteOperation(buf []byte, index uint64) interface{} { + var req structs.ACLBindingRuleBatchDeleteRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "bindingrule"}, time.Now(), + []metrics.Label{{Name: "op", Value: "delete"}}) + + return c.state.ACLBindingRuleBatchDelete(index, req.BindingRuleIDs) +} + +func (c *FSM) applyACLAuthMethodSetOperation(buf []byte, index uint64) interface{} { + var req structs.ACLAuthMethodBatchSetRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "authmethod"}, time.Now(), + []metrics.Label{{Name: "op", Value: "upsert"}}) + + return c.state.ACLAuthMethodBatchSet(index, req.AuthMethods) +} + +func (c *FSM) applyACLAuthMethodDeleteOperation(buf []byte, index uint64) interface{} { + var req structs.ACLAuthMethodBatchDeleteRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "authmethod"}, time.Now(), + []metrics.Label{{Name: "op", Value: "delete"}}) + + return c.state.ACLAuthMethodBatchDelete(index, req.AuthMethodNames) +} diff --git a/agent/consul/fsm/snapshot_oss.go b/agent/consul/fsm/snapshot_oss.go index 3ad281434b..195e6cf136 100644 --- a/agent/consul/fsm/snapshot_oss.go +++ b/agent/consul/fsm/snapshot_oss.go @@ -29,6 +29,8 @@ func init() { registerRestorer(structs.ACLPolicySetRequestType, restorePolicy) registerRestorer(structs.ConfigEntryRequestType, restoreConfigEntry) registerRestorer(structs.ACLRoleSetRequestType, restoreRole) + registerRestorer(structs.ACLBindingRuleSetRequestType, restoreBindingRule) + registerRestorer(structs.ACLAuthMethodSetRequestType, restoreAuthMethod) } func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) error { @@ -218,6 +220,34 @@ func (s *snapshot) persistACLs(sink raft.SnapshotSink, } } + rules, err := s.state.ACLBindingRules() + if err != nil { + return err + } + + for rule := rules.Next(); rule != nil; rule = rules.Next() { + if _, err := sink.Write([]byte{byte(structs.ACLBindingRuleSetRequestType)}); err != nil { + return err + } + if err := encoder.Encode(rule.(*structs.ACLBindingRule)); err != nil { + return err + } + } + + methods, err := s.state.ACLAuthMethods() + if err != nil { + return err + } + + for method := methods.Next(); method != nil; method = rules.Next() { + if _, err := sink.Write([]byte{byte(structs.ACLAuthMethodSetRequestType)}); err != nil { + return err + } + if err := encoder.Encode(method.(*structs.ACLAuthMethod)); err != nil { + return err + } + } + return nil } @@ -626,3 +656,19 @@ func restoreRole(header *snapshotHeader, restore *state.Restore, decoder *codec. } return restore.ACLRole(&req) } + +func restoreBindingRule(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.ACLBindingRule + if err := decoder.Decode(&req); err != nil { + return err + } + return restore.ACLBindingRule(&req) +} + +func restoreAuthMethod(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.ACLAuthMethod + if err := decoder.Decode(&req); err != nil { + return err + } + return restore.ACLAuthMethod(&req) +} diff --git a/agent/consul/fsm/snapshot_oss_test.go b/agent/consul/fsm/snapshot_oss_test.go index b8b5bb1327..9571b5f1a1 100644 --- a/agent/consul/fsm/snapshot_oss_test.go +++ b/agent/consul/fsm/snapshot_oss_test.go @@ -86,7 +86,7 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { }) session := &structs.Session{ID: generateUUID(), Node: "foo"} fsm.state.SessionCreate(9, session) - policy := structs.ACLPolicy{ + policy := &structs.ACLPolicy{ ID: structs.ACLPolicyGlobalManagementID, Name: "global-management", Description: "Builtin Policy that grants unlimited access", @@ -94,7 +94,20 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { Syntax: acl.SyntaxCurrent, } policy.SetHash(true) - require.NoError(fsm.state.ACLPolicySet(1, &policy)) + require.NoError(fsm.state.ACLPolicySet(1, policy)) + + role := &structs.ACLRole{ + ID: "86dedd19-8fae-4594-8294-4e6948a81f9a", + Name: "some-role", + Description: "test snapshot role", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "example", + }, + }, + } + role.SetHash(true) + require.NoError(fsm.state.ACLRoleSet(1, role)) token := &structs.ACLToken{ AccessorID: "30fca056-9fbb-4455-b94a-bf0e2bc575d6", @@ -112,6 +125,26 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { } require.NoError(fsm.state.ACLBootstrap(10, 0, token, false)) + method := &structs.ACLAuthMethod{ + Name: "some-method", + Type: "testing", + Description: "test snapshot auth method", + Config: map[string]interface{}{ + "SessionID": "952ebfa8-2a42-46f0-bcd3-fd98a842000e", + }, + } + require.NoError(fsm.state.ACLAuthMethodSet(1, method)) + + bindingRule := &structs.ACLBindingRule{ + ID: "85184c52-5997-4a84-9817-5945f2632a17", + Description: "test snapshot binding rule", + AuthMethod: "some-method", + Selector: "serviceaccount.namespace==default", + BindType: structs.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + } + require.NoError(fsm.state.ACLBindingRuleSet(1, bindingRule)) + fsm.state.KVSSet(11, &structs.DirEntry{ Key: "/remove", Value: []byte("foo"), @@ -314,21 +347,40 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { t.Fatalf("bad index: %d", idx) } - // Verify ACL Token is restored - _, a, err := fsm2.state.ACLTokenGetByAccessor(nil, token.AccessorID) + // Verify ACL Binding Rule is restored + _, bindingRule2, err := fsm2.state.ACLBindingRuleGetByID(nil, bindingRule.ID) require.NoError(err) - require.Equal(token.AccessorID, a.AccessorID) - require.Equal(token.ModifyIndex, a.ModifyIndex) + require.Equal(bindingRule, bindingRule2) + + // Verify ACL Auth Method is restored + _, method2, err := fsm2.state.ACLAuthMethodGetByName(nil, method.Name) + require.NoError(err) + require.Equal(method, method2) + + // Verify ACL Token is restored + _, token2, err := fsm2.state.ACLTokenGetByAccessor(nil, token.AccessorID) + require.NoError(err) + { + // time.Time is tricky to compare generically when it takes a ser/deserialization round trip. + require.True(token.CreateTime.Equal(token2.CreateTime)) + token2.CreateTime = token.CreateTime + } + require.Equal(token, token2) // Verify the acl-token-bootstrap index was restored canBootstrap, index, err := fsm2.state.CanBootstrapACLToken() require.False(canBootstrap) require.True(index > 0) + // Verify ACL Role is restored + _, role2, err := fsm2.state.ACLRoleGetByID(nil, role.ID) + require.NoError(err) + require.Equal(role, role2) + // Verify ACL Policy is restored _, policy2, err := fsm2.state.ACLPolicyGetByID(nil, structs.ACLPolicyGlobalManagementID) require.NoError(err) - require.Equal(policy.Name, policy2.Name) + require.Equal(policy, policy2) // Verify tombstones are restored func() { diff --git a/agent/consul/leader.go b/agent/consul/leader.go index cbef94da47..6f7799cfe6 100644 --- a/agent/consul/leader.go +++ b/agent/consul/leader.go @@ -427,6 +427,10 @@ func (s *Server) initializeACLs(upgrade bool) error { // leader. s.acls.cache.Purge() + // Purge the auth method validators since they could've changed while we + // were not leader. + s.purgeAuthMethodValidators() + // Remove any token affected by CVE-2019-8336 if !s.InACLDatacenter() { _, token, err := s.fsm.State().ACLTokenGetBySecret(nil, redactedToken) diff --git a/agent/consul/server.go b/agent/consul/server.go index b44b417816..d76a46a9b8 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -115,6 +115,9 @@ type Server struct { aclTokenReapLock sync.RWMutex aclTokenReapEnabled bool + aclAuthMethodValidators map[string]*authMethodValidatorEntry + aclAuthMethodValidatorLock sync.RWMutex + // DEPRECATED (ACL-Legacy-Compat) - only needed while we support both // useNewACLs is used to determine whether we can use new ACLs or not useNewACLs int32 diff --git a/agent/consul/state/acl.go b/agent/consul/state/acl.go index 8ff538922a..1249844c86 100644 --- a/agent/consul/state/acl.go +++ b/agent/consul/state/acl.go @@ -240,12 +240,20 @@ func tokensTableSchema() *memdb.TableSchema { Indexer: &TokenPoliciesIndex{}, }, "roles": &memdb.IndexSchema{ - Name: "roles", - // Need to allow missing for the anonymous token + Name: "roles", AllowMissing: true, Unique: false, Indexer: &TokenRolesIndex{}, }, + "authmethod": &memdb.IndexSchema{ + Name: "authmethod", + AllowMissing: true, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "AuthMethod", + Lowercase: false, + }, + }, "local": &memdb.IndexSchema{ Name: "local", AllowMissing: false, @@ -349,10 +357,54 @@ func rolesTableSchema() *memdb.TableSchema { } } +func bindingRulesTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "acl-binding-rules", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.UUIDFieldIndex{ + Field: "ID", + }, + }, + "authmethod": &memdb.IndexSchema{ + Name: "authmethod", + AllowMissing: false, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "AuthMethod", + Lowercase: true, + }, + }, + }, + } +} + +func authMethodsTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "acl-auth-methods", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Name", + Lowercase: true, + }, + }, + }, + } +} + func init() { registerSchema(tokensTableSchema) registerSchema(policiesTableSchema) registerSchema(rolesTableSchema) + registerSchema(bindingRulesTableSchema) + registerSchema(authMethodsTableSchema) } // ACLTokens is used when saving a snapshot @@ -417,6 +469,46 @@ func (s *Restore) ACLRole(role *structs.ACLRole) error { return nil } +// ACLBindingRules is used when saving a snapshot +func (s *Snapshot) ACLBindingRules() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("acl-binding-rules", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +func (s *Restore) ACLBindingRule(rule *structs.ACLBindingRule) error { + if err := s.tx.Insert("acl-binding-rules", rule); err != nil { + return fmt.Errorf("failed restoring acl binding rule: %s", err) + } + + if err := indexUpdateMaxTxn(s.tx, rule.ModifyIndex, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + return nil +} + +// ACLAuthMethods is used when saving a snapshot +func (s *Snapshot) ACLAuthMethods() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("acl-auth-methods", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +func (s *Restore) ACLAuthMethod(method *structs.ACLAuthMethod) error { + if err := s.tx.Insert("acl-auth-methods", method); err != nil { + return fmt.Errorf("failed restoring acl auth method: %s", err) + } + + if err := indexUpdateMaxTxn(s.tx, method.ModifyIndex, "acl-auth-methods"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + return nil +} + // ACLBootstrap is used to perform a one-time ACL bootstrap operation on a // cluster to get the first management token. func (s *Store) ACLBootstrap(idx, resetIndex uint64, token *structs.ACLToken, legacy bool) error { @@ -789,6 +881,15 @@ func (s *Store) aclTokenSetTxn(tx *memdb.Txn, idx uint64, token *structs.ACLToke return err } + if token.AuthMethod != "" { + method, err := s.getAuthMethodWithTxn(tx, nil, token.AuthMethod, "id") + if err != nil { + return err + } else if method == nil { + return fmt.Errorf("No such auth method with Name: %s", token.AuthMethod) + } + } + for _, svcid := range token.ServiceIdentities { if svcid.ServiceName == "" { return fmt.Errorf("Encountered a Token with an empty service identity name in the state store") @@ -890,7 +991,7 @@ func (s *Store) aclTokenGetTxn(tx *memdb.Txn, ws memdb.WatchSet, value, index st } // ACLTokenList is used to list out all of the ACLs in the state store. -func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy, role string) (uint64, structs.ACLTokens, error) { +func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy, role, methodName string) (uint64, structs.ACLTokens, error) { tx := s.db.Txn(false) defer tx.Abort() @@ -901,57 +1002,53 @@ func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy, role // to false but for defaulted structs (zero values for both) we want it to list out // all tokens so our checks just ensure that global == local - if policy != "" && role != "" { - return 0, nil, fmt.Errorf("cannot filter by role and policy at the same time") - } + needLocalityFilter := false + if policy == "" && role == "" && methodName == "" { + if global == local { + iter, err = tx.Get("acl-tokens", "id") + } else if global { + iter, err = tx.Get("acl-tokens", "local", false) + } else { + iter, err = tx.Get("acl-tokens", "local", true) + } - if policy != "" { + } else if policy != "" && role == "" && methodName == "" { iter, err = tx.Get("acl-tokens", "policies", policy) - if err == nil && global != local { - iter = memdb.NewFilterIterator(iter, func(raw interface{}) bool { - token, ok := raw.(*structs.ACLToken) - if !ok { - return true - } + needLocalityFilter = true - if global && !token.Local { - return false - } else if local && token.Local { - return false - } - - return true - }) - } - } else if role != "" { + } else if policy == "" && role != "" && methodName == "" { iter, err = tx.Get("acl-tokens", "roles", role) - if err == nil && global != local { - iter = memdb.NewFilterIterator(iter, func(raw interface{}) bool { - token, ok := raw.(*structs.ACLToken) - if !ok { - return true - } + needLocalityFilter = true - if global && !token.Local { - return false - } else if local && token.Local { - return false - } + } else if policy == "" && role == "" && methodName != "" { + iter, err = tx.Get("acl-tokens", "authmethod", methodName) + needLocalityFilter = true - return true - }) - } - } else if global == local { - iter, err = tx.Get("acl-tokens", "id") - } else if global { - iter, err = tx.Get("acl-tokens", "local", false) } else { - iter, err = tx.Get("acl-tokens", "local", true) + return 0, nil, fmt.Errorf("can only filter by one of policy, role, or methodName at a time") } if err != nil { return 0, nil, fmt.Errorf("failed acl token lookup: %v", err) } + + if needLocalityFilter && global != local { + iter = memdb.NewFilterIterator(iter, func(raw interface{}) bool { + token, ok := raw.(*structs.ACLToken) + if !ok { + return true + } + + if global && !token.Local { + return false + } else if local && token.Local { + return false + } + + return true + }) + } + ws.Add(iter.WatchCh()) var result structs.ACLTokens @@ -1114,6 +1211,35 @@ func (s *Store) aclTokenDeleteTxn(tx *memdb.Txn, idx uint64, value, index string return nil } +func (s *Store) aclTokenDeleteAllForAuthMethodTxn(tx *memdb.Txn, idx uint64, methodName string) error { + // collect them all + iter, err := tx.Get("acl-tokens", "authmethod", methodName) + if err != nil { + return fmt.Errorf("failed acl token lookup: %v", err) + } + + var tokens structs.ACLTokens + for raw := iter.Next(); raw != nil; raw = iter.Next() { + token := raw.(*structs.ACLToken) + tokens = append(tokens, token) + } + + if len(tokens) > 0 { + // delete them all + for _, token := range tokens { + if err := tx.Delete("acl-tokens", token); err != nil { + return fmt.Errorf("failed deleting acl token: %v", err) + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-tokens"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + } + + return nil +} + func (s *Store) ACLPolicyBatchSet(idx uint64, policies structs.ACLPolicies) error { tx := s.db.Txn(true) defer tx.Abort() @@ -1437,7 +1563,7 @@ func (s *Store) ACLRoleBatchGet(ws memdb.WatchSet, ids []string) (uint64, struct tx := s.db.Txn(false) defer tx.Abort() - roles := make(structs.ACLRoles, 0) + roles := make(structs.ACLRoles, 0, len(ids)) for _, rid := range ids { role, err := s.getRoleWithTxn(tx, ws, rid, "id") if err != nil { @@ -1579,3 +1705,384 @@ func (s *Store) aclRoleDeleteTxn(tx *memdb.Txn, idx uint64, value, index string) } return nil } + +func (s *Store) ACLBindingRuleBatchSet(idx uint64, rules structs.ACLBindingRules) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, rule := range rules { + if err := s.aclBindingRuleSetTxn(tx, idx, rule); err != nil { + return err + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) ACLBindingRuleSet(idx uint64, rule *structs.ACLBindingRule) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclBindingRuleSetTxn(tx, idx, rule); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclBindingRuleSetTxn(tx *memdb.Txn, idx uint64, rule *structs.ACLBindingRule) error { + // Check that the ID and AuthMethod are set + if rule.ID == "" { + return ErrMissingACLBindingRuleID + } else if rule.AuthMethod == "" { + return ErrMissingACLBindingRuleAuthMethod + } + + existing, err := tx.First("acl-binding-rules", "id", rule.ID) + if err != nil { + return fmt.Errorf("failed acl binding rule lookup: %v", err) + } + + // Set the indexes + if existing != nil { + rule.CreateIndex = existing.(*structs.ACLBindingRule).CreateIndex + rule.ModifyIndex = idx + } else { + rule.CreateIndex = idx + rule.ModifyIndex = idx + } + + if method, err := tx.First("acl-auth-methods", "id", rule.AuthMethod); err != nil { + return fmt.Errorf("failed acl auth method lookup: %v", err) + } else if method == nil { + return fmt.Errorf("failed inserting acl binding rule: auth method not found") + } + + if err := tx.Insert("acl-binding-rules", rule); err != nil { + return fmt.Errorf("failed inserting acl binding rule: %v", err) + } + return nil +} + +func (s *Store) ACLBindingRuleGetByID(ws memdb.WatchSet, id string) (uint64, *structs.ACLBindingRule, error) { + return s.aclBindingRuleGet(ws, id, "id") +} + +func (s *Store) aclBindingRuleGet(ws memdb.WatchSet, value, index string) (uint64, *structs.ACLBindingRule, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + watchCh, rawRule, err := tx.FirstWatch("acl-binding-rules", index, value) + if err != nil { + return 0, nil, fmt.Errorf("failed acl binding rule lookup: %v", err) + } + ws.Add(watchCh) + + var rule *structs.ACLBindingRule + if rawRule != nil { + rule = rawRule.(*structs.ACLBindingRule) + } + + idx := maxIndexTxn(tx, "acl-binding-rules") + + return idx, rule, nil +} + +func (s *Store) ACLBindingRuleList(ws memdb.WatchSet, methodName string) (uint64, structs.ACLBindingRules, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + var ( + iter memdb.ResultIterator + err error + ) + if methodName != "" { + iter, err = tx.Get("acl-binding-rules", "authmethod", methodName) + } else { + iter, err = tx.Get("acl-binding-rules", "id") + } + if err != nil { + return 0, nil, fmt.Errorf("failed acl binding rule lookup: %v", err) + } + ws.Add(iter.WatchCh()) + + var result structs.ACLBindingRules + for raw := iter.Next(); raw != nil; raw = iter.Next() { + rule := raw.(*structs.ACLBindingRule) + result = append(result, rule) + } + + // Get the table index. + idx := maxIndexTxn(tx, "acl-binding-rules") + + return idx, result, nil +} + +func (s *Store) ACLBindingRuleDeleteByID(idx uint64, id string) error { + return s.aclBindingRuleDelete(idx, id, "id") +} + +func (s *Store) ACLBindingRuleBatchDelete(idx uint64, bindingRuleIDs []string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, bindingRuleID := range bindingRuleIDs { + s.aclBindingRuleDeleteTxn(tx, idx, bindingRuleID, "id") + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + tx.Commit() + return nil +} + +func (s *Store) aclBindingRuleDelete(idx uint64, value, index string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclBindingRuleDeleteTxn(tx, idx, value, index); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclBindingRuleDeleteTxn(tx *memdb.Txn, idx uint64, value, index string) error { + // Look up the existing binding rule + rawRule, err := tx.First("acl-binding-rules", index, value) + if err != nil { + return fmt.Errorf("failed acl binding rule lookup: %v", err) + } + + if rawRule == nil { + return nil + } + + rule := rawRule.(*structs.ACLBindingRule) + + if err := tx.Delete("acl-binding-rules", rule); err != nil { + return fmt.Errorf("failed deleting acl binding rule: %v", err) + } + return nil +} + +func (s *Store) aclBindingRuleDeleteAllForAuthMethodTxn(tx *memdb.Txn, idx uint64, methodName string) error { + // collect them all + iter, err := tx.Get("acl-binding-rules", "authmethod", methodName) + if err != nil { + return fmt.Errorf("failed acl binding rule lookup: %v", err) + } + + var rules structs.ACLBindingRules + for raw := iter.Next(); raw != nil; raw = iter.Next() { + rule := raw.(*structs.ACLBindingRule) + rules = append(rules, rule) + } + + if len(rules) > 0 { + // delete them all + for _, rule := range rules { + if err := tx.Delete("acl-binding-rules", rule); err != nil { + return fmt.Errorf("failed deleting acl binding rule: %v", err) + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + } + + return nil +} + +func (s *Store) ACLAuthMethodBatchSet(idx uint64, methods structs.ACLAuthMethods) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, method := range methods { + // this is only used when doing batch insertions for upgrades and replication. Therefore + // we take whatever those said. + if err := s.aclAuthMethodSetTxn(tx, idx, method); err != nil { + return err + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) ACLAuthMethodSet(idx uint64, method *structs.ACLAuthMethod) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclAuthMethodSetTxn(tx, idx, method); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclAuthMethodSetTxn(tx *memdb.Txn, idx uint64, method *structs.ACLAuthMethod) error { + // Check that the Name and Type are set + if method.Name == "" { + return ErrMissingACLAuthMethodName + } else if method.Type == "" { + return ErrMissingACLAuthMethodType + } + + existing, err := tx.First("acl-auth-methods", "id", method.Name) + if err != nil { + return fmt.Errorf("failed acl auth method lookup: %v", err) + } + + // Set the indexes + if existing != nil { + method.CreateIndex = existing.(*structs.ACLAuthMethod).CreateIndex + method.ModifyIndex = idx + } else { + method.CreateIndex = idx + method.ModifyIndex = idx + } + + if err := tx.Insert("acl-auth-methods", method); err != nil { + return fmt.Errorf("failed inserting acl auth method: %v", err) + } + return nil +} + +func (s *Store) ACLAuthMethodGetByName(ws memdb.WatchSet, name string) (uint64, *structs.ACLAuthMethod, error) { + return s.aclAuthMethodGet(ws, name, "id") +} + +func (s *Store) aclAuthMethodGet(ws memdb.WatchSet, value, index string) (uint64, *structs.ACLAuthMethod, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + method, err := s.getAuthMethodWithTxn(tx, ws, value, index) + if err != nil { + return 0, nil, err + } + + idx := maxIndexTxn(tx, "acl-auth-methods") + + return idx, method, nil +} + +func (s *Store) getAuthMethodWithTxn(tx *memdb.Txn, ws memdb.WatchSet, value, index string) (*structs.ACLAuthMethod, error) { + watchCh, rawMethod, err := tx.FirstWatch("acl-auth-methods", index, value) + if err != nil { + return nil, fmt.Errorf("failed acl auth method lookup: %v", err) + } + ws.Add(watchCh) + + if rawMethod != nil { + return rawMethod.(*structs.ACLAuthMethod), nil + } + + return nil, nil +} + +func (s *Store) ACLAuthMethodList(ws memdb.WatchSet) (uint64, structs.ACLAuthMethods, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + iter, err := tx.Get("acl-auth-methods", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed acl auth method lookup: %v", err) + } + ws.Add(iter.WatchCh()) + + var result structs.ACLAuthMethods + for raw := iter.Next(); raw != nil; raw = iter.Next() { + method := raw.(*structs.ACLAuthMethod) + result = append(result, method) + } + + // Get the table index. + idx := maxIndexTxn(tx, "acl-auth-methods") + + return idx, result, nil +} + +func (s *Store) ACLAuthMethodDeleteByName(idx uint64, name string) error { + return s.aclAuthMethodDelete(idx, name, "id") +} + +func (s *Store) ACLAuthMethodBatchDelete(idx uint64, names []string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, name := range names { + s.aclAuthMethodDeleteTxn(tx, idx, name, "id") + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + tx.Commit() + return nil +} + +func (s *Store) aclAuthMethodDelete(idx uint64, value, index string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclAuthMethodDeleteTxn(tx, idx, value, index); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclAuthMethodDeleteTxn(tx *memdb.Txn, idx uint64, value, index string) error { + // Look up the existing method + rawMethod, err := tx.First("acl-auth-methods", index, value) + if err != nil { + return fmt.Errorf("failed acl auth method lookup: %v", err) + } + + if rawMethod == nil { + return nil + } + + method := rawMethod.(*structs.ACLAuthMethod) + + if err := s.aclBindingRuleDeleteAllForAuthMethodTxn(tx, idx, method.Name); err != nil { + return err + } + + if err := s.aclTokenDeleteAllForAuthMethodTxn(tx, idx, method.Name); err != nil { + return err + } + + if err := tx.Delete("acl-auth-methods", method); err != nil { + return fmt.Errorf("failed deleting acl auth method: %v", err) + } + return nil +} diff --git a/agent/consul/state/acl_test.go b/agent/consul/state/acl_test.go index b561e1e54c..7dddbaccb7 100644 --- a/agent/consul/state/acl_test.go +++ b/agent/consul/state/acl_test.go @@ -1,6 +1,7 @@ package state import ( + "fmt" "math/rand" "strconv" "testing" @@ -53,6 +54,17 @@ func testACLStateStore(t *testing.T) *Store { return s } +func setupExtraAuthMethods(t *testing.T, s *Store) { + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + }, + } + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) +} + func setupExtraPolicies(t *testing.T, s *Store) { policies := structs.ACLPolicies{ &structs.ACLPolicy{ @@ -205,7 +217,7 @@ func TestStateStore_ACLBootstrap(t *testing.T) { require.Equal(t, uint64(3), index) // Make sure the ACLs are in an expected state. - _, tokens, err := s.ACLTokenList(nil, true, true, "", "") + _, tokens, err := s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) require.Len(t, tokens, 1) compareTokens(t, token1, tokens[0]) @@ -219,7 +231,7 @@ func TestStateStore_ACLBootstrap(t *testing.T) { err = s.ACLBootstrap(32, index, token2.Clone(), false) require.NoError(t, err) - _, tokens, err = s.ACLTokenList(nil, true, true, "", "") + _, tokens, err = s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) require.Len(t, tokens, 2) } @@ -447,6 +459,19 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { require.Error(t, err) }) + t.Run("Unresolvable AuthMethod", func(t *testing.T) { + t.Parallel() + s := testACLTokensStateStore(t) + token := &structs.ACLToken{ + AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", + SecretID: "39171632-6f34-4411-827f-9416403687f4", + AuthMethod: "test", + } + + err := s.ACLTokenSet(2, token, false) + require.Error(t, err) + }) + t.Run("New", func(t *testing.T) { t.Parallel() s := testACLTokensStateStore(t) @@ -543,6 +568,37 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { require.Len(t, rtoken.ServiceIdentities, 1) require.Equal(t, "db", rtoken.ServiceIdentities[0].ServiceName) }) + + t.Run("New with auth method", func(t *testing.T) { + t.Parallel() + s := testACLTokensStateStore(t) + setupExtraAuthMethods(t, s) + + token := &structs.ACLToken{ + AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", + SecretID: "39171632-6f34-4411-827f-9416403687f4", + AuthMethod: "test", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: testRoleID_A, + }, + }, + } + + require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) + + idx, rtoken, err := s.ACLTokenGetByAccessor(nil, "daf37c07-d04d-4fd5-9678-a8206a57d61a") + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + compareTokens(t, token, rtoken) + require.Equal(t, uint64(2), rtoken.CreateIndex) + require.Equal(t, uint64(2), rtoken.ModifyIndex) + require.Equal(t, "test", rtoken.AuthMethod) + require.Len(t, rtoken.Policies, 0) + require.Len(t, rtoken.ServiceIdentities, 0) + require.Len(t, rtoken.Roles, 1) + require.Equal(t, "node-read-role", rtoken.Roles[0].Name) + }) } func TestStateStore_ACLTokens_UpsertBatchRead(t *testing.T) { @@ -828,6 +884,7 @@ func TestStateStore_ACLTokens_ListUpgradeable(t *testing.T) { func TestStateStore_ACLToken_List(t *testing.T) { t.Parallel() s := testACLTokensStateStore(t) + setupExtraAuthMethods(t, s) tokens := structs.ACLTokens{ // the local token @@ -893,118 +950,167 @@ func TestStateStore_ACLToken_List(t *testing.T) { }, Local: true, }, + // the method specific token + &structs.ACLToken{ + AccessorID: "74277ae1-6a9b-4035-b444-2370fe6a2cb5", + SecretID: "ab8ac834-0d35-4cb7-83c3-168203f986cd", + AuthMethod: "test", + }, + // the method specific token and local + &structs.ACLToken{ + AccessorID: "211f0360-ef53-41d3-9d4d-db84396eb6c0", + SecretID: "087a0eb4-366f-4190-ab4c-a4aa3d2562aa", + AuthMethod: "test", + Local: true, + }, } require.NoError(t, s.ACLTokenBatchSet(2, tokens, false)) type testCase struct { - name string - local bool - global bool - policy string - role string - accessors []string + name string + local bool + global bool + policy string + role string + methodName string + accessors []string } cases := []testCase{ { - name: "Global", - local: false, - global: true, - policy: "", - role: "", + name: "Global", + local: false, + global: true, + policy: "", + role: "", + methodName: "", accessors: []string{ structs.ACLTokenAnonymousID, "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global "54866514-3cf2-4fec-8a8a-710583831834", // mgmt + global + "74277ae1-6a9b-4035-b444-2370fe6a2cb5", // authMethod + global "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global }, }, { - name: "Local", - local: true, - global: false, - policy: "", - role: "", + name: "Local", + local: true, + global: false, + policy: "", + role: "", + methodName: "", accessors: []string{ + "211f0360-ef53-41d3-9d4d-db84396eb6c0", // authMethod + local "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local "f1093997-b6c7-496d-bfb8-6b1b1895641b", // mgmt + local }, }, { - name: "Policy", - local: true, - global: true, - policy: testPolicyID_A, - role: "", + name: "Policy", + local: true, + global: true, + policy: testPolicyID_A, + role: "", + methodName: "", accessors: []string{ "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local }, }, { - name: "Policy - Local", - local: true, - global: false, - policy: testPolicyID_A, - role: "", + name: "Policy - Local", + local: true, + global: false, + policy: testPolicyID_A, + role: "", + methodName: "", accessors: []string{ "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local }, }, { - name: "Policy - Global", - local: false, - global: true, - policy: testPolicyID_A, - role: "", + name: "Policy - Global", + local: false, + global: true, + policy: testPolicyID_A, + role: "", + methodName: "", accessors: []string{ "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global }, }, { - name: "Role", - local: true, - global: true, - policy: "", - role: testRoleID_A, + name: "Role", + local: true, + global: true, + policy: "", + role: testRoleID_A, + methodName: "", accessors: []string{ "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local }, }, { - name: "Role - Local", - local: true, - global: false, - policy: "", - role: testRoleID_A, + name: "Role - Local", + local: true, + global: false, + policy: "", + role: testRoleID_A, + methodName: "", accessors: []string{ "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local }, }, { - name: "Role - Global", - local: false, - global: true, - policy: "", - role: testRoleID_A, + name: "Role - Global", + local: false, + global: true, + policy: "", + role: testRoleID_A, + methodName: "", accessors: []string{ "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global }, }, { - name: "All", - local: true, - global: true, - policy: "", - role: "", + name: "AuthMethod - Local", + local: true, + global: false, + policy: "", + role: "", + methodName: "test", + accessors: []string{ + "211f0360-ef53-41d3-9d4d-db84396eb6c0", // authMethod + local + }, + }, + { + name: "AuthMethod - Global", + local: false, + global: true, + policy: "", + role: "", + methodName: "test", + accessors: []string{ + "74277ae1-6a9b-4035-b444-2370fe6a2cb5", // authMethod + global + }, + }, + { + name: "All", + local: true, + global: true, + policy: "", + role: "", + methodName: "", accessors: []string{ structs.ACLTokenAnonymousID, + "211f0360-ef53-41d3-9d4d-db84396eb6c0", // authMethod + local "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local "54866514-3cf2-4fec-8a8a-710583831834", // mgmt + global + "74277ae1-6a9b-4035-b444-2370fe6a2cb5", // authMethod + global "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local "f1093997-b6c7-496d-bfb8-6b1b1895641b", // mgmt + local @@ -1012,16 +1118,23 @@ func TestStateStore_ACLToken_List(t *testing.T) { }, } - t.Run("can't filter on both", func(t *testing.T) { - _, _, err := s.ACLTokenList(nil, false, false, testPolicyID_A, testRoleID_A) - require.Error(t, err) - }) + for _, tc := range []struct{ policy, role, methodName string }{ + {testPolicyID_A, testRoleID_A, "test"}, + {"", testRoleID_A, "test"}, + {testPolicyID_A, "", "test"}, + {testPolicyID_A, testRoleID_A, ""}, + } { + t.Run(fmt.Sprintf("can't filter on more than one: %s/%s/%s", tc.policy, tc.role, tc.methodName), func(t *testing.T) { + _, _, err := s.ACLTokenList(nil, false, false, tc.policy, tc.role, tc.methodName) + require.Error(t, err) + }) + } for _, tc := range cases { tc := tc // capture range variable t.Run(tc.name, func(t *testing.T) { t.Parallel() - _, tokens, err := s.ACLTokenList(nil, tc.local, tc.global, tc.policy, tc.role) + _, tokens, err := s.ACLTokenList(nil, tc.local, tc.global, tc.policy, tc.role, tc.methodName) require.NoError(t, err) require.Len(t, tokens, len(tc.accessors)) tokens.Sort() @@ -1082,7 +1195,7 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { require.Equal(t, "node-read-renamed", retrieved.Policies[0].Name) // list tokens without stale links - _, tokens, err := s.ACLTokenList(nil, true, true, "", "") + _, tokens, err := s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) found := false @@ -1126,7 +1239,7 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { require.Len(t, retrieved.Policies, 0) // list tokens without stale links - _, tokens, err = s.ACLTokenList(nil, true, true, "", "") + _, tokens, err = s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) found = false @@ -1211,7 +1324,7 @@ func TestStateStore_ACLToken_FixupRoleLinks(t *testing.T) { require.Equal(t, "node-read-role-renamed", retrieved.Roles[0].Name) // list tokens without stale links - _, tokens, err := s.ACLTokenList(nil, true, true, "", "") + _, tokens, err := s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) found := false @@ -1255,7 +1368,7 @@ func TestStateStore_ACLToken_FixupRoleLinks(t *testing.T) { require.Len(t, retrieved.Roles, 0) // list tokens without stale links - _, tokens, err = s.ACLTokenList(nil, true, true, "", "") + _, tokens, err = s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) found = false @@ -2504,6 +2617,730 @@ func TestStateStore_ACLRole_Delete(t *testing.T) { }) } +func TestStateStore_ACLAuthMethod_SetGet(t *testing.T) { + t.Parallel() + + // The state store only validates key pieces of data, so we only have to + // care about filling in Name+Type. + + t.Run("Missing Name", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + method := structs.ACLAuthMethod{ + Name: "", + Type: "testing", + Description: "test", + } + + require.Error(t, s.ACLAuthMethodSet(3, &method)) + }) + + t.Run("Missing Type", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + method := structs.ACLAuthMethod{ + Name: "test", + Type: "", + Description: "test", + } + + require.Error(t, s.ACLAuthMethodSet(3, &method)) + }) + + t.Run("New", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + method := structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + } + + require.NoError(t, s.ACLAuthMethodSet(3, &method)) + + idx, rmethod, err := s.ACLAuthMethodGetByName(nil, "test") + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.NotNil(t, rmethod) + require.Equal(t, "test", rmethod.Name) + require.Equal(t, "testing", rmethod.Type) + require.Equal(t, "test", rmethod.Description) + require.Equal(t, uint64(3), rmethod.CreateIndex) + require.Equal(t, uint64(3), rmethod.ModifyIndex) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + // Create the initial method + method := structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + } + + require.NoError(t, s.ACLAuthMethodSet(2, &method)) + + // Now make sure we can update it + update := structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "modified", + Config: map[string]interface{}{ + "Host": "https://localhost:8443", + }, + } + + require.NoError(t, s.ACLAuthMethodSet(3, &update)) + + idx, rmethod, err := s.ACLAuthMethodGetByName(nil, "test") + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.NotNil(t, rmethod) + require.Equal(t, "test", rmethod.Name) + require.Equal(t, "testing", rmethod.Type) + require.Equal(t, "modified", rmethod.Description) + require.Equal(t, update.Config, rmethod.Config) + require.Equal(t, uint64(2), rmethod.CreateIndex) + require.Equal(t, uint64(3), rmethod.ModifyIndex) + }) +} + +func TestStateStore_ACLAuthMethods_UpsertBatchRead(t *testing.T) { + t.Parallel() + + t.Run("Normal", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-1", + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + idx, rmethods, err := s.ACLAuthMethodList(nil) + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.Len(t, rmethods, 2) + rmethods.Sort() + require.ElementsMatch(t, methods, rmethods) + require.Equal(t, uint64(2), rmethods[0].CreateIndex) + require.Equal(t, uint64(2), rmethods[0].ModifyIndex) + require.Equal(t, uint64(2), rmethods[1].CreateIndex) + require.Equal(t, uint64(2), rmethods[1].ModifyIndex) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + // Seed initial data. + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + // Update two methods at the same time. + updates := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1 modified", + Config: map[string]interface{}{ + "Host": "https://localhost:8443", + }, + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2 modified", + Config: map[string]interface{}{ + "Host": "https://localhost:8444", + }, + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(3, updates)) + + idx, rmethods, err := s.ACLAuthMethodList(nil) + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.Len(t, rmethods, 2) + rmethods.Sort() + require.ElementsMatch(t, updates, rmethods) + require.Equal(t, uint64(2), rmethods[0].CreateIndex) + require.Equal(t, uint64(3), rmethods[0].ModifyIndex) + require.Equal(t, uint64(2), rmethods[1].CreateIndex) + require.Equal(t, uint64(3), rmethods[1].ModifyIndex) + }) +} + +func TestStateStore_ACLAuthMethod_List(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + _, rmethods, err := s.ACLAuthMethodList(nil) + require.NoError(t, err) + + require.Len(t, rmethods, 2) + rmethods.Sort() + + require.Equal(t, "test-1", rmethods[0].Name) + require.Equal(t, "testing", rmethods[0].Type) + require.Equal(t, "test-1", rmethods[0].Description) + require.Equal(t, uint64(2), rmethods[0].CreateIndex) + require.Equal(t, uint64(2), rmethods[0].ModifyIndex) + + require.Equal(t, "test-2", rmethods[1].Name) + require.Equal(t, "testing", rmethods[1].Type) + require.Equal(t, "test-2", rmethods[1].Description) + require.Equal(t, uint64(2), rmethods[1].CreateIndex) + require.Equal(t, uint64(2), rmethods[1].ModifyIndex) +} + +func TestStateStore_ACLAuthMethod_Delete(t *testing.T) { + t.Parallel() + + t.Run("Name", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + method := structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + } + + require.NoError(t, s.ACLAuthMethodSet(2, &method)) + + _, rmethod, err := s.ACLAuthMethodGetByName(nil, "test") + require.NoError(t, err) + require.NotNil(t, rmethod) + + require.NoError(t, s.ACLAuthMethodDeleteByName(3, "test")) + require.NoError(t, err) + + _, rmethod, err = s.ACLAuthMethodGetByName(nil, "test") + require.NoError(t, err) + require.Nil(t, rmethod) + }) + + t.Run("Multiple", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + _, rmethod, err := s.ACLAuthMethodGetByName(nil, "test-1") + require.NoError(t, err) + require.NotNil(t, rmethod) + _, rmethod, err = s.ACLAuthMethodGetByName(nil, "test-2") + require.NoError(t, err) + require.NotNil(t, rmethod) + + require.NoError(t, s.ACLAuthMethodBatchDelete(3, []string{"test-1", "test-2"})) + + _, rmethod, err = s.ACLAuthMethodGetByName(nil, "test-1") + require.NoError(t, err) + require.Nil(t, rmethod) + _, rmethod, err = s.ACLAuthMethodGetByName(nil, "test-2") + require.NoError(t, err) + require.Nil(t, rmethod) + }) + + t.Run("Not Found", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + // deletion of non-existant methods is not an error + require.NoError(t, s.ACLAuthMethodDeleteByName(3, "not-found")) + }) +} + +// Deleting an auth method atomically deletes all rules and tokens as well. +func TestStateStore_ACLAuthMethod_Delete_RuleAndTokenCascade(t *testing.T) { + t.Parallel() + + s := testACLStateStore(t) + + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2", + }, + } + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + const ( + method1_rule1 = "dff6f8a3-0115-4b22-8661-04a497ebb23c" + method1_rule2 = "69e2d304-703d-4889-bd94-4a720c061fc3" + method2_rule1 = "997ee45c-d6ba-4da1-a98e-aaa012e7d1e2" + method2_rule2 = "9ebae132-f1f1-4b72-b1d9-a4313ac22075" + ) + + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: method1_rule1, + AuthMethod: "test-1", + Description: "test-m1-r1", + }, + &structs.ACLBindingRule{ + ID: method1_rule2, + AuthMethod: "test-1", + Description: "test-m1-r2", + }, + &structs.ACLBindingRule{ + ID: method2_rule1, + AuthMethod: "test-2", + Description: "test-m2-r1", + }, + &structs.ACLBindingRule{ + ID: method2_rule2, + AuthMethod: "test-2", + Description: "test-m2-r2", + }, + } + require.NoError(t, s.ACLBindingRuleBatchSet(3, rules)) + + const ( // accessors + method1_tok1 = "6d020c5d-c4fd-4348-ba79-beac37ed0b9c" + method1_tok2 = "169160dc-34ab-45c6-aba7-ff65e9ace9cb" + method2_tok1 = "8e14628e-7dde-4573-aca1-6386c0f2095d" + method2_tok2 = "291e5af9-c68e-4dd3-8824-b2bdfdcc89e6" + ) + + tokens := structs.ACLTokens{ + &structs.ACLToken{ + AccessorID: method1_tok1, + SecretID: "7a1950c6-79dc-441c-acd2-e22cd3db0240", + Description: "test-m1-t1", + AuthMethod: "test-1", + }, + &structs.ACLToken{ + AccessorID: method1_tok2, + SecretID: "442cee4c-353f-4957-adbb-33db2f9e267f", + Description: "test-m1-t2", + AuthMethod: "test-1", + }, + &structs.ACLToken{ + AccessorID: method2_tok1, + SecretID: "d9399b7d-6c34-46bd-a675-c1352fadb6fd", + Description: "test-m2-t1", + AuthMethod: "test-2", + }, + &structs.ACLToken{ + AccessorID: method2_tok2, + SecretID: "3b72fc27-9230-42ab-a1e8-02cb489ab177", + Description: "test-m2-t2", + AuthMethod: "test-2", + }, + } + require.NoError(t, s.ACLTokenBatchSet(4, tokens, false)) + + // Delete one method. + require.NoError(t, s.ACLAuthMethodDeleteByName(4, "test-1")) + + // Make sure the method is gone. + _, rmethod, err := s.ACLAuthMethodGetByName(nil, "test-1") + require.NoError(t, err) + require.Nil(t, rmethod) + + // Make sure the rules and tokens are gone. + for _, ruleID := range []string{method1_rule1, method1_rule2} { + _, rrule, err := s.ACLBindingRuleGetByID(nil, ruleID) + require.NoError(t, err) + require.Nil(t, rrule) + } + for _, tokID := range []string{method1_tok1, method1_tok2} { + _, tok, err := s.ACLTokenGetByAccessor(nil, tokID) + require.NoError(t, err) + require.Nil(t, tok) + } + + // Make sure the rules and tokens for the untouched method are still there. + for _, ruleID := range []string{method2_rule1, method2_rule2} { + _, rrule, err := s.ACLBindingRuleGetByID(nil, ruleID) + require.NoError(t, err) + require.NotNil(t, rrule) + } + for _, tokID := range []string{method2_tok1, method2_tok2} { + _, tok, err := s.ACLTokenGetByAccessor(nil, tokID) + require.NoError(t, err) + require.NotNil(t, tok) + } +} + +func TestStateStore_ACLBindingRule_SetGet(t *testing.T) { + t.Parallel() + + // The state store only validates key pieces of data, so we only have to + // care about filling in ID+AuthMethod. + + t.Run("Missing ID", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rule := structs.ACLBindingRule{ + ID: "", + AuthMethod: "test", + Description: "test", + } + + require.Error(t, s.ACLBindingRuleSet(3, &rule)) + }) + + t.Run("Missing AuthMethod", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rule := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "", + Description: "test", + } + + require.Error(t, s.ACLBindingRuleSet(3, &rule)) + }) + + t.Run("Unknown AuthMethod", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rule := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "unknown", + Description: "test", + } + + require.Error(t, s.ACLBindingRuleSet(3, &rule)) + }) + + t.Run("New", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rule := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test", + } + + require.NoError(t, s.ACLBindingRuleSet(3, &rule)) + + idx, rrule, err := s.ACLBindingRuleGetByID(nil, rule.ID) + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.NotNil(t, rrule) + require.Equal(t, rule.ID, rrule.ID) + require.Equal(t, "test", rrule.AuthMethod) + require.Equal(t, "test", rrule.Description) + require.Equal(t, uint64(3), rrule.CreateIndex) + require.Equal(t, uint64(3), rrule.ModifyIndex) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + // Create the initial rule + rule := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test", + } + + require.NoError(t, s.ACLBindingRuleSet(2, &rule)) + + // Now make sure we can update it + update := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "modified", + BindType: structs.BindingRuleBindTypeService, + BindName: "web", + } + + require.NoError(t, s.ACLBindingRuleSet(3, &update)) + + idx, rrule, err := s.ACLBindingRuleGetByID(nil, rule.ID) + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.NotNil(t, rrule) + require.Equal(t, rule.ID, rrule.ID) + require.Equal(t, "test", rrule.AuthMethod) + require.Equal(t, "modified", rrule.Description) + require.Equal(t, structs.BindingRuleBindTypeService, rrule.BindType) + require.Equal(t, "web", rrule.BindName) + require.Equal(t, uint64(2), rrule.CreateIndex) + require.Equal(t, uint64(3), rrule.ModifyIndex) + }) +} + +func TestStateStore_ACLBindingRules_UpsertBatchRead(t *testing.T) { + t.Parallel() + + t.Run("Normal", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-1", + }, + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(2, rules)) + + idx, rrules, err := s.ACLBindingRuleList(nil, "test") + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.Len(t, rrules, 2) + rrules.Sort() + require.ElementsMatch(t, rules, rrules) + require.Equal(t, uint64(2), rrules[0].CreateIndex) + require.Equal(t, uint64(2), rrules[0].ModifyIndex) + require.Equal(t, uint64(2), rrules[1].CreateIndex) + require.Equal(t, uint64(2), rrules[1].ModifyIndex) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + // Seed initial data. + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-1", + }, + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(2, rules)) + + // Update two rules at the same time. + updates := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-1 modified", + BindType: structs.BindingRuleBindTypeService, + BindName: "web-1", + }, + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-2 modified", + BindType: structs.BindingRuleBindTypeService, + BindName: "web-2", + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(3, updates)) + + idx, rrules, err := s.ACLBindingRuleList(nil, "test") + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.Len(t, rrules, 2) + rrules.Sort() + require.ElementsMatch(t, updates, rrules) + require.Equal(t, uint64(2), rrules[0].CreateIndex) + require.Equal(t, uint64(3), rrules[0].ModifyIndex) + require.Equal(t, uint64(2), rrules[1].CreateIndex) + require.Equal(t, uint64(3), rrules[1].ModifyIndex) + }) +} + +func TestStateStore_ACLBindingRule_List(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-1", + }, + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(2, rules)) + + _, rrules, err := s.ACLBindingRuleList(nil, "") + require.NoError(t, err) + + require.Len(t, rrules, 2) + rrules.Sort() + + require.Equal(t, "3ebcc27b-f8ba-4611-b385-79a065dfb983", rrules[0].ID) + require.Equal(t, "test", rrules[0].AuthMethod) + require.Equal(t, "test-1", rrules[0].Description) + require.Equal(t, uint64(2), rrules[0].CreateIndex) + require.Equal(t, uint64(2), rrules[0].ModifyIndex) + + require.Equal(t, "9669b2d7-455c-4d70-b0ac-457fd7969a2e", rrules[1].ID) + require.Equal(t, "test", rrules[1].AuthMethod) + require.Equal(t, "test-2", rrules[1].Description) + require.Equal(t, uint64(2), rrules[1].CreateIndex) + require.Equal(t, uint64(2), rrules[1].ModifyIndex) +} + +func TestStateStore_ACLBindingRule_Delete(t *testing.T) { + t.Parallel() + + t.Run("Name", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rule := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test", + } + + require.NoError(t, s.ACLBindingRuleSet(2, &rule)) + + _, rrule, err := s.ACLBindingRuleGetByID(nil, rule.ID) + require.NoError(t, err) + require.NotNil(t, rrule) + + require.NoError(t, s.ACLBindingRuleDeleteByID(3, rule.ID)) + require.NoError(t, err) + + _, rrule, err = s.ACLBindingRuleGetByID(nil, rule.ID) + require.NoError(t, err) + require.Nil(t, rrule) + }) + + t.Run("Multiple", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-1", + }, + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(2, rules)) + + _, rrule, err := s.ACLBindingRuleGetByID(nil, rules[0].ID) + require.NoError(t, err) + require.NotNil(t, rrule) + _, rrule, err = s.ACLBindingRuleGetByID(nil, rules[1].ID) + require.NoError(t, err) + require.NotNil(t, rrule) + + require.NoError(t, s.ACLBindingRuleBatchDelete(3, []string{rules[0].ID, rules[1].ID})) + + _, rrule, err = s.ACLBindingRuleGetByID(nil, rules[0].ID) + require.NoError(t, err) + require.Nil(t, rrule) + _, rrule, err = s.ACLBindingRuleGetByID(nil, rules[1].ID) + require.NoError(t, err) + require.Nil(t, rrule) + }) + + t.Run("Not Found", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + // deletion of non-existant rules is not an error + require.NoError(t, s.ACLBindingRuleDeleteByID(3, "ed3ce1b8-3a16-4e2f-b82e-f92e3b92410d")) + }) +} + func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { s := testStateStore(t) @@ -2651,7 +3488,7 @@ func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { require.NoError(t, s.ACLRoleBatchSet(2, roles)) // Read the restored ACLs back out and verify that they match. - idx, res, err := s.ACLTokenList(nil, true, true, "", "") + idx, res, err := s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) require.Equal(t, uint64(4), idx) require.ElementsMatch(t, tokens, res) @@ -2991,3 +3828,120 @@ func TestStateStore_ACLRoles_Snapshot_Restore(t *testing.T) { require.Equal(t, uint64(2), s.maxIndex("acl-roles")) }() } + +func TestStateStore_ACLAuthMethods_Snapshot_Restore(t *testing.T) { + s := testACLStateStore(t) + + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2", + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + // Snapshot the ACLs. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + require.NoError(t, s.ACLAuthMethodDeleteByName(3, "test-1")) + + // Verify the snapshot. + require.Equal(t, uint64(2), snap.LastIndex()) + + iter, err := snap.ACLAuthMethods() + require.NoError(t, err) + + var dump structs.ACLAuthMethods + for method := iter.Next(); method != nil; method = iter.Next() { + dump = append(dump, method.(*structs.ACLAuthMethod)) + } + require.ElementsMatch(t, dump, methods) + + // Restore the values into a new state store. + func() { + s := testStateStore(t) + restore := s.Restore() + for _, method := range dump { + require.NoError(t, restore.ACLAuthMethod(method)) + } + restore.Commit() + + // Read the restored methods back out and verify that they match. + idx, res, err := s.ACLAuthMethodList(nil) + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.ElementsMatch(t, methods, res) + require.Equal(t, uint64(2), s.maxIndex("acl-auth-methods")) + }() +} + +func TestStateStore_ACLBindingRules_Snapshot_Restore(t *testing.T) { + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-1", + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-2", + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(2, rules)) + + // Snapshot the ACLs. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + require.NoError(t, s.ACLBindingRuleDeleteByID(3, rules[0].ID)) + + // Verify the snapshot. + require.Equal(t, uint64(2), snap.LastIndex()) + + iter, err := snap.ACLBindingRules() + require.NoError(t, err) + + var dump structs.ACLBindingRules + for rule := iter.Next(); rule != nil; rule = iter.Next() { + dump = append(dump, rule.(*structs.ACLBindingRule)) + } + require.ElementsMatch(t, dump, rules) + + // Restore the values into a new state store. + func() { + s := testStateStore(t) + setupExtraAuthMethods(t, s) + + restore := s.Restore() + for _, rule := range dump { + require.NoError(t, restore.ACLBindingRule(rule)) + } + restore.Commit() + + // Read the restored rules back out and verify that they match. + idx, res, err := s.ACLBindingRuleList(nil, "") + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.ElementsMatch(t, rules, res) + require.Equal(t, uint64(2), s.maxIndex("acl-binding-rules")) + }() +} diff --git a/agent/consul/state/state_store.go b/agent/consul/state/state_store.go index 4dcc74ddde..f38bc42da8 100644 --- a/agent/consul/state/state_store.go +++ b/agent/consul/state/state_store.go @@ -37,17 +37,29 @@ var ( // policy with an empty Name. ErrMissingACLPolicyName = errors.New("Missing ACL Policy Name") - // ErrMissingACLRoleID is returned when an role set is called on + // ErrMissingACLRoleID is returned when a role set is called on // a role with an empty ID. ErrMissingACLRoleID = errors.New("Missing ACL Role ID") - // ErrMissingACLRoleName is returned when an role set is called on + // ErrMissingACLRoleName is returned when a role set is called on // a role with an empty Name. ErrMissingACLRoleName = errors.New("Missing ACL Role Name") - // ErrInvalidACLRoleName is returned when an role set is called on - // a role with an invalid Name. - ErrInvalidACLRoleName = errors.New("Invalid ACL Role Name") + // ErrMissingACLBindingRuleID is returned when a binding rule set + // is called on a binding rule with an empty ID. + ErrMissingACLBindingRuleID = errors.New("Missing ACL Binding Rule ID") + + // ErrMissingACLBindingRuleAuthMethod is returned when a binding rule set + // is called on a binding rule with an empty AuthMethod. + ErrMissingACLBindingRuleAuthMethod = errors.New("Missing ACL Binding Rule Auth Method") + + // ErrMissingACLAuthMethodName is returned when an auth method set is + // called on an auth method with an empty Name. + ErrMissingACLAuthMethodName = errors.New("Missing ACL Auth Method Name") + + // ErrMissingACLAuthMethodType is returned when an auth method set is + // called on an auth method with an empty Type. + ErrMissingACLAuthMethodType = errors.New("Missing ACL Auth Method Type") // ErrMissingQueryID is returned when a Query set is called on // a Query with an empty ID. diff --git a/agent/consul/util.go b/agent/consul/util.go index cb0134240a..19ff2cc2b2 100644 --- a/agent/consul/util.go +++ b/agent/consul/util.go @@ -6,10 +6,13 @@ import ( "net" "runtime" "strconv" + "strings" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/go-version" + "github.com/hashicorp/hil" + "github.com/hashicorp/hil/ast" "github.com/hashicorp/serf/serf" ) @@ -322,3 +325,42 @@ func ServersGetACLMode(members []serf.Member, leader string, datacenter string) return } + +// InterpolateHIL processes the string as if it were HIL and interpolates only +// the provided string->string map as possible variables. +func InterpolateHIL(s string, vars map[string]string) (string, error) { + if strings.Index(s, "${") == -1 { + // Skip going to the trouble of parsing something that has no HIL. + return s, nil + } + + tree, err := hil.Parse(s) + if err != nil { + return "", err + } + + vm := make(map[string]ast.Variable) + for k, v := range vars { + vm[k] = ast.Variable{ + Type: ast.TypeString, + Value: v, + } + } + + config := &hil.EvalConfig{ + GlobalScope: &ast.BasicScope{ + VarMap: vm, + }, + } + + result, err := hil.Eval(tree, config) + if err != nil { + return "", err + } + + if result.Type != hil.TypeString { + return "", fmt.Errorf("generated unexpected hil type: %s", result.Type) + } + + return result.Value.(string), nil +} diff --git a/agent/consul/util_test.go b/agent/consul/util_test.go index b0c6e04d30..654673b628 100644 --- a/agent/consul/util_test.go +++ b/agent/consul/util_test.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/go-version" "github.com/hashicorp/serf/serf" + "github.com/stretchr/testify/require" ) func TestGetPrivateIP(t *testing.T) { @@ -403,3 +404,133 @@ func TestServersMeetMinimumVersion(t *testing.T) { } } } + +func TestInterpolateHIL(t *testing.T) { + for _, test := range []struct { + name string + in string + vars map[string]string + exp string + ok bool + }{ + // valid HIL + { + "empty", + "", + map[string]string{}, + "", + true, + }, + { + "no vars", + "nothing", + map[string]string{}, + "nothing", + true, + }, + { + "just var", + "${item}", + map[string]string{"item": "value"}, + "value", + true, + }, + { + "var in middle", + "before ${item}after", + map[string]string{"item": "value"}, + "before valueafter", + true, + }, + { + "two vars", + "before ${item}after ${more}", + map[string]string{"item": "value", "more": "xyz"}, + "before valueafter xyz", + true, + }, + { + "missing map val", + "${item}", + map[string]string{"item": ""}, + "", + true, + }, + // "weird" HIL, but not technically invalid + { + "just end", + "}", + map[string]string{}, + "}", + true, + }, + { + "var without start", + " item }", + map[string]string{"item": "value"}, + " item }", + true, + }, + { + "two vars missing second start", + "before ${ item }after more }", + map[string]string{"item": "value", "more": "xyz"}, + "before valueafter more }", + true, + }, + // invalid HIL + { + "just start", + "${", + map[string]string{}, + "", + false, + }, + { + "backwards", + "}${", + map[string]string{}, + "", + false, + }, + { + "no varname", + "${}", + map[string]string{}, + "", + false, + }, + { + "missing map key", + "${item}", + map[string]string{}, + "", + false, + }, + { + "var without end", + "${ item ", + map[string]string{"item": "value"}, + "", + false, + }, + { + "two vars missing first end", + "before ${ item after ${ more }", + map[string]string{"item": "value", "more": "xyz"}, + "", + false, + }, + } { + t.Run(test.name, func(t *testing.T) { + out, err := InterpolateHIL(test.in, test.vars) + if test.ok { + require.NoError(t, err) + require.Equal(t, test.exp, out) + } else { + require.NotNil(t, err) + require.Equal(t, out, "") + } + }) + } +} diff --git a/agent/http_oss.go b/agent/http_oss.go index 6a0d5917bc..a4584a5a49 100644 --- a/agent/http_oss.go +++ b/agent/http_oss.go @@ -10,6 +10,8 @@ func init() { registerEndpoint("/v1/acl/info/", []string{"GET"}, (*HTTPServer).ACLGet) registerEndpoint("/v1/acl/clone/", []string{"PUT"}, (*HTTPServer).ACLClone) registerEndpoint("/v1/acl/list", []string{"GET"}, (*HTTPServer).ACLList) + registerEndpoint("/v1/acl/login", []string{"POST"}, (*HTTPServer).ACLLogin) + registerEndpoint("/v1/acl/logout", []string{"POST"}, (*HTTPServer).ACLLogout) registerEndpoint("/v1/acl/replication", []string{"GET"}, (*HTTPServer).ACLReplicationStatus) registerEndpoint("/v1/acl/policies", []string{"GET"}, (*HTTPServer).ACLPolicyList) registerEndpoint("/v1/acl/policy", []string{"PUT"}, (*HTTPServer).ACLPolicyCreate) @@ -18,6 +20,12 @@ func init() { registerEndpoint("/v1/acl/role", []string{"PUT"}, (*HTTPServer).ACLRoleCreate) registerEndpoint("/v1/acl/role/name/", []string{"GET"}, (*HTTPServer).ACLRoleReadByName) registerEndpoint("/v1/acl/role/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).ACLRoleCRUD) + registerEndpoint("/v1/acl/binding-rules", []string{"GET"}, (*HTTPServer).ACLBindingRuleList) + registerEndpoint("/v1/acl/binding-rule", []string{"PUT"}, (*HTTPServer).ACLBindingRuleCreate) + registerEndpoint("/v1/acl/binding-rule/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).ACLBindingRuleCRUD) + registerEndpoint("/v1/acl/auth-methods", []string{"GET"}, (*HTTPServer).ACLAuthMethodList) + registerEndpoint("/v1/acl/auth-method", []string{"PUT"}, (*HTTPServer).ACLAuthMethodCreate) + registerEndpoint("/v1/acl/auth-method/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).ACLAuthMethodCRUD) registerEndpoint("/v1/acl/rules/translate", []string{"POST"}, (*HTTPServer).ACLRulesTranslate) registerEndpoint("/v1/acl/rules/translate/", []string{"GET"}, (*HTTPServer).ACLRulesTranslateLegacyToken) registerEndpoint("/v1/acl/tokens", []string{"GET"}, (*HTTPServer).ACLTokenList) diff --git a/agent/structs/acl.go b/agent/structs/acl.go index 0b46b21e31..0fc63a12cd 100644 --- a/agent/structs/acl.go +++ b/agent/structs/acl.go @@ -186,9 +186,12 @@ func (s *ACLServiceIdentity) SyntheticPolicy() *ACLPolicy { rules := fmt.Sprintf(aclPolicyTemplateServiceIdentity, s.ServiceName, s.ServiceName) hasher := fnv.New128a() + hashID := fmt.Sprintf("%x", hasher.Sum([]byte(rules))) + policy := &ACLPolicy{} - policy.ID = fmt.Sprintf("%x", hasher.Sum([]byte(rules))) - policy.Name = fmt.Sprintf("synthetic-policy-%s", policy.ID) + policy.ID = hashID + policy.Name = fmt.Sprintf("synthetic-policy-%s", hashID) + policy.Description = "synthetic policy" policy.Rules = rules policy.Syntax = acl.SyntaxCurrent policy.Datacenters = s.Datacenters @@ -234,6 +237,9 @@ type ACLToken struct { // to the ACL datacenter and replicated to others. Local bool + // AuthMethod is the name of the auth method used to create this token. + AuthMethod string `json:",omitempty"` + // ExpirationTime represents the point after which a token should be // considered revoked and is eligible for destruction. The zero value // represents NO expiration. @@ -309,7 +315,11 @@ func (t *ACLToken) PolicyIDs() []string { } func (t *ACLToken) RoleIDs() []string { - var ids []string + if len(t.Roles) == 0 { + return nil + } + + ids := make([]string, 0, len(t.Roles)) for _, link := range t.Roles { ids = append(ids, link.ID) } @@ -345,7 +355,8 @@ func (t *ACLToken) UsesNonLegacyFields() bool { len(t.Roles) > 0 || t.Type == "" || t.HasExpirationTime() || - t.ExpirationTTL != 0 + t.ExpirationTTL != 0 || + t.AuthMethod != "" } func (t *ACLToken) EmbeddedPolicy() *ACLPolicy { @@ -428,7 +439,7 @@ func (t *ACLToken) SetHash(force bool) []byte { func (t *ACLToken) EstimateSize() int { // 41 = 16 (RaftIndex) + 8 (Hash) + 8 (ExpirationTime) + 8 (CreateTime) + 1 (Local) - size := 41 + len(t.AccessorID) + len(t.SecretID) + len(t.Description) + len(t.Type) + len(t.Rules) + size := 41 + len(t.AccessorID) + len(t.SecretID) + len(t.Description) + len(t.Type) + len(t.Rules) + len(t.AuthMethod) for _, link := range t.Policies { size += len(link.ID) + len(link.Name) } @@ -451,6 +462,7 @@ type ACLTokenListStub struct { Roles []ACLTokenRoleLink `json:",omitempty"` ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` Local bool + AuthMethod string `json:",omitempty"` ExpirationTime *time.Time `json:",omitempty"` CreateTime time.Time `json:",omitempty"` Hash []byte @@ -469,6 +481,7 @@ func (token *ACLToken) Stub() *ACLTokenListStub { Roles: token.Roles, ServiceIdentities: token.ServiceIdentities, Local: token.Local, + AuthMethod: token.AuthMethod, ExpirationTime: token.ExpirationTime, CreateTime: token.CreateTime, Hash: token.Hash, @@ -722,8 +735,6 @@ type ACLRole struct { ID string // Name is the unique name to reference the role by. - // - // Validated with structs.isValidRoleName() Name string // Description is a human readable description (Optional) @@ -819,6 +830,136 @@ func (r *ACLRole) EstimateSize() int { return size } +const ( + // BindingRuleBindTypeService is the binding rule bind type that + // assigns a Service Identity to the token that is created using the value + // of the computed BindName as the ServiceName like: + // + // &ACLToken{ + // ...other fields... + // ServiceIdentities: []*ACLServiceIdentity{ + // &ACLServiceIdentity{ + // ServiceName: "", + // }, + // }, + // } + BindingRuleBindTypeService = "service" + + // BindingRuleBindTypeRole is the binding rule bind type that only allows + // the binding rule to function if a role with the given name (BindName) + // exists at login-time. If it does the token that is created is directly + // linked to that role like: + // + // &ACLToken{ + // ...other fields... + // Roles: []ACLTokenRoleLink{ + // { Name: "" } + // } + // } + // + // If it does not exist at login-time the rule is ignored. + BindingRuleBindTypeRole = "role" +) + +type ACLBindingRule struct { + // ID is the internal UUID associated with the binding rule + ID string + + // Description is a human readable description (Optional) + Description string + + // AuthMethod is the name of the auth method for which this rule applies. + AuthMethod string + + // Selector is an expression that matches against verified identity + // attributes returned from the auth method during login. + Selector string + + // BindType adjusts how this binding rule is applied at login time. The + // valid values are: + // + // - BindingRuleBindTypeService = "service" + // - BindingRuleBindTypeRole = "role" + BindType string + + // BindName is the target of the binding. Can be lightly templated using + // HIL ${foo} syntax from available field names. How it is used depends + // upon the BindType. + BindName string + + // Embedded Raft Metadata + RaftIndex `hash:"ignore"` +} + +func (r *ACLBindingRule) Clone() *ACLBindingRule { + r2 := *r + return &r2 +} + +type ACLBindingRules []*ACLBindingRule + +func (rules ACLBindingRules) Sort() { + sort.Slice(rules, func(i, j int) bool { + return rules[i].ID < rules[j].ID + }) +} + +type ACLAuthMethodListStub struct { + Name string + Description string + Type string + CreateIndex uint64 + ModifyIndex uint64 +} + +func (p *ACLAuthMethod) Stub() *ACLAuthMethodListStub { + return &ACLAuthMethodListStub{ + Name: p.Name, + Description: p.Description, + Type: p.Type, + CreateIndex: p.CreateIndex, + ModifyIndex: p.ModifyIndex, + } +} + +type ACLAuthMethods []*ACLAuthMethod +type ACLAuthMethodListStubs []*ACLAuthMethodListStub + +func (methods ACLAuthMethods) Sort() { + sort.Slice(methods, func(i, j int) bool { + return methods[i].Name < methods[j].Name + }) +} + +func (methods ACLAuthMethodListStubs) Sort() { + sort.Slice(methods, func(i, j int) bool { + return methods[i].Name < methods[j].Name + }) +} + +type ACLAuthMethod struct { + // Name is a unique identifier for this specific auth method. + // + // Immutable once set and only settable during create. + Name string + + // Type is the type of the auth method this is. + // + // Immutable once set and only settable during create. + Type string + + // Description is just an optional bunch of explanatory text. + Description string + + // Configuration is arbitrary configuration for the auth method. This + // should only contain primitive values and containers (such as lists and + // maps). + Config map[string]interface{} + + // Embedded Raft Metadata + RaftIndex `hash:"ignore"` +} + type ACLReplicationType string const ( @@ -898,6 +1039,7 @@ type ACLTokenListRequest struct { IncludeGlobal bool // Whether global tokens should be included Policy string // Policy filter Role string // Role filter + AuthMethod string // Auth Method filter Datacenter string // The datacenter to perform the request within QueryOptions } @@ -1068,7 +1210,7 @@ func cloneStringSlice(s []string) []string { // ACLRoleSetRequest is used at the RPC layer for creation and update requests type ACLRoleSetRequest struct { - Role ACLRole // The policy to upsert + Role ACLRole // The role to upsert Datacenter string // The datacenter to perform the request within WriteRequest } @@ -1154,3 +1296,161 @@ type ACLRoleBatchSetRequest struct { type ACLRoleBatchDeleteRequest struct { RoleIDs []string } + +// ACLBindingRuleSetRequest is used at the RPC layer for creation and update requests +type ACLBindingRuleSetRequest struct { + BindingRule ACLBindingRule // The rule to upsert + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLBindingRuleSetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLBindingRuleDeleteRequest is used at the RPC layer deletion requests +type ACLBindingRuleDeleteRequest struct { + BindingRuleID string // id of the rule to delete + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLBindingRuleDeleteRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLBindingRuleGetRequest is used at the RPC layer to perform rule read operations +type ACLBindingRuleGetRequest struct { + BindingRuleID string // id used for the rule lookup + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLBindingRuleGetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLBindingRuleListRequest is used at the RPC layer to request a listing of rules +type ACLBindingRuleListRequest struct { + AuthMethod string // optional filter + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLBindingRuleListRequest) RequestDatacenter() string { + return r.Datacenter +} + +type ACLBindingRuleListResponse struct { + BindingRules ACLBindingRules + QueryMeta +} + +// ACLBindingRuleResponse returns a single binding + metadata +type ACLBindingRuleResponse struct { + BindingRule *ACLBindingRule + QueryMeta +} + +// ACLBindingRuleBatchSetRequest is used at the Raft layer for batching +// multiple rule creations and updates +type ACLBindingRuleBatchSetRequest struct { + BindingRules ACLBindingRules +} + +// ACLBindingRuleBatchDeleteRequest is used at the Raft layer for batching +// multiple rule deletions +type ACLBindingRuleBatchDeleteRequest struct { + BindingRuleIDs []string +} + +// ACLAuthMethodSetRequest is used at the RPC layer for creation and update requests +type ACLAuthMethodSetRequest struct { + AuthMethod ACLAuthMethod // The auth method to upsert + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLAuthMethodSetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLAuthMethodDeleteRequest is used at the RPC layer deletion requests +type ACLAuthMethodDeleteRequest struct { + AuthMethodName string // name of the auth method to delete + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLAuthMethodDeleteRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLAuthMethodGetRequest is used at the RPC layer to perform rule read operations +type ACLAuthMethodGetRequest struct { + AuthMethodName string // name used for the auth method lookup + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLAuthMethodGetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLAuthMethodListRequest is used at the RPC layer to request a listing of auth methods +type ACLAuthMethodListRequest struct { + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLAuthMethodListRequest) RequestDatacenter() string { + return r.Datacenter +} + +type ACLAuthMethodListResponse struct { + AuthMethods ACLAuthMethodListStubs + QueryMeta +} + +// ACLAuthMethodResponse returns a single auth method + metadata +type ACLAuthMethodResponse struct { + AuthMethod *ACLAuthMethod + QueryMeta +} + +// ACLAuthMethodBatchSetRequest is used at the Raft layer for batching +// multiple auth method creations and updates +type ACLAuthMethodBatchSetRequest struct { + AuthMethods ACLAuthMethods +} + +// ACLAuthMethodBatchDeleteRequest is used at the Raft layer for batching +// multiple auth method deletions +type ACLAuthMethodBatchDeleteRequest struct { + AuthMethodNames []string +} + +type ACLLoginParams struct { + AuthMethod string + BearerToken string + Meta map[string]string `json:",omitempty"` +} + +type ACLLoginRequest struct { + Auth *ACLLoginParams + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLLoginRequest) RequestDatacenter() string { + return r.Datacenter +} + +type ACLLogoutRequest struct { + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLLogoutRequest) RequestDatacenter() string { + return r.Datacenter +} diff --git a/agent/structs/acl_cache.go b/agent/structs/acl_cache.go index 8a4f494194..1494727070 100644 --- a/agent/structs/acl_cache.go +++ b/agent/structs/acl_cache.go @@ -60,7 +60,6 @@ func (e *AuthorizerCacheEntry) Age() time.Duration { return time.Since(e.CacheTime) } -// RoleCacheEntry is the payload for by by-id and by-name caches. type RoleCacheEntry struct { Role *ACLRole CacheTime time.Time @@ -173,7 +172,7 @@ func (c *ACLCaches) GetAuthorizer(id string) *AuthorizerCacheEntry { return nil } -// GetRoleByID fetches a role from the cache by id and returns it +// GetRole fetches a role from the cache by id and returns it func (c *ACLCaches) GetRole(roleID string) *RoleCacheEntry { if c == nil || c.roles == nil { return nil @@ -228,9 +227,11 @@ func (c *ACLCaches) PutAuthorizerWithTTL(id string, authorizer acl.Authorizer, t } func (c *ACLCaches) PutRole(roleID string, role *ACLRole) { - if c != nil && c.roles != nil { - c.roles.Add(roleID, &RoleCacheEntry{Role: role, CacheTime: time.Now()}) + if c == nil || c.roles == nil { + return } + + c.roles.Add(roleID, &RoleCacheEntry{Role: role, CacheTime: time.Now()}) } func (c *ACLCaches) RemoveIdentity(id string) { @@ -246,7 +247,7 @@ func (c *ACLCaches) RemovePolicy(policyID string) { } func (c *ACLCaches) RemoveRole(roleID string) { - if c != nil && c.roles != nil && roleID != "" { + if c != nil && c.roles != nil { c.roles.Remove(roleID) } } diff --git a/agent/structs/acl_cache_test.go b/agent/structs/acl_cache_test.go index dbbf717c88..337d1860f3 100644 --- a/agent/structs/acl_cache_test.go +++ b/agent/structs/acl_cache_test.go @@ -113,6 +113,7 @@ func TestStructs_ACLCaches(t *testing.T) { require.NotNil(t, cache) cache.PutRole("foo", &ACLRole{}) + entry := cache.GetRole("foo") require.NotNil(t, entry) require.NotNil(t, entry.Role) diff --git a/agent/structs/acl_test.go b/agent/structs/acl_test.go index 0d69c9886d..a7860a49d6 100644 --- a/agent/structs/acl_test.go +++ b/agent/structs/acl_test.go @@ -188,12 +188,13 @@ node_prefix "" { expect := &ACLPolicy{ Syntax: acl.SyntaxCurrent, Datacenters: test.datacenters, + Description: "synthetic policy", Rules: test.expectRules, } got := svcid.SyntheticPolicy() require.NotEmpty(t, got.ID) - require.Equal(t, got.Name, "synthetic-policy-"+got.ID) + require.True(t, strings.HasPrefix(got.Name, "synthetic-policy-")) // strip irrelevant fields before equality got.ID = "" got.Name = "" diff --git a/agent/structs/structs.go b/agent/structs/structs.go index 2ad0a4e97c..56a19fc4e7 100644 --- a/agent/structs/structs.go +++ b/agent/structs/structs.go @@ -33,31 +33,35 @@ type RaftIndex struct { // These are serialized between Consul servers and stored in Consul snapshots, // so entries must only ever be added. const ( - RegisterRequestType MessageType = 0 - DeregisterRequestType = 1 - KVSRequestType = 2 - SessionRequestType = 3 - ACLRequestType = 4 // DEPRECATED (ACL-Legacy-Compat) - TombstoneRequestType = 5 - CoordinateBatchUpdateType = 6 - PreparedQueryRequestType = 7 - TxnRequestType = 8 - AutopilotRequestType = 9 - AreaRequestType = 10 - ACLBootstrapRequestType = 11 - IntentionRequestType = 12 - ConnectCARequestType = 13 - ConnectCAProviderStateType = 14 - ConnectCAConfigType = 15 // FSM snapshots only. - IndexRequestType = 16 // FSM snapshots only. - ACLTokenSetRequestType = 17 - ACLTokenDeleteRequestType = 18 - ACLPolicySetRequestType = 19 - ACLPolicyDeleteRequestType = 20 - ConnectCALeafRequestType = 21 - ConfigEntryRequestType = 22 - ACLRoleSetRequestType = 23 - ACLRoleDeleteRequestType = 24 + RegisterRequestType MessageType = 0 + DeregisterRequestType = 1 + KVSRequestType = 2 + SessionRequestType = 3 + ACLRequestType = 4 // DEPRECATED (ACL-Legacy-Compat) + TombstoneRequestType = 5 + CoordinateBatchUpdateType = 6 + PreparedQueryRequestType = 7 + TxnRequestType = 8 + AutopilotRequestType = 9 + AreaRequestType = 10 + ACLBootstrapRequestType = 11 + IntentionRequestType = 12 + ConnectCARequestType = 13 + ConnectCAProviderStateType = 14 + ConnectCAConfigType = 15 // FSM snapshots only. + IndexRequestType = 16 // FSM snapshots only. + ACLTokenSetRequestType = 17 + ACLTokenDeleteRequestType = 18 + ACLPolicySetRequestType = 19 + ACLPolicyDeleteRequestType = 20 + ConnectCALeafRequestType = 21 + ConfigEntryRequestType = 22 + ACLRoleSetRequestType = 23 + ACLRoleDeleteRequestType = 24 + ACLBindingRuleSetRequestType = 25 + ACLBindingRuleDeleteRequestType = 26 + ACLAuthMethodSetRequestType = 27 + ACLAuthMethodDeleteRequestType = 28 ) const ( diff --git a/api/acl.go b/api/acl.go index 2713d0ddc9..3327f667c3 100644 --- a/api/acl.go +++ b/api/acl.go @@ -6,6 +6,8 @@ import ( "io/ioutil" "net/url" "time" + + "github.com/mitchellh/mapstructure" ) const ( @@ -132,6 +134,96 @@ type ACLRole struct { ModifyIndex uint64 } +// BindingRuleBindType is the type of binding rule mechanism used. +type BindingRuleBindType string + +const ( + // BindingRuleBindTypeService binds to a service identity with the given name. + BindingRuleBindTypeService BindingRuleBindType = "service" + + // BindingRuleBindTypeRole binds to pre-existing roles with the given name. + BindingRuleBindTypeRole BindingRuleBindType = "role" +) + +type ACLBindingRule struct { + ID string + Description string + AuthMethod string + Selector string + BindType BindingRuleBindType + BindName string + + CreateIndex uint64 + ModifyIndex uint64 +} + +type ACLAuthMethod struct { + Name string + Type string + Description string + + // Configuration is arbitrary configuration for the auth method. This + // should only contain primitive values and containers (such as lists and + // maps). + Config map[string]interface{} + + CreateIndex uint64 + ModifyIndex uint64 +} + +type ACLAuthMethodListEntry struct { + Name string + Type string + Description string + CreateIndex uint64 + ModifyIndex uint64 +} + +// ParseKubernetesAuthMethodConfig takes a raw config map and returns a parsed +// KubernetesAuthMethodConfig. +func ParseKubernetesAuthMethodConfig(raw map[string]interface{}) (*KubernetesAuthMethodConfig, error) { + var config KubernetesAuthMethodConfig + decodeConf := &mapstructure.DecoderConfig{ + Result: &config, + WeaklyTypedInput: true, + } + + decoder, err := mapstructure.NewDecoder(decodeConf) + if err != nil { + return nil, err + } + + if err := decoder.Decode(raw); err != nil { + return nil, fmt.Errorf("error decoding config: %s", err) + } + + return &config, nil +} + +// KubernetesAuthMethodConfig is the config for the built-in Consul auth method +// for Kubernetes. +type KubernetesAuthMethodConfig struct { + Host string `json:",omitempty"` + CACert string `json:",omitempty"` + ServiceAccountJWT string `json:",omitempty"` +} + +// RenderToConfig converts this into a map[string]interface{} suitable for use +// in the ACLAuthMethod.Config field. +func (c *KubernetesAuthMethodConfig) RenderToConfig() map[string]interface{} { + return map[string]interface{}{ + "Host": c.Host, + "CACert": c.CACert, + "ServiceAccountJWT": c.ServiceAccountJWT, + } +} + +type ACLLoginParams struct { + AuthMethod string + BearerToken string + Meta map[string]string `json:",omitempty"` +} + // ACL can be used to query the ACL endpoints type ACL struct { c *Client @@ -498,7 +590,7 @@ func (a *ACL) PolicyCreate(policy *ACLPolicy, q *WriteOptions) (*ACLPolicy, *Wri // existing policy ID func (a *ACL) PolicyUpdate(policy *ACLPolicy, q *WriteOptions) (*ACLPolicy, *WriteMeta, error) { if policy.ID == "" { - return nil, nil, fmt.Errorf("Must specify an ID in Policy Creation") + return nil, nil, fmt.Errorf("Must specify an ID in Policy Update") } r := a.c.newRequest("PUT", "/v1/acl/policy/"+policy.ID) @@ -654,7 +746,7 @@ func (a *ACL) RoleCreate(role *ACLRole, q *WriteOptions) (*ACLRole, *WriteMeta, // existing role ID func (a *ACL) RoleUpdate(role *ACLRole, q *WriteOptions) (*ACLRole, *WriteMeta, error) { if role.ID == "" { - return nil, nil, fmt.Errorf("Must specify an ID in Role Creation") + return nil, nil, fmt.Errorf("Must specify an ID in Role Update") } r := a.c.newRequest("PUT", "/v1/acl/role/"+role.ID) @@ -763,3 +855,271 @@ func (a *ACL) RoleList(q *QueryOptions) ([]*ACLRole, *QueryMeta, error) { } return entries, qm, nil } + +// AuthMethodCreate will create a new auth method. +func (a *ACL) AuthMethodCreate(method *ACLAuthMethod, q *WriteOptions) (*ACLAuthMethod, *WriteMeta, error) { + if method.Name == "" { + return nil, nil, fmt.Errorf("Must specify a Name in Auth Method Creation") + } + + r := a.c.newRequest("PUT", "/v1/acl/auth-method") + r.setWriteOptions(q) + r.obj = method + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLAuthMethod + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// AuthMethodUpdate updates an auth method. +func (a *ACL) AuthMethodUpdate(method *ACLAuthMethod, q *WriteOptions) (*ACLAuthMethod, *WriteMeta, error) { + if method.Name == "" { + return nil, nil, fmt.Errorf("Must specify a Name in Auth Method Update") + } + + r := a.c.newRequest("PUT", "/v1/acl/auth-method/"+url.QueryEscape(method.Name)) + r.setWriteOptions(q) + r.obj = method + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLAuthMethod + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// AuthMethodDelete deletes an auth method given its Name. +func (a *ACL) AuthMethodDelete(methodName string, q *WriteOptions) (*WriteMeta, error) { + if methodName == "" { + return nil, fmt.Errorf("Must specify a Name in Auth Method Delete") + } + + r := a.c.newRequest("DELETE", "/v1/acl/auth-method/"+url.QueryEscape(methodName)) + r.setWriteOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, err + } + resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + return wm, nil +} + +// AuthMethodRead retrieves the auth method. Returns nil if not found. +func (a *ACL) AuthMethodRead(methodName string, q *QueryOptions) (*ACLAuthMethod, *QueryMeta, error) { + if methodName == "" { + return nil, nil, fmt.Errorf("Must specify a Name in Auth Method Read") + } + + r := a.c.newRequest("GET", "/v1/acl/auth-method/"+url.QueryEscape(methodName)) + r.setQueryOptions(q) + found, rtt, resp, err := requireNotFoundOrOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + if !found { + return nil, qm, nil + } + + var out ACLAuthMethod + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, qm, nil +} + +// AuthMethodList retrieves a listing of all auth methods. The listing does not +// include some metadata for the auth method as those should be retrieved by +// subsequent calls to AuthMethodRead. +func (a *ACL) AuthMethodList(q *QueryOptions) ([]*ACLAuthMethodListEntry, *QueryMeta, error) { + r := a.c.newRequest("GET", "/v1/acl/auth-methods") + r.setQueryOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + var entries []*ACLAuthMethodListEntry + if err := decodeBody(resp, &entries); err != nil { + return nil, nil, err + } + return entries, qm, nil +} + +// BindingRuleCreate will create a new binding rule. It is not allowed for the +// binding rule parameter's ID field to be set as this will be generated by +// Consul while processing the request. +func (a *ACL) BindingRuleCreate(rule *ACLBindingRule, q *WriteOptions) (*ACLBindingRule, *WriteMeta, error) { + if rule.ID != "" { + return nil, nil, fmt.Errorf("Cannot specify an ID in Binding Rule Creation") + } + + r := a.c.newRequest("PUT", "/v1/acl/binding-rule") + r.setWriteOptions(q) + r.obj = rule + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLBindingRule + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// BindingRuleUpdate updates a binding rule. The ID field of the role binding +// rule parameter must be set to an existing binding rule ID. +func (a *ACL) BindingRuleUpdate(rule *ACLBindingRule, q *WriteOptions) (*ACLBindingRule, *WriteMeta, error) { + if rule.ID == "" { + return nil, nil, fmt.Errorf("Must specify an ID in Binding Rule Update") + } + + r := a.c.newRequest("PUT", "/v1/acl/binding-rule/"+rule.ID) + r.setWriteOptions(q) + r.obj = rule + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLBindingRule + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// BindingRuleDelete deletes a binding rule given its ID. +func (a *ACL) BindingRuleDelete(bindingRuleID string, q *WriteOptions) (*WriteMeta, error) { + r := a.c.newRequest("DELETE", "/v1/acl/binding-rule/"+bindingRuleID) + r.setWriteOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, err + } + resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + return wm, nil +} + +// BindingRuleRead retrieves the binding rule details. Returns nil if not found. +func (a *ACL) BindingRuleRead(bindingRuleID string, q *QueryOptions) (*ACLBindingRule, *QueryMeta, error) { + r := a.c.newRequest("GET", "/v1/acl/binding-rule/"+bindingRuleID) + r.setQueryOptions(q) + found, rtt, resp, err := requireNotFoundOrOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + if !found { + return nil, qm, nil + } + + var out ACLBindingRule + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, qm, nil +} + +// BindingRuleList retrieves a listing of all binding rules. +func (a *ACL) BindingRuleList(methodName string, q *QueryOptions) ([]*ACLBindingRule, *QueryMeta, error) { + r := a.c.newRequest("GET", "/v1/acl/binding-rules") + if methodName != "" { + r.params.Set("authmethod", methodName) + } + r.setQueryOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + var entries []*ACLBindingRule + if err := decodeBody(resp, &entries); err != nil { + return nil, nil, err + } + return entries, qm, nil +} + +// Login is used to exchange auth method credentials for a newly-minted Consul Token. +func (a *ACL) Login(auth *ACLLoginParams, q *WriteOptions) (*ACLToken, *WriteMeta, error) { + r := a.c.newRequest("POST", "/v1/acl/login") + r.setWriteOptions(q) + r.obj = auth + + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLToken + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + return &out, wm, nil +} + +// Logout is used to destroy a Consul Token created via Login(). +func (a *ACL) Logout(q *WriteOptions) (*WriteMeta, error) { + r := a.c.newRequest("POST", "/v1/acl/logout") + r.setWriteOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, err + } + resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + return wm, nil +} diff --git a/api/api.go b/api/api.go index e8370d0441..4b17ff6cda 100644 --- a/api/api.go +++ b/api/api.go @@ -30,6 +30,10 @@ const ( // the HTTP token. HTTPTokenEnvName = "CONSUL_HTTP_TOKEN" + // HTTPTokenFileEnvName defines an environment variable name which sets + // the HTTP token file. + HTTPTokenFileEnvName = "CONSUL_HTTP_TOKEN_FILE" + // HTTPAuthEnvName defines an environment variable name which sets // the HTTP authentication header. HTTPAuthEnvName = "CONSUL_HTTP_AUTH" @@ -280,6 +284,10 @@ type Config struct { // which overrides the agent's default token. Token string + // TokenFile is a file containing the current token to use for this client. + // If provided it is read once at startup and never again. + TokenFile string + TLSConfig TLSConfig } @@ -343,6 +351,10 @@ func defaultConfig(transportFn func() *http.Transport) *Config { config.Address = addr } + if tokenFile := os.Getenv(HTTPTokenFileEnvName); tokenFile != "" { + config.TokenFile = tokenFile + } + if token := os.Getenv(HTTPTokenEnvName); token != "" { config.Token = token } @@ -449,6 +461,7 @@ func (c *Config) GenerateEnv() []string { env = append(env, fmt.Sprintf("%s=%s", HTTPAddrEnvName, c.Address), fmt.Sprintf("%s=%s", HTTPTokenEnvName, c.Token), + fmt.Sprintf("%s=%s", HTTPTokenFileEnvName, c.TokenFile), fmt.Sprintf("%s=%t", HTTPSSLEnvName, c.Scheme == "https"), fmt.Sprintf("%s=%s", HTTPCAFile, c.TLSConfig.CAFile), fmt.Sprintf("%s=%s", HTTPCAPath, c.TLSConfig.CAPath), @@ -541,6 +554,19 @@ func NewClient(config *Config) (*Client, error) { config.Address = parts[1] } + // If the TokenFile is set, always use that, even if a Token is configured. + // This is because when TokenFile is set it is read into the Token field. + // We want any derived clients to have to re-read the token file. + if config.TokenFile != "" { + data, err := ioutil.ReadFile(config.TokenFile) + if err != nil { + return nil, fmt.Errorf("Error loading token file: %s", err) + } + + if token := strings.TrimSpace(string(data)); token != "" { + config.Token = token + } + } if config.Token == "" { config.Token = defConfig.Token } @@ -820,6 +846,8 @@ func (c *Client) write(endpoint string, in, out interface{}, q *WriteOptions) (* } // parseQueryMeta is used to help parse query meta-data +// +// TODO(rb): bug? the error from this function is never handled func parseQueryMeta(resp *http.Response, q *QueryMeta) error { header := resp.Header diff --git a/api/api_test.go b/api/api_test.go index eca799e022..7934ed87b1 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -875,9 +875,10 @@ func TestAPI_GenerateEnv(t *testing.T) { t.Parallel() c := &Config{ - Address: "127.0.0.1:8500", - Token: "test", - Scheme: "http", + Address: "127.0.0.1:8500", + Token: "test", + TokenFile: "test.file", + Scheme: "http", TLSConfig: TLSConfig{ CAFile: "", CAPath: "", @@ -891,6 +892,7 @@ func TestAPI_GenerateEnv(t *testing.T) { expected := []string{ "CONSUL_HTTP_ADDR=127.0.0.1:8500", "CONSUL_HTTP_TOKEN=test", + "CONSUL_HTTP_TOKEN_FILE=test.file", "CONSUL_HTTP_SSL=false", "CONSUL_CACERT=", "CONSUL_CAPATH=", @@ -908,9 +910,10 @@ func TestAPI_GenerateEnvHTTPS(t *testing.T) { t.Parallel() c := &Config{ - Address: "127.0.0.1:8500", - Token: "test", - Scheme: "https", + Address: "127.0.0.1:8500", + Token: "test", + TokenFile: "test.file", + Scheme: "https", TLSConfig: TLSConfig{ CAFile: "/var/consul/ca.crt", CAPath: "/var/consul/ca.dir", @@ -928,6 +931,7 @@ func TestAPI_GenerateEnvHTTPS(t *testing.T) { expected := []string{ "CONSUL_HTTP_ADDR=127.0.0.1:8500", "CONSUL_HTTP_TOKEN=test", + "CONSUL_HTTP_TOKEN_FILE=test.file", "CONSUL_HTTP_SSL=true", "CONSUL_CACERT=/var/consul/ca.crt", "CONSUL_CAPATH=/var/consul/ca.dir", diff --git a/command/acl/acl_helpers.go b/command/acl/acl_helpers.go index 928f8beb55..57c51ce14b 100644 --- a/command/acl/acl_helpers.go +++ b/command/acl/acl_helpers.go @@ -1,6 +1,7 @@ package acl import ( + "encoding/json" "fmt" "strings" @@ -23,20 +24,26 @@ func PrintToken(token *api.ACLToken, ui cli.Ui, showMeta bool) { ui.Info(fmt.Sprintf("Create Index: %d", token.CreateIndex)) ui.Info(fmt.Sprintf("Modify Index: %d", token.ModifyIndex)) } - ui.Info(fmt.Sprintf("Policies:")) - for _, policy := range token.Policies { - ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + if len(token.Policies) > 0 { + ui.Info(fmt.Sprintf("Policies:")) + for _, policy := range token.Policies { + ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + } } - ui.Info(fmt.Sprintf("Roles:")) - for _, role := range token.Roles { - ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name)) + if len(token.Roles) > 0 { + ui.Info(fmt.Sprintf("Roles:")) + for _, role := range token.Roles { + ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name)) + } } - ui.Info(fmt.Sprintf("Service Identities:")) - for _, svcid := range token.ServiceIdentities { - if len(svcid.Datacenters) > 0 { - ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) - } else { - ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + if len(token.ServiceIdentities) > 0 { + ui.Info(fmt.Sprintf("Service Identities:")) + for _, svcid := range token.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } } } if token.Rules != "" { @@ -59,20 +66,26 @@ func PrintTokenListEntry(token *api.ACLTokenListEntry, ui cli.Ui, showMeta bool) ui.Info(fmt.Sprintf("Create Index: %d", token.CreateIndex)) ui.Info(fmt.Sprintf("Modify Index: %d", token.ModifyIndex)) } - ui.Info(fmt.Sprintf("Policies:")) - for _, policy := range token.Policies { - ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + if len(token.Policies) > 0 { + ui.Info(fmt.Sprintf("Policies:")) + for _, policy := range token.Policies { + ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + } } - ui.Info(fmt.Sprintf("Roles:")) - for _, role := range token.Roles { - ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name)) + if len(token.Roles) > 0 { + ui.Info(fmt.Sprintf("Roles:")) + for _, role := range token.Roles { + ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name)) + } } - ui.Info(fmt.Sprintf("Service Identities:")) - for _, svcid := range token.ServiceIdentities { - if len(svcid.Datacenters) > 0 { - ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) - } else { - ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + if len(token.ServiceIdentities) > 0 { + ui.Info(fmt.Sprintf("Service Identities:")) + for _, svcid := range token.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } } } } @@ -112,16 +125,20 @@ func PrintRole(role *api.ACLRole, ui cli.Ui, showMeta bool) { ui.Info(fmt.Sprintf("Create Index: %d", role.CreateIndex)) ui.Info(fmt.Sprintf("Modify Index: %d", role.ModifyIndex)) } - ui.Info(fmt.Sprintf("Policies:")) - for _, policy := range role.Policies { - ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + if len(role.Policies) > 0 { + ui.Info(fmt.Sprintf("Policies:")) + for _, policy := range role.Policies { + ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + } } - ui.Info(fmt.Sprintf("Service Identities:")) - for _, svcid := range role.ServiceIdentities { - if len(svcid.Datacenters) > 0 { - ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) - } else { - ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + if len(role.ServiceIdentities) > 0 { + ui.Info(fmt.Sprintf("Service Identities:")) + for _, svcid := range role.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } } } } @@ -135,18 +152,74 @@ func PrintRoleListEntry(role *api.ACLRole, ui cli.Ui, showMeta bool) { ui.Info(fmt.Sprintf(" Create Index: %d", role.CreateIndex)) ui.Info(fmt.Sprintf(" Modify Index: %d", role.ModifyIndex)) } - ui.Info(fmt.Sprintf(" Policies:")) - for _, policy := range role.Policies { - ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) - } - ui.Info(fmt.Sprintf(" Service Identities:")) - for _, svcid := range role.ServiceIdentities { - if len(svcid.Datacenters) > 0 { - ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) - } else { - ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + if len(role.Policies) > 0 { + ui.Info(fmt.Sprintf(" Policies:")) + for _, policy := range role.Policies { + ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) } } + if len(role.ServiceIdentities) > 0 { + ui.Info(fmt.Sprintf(" Service Identities:")) + for _, svcid := range role.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } + } + } +} + +func PrintAuthMethod(method *api.ACLAuthMethod, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("Name: %s", method.Name)) + ui.Info(fmt.Sprintf("Type: %s", method.Type)) + ui.Info(fmt.Sprintf("Description: %s", method.Description)) + if showMeta { + ui.Info(fmt.Sprintf("Create Index: %d", method.CreateIndex)) + ui.Info(fmt.Sprintf("Modify Index: %d", method.ModifyIndex)) + } + ui.Info(fmt.Sprintf("Config:")) + output, err := json.MarshalIndent(method.Config, "", " ") + if err != nil { + ui.Error(fmt.Sprintf("Error formatting auth method configuration: %s", err)) + } + ui.Output(string(output)) +} + +func PrintAuthMethodListEntry(method *api.ACLAuthMethodListEntry, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("%s:", method.Name)) + ui.Info(fmt.Sprintf(" Type: %s", method.Type)) + ui.Info(fmt.Sprintf(" Description: %s", method.Description)) + if showMeta { + ui.Info(fmt.Sprintf(" Create Index: %d", method.CreateIndex)) + ui.Info(fmt.Sprintf(" Modify Index: %d", method.ModifyIndex)) + } +} + +func PrintBindingRule(rule *api.ACLBindingRule, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("ID: %s", rule.ID)) + ui.Info(fmt.Sprintf("AuthMethod: %s", rule.AuthMethod)) + ui.Info(fmt.Sprintf("Description: %s", rule.Description)) + ui.Info(fmt.Sprintf("BindType: %s", rule.BindType)) + ui.Info(fmt.Sprintf("BindName: %s", rule.BindName)) + ui.Info(fmt.Sprintf("Selector: %s", rule.Selector)) + if showMeta { + ui.Info(fmt.Sprintf("Create Index: %d", rule.CreateIndex)) + ui.Info(fmt.Sprintf("Modify Index: %d", rule.ModifyIndex)) + } +} + +func PrintBindingRuleListEntry(rule *api.ACLBindingRule, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("%s:", rule.ID)) + ui.Info(fmt.Sprintf(" AuthMethod: %s", rule.AuthMethod)) + ui.Info(fmt.Sprintf(" Description: %s", rule.Description)) + ui.Info(fmt.Sprintf(" BindType: %s", rule.BindType)) + ui.Info(fmt.Sprintf(" BindName: %s", rule.BindName)) + ui.Info(fmt.Sprintf(" Selector: %s", rule.Selector)) + if showMeta { + ui.Info(fmt.Sprintf(" Create Index: %d", rule.CreateIndex)) + ui.Info(fmt.Sprintf(" Modify Index: %d", rule.ModifyIndex)) + } } func GetTokenIDFromPartial(client *api.Client, partialID string) (string, error) { @@ -309,6 +382,34 @@ func GetRoleIDByName(client *api.Client, name string) (string, error) { return "", fmt.Errorf("No such role with name %s", name) } +func GetBindingRuleIDFromPartial(client *api.Client, partialID string) (string, error) { + // the full UUID string was given + if len(partialID) == 36 { + return partialID, nil + } + + rules, _, err := client.ACL().BindingRuleList("", nil) + if err != nil { + return "", err + } + + ruleID := "" + for _, rule := range rules { + if strings.HasPrefix(rule.ID, partialID) { + if ruleID != "" { + return "", fmt.Errorf("Partial rule ID is not unique") + } + ruleID = rule.ID + } + } + + if ruleID == "" { + return "", fmt.Errorf("No such rule ID with prefix: %s", partialID) + } + + return ruleID, nil +} + func ExtractServiceIdentities(serviceIdents []string) ([]*api.ACLServiceIdentity, error) { var out []*api.ACLServiceIdentity for _, svcidRaw := range serviceIdents { @@ -329,3 +430,27 @@ func ExtractServiceIdentities(serviceIdents []string) ([]*api.ACLServiceIdentity } return out, nil } + +// TestKubernetesJWT_A is a valid service account jwt extracted from a minikube setup. +// +// { +// "iss": "kubernetes/serviceaccount", +// "kubernetes.io/serviceaccount/namespace": "default", +// "kubernetes.io/serviceaccount/secret.name": "admin-token-qlz42", +// "kubernetes.io/serviceaccount/service-account.name": "admin", +// "kubernetes.io/serviceaccount/service-account.uid": "738bc251-6532-11e9-b67f-48e6c8b8ecb5", +// "sub": "system:serviceaccount:default:admin" +// } +const TestKubernetesJWT_A = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImFkbWluLXRva2VuLXFsejQyIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQubmFtZSI6ImFkbWluIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQudWlkIjoiNzM4YmMyNTEtNjUzMi0xMWU5LWI2N2YtNDhlNmM4YjhlY2I1Iiwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6YWRtaW4ifQ.ixMlnWrAG7NVuTTKu8cdcYfM7gweS3jlKaEsIBNGOVEjPE7rtXtgMkAwjQTdYR08_0QBjkgzy5fQC5ZNyglSwONJ-bPaXGvhoH1cTnRi1dz9H_63CfqOCvQP1sbdkMeRxNTGVAyWZT76rXoCUIfHP4LY2I8aab0KN9FTIcgZRF0XPTtT70UwGIrSmRpxW38zjiy2ymWL01cc5VWGhJqVysmWmYk3wNp0h5N57H_MOrz4apQR4pKaamzskzjLxO55gpbmZFC76qWuUdexAR7DT2fpbHLOw90atN_NlLMY-VrXyW3-Ei5EhYaVreMB9PSpKwkrA4jULITohV-sxpa1LA" + +// TestKubernetesJWT_B is a valid service account jwt extracted from a minikube setup. +// +// { +// "iss": "kubernetes/serviceaccount", +// "kubernetes.io/serviceaccount/namespace": "default", +// "kubernetes.io/serviceaccount/secret.name": "demo-token-kmb9n", +// "kubernetes.io/serviceaccount/service-account.name": "demo", +// "kubernetes.io/serviceaccount/service-account.uid": "76091af4-4b56-11e9-ac4b-708b11801cbe", +// "sub": "system:serviceaccount:default:demo" +// } +const TestKubernetesJWT_B = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImRlbW8tdG9rZW4ta21iOW4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZGVtbyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6Ijc2MDkxYWY0LTRiNTYtMTFlOS1hYzRiLTcwOGIxMTgwMWNiZSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmRlbW8ifQ.ZiAHjijBAOsKdum0Aix6lgtkLkGo9_Tu87dWQ5Zfwnn3r2FejEWDAnftTft1MqqnMzivZ9Wyyki5ZjQRmTAtnMPJuHC-iivqY4Wh4S6QWCJ1SivBv5tMZR79t5t8mE7R1-OHwst46spru1pps9wt9jsA04d3LpV0eeKYgdPTVaQKklxTm397kIMUugA6yINIBQ3Rh8eQqBgNwEmL4iqyYubzHLVkGkoP9MJikFI05vfRiHtYr-piXz6JFDzXMQj9rW6xtMmrBSn79ChbyvC5nz-Nj2rJPnHsb_0rDUbmXY5PpnMhBpdSH-CbZ4j8jsiib6DtaGJhVZeEQ1GjsFAZwQ" diff --git a/command/acl/authmethod/authmethod.go b/command/acl/authmethod/authmethod.go new file mode 100644 index 0000000000..d8be7857ab --- /dev/null +++ b/command/acl/authmethod/authmethod.go @@ -0,0 +1,64 @@ +package authmethod + +import ( + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New() *cmd { + return &cmd{} +} + +type cmd struct{} + +func (c *cmd) Run(args []string) int { + return cli.RunResultHelp +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(help, nil) +} + +const synopsis = "Manage Consul's ACL Auth Methods" +const help = ` +Usage: consul acl auth-method [options] [args] + + This command has subcommands for managing Consul's ACL Auth Methods. + Here are some simple examples, and more detailed examples are available in + the subcommands or the documentation. + + Create a new auth method: + + $ consul acl auth-method create -type "kubernetes" \ + -name "my-k8s" \ + -description "This is an example kube auth method" \ + -kubernetes-host "https://apiserver.example.com:8443" \ + -kubernetes-ca-file /path/to/kube.ca.crt \ + -kubernetes-service-account-jwt "JWT_CONTENTS" + + List all auth methods: + + $ consul acl auth-method list + + Update all editable fields of the auth method: + + $ consul acl auth-method update -name "my-k8s" \ + -description "new description" \ + -kubernetes-host "https://new-apiserver.example.com:8443" \ + -kubernetes-ca-file /path/to/new-kube.ca.crt \ + -kubernetes-service-account-jwt "NEW_JWT_CONTENTS" + + Read an auth method: + + $ consul acl auth-method read -name my-k8s + + Delete an auth method: + + $ consul acl auth-method delete -name my-k8s + + For more examples, ask for subcommand help or view the documentation. +` diff --git a/command/acl/authmethod/create/authmethod_create.go b/command/acl/authmethod/create/authmethod_create.go new file mode 100644 index 0000000000..46a55882b6 --- /dev/null +++ b/command/acl/authmethod/create/authmethod_create.go @@ -0,0 +1,186 @@ +package authmethodcreate + +import ( + "flag" + "fmt" + "io" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/hashicorp/consul/command/helpers" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + authMethodType string + name string + description string + + k8sHost string + k8sCACert string + k8sServiceAccountJWT string + + showMeta bool + + testStdin io.Reader +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that auth method metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.authMethodType, + "type", + "", + "The new auth method's type. This flag is required.", + ) + c.flags.StringVar( + &c.name, + "name", + "", + "The new auth method's name. This flag is required.", + ) + c.flags.StringVar( + &c.description, + "description", + "", + "A description of the auth method.", + ) + + c.flags.StringVar( + &c.k8sHost, + "kubernetes-host", + "", + "Address of the Kubernetes API server. This flag is required for type=kubernetes.", + ) + c.flags.StringVar( + &c.k8sCACert, + "kubernetes-ca-cert", + "", + "PEM encoded CA cert for use by the TLS client used to talk with the "+ + "Kubernetes API. May be prefixed with '@' to indicate that the "+ + "value is a file path to load the cert from. "+ + "This flag is required for type=kubernetes.", + ) + c.flags.StringVar( + &c.k8sServiceAccountJWT, + "kubernetes-service-account-jwt", + "", + "A kubernetes service account JWT used to access the TokenReview API to "+ + "validate other JWTs during login. "+ + "This flag is required for type=kubernetes.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.authMethodType == "" { + c.UI.Error(fmt.Sprintf("Missing required '-type' flag")) + c.UI.Error(c.Help()) + return 1 + } else if c.name == "" { + c.UI.Error(fmt.Sprintf("Missing required '-name' flag")) + c.UI.Error(c.Help()) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + newAuthMethod := &api.ACLAuthMethod{ + Type: c.authMethodType, + Name: c.name, + Description: c.description, + } + + if c.authMethodType == "kubernetes" { + if c.k8sHost == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-host' flag")) + return 1 + } else if c.k8sCACert == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-ca-cert' flag")) + return 1 + } else if c.k8sServiceAccountJWT == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-service-account-jwt' flag")) + return 1 + } + + c.k8sCACert, err = helpers.LoadDataSource(c.k8sCACert, c.testStdin) + if err != nil { + c.UI.Error(fmt.Sprintf("Invalid '-kubernetes-ca-cert' value: %v", err)) + return 1 + } else if c.k8sCACert == "" { + c.UI.Error(fmt.Sprintf("Kubernetes CA Cert is empty")) + return 1 + } + + newAuthMethod.Config = map[string]interface{}{ + "Host": c.k8sHost, + "CACert": c.k8sCACert, + "ServiceAccountJWT": c.k8sServiceAccountJWT, + } + } + + method, _, err := client.ACL().AuthMethodCreate(newAuthMethod, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to create new auth method: %v", err)) + return 1 + } + + acl.PrintAuthMethod(method, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Create an ACL Auth Method" + +const help = ` +Usage: consul acl auth-method create -name NAME -type TYPE [options] + + Create a new auth method: + + $ consul acl auth-method create -type "kubernetes" \ + -name "my-k8s" \ + -description "This is an example kube method" \ + -kubernetes-host "https://apiserver.example.com:8443" \ + -kubernetes-ca-file /path/to/kube.ca.crt \ + -kubernetes-service-account-jwt "JWT_CONTENTS" +` diff --git a/command/acl/authmethod/create/authmethod_create_test.go b/command/acl/authmethod/create/authmethod_create_test.go new file mode 100644 index 0000000000..a5bb222dd1 --- /dev/null +++ b/command/acl/authmethod/create/authmethod_create_test.go @@ -0,0 +1,226 @@ +package authmethodcreate + +import ( + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestAuthMethodCreateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestAuthMethodCreateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + t.Run("type required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-type' flag") + }) + + t.Run("name required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=testing", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-name' flag") + }) + + t.Run("invalid type", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=invalid", + "-name=my-method", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Invalid Auth Method: Type should be one of") + }) + + t.Run("create testing", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=testing", + "-name=test", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) +} + +func TestAuthMethodCreateCommand_k8s(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + t.Run("k8s host required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=kubernetes", + "-name=k8s", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-host' flag") + }) + + t.Run("k8s ca cert required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=kubernetes", + "-name=k8s", + "-kubernetes-host=https://foo.internal:8443", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-ca-cert' flag") + }) + + ca := connect.TestCA(t, nil) + + t.Run("k8s jwt required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=kubernetes", + "-name=k8s", + "-kubernetes-host=https://foo.internal:8443", + "-kubernetes-ca-cert", ca.RootCert, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-service-account-jwt' flag") + }) + + t.Run("create k8s", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=kubernetes", + "-name=k8s", + "-kubernetes-host", "https://foo.internal:8443", + "-kubernetes-ca-cert", ca.RootCert, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_A, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) + + caFile := filepath.Join(testDir, "ca.crt") + require.NoError(t, ioutil.WriteFile(caFile, []byte(ca.RootCert), 0600)) + + t.Run("create k8s with cert file", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=kubernetes", + "-name=k8s", + "-kubernetes-host", "https://foo.internal:8443", + "-kubernetes-ca-cert", "@" + caFile, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_A, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) +} diff --git a/command/acl/authmethod/delete/authmethod_delete.go b/command/acl/authmethod/delete/authmethod_delete.go new file mode 100644 index 0000000000..d8c341c989 --- /dev/null +++ b/command/acl/authmethod/delete/authmethod_delete.go @@ -0,0 +1,82 @@ +package authmethoddelete + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + name string +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.StringVar( + &c.name, + "name", + "", + "The name of the auth method to delete.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.name == "" { + c.UI.Error(fmt.Sprintf("Must specify the -name parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + if _, err := client.ACL().AuthMethodDelete(c.name, nil); err != nil { + c.UI.Error(fmt.Sprintf("Error deleting auth method %q: %v", c.name, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Auth method %q deleted successfully", c.name)) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Delete an ACL Auth Method" +const help = ` +Usage: consul acl auth-method delete -name NAME [options] + + Delete an auth method: + + $ consul acl auth-method delete -name "my-auth-method" +` diff --git a/command/acl/authmethod/delete/authmethod_delete_test.go b/command/acl/authmethod/delete/authmethod_delete_test.go new file mode 100644 index 0000000000..5d0638727c --- /dev/null +++ b/command/acl/authmethod/delete/authmethod_delete_test.go @@ -0,0 +1,131 @@ +package authmethoddelete + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestAuthMethodDeleteCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestAuthMethodDeleteCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("name required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -name parameter") + }) + + t.Run("delete notfound", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=notfound", + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, "notfound") + }) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "test-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "testing", + Description: "test", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("delete works", func(t *testing.T) { + name := createAuthMethod(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, name) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.Nil(t, method) + }) +} diff --git a/command/acl/authmethod/list/authmethod_list.go b/command/acl/authmethod/list/authmethod_list.go new file mode 100644 index 0000000000..837d5f9ce8 --- /dev/null +++ b/command/acl/authmethod/list/authmethod_list.go @@ -0,0 +1,83 @@ +package authmethodlist + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that auth method metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + methods, _, err := client.ACL().AuthMethodList(nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to retrieve the auth method list: %v", err)) + return 1 + } + + for _, method := range methods { + acl.PrintAuthMethodListEntry(method, c.UI, c.showMeta) + } + + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Lists ACL Auth Methods" +const help = ` +Usage: consul acl auth-method list [options] + + List all auth methods: + + $ consul acl auth-method list +` diff --git a/command/acl/authmethod/list/authmethod_list_test.go b/command/acl/authmethod/list/authmethod_list_test.go new file mode 100644 index 0000000000..a8a650393e --- /dev/null +++ b/command/acl/authmethod/list/authmethod_list_test.go @@ -0,0 +1,109 @@ +package authmethodlist + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestAuthMethodListCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestAuthMethodListCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + t.Run("found none", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + require.Empty(t, ui.OutputWriter.String()) + }) + + client := a.Client() + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "test-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "testing", + Description: "test", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + var methodNames []string + for i := 0; i < 5; i++ { + methodName := createAuthMethod(t) + methodNames = append(methodNames, methodName) + } + + t.Run("found some", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + output := ui.OutputWriter.String() + + for _, methodName := range methodNames { + require.Contains(t, output, methodName) + } + }) +} diff --git a/command/acl/authmethod/read/authmethod_read.go b/command/acl/authmethod/read/authmethod_read.go new file mode 100644 index 0000000000..1a98bbf64d --- /dev/null +++ b/command/acl/authmethod/read/authmethod_read.go @@ -0,0 +1,96 @@ +package authmethodread + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + name string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that auth method metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.name, + "name", + "", + "The name of the auth method to read.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.name == "" { + c.UI.Error(fmt.Sprintf("Must specify the -name parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + method, _, err := client.ACL().AuthMethodRead(c.name, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading auth method %q: %v", c.name, err)) + return 1 + } else if method == nil { + c.UI.Error(fmt.Sprintf("Auth method not found with name %q", c.name)) + return 1 + } + acl.PrintAuthMethod(method, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Read an ACL Auth Method" +const help = ` +Usage: consul acl auth-method read -name NAME [options] + + Read an auth method: + + $ consul acl auth-method read -name my-auth-method +` diff --git a/command/acl/authmethod/read/authmethod_read_test.go b/command/acl/authmethod/read/authmethod_read_test.go new file mode 100644 index 0000000000..72b78e8005 --- /dev/null +++ b/command/acl/authmethod/read/authmethod_read_test.go @@ -0,0 +1,118 @@ +package authmethodread + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestAuthMethodReadCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestAuthMethodReadCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("name required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -name parameter") + }) + + t.Run("not found", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=notfound", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Auth method not found with name") + }) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "test-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "testing", + Description: "test", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("read by name", func(t *testing.T) { + name := createAuthMethod(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, name) + }) +} diff --git a/command/acl/authmethod/update/authmethod_update.go b/command/acl/authmethod/update/authmethod_update.go new file mode 100644 index 0000000000..6f77235f51 --- /dev/null +++ b/command/acl/authmethod/update/authmethod_update.go @@ -0,0 +1,220 @@ +package authmethodupdate + +import ( + "flag" + "fmt" + "io" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/hashicorp/consul/command/helpers" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + name string + + description string + + k8sHost string + k8sCACert string + k8sServiceAccountJWT string + + noMerge bool + showMeta bool + + testStdin io.Reader +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that auth method metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.name, + "name", + "", + "The auth method name.", + ) + + c.flags.StringVar( + &c.description, + "description", + "", + "A description of the auth method.", + ) + + c.flags.StringVar( + &c.k8sHost, + "kubernetes-host", + "", + "Address of the Kubernetes API server. This flag is required for type=kubernetes.", + ) + c.flags.StringVar( + &c.k8sCACert, + "kubernetes-ca-cert", + "", + "PEM encoded CA cert for use by the TLS client used to talk with the "+ + "Kubernetes API. May be prefixed with '@' to indicate that the "+ + "value is a file path to load the cert from. "+ + "This flag is required for type=kubernetes.", + ) + c.flags.StringVar( + &c.k8sServiceAccountJWT, + "kubernetes-service-account-jwt", + "", + "A kubernetes service account JWT used to access the TokenReview API to "+ + "validate other JWTs during login. "+ + "This flag is required for type=kubernetes.", + ) + + c.flags.BoolVar(&c.noMerge, "no-merge", false, "Do not merge the current auth method "+ + "information with what is provided to the command. Instead overwrite all fields "+ + "with the exception of the name which is immutable.") + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.name == "" { + c.UI.Error(fmt.Sprintf("Cannot update an auth method without specifying the -name parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + // Regardless of merge, we need to fetch the prior immutable fields first. + currentAuthMethod, _, err := client.ACL().AuthMethodRead(c.name, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error when retrieving current auth method: %v", err)) + return 1 + } else if currentAuthMethod == nil { + c.UI.Error(fmt.Sprintf("Auth method not found with name %q", c.name)) + return 1 + } + + if c.k8sCACert != "" { + c.k8sCACert, err = helpers.LoadDataSource(c.k8sCACert, c.testStdin) + if err != nil { + c.UI.Error(fmt.Sprintf("Invalid '-kubernetes-ca-cert' value: %v", err)) + return 1 + } else if c.k8sCACert == "" { + c.UI.Error(fmt.Sprintf("Kubernetes CA Cert is empty")) + return 1 + } + } + + var method *api.ACLAuthMethod + if c.noMerge { + method = &api.ACLAuthMethod{ + Name: currentAuthMethod.Name, + Type: currentAuthMethod.Type, + Description: c.description, + } + + if currentAuthMethod.Type == "kubernetes" { + if c.k8sHost == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-host' flag")) + return 1 + } else if c.k8sCACert == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-ca-cert' flag")) + return 1 + } else if c.k8sServiceAccountJWT == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-service-account-jwt' flag")) + return 1 + } + + method.Config = map[string]interface{}{ + "Host": c.k8sHost, + "CACert": c.k8sCACert, + "ServiceAccountJWT": c.k8sServiceAccountJWT, + } + } + } else { + methodCopy := *currentAuthMethod + method = &methodCopy + + if c.description != "" { + method.Description = c.description + } + if method.Config == nil { + method.Config = make(map[string]interface{}) + } + if currentAuthMethod.Type == "kubernetes" { + if c.k8sHost != "" { + method.Config["Host"] = c.k8sHost + } + if c.k8sCACert != "" { + method.Config["CACert"] = c.k8sCACert + } + if c.k8sServiceAccountJWT != "" { + method.Config["ServiceAccountJWT"] = c.k8sServiceAccountJWT + } + } + } + + method, _, err = client.ACL().AuthMethodUpdate(method, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error updating auth method %q: %v", c.name, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Auth method updated successfully")) + acl.PrintAuthMethod(method, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Update an ACL Auth Method" +const help = ` +Usage: consul acl auth-method update -name NAME [options] + + Updates an auth method. By default it will merge the auth method + information with its current state so that you do not have to provide all + parameters. This behavior can be disabled by passing -no-merge. + + Update all editable fields of the auth method: + + $ consul acl auth-method update -name "my-k8s" \ + -description "new description" \ + -kubernetes-host "https://new-apiserver.example.com:8443" \ + -kubernetes-ca-file /path/to/new-kube.ca.crt \ + -kubernetes-service-account-jwt "NEW_JWT_CONTENTS" +` diff --git a/command/acl/authmethod/update/authmethod_update_test.go b/command/acl/authmethod/update/authmethod_update_test.go new file mode 100644 index 0000000000..ba5d92758e --- /dev/null +++ b/command/acl/authmethod/update/authmethod_update_test.go @@ -0,0 +1,647 @@ +package authmethodupdate + +import ( + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestAuthMethodUpdateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestAuthMethodUpdateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("update without name", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Cannot update an auth method without specifying the -name parameter") + }) + + t.Run("update nonexistent method", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=test", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Auth method not found with name") + }) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "test-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "testing", + Description: "test", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("update all fields", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + }) +} + +func TestAuthMethodUpdateCommand_noMerge(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("update without name", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Cannot update an auth method without specifying the -name parameter") + }) + + t.Run("update nonexistent method", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=test", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Auth method not found with name") + }) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "test-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "testing", + Description: "test", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("update all fields", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + }) +} + +func TestAuthMethodUpdateCommand_k8s(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + ca := connect.TestCA(t, nil) + ca2 := connect.TestCA(t, nil) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "k8s-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "kubernetes", + Description: "test", + Config: map[string]interface{}{ + "Host": "https://foo.internal:8443", + "CACert": ca.RootCert, + "ServiceAccountJWT": acl.TestKubernetesJWT_A, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("update all fields", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", ca2.RootCert, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) + + ca2File := filepath.Join(testDir, "ca2.crt") + require.NoError(t, ioutil.WriteFile(ca2File, []byte(ca2.RootCert), 0600)) + + t.Run("update all fields with cert file", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", "@" + ca2File, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) + + t.Run("update all fields but k8s host", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + "-kubernetes-ca-cert", ca2.RootCert, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) + + t.Run("update all fields but k8s ca cert", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) + + t.Run("update all fields but k8s jwt", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", ca2.RootCert, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_A, config.ServiceAccountJWT) + }) +} + +func TestAuthMethodUpdateCommand_k8s_noMerge(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + ca := connect.TestCA(t, nil) + ca2 := connect.TestCA(t, nil) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "k8s-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "kubernetes", + Description: "test", + Config: map[string]interface{}{ + "Host": "https://foo.internal:8443", + "CACert": ca.RootCert, + "ServiceAccountJWT": acl.TestKubernetesJWT_A, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("update missing k8s host", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + "-kubernetes-ca-cert", ca2.RootCert, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-host' flag") + }) + + t.Run("update missing k8s ca cert", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-ca-cert' flag") + }) + + t.Run("update missing k8s jwt", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", ca2.RootCert, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-service-account-jwt' flag") + }) + + t.Run("update all fields", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", ca2.RootCert, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) + + ca2File := filepath.Join(testDir, "ca2.crt") + require.NoError(t, ioutil.WriteFile(ca2File, []byte(ca2.RootCert), 0600)) + + t.Run("update all fields with cert file", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", "@" + ca2File, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) +} diff --git a/command/acl/bindingrule/bindingrule.go b/command/acl/bindingrule/bindingrule.go new file mode 100644 index 0000000000..2b94139463 --- /dev/null +++ b/command/acl/bindingrule/bindingrule.go @@ -0,0 +1,60 @@ +package bindingrule + +import ( + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New() *cmd { + return &cmd{} +} + +type cmd struct{} + +func (c *cmd) Run(args []string) int { + return cli.RunResultHelp +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(help, nil) +} + +const synopsis = "Manage Consul's ACL Binding Rules" +const help = ` +Usage: consul acl binding-rule [options] [args] + + This command has subcommands for managing Consul's ACL Binding Rules. Here + are some simple examples, and more detailed examples are available in the + subcommands or the documentation. + + Create a new binding rule: + + $ consul acl binding-rule create \ + -method=minikube \ + -bind-type=service \ + -bind-name='k8s-${serviceaccount.name}' \ + -selector='serviceaccount.namespace==default and serviceaccount.name==web' + + List all binding rules: + + $ consul acl binding-rule list + + Update a binding rule: + + $ consul acl binding-rule update -id=43cb72df-9c6f-4315-ac8a-01a9d98155ef \ + -bind-name='k8s-${serviceaccount.name}' + + Read a binding rule: + + $ consul acl binding-rule read -id fdabbcb5-9de5-4b1a-961f-77214ae88cba + + Delete a binding rule: + + $ consul acl binding-rule delete -id b6b856da-5193-4e78-845a-7d61ca8371ba + + For more examples, ask for subcommand help or view the documentation. +` diff --git a/command/acl/bindingrule/create/bindingrule_create.go b/command/acl/bindingrule/create/bindingrule_create.go new file mode 100644 index 0000000000..01bcfcbe72 --- /dev/null +++ b/command/acl/bindingrule/create/bindingrule_create.go @@ -0,0 +1,148 @@ +package bindingrulecreate + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + authMethodName string + description string + selector string + bindType string + bindName string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that binding rule metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.authMethodName, + "method", + "", + "The auth method's name for which this binding rule applies. "+ + "This flag is required.", + ) + c.flags.StringVar( + &c.description, + "description", + "", + "A description of the binding rule.", + ) + c.flags.StringVar( + &c.selector, + "selector", + "", + "Selector is an expression that matches against verified identity "+ + "attributes returned from the auth method during login.", + ) + c.flags.StringVar( + &c.bindType, + "bind-type", + string(api.BindingRuleBindTypeService), + "Type of binding to perform (\"service\" or \"role\").", + ) + c.flags.StringVar( + &c.bindName, + "bind-name", + "", + "Name to bind on match. Can use ${var} interpolation. "+ + "This flag is required.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.authMethodName == "" { + c.UI.Error(fmt.Sprintf("Missing required '-method' flag")) + c.UI.Error(c.Help()) + return 1 + } else if c.bindType == "" { + c.UI.Error(fmt.Sprintf("Missing required '-bind-type' flag")) + c.UI.Error(c.Help()) + return 1 + } else if c.bindName == "" { + c.UI.Error(fmt.Sprintf("Missing required '-bind-name' flag")) + c.UI.Error(c.Help()) + return 1 + } + + newRule := &api.ACLBindingRule{ + Description: c.description, + AuthMethod: c.authMethodName, + BindType: api.BindingRuleBindType(c.bindType), + BindName: c.bindName, + Selector: c.selector, + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + rule, _, err := client.ACL().BindingRuleCreate(newRule, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to create new binding rule: %v", err)) + return 1 + } + + acl.PrintBindingRule(rule, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Create an ACL Binding Rule" + +const help = ` +Usage: consul acl binding-rule create [options] + + Create a new binding rule: + + $ consul acl binding-rule create \ + -method=minikube \ + -bind-type=service \ + -bind-name='k8s-${serviceaccount.name}' \ + -selector='serviceaccount.namespace==default and serviceaccount.name==web' +` diff --git a/command/acl/bindingrule/create/bindingrule_create_test.go b/command/acl/bindingrule/create/bindingrule_create_test.go new file mode 100644 index 0000000000..0e8b510963 --- /dev/null +++ b/command/acl/bindingrule/create/bindingrule_create_test.go @@ -0,0 +1,178 @@ +package bindingrulecreate + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestBindingRuleCreateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestBindingRuleCreateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // create an auth method in advance + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + t.Run("method is required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-method' flag") + }) + + t.Run("bind type required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bind-type' flag") + }) + + t.Run("bind name required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=service", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bind-name' flag") + }) + + t.Run("must use roughly valid selector", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=service", + "-bind-name=demo", + "-selector", "foo", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Selector is invalid") + }) + + t.Run("create it with no selector", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=service", + "-bind-name=demo", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) + + t.Run("create it with a match selector", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=service", + "-bind-name=demo", + "-selector", "serviceaccount.namespace==default and serviceaccount.name==vault", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) + + t.Run("create it with type role", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=role", + "-bind-name=demo", + "-selector", "serviceaccount.namespace==default and serviceaccount.name==vault", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) +} diff --git a/command/acl/bindingrule/delete/bindingrule_delete.go b/command/acl/bindingrule/delete/bindingrule_delete.go new file mode 100644 index 0000000000..7956e1e3aa --- /dev/null +++ b/command/acl/bindingrule/delete/bindingrule_delete.go @@ -0,0 +1,97 @@ +package bindingruledelete + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + ruleID string +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.StringVar( + &c.ruleID, + "id", + "", + "The ID of the binding rule to delete. "+ + "It may be specified as a unique ID prefix but will error if the prefix "+ + "matches multiple binding rule IDs", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.ruleID == "" { + c.UI.Error(fmt.Sprintf("Must specify the -id parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + ruleID, err := acl.GetBindingRuleIDFromPartial(client, c.ruleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error determining binding rule ID: %v", err)) + return 1 + } + + if _, err := client.ACL().BindingRuleDelete(ruleID, nil); err != nil { + c.UI.Error(fmt.Sprintf("Error deleting binding rule %q: %v", ruleID, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Binding rule %q deleted successfully", ruleID)) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Delete an ACL Binding Rule" +const help = ` +Usage: consul acl binding-rule delete -id ID [options] + + Deletes an ACL binding rule by providing the ID or a unique ID prefix. + + Delete by prefix: + + $ consul acl binding-rule delete -id b6b85 + + Delete by full ID: + + $ consul acl binding-rule delete -id b6b856da-5193-4e78-845a-7d61ca8371ba +` diff --git a/command/acl/bindingrule/delete/bindingrule_delete_test.go b/command/acl/bindingrule/delete/bindingrule_delete_test.go new file mode 100644 index 0000000000..497f26b21c --- /dev/null +++ b/command/acl/bindingrule/delete/bindingrule_delete_test.go @@ -0,0 +1,187 @@ +package bindingruledelete + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestBindingRuleDeleteCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestBindingRuleDeleteCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // create an auth method in advance + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + createRule := func(t *testing.T) string { + rule, _, err := client.ACL().BindingRuleCreate( + &api.ACLBindingRule{ + AuthMethod: "test", + Description: "test rule", + BindType: api.BindingRuleBindTypeService, + BindName: "test-${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return rule.ID + } + + createDupe := func(t *testing.T) string { + for { + // Check for 1-char duplicates. + rules, _, err := client.ACL().BindingRuleList( + "test", + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + + m := make(map[byte]struct{}) + for _, rule := range rules { + c := rule.ID[0] + + if _, ok := m[c]; ok { + return string(c) + } + m[c] = struct{}{} + } + + _ = createRule(t) + } + } + + t.Run("id required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id parameter") + }) + + t.Run("delete works", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, id) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.Nil(t, rule) + }) + + t.Run("delete works via prefixes", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id[0:5], + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, id) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.Nil(t, rule) + }) + + t.Run("delete fails when prefix matches more than one rule", func(t *testing.T) { + prefix := createDupe(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + prefix, + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID") + }) +} diff --git a/command/acl/bindingrule/list/bindingrule_list.go b/command/acl/bindingrule/list/bindingrule_list.go new file mode 100644 index 0000000000..1150ac42c2 --- /dev/null +++ b/command/acl/bindingrule/list/bindingrule_list.go @@ -0,0 +1,98 @@ +package bindingrulelist + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + authMethodName string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that binding rule metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.authMethodName, + "method", + "", + "Only show rules linked to the auth method with the given name.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + rules, _, err := client.ACL().BindingRuleList(c.authMethodName, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to retrieve the binding rule list: %v", err)) + return 1 + } + + for _, rule := range rules { + acl.PrintBindingRuleListEntry(rule, c.UI, c.showMeta) + } + + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Lists ACL Binding Rules" +const help = ` +Usage: consul acl binding-rule list [options] + + Lists all the ACL binding rules. + + Show all: + + $ consul acl binding-rule list + + Show all for a specific auth method: + + $ consul acl binding-rule list -method="my-method" +` diff --git a/command/acl/bindingrule/list/bindingrule_list_test.go b/command/acl/bindingrule/list/bindingrule_list_test.go new file mode 100644 index 0000000000..2d935857e3 --- /dev/null +++ b/command/acl/bindingrule/list/bindingrule_list_test.go @@ -0,0 +1,167 @@ +package bindingrulelist + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestBindingRuleListCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestBindingRuleListCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + createRule := func(t *testing.T, methodName, description string) string { + rule, _, err := client.ACL().BindingRuleCreate( + &api.ACLBindingRule{ + AuthMethod: methodName, + Description: description, + BindType: api.BindingRuleBindTypeService, + BindName: "test-${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return rule.ID + } + + var ruleIDs []string + for i := 0; i < 10; i++ { + name := fmt.Sprintf("test-rule-%d", i) + + var methodName string + if i%2 == 0 { + methodName = "test-1" + } else { + methodName = "test-2" + } + + id := createRule(t, methodName, name) + + ruleIDs = append(ruleIDs, id) + } + + t.Run("normal", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + output := ui.OutputWriter.String() + + for i, v := range ruleIDs { + require.Contains(t, output, fmt.Sprintf("test-rule-%d", i)) + require.Contains(t, output, v) + } + }) + + t.Run("filter by method 1", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test-1", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + output := ui.OutputWriter.String() + + for i, v := range ruleIDs { + if i%2 == 0 { + require.Contains(t, output, fmt.Sprintf("test-rule-%d", i)) + require.Contains(t, output, v) + } + } + }) + + t.Run("filter by method 2", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test-2", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + output := ui.OutputWriter.String() + + for i, v := range ruleIDs { + if i%2 == 1 { + require.Contains(t, output, fmt.Sprintf("test-rule-%d", i)) + require.Contains(t, output, v) + } + } + }) +} diff --git a/command/acl/bindingrule/read/bindingrule_read.go b/command/acl/bindingrule/read/bindingrule_read.go new file mode 100644 index 0000000000..677a950cf2 --- /dev/null +++ b/command/acl/bindingrule/read/bindingrule_read.go @@ -0,0 +1,108 @@ +package bindingruleread + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + ruleID string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that binding rule metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.ruleID, + "id", + "", + "The ID of the binding rule to read. "+ + "It may be specified as a unique ID prefix but will error if the prefix "+ + "matches multiple binding rule IDs", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.ruleID == "" { + c.UI.Error(fmt.Sprintf("Must specify the -id parameter.")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + ruleID, err := acl.GetBindingRuleIDFromPartial(client, c.ruleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error determining binding rule ID: %v", err)) + return 1 + } + + rule, _, err := client.ACL().BindingRuleRead(ruleID, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading binding rule %q: %v", ruleID, err)) + return 1 + } else if rule == nil { + c.UI.Error(fmt.Sprintf("Binding rule not found with ID %q", ruleID)) + return 1 + } + + acl.PrintBindingRule(rule, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Read an ACL Binding Rule" +const help = ` +Usage: consul acl binding-rule read -id ID [options] + + This command will retrieve and print out the details of a single binding + rule. + + Read a binding rule: + + $ consul acl binding-rule read -id fdabbcb5-9de5-4b1a-961f-77214ae88cba +` diff --git a/command/acl/bindingrule/read/bindingrule_read_test.go b/command/acl/bindingrule/read/bindingrule_read_test.go new file mode 100644 index 0000000000..205e29e2fa --- /dev/null +++ b/command/acl/bindingrule/read/bindingrule_read_test.go @@ -0,0 +1,152 @@ +package bindingruleread + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestBindingRuleReadCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestBindingRuleReadCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // create an auth method in advance + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + createRule := func(t *testing.T) string { + rule, _, err := client.ACL().BindingRuleCreate( + &api.ACLBindingRule{ + AuthMethod: "test", + Description: "test rule", + BindType: api.BindingRuleBindTypeService, + BindName: "test-${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return rule.ID + } + + t.Run("id required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id parameter") + }) + + t.Run("read by id not found", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + fakeID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Binding rule not found with ID") + }) + + t.Run("read by id", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + id, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("test rule")) + require.Contains(t, output, id) + }) + + t.Run("read by id prefix", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + id[0:5], + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("test rule")) + require.Contains(t, output, id) + }) +} diff --git a/command/acl/bindingrule/update/bindingrule_update.go b/command/acl/bindingrule/update/bindingrule_update.go new file mode 100644 index 0000000000..0f6d23f659 --- /dev/null +++ b/command/acl/bindingrule/update/bindingrule_update.go @@ -0,0 +1,212 @@ +package bindingruleupdate + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + ruleID string + + description string + selector string + bindType string + bindName string + + noMerge bool + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that binding rule metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.ruleID, + "id", + "", + "The ID of the binding rule to update. "+ + "It may be specified as a unique ID prefix but will error if the prefix "+ + "matches multiple binding rule IDs", + ) + + c.flags.StringVar( + &c.description, + "description", + "", + "A description of the binding rule.", + ) + c.flags.StringVar( + &c.selector, + "selector", + "", + "Selector is an expression that matches against verified identity "+ + "attributes returned from the auth method during login.", + ) + c.flags.StringVar( + &c.bindType, + "bind-type", + string(api.BindingRuleBindTypeService), + "Type of binding to perform (\"service\" or \"role\").", + ) + c.flags.StringVar( + &c.bindName, + "bind-name", + "", + "Name to bind on match. Can use ${var} interpolation. "+ + "This flag is required.", + ) + + c.flags.BoolVar( + &c.noMerge, + "no-merge", + false, + "Do not merge the current binding rule "+ + "information with what is provided to the command. Instead overwrite all fields "+ + "with the exception of the binding rule ID which is immutable.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.ruleID == "" { + c.UI.Error(fmt.Sprintf("Cannot update a binding rule without specifying the -id parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + ruleID, err := acl.GetBindingRuleIDFromPartial(client, c.ruleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error determining binding rule ID: %v", err)) + return 1 + } + + // Read the current binding rule in both cases so we can fail better if not found. + currentRule, _, err := client.ACL().BindingRuleRead(ruleID, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error when retrieving current binding rule: %v", err)) + return 1 + } else if currentRule == nil { + c.UI.Error(fmt.Sprintf("Binding rule not found with ID %q", ruleID)) + return 1 + } + + var rule *api.ACLBindingRule + if c.noMerge { + if c.bindType == "" { + c.UI.Error(fmt.Sprintf("Missing required '-bind-type' flag")) + c.UI.Error(c.Help()) + return 1 + } else if c.bindName == "" { + c.UI.Error(fmt.Sprintf("Missing required '-bind-name' flag")) + c.UI.Error(c.Help()) + return 1 + } + + rule = &api.ACLBindingRule{ + ID: ruleID, + AuthMethod: currentRule.AuthMethod, // immutable + Description: c.description, + BindType: api.BindingRuleBindType(c.bindType), + BindName: c.bindName, + Selector: c.selector, + } + + } else { + rule = currentRule + + if c.description != "" { + rule.Description = c.description + } + if c.bindType != "" { + rule.BindType = api.BindingRuleBindType(c.bindType) + } + if c.bindName != "" { + rule.BindName = c.bindName + } + if isFlagSet(c.flags, "selector") { + rule.Selector = c.selector // empty is valid + } + } + + rule, _, err = client.ACL().BindingRuleUpdate(rule, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error updating binding rule %q: %v", ruleID, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Binding rule updated successfully")) + acl.PrintBindingRule(rule, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +func isFlagSet(flags *flag.FlagSet, name string) bool { + found := false + flags.Visit(func(f *flag.Flag) { + if f.Name == name { + found = true + } + }) + return found +} + +const synopsis = "Update an ACL Binding Rule" +const help = ` +Usage: consul acl binding-rule update -id ID [options] + + Updates a binding rule. By default it will merge the binding rule + information with its current state so that you do not have to provide all + parameters. This behavior can be disabled by passing -no-merge. + + Update all editable fields of the binding rule: + + $ consul acl binding-rule update \ + -id=43cb72df-9c6f-4315-ac8a-01a9d98155ef \ + -description="new description" \ + -bind-type=role \ + -bind-name='k8s-${serviceaccount.name}' \ + -selector='serviceaccount.namespace==default and serviceaccount.name==web' +` diff --git a/command/acl/bindingrule/update/bindingrule_update_test.go b/command/acl/bindingrule/update/bindingrule_update_test.go new file mode 100644 index 0000000000..82a6e1fbb4 --- /dev/null +++ b/command/acl/bindingrule/update/bindingrule_update_test.go @@ -0,0 +1,768 @@ +package bindingruleupdate + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + uuid "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestBindingRuleUpdateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestBindingRuleUpdateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // create an auth method in advance + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + deleteRules := func(t *testing.T) { + rules, _, err := client.ACL().BindingRuleList( + "test", + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + + for _, rule := range rules { + _, err := client.ACL().BindingRuleDelete( + rule.ID, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + } + + t.Run("rule id required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Cannot update a binding rule without specifying the -id parameter") + }) + + t.Run("rule id partial matches nothing", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID[0:5], + "-token=root", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID") + }) + + t.Run("rule id exact match doesn't exist", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID, + "-token=root", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Binding rule not found with ID") + }) + + createRule := func(t *testing.T) string { + rule, _, err := client.ACL().BindingRuleCreate( + &api.ACLBindingRule{ + AuthMethod: "test", + Description: "test rule", + BindType: api.BindingRuleBindTypeService, + BindName: "test-${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return rule.ID + } + + createDupe := func(t *testing.T) string { + for { + // Check for 1-char duplicates. + rules, _, err := client.ACL().BindingRuleList( + "test", + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + + m := make(map[byte]struct{}) + for _, rule := range rules { + c := rule.ID[0] + + if _, ok := m[c]; ok { + return string(c) + } + m[c] = struct{}{} + } + + _ = createRule(t) + } + } + + t.Run("rule id partial matches multiple", func(t *testing.T) { + prefix := createDupe(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + prefix, + "-token=root", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID") + }) + + t.Run("must use roughly valid selector", func(t *testing.T) { + id := createRule(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-selector", "foo", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Selector is invalid") + }) + + t.Run("update all fields", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-description=test rule edited", + "-bind-type", "role", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields - partial", func(t *testing.T) { + deleteRules(t) // reset since we created a bunch that might be dupes + + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id[0:5], + "-description=test rule edited", + "-bind-type", "role", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields but description", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-bind-type", "role", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule", rule.Description) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields but bind name", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-description=test rule edited", + "-bind-type", "role", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "test-${serviceaccount.name}", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields but must exist", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-description=test rule edited", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields but selector", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-description=test rule edited", + "-bind-type", "role", + "-bind-name=role-updated", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==default", rule.Selector) + }) + + t.Run("update all fields clear selector", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-description=test rule edited", + "-bind-type", "role", + "-bind-name=role-updated", + "-selector=", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Empty(t, rule.Selector) + }) +} + +func TestBindingRuleUpdateCommand_noMerge(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // create an auth method in advance + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + deleteRules := func(t *testing.T) { + rules, _, err := client.ACL().BindingRuleList( + "test", + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + + for _, rule := range rules { + _, err := client.ACL().BindingRuleDelete( + rule.ID, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + } + + t.Run("rule id required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Cannot update a binding rule without specifying the -id parameter") + }) + + t.Run("rule id partial matches nothing", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID[0:5], + "-token=root", + "-no-merge", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID") + }) + + t.Run("rule id exact match doesn't exist", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID, + "-token=root", + "-no-merge", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Binding rule not found with ID") + }) + + createRule := func(t *testing.T) string { + rule, _, err := client.ACL().BindingRuleCreate( + &api.ACLBindingRule{ + AuthMethod: "test", + Description: "test rule", + BindType: api.BindingRuleBindTypeRole, + BindName: "test-${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return rule.ID + } + + createDupe := func(t *testing.T) string { + for { + // Check for 1-char duplicates. + rules, _, err := client.ACL().BindingRuleList( + "test", + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + + m := make(map[byte]struct{}) + for _, rule := range rules { + c := rule.ID[0] + + if _, ok := m[c]; ok { + return string(c) + } + m[c] = struct{}{} + } + + _ = createRule(t) + } + } + + t.Run("rule id partial matches multiple", func(t *testing.T) { + prefix := createDupe(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + prefix, + "-token=root", + "-no-merge", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID") + }) + + t.Run("must use roughly valid selector", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id", id, + "-description=test rule edited", + "-bind-type", "service", + "-bind-name=role-updated", + "-selector", "foo", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Selector is invalid") + }) + + t.Run("update all fields", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id", id, + "-description=test rule edited", + "-bind-type", "service", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields - partial", func(t *testing.T) { + deleteRules(t) // reset since we created a bunch that might be dupes + + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id", id[0:5], + "-description=test rule edited", + "-bind-type", "service", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields but description", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id", id, + "-bind-type", "service", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Empty(t, rule.Description) + require.Equal(t, api.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("missing bind name", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id=" + id, + "-description=test rule edited", + "-bind-type", "role", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bind-name' flag") + }) + + t.Run("update all fields but selector", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id", id, + "-description=test rule edited", + "-bind-type", "service", + "-bind-name=role-updated", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Empty(t, rule.Selector) + }) +} diff --git a/command/acl/role/delete/role_delete.go b/command/acl/role/delete/role_delete.go index 543133ae1b..5e1b17ad4b 100644 --- a/command/acl/role/delete/role_delete.go +++ b/command/acl/role/delete/role_delete.go @@ -21,7 +21,8 @@ type cmd struct { http *flags.HTTPFlags help string - roleID string + roleID string + roleName string } func (c *cmd) init() { @@ -29,6 +30,7 @@ func (c *cmd) init() { c.flags.StringVar(&c.roleID, "id", "", "The ID of the role to delete. "+ "It may be specified as a unique ID prefix but will error if the prefix "+ "matches multiple role IDs") + c.flags.StringVar(&c.roleName, "name", "", "The name of the role to delete.") c.http = &flags.HTTPFlags{} flags.Merge(c.flags, c.http.ClientFlags()) flags.Merge(c.flags, c.http.ServerFlags()) @@ -40,8 +42,8 @@ func (c *cmd) Run(args []string) int { return 1 } - if c.roleID == "" { - c.UI.Error(fmt.Sprintf("Must specify the -id parameter")) + if c.roleID == "" && c.roleName == "" { + c.UI.Error(fmt.Sprintf("Must specify the -id or -name parameters")) return 1 } @@ -51,7 +53,12 @@ func (c *cmd) Run(args []string) int { return 1 } - roleID, err := acl.GetRoleIDFromPartial(client, c.roleID) + var roleID string + if c.roleID != "" { + roleID, err = acl.GetRoleIDFromPartial(client, c.roleID) + } else { + roleID, err = acl.GetRoleIDByName(client, c.roleName) + } if err != nil { c.UI.Error(fmt.Sprintf("Error determining role ID: %v", err)) return 1 diff --git a/command/acl/role/delete/role_delete_test.go b/command/acl/role/delete/role_delete_test.go index e6523941f6..25f2faf0af 100644 --- a/command/acl/role/delete/role_delete_test.go +++ b/command/acl/role/delete/role_delete_test.go @@ -45,7 +45,7 @@ func TestRoleDeleteCommand(t *testing.T) { client := a.Client() - t.Run("id required", func(t *testing.T) { + t.Run("id or name required", func(t *testing.T) { ui := cli.NewMockUi() cmd := New(ui) @@ -56,7 +56,7 @@ func TestRoleDeleteCommand(t *testing.T) { code := cmd.Run(args) require.Equal(t, code, 1) - require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id parameter") + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id or -name parameters") }) t.Run("delete works", func(t *testing.T) { diff --git a/command/commands_oss.go b/command/commands_oss.go index f92fcf8bba..89c3a2d408 100644 --- a/command/commands_oss.go +++ b/command/commands_oss.go @@ -3,6 +3,18 @@ package command import ( "github.com/hashicorp/consul/command/acl" aclagent "github.com/hashicorp/consul/command/acl/agenttokens" + aclam "github.com/hashicorp/consul/command/acl/authmethod" + aclamcreate "github.com/hashicorp/consul/command/acl/authmethod/create" + aclamdelete "github.com/hashicorp/consul/command/acl/authmethod/delete" + aclamlist "github.com/hashicorp/consul/command/acl/authmethod/list" + aclamread "github.com/hashicorp/consul/command/acl/authmethod/read" + aclamupdate "github.com/hashicorp/consul/command/acl/authmethod/update" + aclbr "github.com/hashicorp/consul/command/acl/bindingrule" + aclbrcreate "github.com/hashicorp/consul/command/acl/bindingrule/create" + aclbrdelete "github.com/hashicorp/consul/command/acl/bindingrule/delete" + aclbrlist "github.com/hashicorp/consul/command/acl/bindingrule/list" + aclbrread "github.com/hashicorp/consul/command/acl/bindingrule/read" + aclbrupdate "github.com/hashicorp/consul/command/acl/bindingrule/update" aclbootstrap "github.com/hashicorp/consul/command/acl/bootstrap" aclpolicy "github.com/hashicorp/consul/command/acl/policy" aclpcreate "github.com/hashicorp/consul/command/acl/policy/create" @@ -57,6 +69,8 @@ import ( kvput "github.com/hashicorp/consul/command/kv/put" "github.com/hashicorp/consul/command/leave" "github.com/hashicorp/consul/command/lock" + login "github.com/hashicorp/consul/command/login" + logout "github.com/hashicorp/consul/command/logout" "github.com/hashicorp/consul/command/maint" "github.com/hashicorp/consul/command/members" "github.com/hashicorp/consul/command/monitor" @@ -118,6 +132,18 @@ func init() { Register("acl role read", func(ui cli.Ui) (cli.Command, error) { return aclrread.New(ui), nil }) Register("acl role update", func(ui cli.Ui) (cli.Command, error) { return aclrupdate.New(ui), nil }) Register("acl role delete", func(ui cli.Ui) (cli.Command, error) { return aclrdelete.New(ui), nil }) + Register("acl auth-method", func(cli.Ui) (cli.Command, error) { return aclam.New(), nil }) + Register("acl auth-method create", func(ui cli.Ui) (cli.Command, error) { return aclamcreate.New(ui), nil }) + Register("acl auth-method list", func(ui cli.Ui) (cli.Command, error) { return aclamlist.New(ui), nil }) + Register("acl auth-method read", func(ui cli.Ui) (cli.Command, error) { return aclamread.New(ui), nil }) + Register("acl auth-method update", func(ui cli.Ui) (cli.Command, error) { return aclamupdate.New(ui), nil }) + Register("acl auth-method delete", func(ui cli.Ui) (cli.Command, error) { return aclamdelete.New(ui), nil }) + Register("acl binding-rule", func(cli.Ui) (cli.Command, error) { return aclbr.New(), nil }) + Register("acl binding-rule create", func(ui cli.Ui) (cli.Command, error) { return aclbrcreate.New(ui), nil }) + Register("acl binding-rule list", func(ui cli.Ui) (cli.Command, error) { return aclbrlist.New(ui), nil }) + Register("acl binding-rule read", func(ui cli.Ui) (cli.Command, error) { return aclbrread.New(ui), nil }) + Register("acl binding-rule update", func(ui cli.Ui) (cli.Command, error) { return aclbrupdate.New(ui), nil }) + Register("acl binding-rule delete", func(ui cli.Ui) (cli.Command, error) { return aclbrdelete.New(ui), nil }) Register("agent", func(ui cli.Ui) (cli.Command, error) { return agent.New(ui, rev, ver, verPre, verHuman, make(chan struct{})), nil }) @@ -153,6 +179,8 @@ func init() { Register("kv put", func(ui cli.Ui) (cli.Command, error) { return kvput.New(ui), nil }) Register("leave", func(ui cli.Ui) (cli.Command, error) { return leave.New(ui), nil }) Register("lock", func(ui cli.Ui) (cli.Command, error) { return lock.New(ui), nil }) + Register("login", func(ui cli.Ui) (cli.Command, error) { return login.New(ui), nil }) + Register("logout", func(ui cli.Ui) (cli.Command, error) { return logout.New(ui), nil }) Register("maint", func(ui cli.Ui) (cli.Command, error) { return maint.New(ui), nil }) Register("members", func(ui cli.Ui) (cli.Command, error) { return members.New(ui), nil }) Register("monitor", func(ui cli.Ui) (cli.Command, error) { return monitor.New(ui, MakeShutdownCh()), nil }) diff --git a/command/connect/envoy/envoy.go b/command/connect/envoy/envoy.go index 3500b69476..5aa9bea182 100644 --- a/command/connect/envoy/envoy.go +++ b/command/connect/envoy/envoy.go @@ -104,7 +104,7 @@ func (c *cmd) Run(args []string) int { // enabled. c.grpcAddr = "localhost:8502" } - if c.http.Token() == "" { + if c.http.Token() == "" && c.http.TokenFile() == "" { // Extra check needed since CONSUL_HTTP_TOKEN has not been consulted yet but // calling SetToken with empty will force that to override the if proxyToken := os.Getenv(proxyAgent.EnvProxyToken); proxyToken != "" { diff --git a/command/connect/proxy/proxy.go b/command/connect/proxy/proxy.go index a60f2217e7..4c99ab5730 100644 --- a/command/connect/proxy/proxy.go +++ b/command/connect/proxy/proxy.go @@ -129,7 +129,7 @@ func (c *cmd) Run(args []string) int { if c.sidecarFor == "" { c.sidecarFor = os.Getenv(proxyAgent.EnvSidecarFor) } - if c.http.Token() == "" { + if c.http.Token() == "" && c.http.TokenFile() == "" { c.http.SetToken(os.Getenv(proxyAgent.EnvProxyToken)) } diff --git a/command/flags/http.go b/command/flags/http.go index 7d02f6ab3b..e2688fab8c 100644 --- a/command/flags/http.go +++ b/command/flags/http.go @@ -2,6 +2,8 @@ package flags import ( "flag" + "io/ioutil" + "strings" "github.com/hashicorp/consul/api" ) @@ -10,6 +12,7 @@ type HTTPFlags struct { // client api flags address StringValue token StringValue + tokenFile StringValue caFile StringValue caPath StringValue certFile StringValue @@ -33,6 +36,10 @@ func (f *HTTPFlags) ClientFlags() *flag.FlagSet { "ACL token to use in the request. This can also be specified via the "+ "CONSUL_HTTP_TOKEN environment variable. If unspecified, the query will "+ "default to the token of the Consul agent at the HTTP address.") + fs.Var(&f.tokenFile, "token-file", + "File containing the ACL token to use in the request instead of one specified "+ + "via the -token argument or CONSUL_HTTP_TOKEN environment variable. "+ + "This can also be specified via the CONSUL_HTTP_TOKEN_FILE environment variable.") fs.Var(&f.caFile, "ca-file", "Path to a CA file to use for TLS when communicating with Consul. This "+ "can also be specified via the CONSUL_CACERT environment variable.") @@ -88,6 +95,28 @@ func (f *HTTPFlags) SetToken(v string) error { return f.token.Set(v) } +func (f *HTTPFlags) TokenFile() string { + return f.tokenFile.String() +} + +func (f *HTTPFlags) SetTokenFile(v string) error { + return f.tokenFile.Set(v) +} + +func (f *HTTPFlags) ReadTokenFile() (string, error) { + tokenFile := f.tokenFile.String() + if tokenFile == "" { + return "", nil + } + + data, err := ioutil.ReadFile(tokenFile) + if err != nil { + return "", err + } + + return strings.TrimSpace(string(data)), nil +} + func (f *HTTPFlags) APIClient() (*api.Client, error) { c := api.DefaultConfig() @@ -99,6 +128,7 @@ func (f *HTTPFlags) APIClient() (*api.Client, error) { func (f *HTTPFlags) MergeOntoConfig(c *api.Config) { f.address.Merge(&c.Address) f.token.Merge(&c.Token) + f.tokenFile.Merge(&c.TokenFile) f.caFile.Merge(&c.TLSConfig.CAFile) f.caPath.Merge(&c.TLSConfig.CAPath) f.certFile.Merge(&c.TLSConfig.CertFile) diff --git a/command/login/login.go b/command/login/login.go new file mode 100644 index 0000000000..ada268cac8 --- /dev/null +++ b/command/login/login.go @@ -0,0 +1,148 @@ +package login + +import ( + "flag" + "fmt" + "io/ioutil" + "strings" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/flags" + "github.com/hashicorp/consul/lib/file" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + shutdownCh <-chan struct{} + + bearerToken string + + // flags + authMethodName string + bearerTokenFile string + tokenSinkFile string + meta map[string]string +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.StringVar(&c.authMethodName, "method", "", + "Name of the auth method to login to.") + + c.flags.StringVar(&c.bearerTokenFile, "bearer-token-file", "", + "Path to a file containing a secret bearer token to use with this auth method.") + + c.flags.StringVar(&c.tokenSinkFile, "token-sink-file", "", + "The most recent token's SecretID is kept up to date in this file.") + + c.flags.Var((*flags.FlagMapValue)(&c.meta), "meta", + "Metadata to set on the token, formatted as key=value. This flag "+ + "may be specified multiple times to set multiple meta fields.") + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + if len(c.flags.Args()) > 0 { + c.UI.Error(fmt.Sprintf("Should have no non-flag arguments.")) + return 1 + } + + if c.authMethodName == "" { + c.UI.Error(fmt.Sprintf("Missing required '-method' flag")) + return 1 + } + if c.tokenSinkFile == "" { + c.UI.Error(fmt.Sprintf("Missing required '-token-sink-file' flag")) + return 1 + } + + if c.bearerTokenFile == "" { + c.UI.Error(fmt.Sprintf("Missing required '-bearer-token-file' flag")) + return 1 + } + + data, err := ioutil.ReadFile(c.bearerTokenFile) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + c.bearerToken = strings.TrimSpace(string(data)) + + if c.bearerToken == "" { + c.UI.Error(fmt.Sprintf("No bearer token found in %s", c.bearerTokenFile)) + return 1 + } + + // Ensure that we don't try to use a token when performing a login + // operation. + c.http.SetToken("") + c.http.SetTokenFile("") + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + // Do the login. + req := &api.ACLLoginParams{ + AuthMethod: c.authMethodName, + BearerToken: c.bearerToken, + Meta: c.meta, + } + tok, _, err := client.ACL().Login(req, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error logging in: %s", err)) + return 1 + } + + if err := c.writeToSink(tok); err != nil { + c.UI.Error(fmt.Sprintf("Error writing token to file sink: %s", err)) + return 1 + } + + return 0 +} + +func (c *cmd) writeToSink(tok *api.ACLToken) error { + payload := []byte(tok.SecretID) + return file.WriteAtomicWithPerms(c.tokenSinkFile, payload, 0600) +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Login to Consul using an Auth Method" + +const help = ` +Usage: consul login [options] + + The login command will exchange the provided third party credentials with the + requested auth method for a newly minted Consul ACL Token. The companion + command 'consul logout' should be used to destroy any tokens created this way + to avoid a resource leak. +` diff --git a/command/login/login_test.go b/command/login/login_test.go new file mode 100644 index 0000000000..c2988d8626 --- /dev/null +++ b/command/login/login_test.go @@ -0,0 +1,321 @@ +package login + +import ( + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth" + "github.com/hashicorp/consul/agent/consul/authmethod/testauth" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" +) + +func TestLoginCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestLoginCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("method is required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-method' flag") + }) + + tokenSinkFile := filepath.Join(testDir, "test.token") + + t.Run("token-sink-file is required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-token-sink-file' flag") + }) + + t.Run("bearer-token-file is required", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bearer-token-file' flag") + }) + + bearerTokenFile := filepath.Join(testDir, "bearer.token") + + t.Run("bearer-token-file is empty", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte(""), 0600)) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", bearerTokenFile, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "No bearer token found in") + }) + + require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte("demo-token"), 0600)) + + t.Run("try login with no method configured", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", bearerTokenFile, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)") + }) + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + + testauth.InstallSessionToken( + testSessionID, + "demo-token", + "default", "demo", "76091af4-4b56-11e9-ac4b-708b11801cbe", + ) + + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + t.Run("try login with method configured but no binding rules", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", bearerTokenFile, + } + + code := cmd.Run(args) + require.Equal(t, 1, code, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (Permission denied)") + }) + + { + _, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{ + AuthMethod: "test", + BindType: api.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + t.Run("try login with method configured and binding rules", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", bearerTokenFile, + } + + code := cmd.Run(args) + require.Equal(t, 0, code, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + require.Empty(t, ui.OutputWriter.String()) + + raw, err := ioutil.ReadFile(tokenSinkFile) + require.NoError(t, err) + + token := strings.TrimSpace(string(raw)) + require.Len(t, token, 36, "must be a valid uid: %s", token) + }) +} + +func TestLoginCommand_k8s(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + tokenSinkFile := filepath.Join(testDir, "test.token") + bearerTokenFile := filepath.Join(testDir, "bearer.token") + + // the "B" jwt will be the one being reviewed + require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte(acl.TestKubernetesJWT_B), 0600)) + + // spin up a fake api server + testSrv := kubeauth.StartTestAPIServer(t) + defer testSrv.Stop() + + testSrv.AuthorizeJWT(acl.TestKubernetesJWT_A) + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "", + acl.TestKubernetesJWT_B, + ) + + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "k8s", + Type: "kubernetes", + Config: map[string]interface{}{ + "Host": testSrv.Addr(), + "CACert": testSrv.CACert(), + // the "A" jwt will be the one with token review privs + "ServiceAccountJWT": acl.TestKubernetesJWT_A, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + { + _, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{ + AuthMethod: "k8s", + BindType: api.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + t.Run("try login with method configured and binding rules", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=k8s", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", bearerTokenFile, + } + + code := cmd.Run(args) + require.Equal(t, 0, code, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + require.Empty(t, ui.OutputWriter.String()) + + raw, err := ioutil.ReadFile(tokenSinkFile) + require.NoError(t, err) + + token := strings.TrimSpace(string(raw)) + require.Len(t, token, 36, "must be a valid uid: %s", token) + }) +} diff --git a/command/logout/logout.go b/command/logout/logout.go new file mode 100644 index 0000000000..eca9c416be --- /dev/null +++ b/command/logout/logout.go @@ -0,0 +1,70 @@ +package logout + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + if len(c.flags.Args()) > 0 { + c.UI.Error(fmt.Sprintf("Should have no non-flag arguments.")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + if _, err := client.ACL().Logout(nil); err != nil { + c.UI.Error(fmt.Sprintf("Error destroying token: %v", err)) + return 1 + } + + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Destroy a Consul Token created with Login" + +const help = ` +Usage: consul logout [options] + + The logout command will destroy the provided token if it was created from + 'consul login'. +` diff --git a/command/logout/logout_test.go b/command/logout/logout_test.go new file mode 100644 index 0000000000..5596297b9c --- /dev/null +++ b/command/logout/logout_test.go @@ -0,0 +1,299 @@ +package logout + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth" + "github.com/hashicorp/consul/agent/consul/authmethod/testauth" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" +) + +func TestLogout_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestLogoutCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("no token specified", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)") + }) + + t.Run("logout of deleted token", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + fakeID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)") + }) + + plainToken, _, err := client.ACL().TokenCreate( + &api.ACLToken{Description: "test"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + t.Run("logout of ordinary token", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + plainToken.SecretID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (Permission denied)") + }) + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + + testauth.InstallSessionToken( + testSessionID, + "demo-token", + "default", "demo", "76091af4-4b56-11e9-ac4b-708b11801cbe", + ) + + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + { + _, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{ + AuthMethod: "test", + BindType: api.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + var loginTokenSecret string + { + tok, _, err := client.ACL().Login(&api.ACLLoginParams{ + AuthMethod: "test", + BearerToken: "demo-token", + }, nil) + require.NoError(t, err) + + loginTokenSecret = tok.SecretID + } + + t.Run("logout of login token", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + loginTokenSecret, + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + }) +} + +func TestLogoutCommand_k8s(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("no token specified", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)") + }) + + t.Run("logout of deleted token", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + fakeID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)") + }) + + plainToken, _, err := client.ACL().TokenCreate( + &api.ACLToken{Description: "test"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + t.Run("logout of ordinary token", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + plainToken.SecretID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (Permission denied)") + }) + + // go to the trouble of creating a login token + // require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte(acl.TestKubernetesJWT_B), 0600)) + + // spin up a fake api server + testSrv := kubeauth.StartTestAPIServer(t) + defer testSrv.Stop() + + testSrv.AuthorizeJWT(acl.TestKubernetesJWT_A) + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "", + acl.TestKubernetesJWT_B, + ) + + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "k8s", + Type: "kubernetes", + Config: map[string]interface{}{ + "Host": testSrv.Addr(), + "CACert": testSrv.CACert(), + // the "A" jwt will be the one with token review privs + "ServiceAccountJWT": acl.TestKubernetesJWT_A, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + { + _, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{ + AuthMethod: "k8s", + BindType: api.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + var loginTokenSecret string + { + tok, _, err := client.ACL().Login(&api.ACLLoginParams{ + AuthMethod: "k8s", + BearerToken: acl.TestKubernetesJWT_B, + }, nil) + require.NoError(t, err) + + loginTokenSecret = tok.SecretID + } + + t.Run("logout of login token", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + loginTokenSecret, + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + }) +} diff --git a/command/watch/watch.go b/command/watch/watch.go index fd44e81fbb..92178a1f06 100644 --- a/command/watch/watch.go +++ b/command/watch/watch.go @@ -87,6 +87,14 @@ func (c *cmd) Run(args []string) int { return 1 } + token := c.http.Token() + if tokenFromFile, err := c.http.ReadTokenFile(); err != nil { + c.UI.Error(fmt.Sprintf("Error loading token file: %s", err)) + return 1 + } else if tokenFromFile != "" { + token = tokenFromFile + } + // Compile the watch parameters params := make(map[string]interface{}) if c.watchType != "" { @@ -95,8 +103,8 @@ func (c *cmd) Run(args []string) int { if c.http.Datacenter() != "" { params["datacenter"] = c.http.Datacenter() } - if c.http.Token() != "" { - params["token"] = c.http.Token() + if token != "" { + params["token"] = token } if c.key != "" { params["key"] = c.key diff --git a/go.mod b/go.mod index 4d87ae9f00..642f2d47d7 100644 --- a/go.mod +++ b/go.mod @@ -123,7 +123,9 @@ require ( gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d // indirect gopkg.in/mgo.v2 v2.0.0-20160818020120-3f83fa500528 // indirect gopkg.in/ory-am/dockertest.v3 v3.3.4 // indirect + gopkg.in/square/go-jose.v2 v2.3.1 gotest.tools v2.2.0+incompatible // indirect - k8s.io/api v0.0.0-20190118113203-912cbe2bfef3 // indirect - k8s.io/apimachinery v0.0.0-20180904193909-def12e63c512 // indirect + k8s.io/api v0.0.0-20190325185214-7544f9db76f6 + k8s.io/apimachinery v0.0.0-20190223001710-c182ff3b9841 + k8s.io/client-go v8.0.0+incompatible ) diff --git a/go.sum b/go.sum index 8bba6865b0..649673f1f2 100644 --- a/go.sum +++ b/go.sum @@ -383,6 +383,8 @@ gopkg.in/mgo.v2 v2.0.0-20160818020120-3f83fa500528 h1:/saqWwm73dLmuzbNhe92F0QsZ/ gopkg.in/mgo.v2 v2.0.0-20160818020120-3f83fa500528/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= gopkg.in/ory-am/dockertest.v3 v3.3.4 h1:oen8RiwxVNxtQ1pRoV4e4jqh6UjNsOuIZ1NXns6jdcw= gopkg.in/ory-am/dockertest.v3 v3.3.4/go.mod h1:s9mmoLkaGeAh97qygnNj4xWkiN7e1SKekYC6CovU+ek= +gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4= +gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= @@ -391,10 +393,10 @@ gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= k8s.io/api v0.0.0-20180806132203-61b11ee65332/go.mod h1:iuAfoD4hCxJ8Onx9kaTIt30j7jUFS00AXQi6QMi99vA= -k8s.io/api v0.0.0-20190118113203-912cbe2bfef3 h1:lV0+KGoNkvZOt4zGT4H83hQrzWMt/US/LSz4z4+BQS4= -k8s.io/api v0.0.0-20190118113203-912cbe2bfef3/go.mod h1:iuAfoD4hCxJ8Onx9kaTIt30j7jUFS00AXQi6QMi99vA= +k8s.io/api v0.0.0-20190325185214-7544f9db76f6 h1:9MWtbqhwTyDvF4cS1qAhxDb9Mi8taXiAu+5nEacl7gY= +k8s.io/api v0.0.0-20190325185214-7544f9db76f6/go.mod h1:iuAfoD4hCxJ8Onx9kaTIt30j7jUFS00AXQi6QMi99vA= k8s.io/apimachinery v0.0.0-20180821005732-488889b0007f/go.mod h1:ccL7Eh7zubPUSh9A3USN90/OzHNSVN6zxzde07TDCL0= -k8s.io/apimachinery v0.0.0-20180904193909-def12e63c512 h1:/Z1m/6oEN6hE2SzWP4BHW2yATeUrBRr+1GxNf1Ny58Y= -k8s.io/apimachinery v0.0.0-20180904193909-def12e63c512/go.mod h1:ccL7Eh7zubPUSh9A3USN90/OzHNSVN6zxzde07TDCL0= +k8s.io/apimachinery v0.0.0-20190223001710-c182ff3b9841 h1:Q4RZrHNtlC/mSdC1sTrcZ5RchC/9vxLVj57pWiCBKv4= +k8s.io/apimachinery v0.0.0-20190223001710-c182ff3b9841/go.mod h1:ccL7Eh7zubPUSh9A3USN90/OzHNSVN6zxzde07TDCL0= k8s.io/client-go v8.0.0+incompatible h1:tTI4hRmb1DRMl4fG6Vclfdi6nTM82oIrTT7HfitmxC4= k8s.io/client-go v8.0.0+incompatible/go.mod h1:7vJpHMYJwNQCWgzmNV+VYUl1zCObLyodBc8nIyt8L5s= diff --git a/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go b/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go new file mode 100644 index 0000000000..593f653008 --- /dev/null +++ b/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go @@ -0,0 +1,77 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package pbkdf2 implements the key derivation function PBKDF2 as defined in RFC +2898 / PKCS #5 v2.0. + +A key derivation function is useful when encrypting data based on a password +or any other not-fully-random data. It uses a pseudorandom function to derive +a secure encryption key based on the password. + +While v2.0 of the standard defines only one pseudorandom function to use, +HMAC-SHA1, the drafted v2.1 specification allows use of all five FIPS Approved +Hash Functions SHA-1, SHA-224, SHA-256, SHA-384 and SHA-512 for HMAC. To +choose, you can pass the `New` functions from the different SHA packages to +pbkdf2.Key. +*/ +package pbkdf2 // import "golang.org/x/crypto/pbkdf2" + +import ( + "crypto/hmac" + "hash" +) + +// Key derives a key from the password, salt and iteration count, returning a +// []byte of length keylen that can be used as cryptographic key. The key is +// derived based on the method described as PBKDF2 with the HMAC variant using +// the supplied hash function. +// +// For example, to use a HMAC-SHA-1 based PBKDF2 key derivation function, you +// can get a derived key for e.g. AES-256 (which needs a 32-byte key) by +// doing: +// +// dk := pbkdf2.Key([]byte("some password"), salt, 4096, 32, sha1.New) +// +// Remember to get a good random salt. At least 8 bytes is recommended by the +// RFC. +// +// Using a higher iteration count will increase the cost of an exhaustive +// search but will also make derivation proportionally slower. +func Key(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte { + prf := hmac.New(h, password) + hashLen := prf.Size() + numBlocks := (keyLen + hashLen - 1) / hashLen + + var buf [4]byte + dk := make([]byte, 0, numBlocks*hashLen) + U := make([]byte, hashLen) + for block := 1; block <= numBlocks; block++ { + // N.B.: || means concatenation, ^ means XOR + // for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter + // U_1 = PRF(password, salt || uint(i)) + prf.Reset() + prf.Write(salt) + buf[0] = byte(block >> 24) + buf[1] = byte(block >> 16) + buf[2] = byte(block >> 8) + buf[3] = byte(block) + prf.Write(buf[:4]) + dk = prf.Sum(dk) + T := dk[len(dk)-hashLen:] + copy(U, T) + + // U_n = PRF(password, U_(n-1)) + for n := 2; n <= iter; n++ { + prf.Reset() + prf.Write(U) + U = U[:0] + U = prf.Sum(U) + for x := range U { + T[x] ^= U[x] + } + } + } + return dk[:keyLen] +} diff --git a/vendor/gopkg.in/square/go-jose.v2/.gitcookies.sh.enc b/vendor/gopkg.in/square/go-jose.v2/.gitcookies.sh.enc new file mode 100644 index 0000000000..730e569b06 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/.gitcookies.sh.enc @@ -0,0 +1 @@ +'|Ê&{tÄU|gGê(ìCy=+¨œòcû:u:/pœ#~žü["±4¤!­nÙAªDK<ŠufÿhÅa¿Â:ºü¸¡´B/£Ø¤¹¤ò_hÎÛSãT*wÌx¼¯¹-ç|àÀÓƒÑÄäóÌ㣗A$$â6£ÁâG)8nÏpûÆË¡3ÌšœoïÏvŽB–3¿­]xÝ“Ó2l§G•|qRÞ¯ ö2 5R–Ó×Ç$´ñ½Yè¡ÞÝ™l‘Ë«yAI"ÛŒ˜®íû¹¼kÄ|Kåþ[9ÆâÒå=°úÿŸñ|@S•3 ó#æx?¾V„,¾‚SÆÝõœwPíogÒ6&V6 ©D.dBŠ 7 \ No newline at end of file diff --git a/vendor/gopkg.in/square/go-jose.v2/.gitignore b/vendor/gopkg.in/square/go-jose.v2/.gitignore new file mode 100644 index 0000000000..5b4d73b681 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/.gitignore @@ -0,0 +1,7 @@ +*~ +.*.swp +*.out +*.test +*.pem +*.cov +jose-util/jose-util diff --git a/vendor/gopkg.in/square/go-jose.v2/.travis.yml b/vendor/gopkg.in/square/go-jose.v2/.travis.yml new file mode 100644 index 0000000000..fc501ca9b7 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/.travis.yml @@ -0,0 +1,46 @@ +language: go + +sudo: false + +matrix: + fast_finish: true + allow_failures: + - go: tip + +go: +- '1.7.x' +- '1.8.x' +- '1.9.x' +- '1.10.x' +- '1.11.x' + +go_import_path: gopkg.in/square/go-jose.v2 + +before_script: +- export PATH=$HOME/.local/bin:$PATH + +before_install: +# Install encrypted gitcookies to get around bandwidth-limits +# that is causing Travis-CI builds to fail. For more info, see +# https://github.com/golang/go/issues/12933 +- openssl aes-256-cbc -K $encrypted_1528c3c2cafd_key -iv $encrypted_1528c3c2cafd_iv -in .gitcookies.sh.enc -out .gitcookies.sh -d || true +- bash .gitcookies.sh || true +- go get github.com/wadey/gocovmerge +- go get github.com/mattn/goveralls +- go get github.com/stretchr/testify/assert +- go get golang.org/x/tools/cmd/cover || true +- go get code.google.com/p/go.tools/cmd/cover || true +- pip install cram --user + +script: +- go test . -v -covermode=count -coverprofile=profile.cov +- go test ./cipher -v -covermode=count -coverprofile=cipher/profile.cov +- go test ./jwt -v -covermode=count -coverprofile=jwt/profile.cov +- go test ./json -v # no coverage for forked encoding/json package +- cd jose-util && go build && PATH=$PWD:$PATH cram -v jose-util.t +- cd .. + +after_success: +- gocovmerge *.cov */*.cov > merged.coverprofile +- $HOME/gopath/bin/goveralls -coverprofile merged.coverprofile -service=travis-ci + diff --git a/vendor/gopkg.in/square/go-jose.v2/BUG-BOUNTY.md b/vendor/gopkg.in/square/go-jose.v2/BUG-BOUNTY.md new file mode 100644 index 0000000000..3305db0f65 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/BUG-BOUNTY.md @@ -0,0 +1,10 @@ +Serious about security +====================== + +Square recognizes the important contributions the security research community +can make. We therefore encourage reporting security issues with the code +contained in this repository. + +If you believe you have discovered a security vulnerability, please follow the +guidelines at . + diff --git a/vendor/gopkg.in/square/go-jose.v2/CONTRIBUTING.md b/vendor/gopkg.in/square/go-jose.v2/CONTRIBUTING.md new file mode 100644 index 0000000000..61b183651c --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/CONTRIBUTING.md @@ -0,0 +1,14 @@ +# Contributing + +If you would like to contribute code to go-jose you can do so through GitHub by +forking the repository and sending a pull request. + +When submitting code, please make every effort to follow existing conventions +and style in order to keep the code as readable as possible. Please also make +sure all tests pass by running `go test`, and format your code with `go fmt`. +We also recommend using `golint` and `errcheck`. + +Before your code can be accepted into the project you must also sign the +[Individual Contributor License Agreement][1]. + + [1]: https://spreadsheets.google.com/spreadsheet/viewform?formkey=dDViT2xzUHAwRkI3X3k5Z0lQM091OGc6MQ&ndplr=1 diff --git a/vendor/gopkg.in/square/go-jose.v2/LICENSE b/vendor/gopkg.in/square/go-jose.v2/LICENSE new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/gopkg.in/square/go-jose.v2/README.md b/vendor/gopkg.in/square/go-jose.v2/README.md new file mode 100644 index 0000000000..1791bfa8f6 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/README.md @@ -0,0 +1,118 @@ +# Go JOSE + +[![godoc](http://img.shields.io/badge/godoc-version_1-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v1) +[![godoc](http://img.shields.io/badge/godoc-version_2-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v2) +[![license](http://img.shields.io/badge/license-apache_2.0-blue.svg?style=flat)](https://raw.githubusercontent.com/square/go-jose/master/LICENSE) +[![build](https://travis-ci.org/square/go-jose.svg?branch=v2)](https://travis-ci.org/square/go-jose) +[![coverage](https://coveralls.io/repos/github/square/go-jose/badge.svg?branch=v2)](https://coveralls.io/r/square/go-jose) + +Package jose aims to provide an implementation of the Javascript Object Signing +and Encryption set of standards. This includes support for JSON Web Encryption, +JSON Web Signature, and JSON Web Token standards. + +**Disclaimer**: This library contains encryption software that is subject to +the U.S. Export Administration Regulations. You may not export, re-export, +transfer or download this code or any part of it in violation of any United +States law, directive or regulation. In particular this software may not be +exported or re-exported in any form or on any media to Iran, North Sudan, +Syria, Cuba, or North Korea, or to denied persons or entities mentioned on any +US maintained blocked list. + +## Overview + +The implementation follows the +[JSON Web Encryption](http://dx.doi.org/10.17487/RFC7516) (RFC 7516), +[JSON Web Signature](http://dx.doi.org/10.17487/RFC7515) (RFC 7515), and +[JSON Web Token](http://dx.doi.org/10.17487/RFC7519) (RFC 7519). +Tables of supported algorithms are shown below. The library supports both +the compact and full serialization formats, and has optional support for +multiple recipients. It also comes with a small command-line utility +([`jose-util`](https://github.com/square/go-jose/tree/v2/jose-util)) +for dealing with JOSE messages in a shell. + +**Note**: We use a forked version of the `encoding/json` package from the Go +standard library which uses case-sensitive matching for member names (instead +of [case-insensitive matching](https://www.ietf.org/mail-archive/web/json/current/msg03763.html)). +This is to avoid differences in interpretation of messages between go-jose and +libraries in other languages. + +### Versions + +We use [gopkg.in](https://gopkg.in) for versioning. + +[Version 2](https://gopkg.in/square/go-jose.v2) +([branch](https://github.com/square/go-jose/tree/v2), +[doc](https://godoc.org/gopkg.in/square/go-jose.v2)) is the current version: + + import "gopkg.in/square/go-jose.v2" + +The old `v1` branch ([go-jose.v1](https://gopkg.in/square/go-jose.v1)) will +still receive backported bug fixes and security fixes, but otherwise +development is frozen. All new feature development takes place on the `v2` +branch. Version 2 also contains additional sub-packages such as the +[jwt](https://godoc.org/gopkg.in/square/go-jose.v2/jwt) implementation +contributed by [@shaxbee](https://github.com/shaxbee). + +### Supported algorithms + +See below for a table of supported algorithms. Algorithm identifiers match +the names in the [JSON Web Algorithms](http://dx.doi.org/10.17487/RFC7518) +standard where possible. The Godoc reference has a list of constants. + + Key encryption | Algorithm identifier(s) + :------------------------- | :------------------------------ + RSA-PKCS#1v1.5 | RSA1_5 + RSA-OAEP | RSA-OAEP, RSA-OAEP-256 + AES key wrap | A128KW, A192KW, A256KW + AES-GCM key wrap | A128GCMKW, A192GCMKW, A256GCMKW + ECDH-ES + AES key wrap | ECDH-ES+A128KW, ECDH-ES+A192KW, ECDH-ES+A256KW + ECDH-ES (direct) | ECDH-ES1 + Direct encryption | dir1 + +1. Not supported in multi-recipient mode + + Signing / MAC | Algorithm identifier(s) + :------------------------- | :------------------------------ + RSASSA-PKCS#1v1.5 | RS256, RS384, RS512 + RSASSA-PSS | PS256, PS384, PS512 + HMAC | HS256, HS384, HS512 + ECDSA | ES256, ES384, ES512 + Ed25519 | EdDSA2 + +2. Only available in version 2 of the package + + Content encryption | Algorithm identifier(s) + :------------------------- | :------------------------------ + AES-CBC+HMAC | A128CBC-HS256, A192CBC-HS384, A256CBC-HS512 + AES-GCM | A128GCM, A192GCM, A256GCM + + Compression | Algorithm identifiers(s) + :------------------------- | ------------------------------- + DEFLATE (RFC 1951) | DEF + +### Supported key types + +See below for a table of supported key types. These are understood by the +library, and can be passed to corresponding functions such as `NewEncrypter` or +`NewSigner`. Each of these keys can also be wrapped in a JWK if desired, which +allows attaching a key id. + + Algorithm(s) | Corresponding types + :------------------------- | ------------------------------- + RSA | *[rsa.PublicKey](http://golang.org/pkg/crypto/rsa/#PublicKey), *[rsa.PrivateKey](http://golang.org/pkg/crypto/rsa/#PrivateKey) + ECDH, ECDSA | *[ecdsa.PublicKey](http://golang.org/pkg/crypto/ecdsa/#PublicKey), *[ecdsa.PrivateKey](http://golang.org/pkg/crypto/ecdsa/#PrivateKey) + EdDSA1 | [ed25519.PublicKey](https://godoc.org/golang.org/x/crypto/ed25519#PublicKey), [ed25519.PrivateKey](https://godoc.org/golang.org/x/crypto/ed25519#PrivateKey) + AES, HMAC | []byte + +1. Only available in version 2 of the package + +## Examples + +[![godoc](http://img.shields.io/badge/godoc-version_1-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v1) +[![godoc](http://img.shields.io/badge/godoc-version_2-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v2) + +Examples can be found in the Godoc +reference for this package. The +[`jose-util`](https://github.com/square/go-jose/tree/v2/jose-util) +subdirectory also contains a small command-line utility which might be useful +as an example. diff --git a/vendor/gopkg.in/square/go-jose.v2/asymmetric.go b/vendor/gopkg.in/square/go-jose.v2/asymmetric.go new file mode 100644 index 0000000000..67935561bc --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/asymmetric.go @@ -0,0 +1,592 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package jose + +import ( + "crypto" + "crypto/aes" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "errors" + "fmt" + "math/big" + + "golang.org/x/crypto/ed25519" + "gopkg.in/square/go-jose.v2/cipher" + "gopkg.in/square/go-jose.v2/json" +) + +// A generic RSA-based encrypter/verifier +type rsaEncrypterVerifier struct { + publicKey *rsa.PublicKey +} + +// A generic RSA-based decrypter/signer +type rsaDecrypterSigner struct { + privateKey *rsa.PrivateKey +} + +// A generic EC-based encrypter/verifier +type ecEncrypterVerifier struct { + publicKey *ecdsa.PublicKey +} + +type edEncrypterVerifier struct { + publicKey ed25519.PublicKey +} + +// A key generator for ECDH-ES +type ecKeyGenerator struct { + size int + algID string + publicKey *ecdsa.PublicKey +} + +// A generic EC-based decrypter/signer +type ecDecrypterSigner struct { + privateKey *ecdsa.PrivateKey +} + +type edDecrypterSigner struct { + privateKey ed25519.PrivateKey +} + +// newRSARecipient creates recipientKeyInfo based on the given key. +func newRSARecipient(keyAlg KeyAlgorithm, publicKey *rsa.PublicKey) (recipientKeyInfo, error) { + // Verify that key management algorithm is supported by this encrypter + switch keyAlg { + case RSA1_5, RSA_OAEP, RSA_OAEP_256: + default: + return recipientKeyInfo{}, ErrUnsupportedAlgorithm + } + + if publicKey == nil { + return recipientKeyInfo{}, errors.New("invalid public key") + } + + return recipientKeyInfo{ + keyAlg: keyAlg, + keyEncrypter: &rsaEncrypterVerifier{ + publicKey: publicKey, + }, + }, nil +} + +// newRSASigner creates a recipientSigInfo based on the given key. +func newRSASigner(sigAlg SignatureAlgorithm, privateKey *rsa.PrivateKey) (recipientSigInfo, error) { + // Verify that key management algorithm is supported by this encrypter + switch sigAlg { + case RS256, RS384, RS512, PS256, PS384, PS512: + default: + return recipientSigInfo{}, ErrUnsupportedAlgorithm + } + + if privateKey == nil { + return recipientSigInfo{}, errors.New("invalid private key") + } + + return recipientSigInfo{ + sigAlg: sigAlg, + publicKey: staticPublicKey(&JSONWebKey{ + Key: privateKey.Public(), + }), + signer: &rsaDecrypterSigner{ + privateKey: privateKey, + }, + }, nil +} + +func newEd25519Signer(sigAlg SignatureAlgorithm, privateKey ed25519.PrivateKey) (recipientSigInfo, error) { + if sigAlg != EdDSA { + return recipientSigInfo{}, ErrUnsupportedAlgorithm + } + + if privateKey == nil { + return recipientSigInfo{}, errors.New("invalid private key") + } + return recipientSigInfo{ + sigAlg: sigAlg, + publicKey: staticPublicKey(&JSONWebKey{ + Key: privateKey.Public(), + }), + signer: &edDecrypterSigner{ + privateKey: privateKey, + }, + }, nil +} + +// newECDHRecipient creates recipientKeyInfo based on the given key. +func newECDHRecipient(keyAlg KeyAlgorithm, publicKey *ecdsa.PublicKey) (recipientKeyInfo, error) { + // Verify that key management algorithm is supported by this encrypter + switch keyAlg { + case ECDH_ES, ECDH_ES_A128KW, ECDH_ES_A192KW, ECDH_ES_A256KW: + default: + return recipientKeyInfo{}, ErrUnsupportedAlgorithm + } + + if publicKey == nil || !publicKey.Curve.IsOnCurve(publicKey.X, publicKey.Y) { + return recipientKeyInfo{}, errors.New("invalid public key") + } + + return recipientKeyInfo{ + keyAlg: keyAlg, + keyEncrypter: &ecEncrypterVerifier{ + publicKey: publicKey, + }, + }, nil +} + +// newECDSASigner creates a recipientSigInfo based on the given key. +func newECDSASigner(sigAlg SignatureAlgorithm, privateKey *ecdsa.PrivateKey) (recipientSigInfo, error) { + // Verify that key management algorithm is supported by this encrypter + switch sigAlg { + case ES256, ES384, ES512: + default: + return recipientSigInfo{}, ErrUnsupportedAlgorithm + } + + if privateKey == nil { + return recipientSigInfo{}, errors.New("invalid private key") + } + + return recipientSigInfo{ + sigAlg: sigAlg, + publicKey: staticPublicKey(&JSONWebKey{ + Key: privateKey.Public(), + }), + signer: &ecDecrypterSigner{ + privateKey: privateKey, + }, + }, nil +} + +// Encrypt the given payload and update the object. +func (ctx rsaEncrypterVerifier) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) { + encryptedKey, err := ctx.encrypt(cek, alg) + if err != nil { + return recipientInfo{}, err + } + + return recipientInfo{ + encryptedKey: encryptedKey, + header: &rawHeader{}, + }, nil +} + +// Encrypt the given payload. Based on the key encryption algorithm, +// this will either use RSA-PKCS1v1.5 or RSA-OAEP (with SHA-1 or SHA-256). +func (ctx rsaEncrypterVerifier) encrypt(cek []byte, alg KeyAlgorithm) ([]byte, error) { + switch alg { + case RSA1_5: + return rsa.EncryptPKCS1v15(RandReader, ctx.publicKey, cek) + case RSA_OAEP: + return rsa.EncryptOAEP(sha1.New(), RandReader, ctx.publicKey, cek, []byte{}) + case RSA_OAEP_256: + return rsa.EncryptOAEP(sha256.New(), RandReader, ctx.publicKey, cek, []byte{}) + } + + return nil, ErrUnsupportedAlgorithm +} + +// Decrypt the given payload and return the content encryption key. +func (ctx rsaDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) { + return ctx.decrypt(recipient.encryptedKey, headers.getAlgorithm(), generator) +} + +// Decrypt the given payload. Based on the key encryption algorithm, +// this will either use RSA-PKCS1v1.5 or RSA-OAEP (with SHA-1 or SHA-256). +func (ctx rsaDecrypterSigner) decrypt(jek []byte, alg KeyAlgorithm, generator keyGenerator) ([]byte, error) { + // Note: The random reader on decrypt operations is only used for blinding, + // so stubbing is meanlingless (hence the direct use of rand.Reader). + switch alg { + case RSA1_5: + defer func() { + // DecryptPKCS1v15SessionKey sometimes panics on an invalid payload + // because of an index out of bounds error, which we want to ignore. + // This has been fixed in Go 1.3.1 (released 2014/08/13), the recover() + // only exists for preventing crashes with unpatched versions. + // See: https://groups.google.com/forum/#!topic/golang-dev/7ihX6Y6kx9k + // See: https://code.google.com/p/go/source/detail?r=58ee390ff31602edb66af41ed10901ec95904d33 + _ = recover() + }() + + // Perform some input validation. + keyBytes := ctx.privateKey.PublicKey.N.BitLen() / 8 + if keyBytes != len(jek) { + // Input size is incorrect, the encrypted payload should always match + // the size of the public modulus (e.g. using a 2048 bit key will + // produce 256 bytes of output). Reject this since it's invalid input. + return nil, ErrCryptoFailure + } + + cek, _, err := generator.genKey() + if err != nil { + return nil, ErrCryptoFailure + } + + // When decrypting an RSA-PKCS1v1.5 payload, we must take precautions to + // prevent chosen-ciphertext attacks as described in RFC 3218, "Preventing + // the Million Message Attack on Cryptographic Message Syntax". We are + // therefore deliberately ignoring errors here. + _ = rsa.DecryptPKCS1v15SessionKey(rand.Reader, ctx.privateKey, jek, cek) + + return cek, nil + case RSA_OAEP: + // Use rand.Reader for RSA blinding + return rsa.DecryptOAEP(sha1.New(), rand.Reader, ctx.privateKey, jek, []byte{}) + case RSA_OAEP_256: + // Use rand.Reader for RSA blinding + return rsa.DecryptOAEP(sha256.New(), rand.Reader, ctx.privateKey, jek, []byte{}) + } + + return nil, ErrUnsupportedAlgorithm +} + +// Sign the given payload +func (ctx rsaDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) { + var hash crypto.Hash + + switch alg { + case RS256, PS256: + hash = crypto.SHA256 + case RS384, PS384: + hash = crypto.SHA384 + case RS512, PS512: + hash = crypto.SHA512 + default: + return Signature{}, ErrUnsupportedAlgorithm + } + + hasher := hash.New() + + // According to documentation, Write() on hash never fails + _, _ = hasher.Write(payload) + hashed := hasher.Sum(nil) + + var out []byte + var err error + + switch alg { + case RS256, RS384, RS512: + out, err = rsa.SignPKCS1v15(RandReader, ctx.privateKey, hash, hashed) + case PS256, PS384, PS512: + out, err = rsa.SignPSS(RandReader, ctx.privateKey, hash, hashed, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthAuto, + }) + } + + if err != nil { + return Signature{}, err + } + + return Signature{ + Signature: out, + protected: &rawHeader{}, + }, nil +} + +// Verify the given payload +func (ctx rsaEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error { + var hash crypto.Hash + + switch alg { + case RS256, PS256: + hash = crypto.SHA256 + case RS384, PS384: + hash = crypto.SHA384 + case RS512, PS512: + hash = crypto.SHA512 + default: + return ErrUnsupportedAlgorithm + } + + hasher := hash.New() + + // According to documentation, Write() on hash never fails + _, _ = hasher.Write(payload) + hashed := hasher.Sum(nil) + + switch alg { + case RS256, RS384, RS512: + return rsa.VerifyPKCS1v15(ctx.publicKey, hash, hashed, signature) + case PS256, PS384, PS512: + return rsa.VerifyPSS(ctx.publicKey, hash, hashed, signature, nil) + } + + return ErrUnsupportedAlgorithm +} + +// Encrypt the given payload and update the object. +func (ctx ecEncrypterVerifier) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) { + switch alg { + case ECDH_ES: + // ECDH-ES mode doesn't wrap a key, the shared secret is used directly as the key. + return recipientInfo{ + header: &rawHeader{}, + }, nil + case ECDH_ES_A128KW, ECDH_ES_A192KW, ECDH_ES_A256KW: + default: + return recipientInfo{}, ErrUnsupportedAlgorithm + } + + generator := ecKeyGenerator{ + algID: string(alg), + publicKey: ctx.publicKey, + } + + switch alg { + case ECDH_ES_A128KW: + generator.size = 16 + case ECDH_ES_A192KW: + generator.size = 24 + case ECDH_ES_A256KW: + generator.size = 32 + } + + kek, header, err := generator.genKey() + if err != nil { + return recipientInfo{}, err + } + + block, err := aes.NewCipher(kek) + if err != nil { + return recipientInfo{}, err + } + + jek, err := josecipher.KeyWrap(block, cek) + if err != nil { + return recipientInfo{}, err + } + + return recipientInfo{ + encryptedKey: jek, + header: &header, + }, nil +} + +// Get key size for EC key generator +func (ctx ecKeyGenerator) keySize() int { + return ctx.size +} + +// Get a content encryption key for ECDH-ES +func (ctx ecKeyGenerator) genKey() ([]byte, rawHeader, error) { + priv, err := ecdsa.GenerateKey(ctx.publicKey.Curve, RandReader) + if err != nil { + return nil, rawHeader{}, err + } + + out := josecipher.DeriveECDHES(ctx.algID, []byte{}, []byte{}, priv, ctx.publicKey, ctx.size) + + b, err := json.Marshal(&JSONWebKey{ + Key: &priv.PublicKey, + }) + if err != nil { + return nil, nil, err + } + + headers := rawHeader{ + headerEPK: makeRawMessage(b), + } + + return out, headers, nil +} + +// Decrypt the given payload and return the content encryption key. +func (ctx ecDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) { + epk, err := headers.getEPK() + if err != nil { + return nil, errors.New("square/go-jose: invalid epk header") + } + if epk == nil { + return nil, errors.New("square/go-jose: missing epk header") + } + + publicKey, ok := epk.Key.(*ecdsa.PublicKey) + if publicKey == nil || !ok { + return nil, errors.New("square/go-jose: invalid epk header") + } + + if !ctx.privateKey.Curve.IsOnCurve(publicKey.X, publicKey.Y) { + return nil, errors.New("square/go-jose: invalid public key in epk header") + } + + apuData, err := headers.getAPU() + if err != nil { + return nil, errors.New("square/go-jose: invalid apu header") + } + apvData, err := headers.getAPV() + if err != nil { + return nil, errors.New("square/go-jose: invalid apv header") + } + + deriveKey := func(algID string, size int) []byte { + return josecipher.DeriveECDHES(algID, apuData.bytes(), apvData.bytes(), ctx.privateKey, publicKey, size) + } + + var keySize int + + algorithm := headers.getAlgorithm() + switch algorithm { + case ECDH_ES: + // ECDH-ES uses direct key agreement, no key unwrapping necessary. + return deriveKey(string(headers.getEncryption()), generator.keySize()), nil + case ECDH_ES_A128KW: + keySize = 16 + case ECDH_ES_A192KW: + keySize = 24 + case ECDH_ES_A256KW: + keySize = 32 + default: + return nil, ErrUnsupportedAlgorithm + } + + key := deriveKey(string(algorithm), keySize) + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + return josecipher.KeyUnwrap(block, recipient.encryptedKey) +} + +func (ctx edDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) { + if alg != EdDSA { + return Signature{}, ErrUnsupportedAlgorithm + } + + sig, err := ctx.privateKey.Sign(RandReader, payload, crypto.Hash(0)) + if err != nil { + return Signature{}, err + } + + return Signature{ + Signature: sig, + protected: &rawHeader{}, + }, nil +} + +func (ctx edEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error { + if alg != EdDSA { + return ErrUnsupportedAlgorithm + } + ok := ed25519.Verify(ctx.publicKey, payload, signature) + if !ok { + return errors.New("square/go-jose: ed25519 signature failed to verify") + } + return nil +} + +// Sign the given payload +func (ctx ecDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) { + var expectedBitSize int + var hash crypto.Hash + + switch alg { + case ES256: + expectedBitSize = 256 + hash = crypto.SHA256 + case ES384: + expectedBitSize = 384 + hash = crypto.SHA384 + case ES512: + expectedBitSize = 521 + hash = crypto.SHA512 + } + + curveBits := ctx.privateKey.Curve.Params().BitSize + if expectedBitSize != curveBits { + return Signature{}, fmt.Errorf("square/go-jose: expected %d bit key, got %d bits instead", expectedBitSize, curveBits) + } + + hasher := hash.New() + + // According to documentation, Write() on hash never fails + _, _ = hasher.Write(payload) + hashed := hasher.Sum(nil) + + r, s, err := ecdsa.Sign(RandReader, ctx.privateKey, hashed) + if err != nil { + return Signature{}, err + } + + keyBytes := curveBits / 8 + if curveBits%8 > 0 { + keyBytes++ + } + + // We serialize the outputs (r and s) into big-endian byte arrays and pad + // them with zeros on the left to make sure the sizes work out. Both arrays + // must be keyBytes long, and the output must be 2*keyBytes long. + rBytes := r.Bytes() + rBytesPadded := make([]byte, keyBytes) + copy(rBytesPadded[keyBytes-len(rBytes):], rBytes) + + sBytes := s.Bytes() + sBytesPadded := make([]byte, keyBytes) + copy(sBytesPadded[keyBytes-len(sBytes):], sBytes) + + out := append(rBytesPadded, sBytesPadded...) + + return Signature{ + Signature: out, + protected: &rawHeader{}, + }, nil +} + +// Verify the given payload +func (ctx ecEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error { + var keySize int + var hash crypto.Hash + + switch alg { + case ES256: + keySize = 32 + hash = crypto.SHA256 + case ES384: + keySize = 48 + hash = crypto.SHA384 + case ES512: + keySize = 66 + hash = crypto.SHA512 + default: + return ErrUnsupportedAlgorithm + } + + if len(signature) != 2*keySize { + return fmt.Errorf("square/go-jose: invalid signature size, have %d bytes, wanted %d", len(signature), 2*keySize) + } + + hasher := hash.New() + + // According to documentation, Write() on hash never fails + _, _ = hasher.Write(payload) + hashed := hasher.Sum(nil) + + r := big.NewInt(0).SetBytes(signature[:keySize]) + s := big.NewInt(0).SetBytes(signature[keySize:]) + + match := ecdsa.Verify(ctx.publicKey, hashed, r, s) + if !match { + return errors.New("square/go-jose: ecdsa signature failed to verify") + } + + return nil +} diff --git a/vendor/gopkg.in/square/go-jose.v2/cipher/cbc_hmac.go b/vendor/gopkg.in/square/go-jose.v2/cipher/cbc_hmac.go new file mode 100644 index 0000000000..126b85ce25 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/cipher/cbc_hmac.go @@ -0,0 +1,196 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package josecipher + +import ( + "bytes" + "crypto/cipher" + "crypto/hmac" + "crypto/sha256" + "crypto/sha512" + "crypto/subtle" + "encoding/binary" + "errors" + "hash" +) + +const ( + nonceBytes = 16 +) + +// NewCBCHMAC instantiates a new AEAD based on CBC+HMAC. +func NewCBCHMAC(key []byte, newBlockCipher func([]byte) (cipher.Block, error)) (cipher.AEAD, error) { + keySize := len(key) / 2 + integrityKey := key[:keySize] + encryptionKey := key[keySize:] + + blockCipher, err := newBlockCipher(encryptionKey) + if err != nil { + return nil, err + } + + var hash func() hash.Hash + switch keySize { + case 16: + hash = sha256.New + case 24: + hash = sha512.New384 + case 32: + hash = sha512.New + } + + return &cbcAEAD{ + hash: hash, + blockCipher: blockCipher, + authtagBytes: keySize, + integrityKey: integrityKey, + }, nil +} + +// An AEAD based on CBC+HMAC +type cbcAEAD struct { + hash func() hash.Hash + authtagBytes int + integrityKey []byte + blockCipher cipher.Block +} + +func (ctx *cbcAEAD) NonceSize() int { + return nonceBytes +} + +func (ctx *cbcAEAD) Overhead() int { + // Maximum overhead is block size (for padding) plus auth tag length, where + // the length of the auth tag is equivalent to the key size. + return ctx.blockCipher.BlockSize() + ctx.authtagBytes +} + +// Seal encrypts and authenticates the plaintext. +func (ctx *cbcAEAD) Seal(dst, nonce, plaintext, data []byte) []byte { + // Output buffer -- must take care not to mangle plaintext input. + ciphertext := make([]byte, uint64(len(plaintext))+uint64(ctx.Overhead()))[:len(plaintext)] + copy(ciphertext, plaintext) + ciphertext = padBuffer(ciphertext, ctx.blockCipher.BlockSize()) + + cbc := cipher.NewCBCEncrypter(ctx.blockCipher, nonce) + + cbc.CryptBlocks(ciphertext, ciphertext) + authtag := ctx.computeAuthTag(data, nonce, ciphertext) + + ret, out := resize(dst, uint64(len(dst))+uint64(len(ciphertext))+uint64(len(authtag))) + copy(out, ciphertext) + copy(out[len(ciphertext):], authtag) + + return ret +} + +// Open decrypts and authenticates the ciphertext. +func (ctx *cbcAEAD) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) { + if len(ciphertext) < ctx.authtagBytes { + return nil, errors.New("square/go-jose: invalid ciphertext (too short)") + } + + offset := len(ciphertext) - ctx.authtagBytes + expectedTag := ctx.computeAuthTag(data, nonce, ciphertext[:offset]) + match := subtle.ConstantTimeCompare(expectedTag, ciphertext[offset:]) + if match != 1 { + return nil, errors.New("square/go-jose: invalid ciphertext (auth tag mismatch)") + } + + cbc := cipher.NewCBCDecrypter(ctx.blockCipher, nonce) + + // Make copy of ciphertext buffer, don't want to modify in place + buffer := append([]byte{}, []byte(ciphertext[:offset])...) + + if len(buffer)%ctx.blockCipher.BlockSize() > 0 { + return nil, errors.New("square/go-jose: invalid ciphertext (invalid length)") + } + + cbc.CryptBlocks(buffer, buffer) + + // Remove padding + plaintext, err := unpadBuffer(buffer, ctx.blockCipher.BlockSize()) + if err != nil { + return nil, err + } + + ret, out := resize(dst, uint64(len(dst))+uint64(len(plaintext))) + copy(out, plaintext) + + return ret, nil +} + +// Compute an authentication tag +func (ctx *cbcAEAD) computeAuthTag(aad, nonce, ciphertext []byte) []byte { + buffer := make([]byte, uint64(len(aad))+uint64(len(nonce))+uint64(len(ciphertext))+8) + n := 0 + n += copy(buffer, aad) + n += copy(buffer[n:], nonce) + n += copy(buffer[n:], ciphertext) + binary.BigEndian.PutUint64(buffer[n:], uint64(len(aad))*8) + + // According to documentation, Write() on hash.Hash never fails. + hmac := hmac.New(ctx.hash, ctx.integrityKey) + _, _ = hmac.Write(buffer) + + return hmac.Sum(nil)[:ctx.authtagBytes] +} + +// resize ensures the the given slice has a capacity of at least n bytes. +// If the capacity of the slice is less than n, a new slice is allocated +// and the existing data will be copied. +func resize(in []byte, n uint64) (head, tail []byte) { + if uint64(cap(in)) >= n { + head = in[:n] + } else { + head = make([]byte, n) + copy(head, in) + } + + tail = head[len(in):] + return +} + +// Apply padding +func padBuffer(buffer []byte, blockSize int) []byte { + missing := blockSize - (len(buffer) % blockSize) + ret, out := resize(buffer, uint64(len(buffer))+uint64(missing)) + padding := bytes.Repeat([]byte{byte(missing)}, missing) + copy(out, padding) + return ret +} + +// Remove padding +func unpadBuffer(buffer []byte, blockSize int) ([]byte, error) { + if len(buffer)%blockSize != 0 { + return nil, errors.New("square/go-jose: invalid padding") + } + + last := buffer[len(buffer)-1] + count := int(last) + + if count == 0 || count > blockSize || count > len(buffer) { + return nil, errors.New("square/go-jose: invalid padding") + } + + padding := bytes.Repeat([]byte{last}, count) + if !bytes.HasSuffix(buffer, padding) { + return nil, errors.New("square/go-jose: invalid padding") + } + + return buffer[:len(buffer)-count], nil +} diff --git a/vendor/gopkg.in/square/go-jose.v2/cipher/concat_kdf.go b/vendor/gopkg.in/square/go-jose.v2/cipher/concat_kdf.go new file mode 100644 index 0000000000..f62c3bdba5 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/cipher/concat_kdf.go @@ -0,0 +1,75 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package josecipher + +import ( + "crypto" + "encoding/binary" + "hash" + "io" +) + +type concatKDF struct { + z, info []byte + i uint32 + cache []byte + hasher hash.Hash +} + +// NewConcatKDF builds a KDF reader based on the given inputs. +func NewConcatKDF(hash crypto.Hash, z, algID, ptyUInfo, ptyVInfo, supPubInfo, supPrivInfo []byte) io.Reader { + buffer := make([]byte, uint64(len(algID))+uint64(len(ptyUInfo))+uint64(len(ptyVInfo))+uint64(len(supPubInfo))+uint64(len(supPrivInfo))) + n := 0 + n += copy(buffer, algID) + n += copy(buffer[n:], ptyUInfo) + n += copy(buffer[n:], ptyVInfo) + n += copy(buffer[n:], supPubInfo) + copy(buffer[n:], supPrivInfo) + + hasher := hash.New() + + return &concatKDF{ + z: z, + info: buffer, + hasher: hasher, + cache: []byte{}, + i: 1, + } +} + +func (ctx *concatKDF) Read(out []byte) (int, error) { + copied := copy(out, ctx.cache) + ctx.cache = ctx.cache[copied:] + + for copied < len(out) { + ctx.hasher.Reset() + + // Write on a hash.Hash never fails + _ = binary.Write(ctx.hasher, binary.BigEndian, ctx.i) + _, _ = ctx.hasher.Write(ctx.z) + _, _ = ctx.hasher.Write(ctx.info) + + hash := ctx.hasher.Sum(nil) + chunkCopied := copy(out[copied:], hash) + copied += chunkCopied + ctx.cache = hash[chunkCopied:] + + ctx.i++ + } + + return copied, nil +} diff --git a/vendor/gopkg.in/square/go-jose.v2/cipher/ecdh_es.go b/vendor/gopkg.in/square/go-jose.v2/cipher/ecdh_es.go new file mode 100644 index 0000000000..c128e327f3 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/cipher/ecdh_es.go @@ -0,0 +1,62 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package josecipher + +import ( + "crypto" + "crypto/ecdsa" + "encoding/binary" +) + +// DeriveECDHES derives a shared encryption key using ECDH/ConcatKDF as described in JWE/JWA. +// It is an error to call this function with a private/public key that are not on the same +// curve. Callers must ensure that the keys are valid before calling this function. Output +// size may be at most 1<<16 bytes (64 KiB). +func DeriveECDHES(alg string, apuData, apvData []byte, priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, size int) []byte { + if size > 1<<16 { + panic("ECDH-ES output size too large, must be less than or equal to 1<<16") + } + + // algId, partyUInfo, partyVInfo inputs must be prefixed with the length + algID := lengthPrefixed([]byte(alg)) + ptyUInfo := lengthPrefixed(apuData) + ptyVInfo := lengthPrefixed(apvData) + + // suppPubInfo is the encoded length of the output size in bits + supPubInfo := make([]byte, 4) + binary.BigEndian.PutUint32(supPubInfo, uint32(size)*8) + + if !priv.PublicKey.Curve.IsOnCurve(pub.X, pub.Y) { + panic("public key not on same curve as private key") + } + + z, _ := priv.PublicKey.Curve.ScalarMult(pub.X, pub.Y, priv.D.Bytes()) + reader := NewConcatKDF(crypto.SHA256, z.Bytes(), algID, ptyUInfo, ptyVInfo, supPubInfo, []byte{}) + + key := make([]byte, size) + + // Read on the KDF will never fail + _, _ = reader.Read(key) + return key +} + +func lengthPrefixed(data []byte) []byte { + out := make([]byte, len(data)+4) + binary.BigEndian.PutUint32(out, uint32(len(data))) + copy(out[4:], data) + return out +} diff --git a/vendor/gopkg.in/square/go-jose.v2/cipher/key_wrap.go b/vendor/gopkg.in/square/go-jose.v2/cipher/key_wrap.go new file mode 100644 index 0000000000..1d36d50151 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/cipher/key_wrap.go @@ -0,0 +1,109 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package josecipher + +import ( + "crypto/cipher" + "crypto/subtle" + "encoding/binary" + "errors" +) + +var defaultIV = []byte{0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6} + +// KeyWrap implements NIST key wrapping; it wraps a content encryption key (cek) with the given block cipher. +func KeyWrap(block cipher.Block, cek []byte) ([]byte, error) { + if len(cek)%8 != 0 { + return nil, errors.New("square/go-jose: key wrap input must be 8 byte blocks") + } + + n := len(cek) / 8 + r := make([][]byte, n) + + for i := range r { + r[i] = make([]byte, 8) + copy(r[i], cek[i*8:]) + } + + buffer := make([]byte, 16) + tBytes := make([]byte, 8) + copy(buffer, defaultIV) + + for t := 0; t < 6*n; t++ { + copy(buffer[8:], r[t%n]) + + block.Encrypt(buffer, buffer) + + binary.BigEndian.PutUint64(tBytes, uint64(t+1)) + + for i := 0; i < 8; i++ { + buffer[i] = buffer[i] ^ tBytes[i] + } + copy(r[t%n], buffer[8:]) + } + + out := make([]byte, (n+1)*8) + copy(out, buffer[:8]) + for i := range r { + copy(out[(i+1)*8:], r[i]) + } + + return out, nil +} + +// KeyUnwrap implements NIST key unwrapping; it unwraps a content encryption key (cek) with the given block cipher. +func KeyUnwrap(block cipher.Block, ciphertext []byte) ([]byte, error) { + if len(ciphertext)%8 != 0 { + return nil, errors.New("square/go-jose: key wrap input must be 8 byte blocks") + } + + n := (len(ciphertext) / 8) - 1 + r := make([][]byte, n) + + for i := range r { + r[i] = make([]byte, 8) + copy(r[i], ciphertext[(i+1)*8:]) + } + + buffer := make([]byte, 16) + tBytes := make([]byte, 8) + copy(buffer[:8], ciphertext[:8]) + + for t := 6*n - 1; t >= 0; t-- { + binary.BigEndian.PutUint64(tBytes, uint64(t+1)) + + for i := 0; i < 8; i++ { + buffer[i] = buffer[i] ^ tBytes[i] + } + copy(buffer[8:], r[t%n]) + + block.Decrypt(buffer, buffer) + + copy(r[t%n], buffer[8:]) + } + + if subtle.ConstantTimeCompare(buffer[:8], defaultIV) == 0 { + return nil, errors.New("square/go-jose: failed to unwrap key") + } + + out := make([]byte, n*8) + for i := range r { + copy(out[i*8:], r[i]) + } + + return out, nil +} diff --git a/vendor/gopkg.in/square/go-jose.v2/crypter.go b/vendor/gopkg.in/square/go-jose.v2/crypter.go new file mode 100644 index 0000000000..c45c71206b --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/crypter.go @@ -0,0 +1,535 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package jose + +import ( + "crypto/ecdsa" + "crypto/rsa" + "errors" + "fmt" + "reflect" + + "gopkg.in/square/go-jose.v2/json" +) + +// Encrypter represents an encrypter which produces an encrypted JWE object. +type Encrypter interface { + Encrypt(plaintext []byte) (*JSONWebEncryption, error) + EncryptWithAuthData(plaintext []byte, aad []byte) (*JSONWebEncryption, error) + Options() EncrypterOptions +} + +// A generic content cipher +type contentCipher interface { + keySize() int + encrypt(cek []byte, aad, plaintext []byte) (*aeadParts, error) + decrypt(cek []byte, aad []byte, parts *aeadParts) ([]byte, error) +} + +// A key generator (for generating/getting a CEK) +type keyGenerator interface { + keySize() int + genKey() ([]byte, rawHeader, error) +} + +// A generic key encrypter +type keyEncrypter interface { + encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) // Encrypt a key +} + +// A generic key decrypter +type keyDecrypter interface { + decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) // Decrypt a key +} + +// A generic encrypter based on the given key encrypter and content cipher. +type genericEncrypter struct { + contentAlg ContentEncryption + compressionAlg CompressionAlgorithm + cipher contentCipher + recipients []recipientKeyInfo + keyGenerator keyGenerator + extraHeaders map[HeaderKey]interface{} +} + +type recipientKeyInfo struct { + keyID string + keyAlg KeyAlgorithm + keyEncrypter keyEncrypter +} + +// EncrypterOptions represents options that can be set on new encrypters. +type EncrypterOptions struct { + Compression CompressionAlgorithm + + // Optional map of additional keys to be inserted into the protected header + // of a JWS object. Some specifications which make use of JWS like to insert + // additional values here. All values must be JSON-serializable. + ExtraHeaders map[HeaderKey]interface{} +} + +// WithHeader adds an arbitrary value to the ExtraHeaders map, initializing it +// if necessary. It returns itself and so can be used in a fluent style. +func (eo *EncrypterOptions) WithHeader(k HeaderKey, v interface{}) *EncrypterOptions { + if eo.ExtraHeaders == nil { + eo.ExtraHeaders = map[HeaderKey]interface{}{} + } + eo.ExtraHeaders[k] = v + return eo +} + +// WithContentType adds a content type ("cty") header and returns the updated +// EncrypterOptions. +func (eo *EncrypterOptions) WithContentType(contentType ContentType) *EncrypterOptions { + return eo.WithHeader(HeaderContentType, contentType) +} + +// WithType adds a type ("typ") header and returns the updated EncrypterOptions. +func (eo *EncrypterOptions) WithType(typ ContentType) *EncrypterOptions { + return eo.WithHeader(HeaderType, typ) +} + +// Recipient represents an algorithm/key to encrypt messages to. +// +// PBES2Count and PBES2Salt correspond with the "p2c" and "p2s" headers used +// on the password-based encryption algorithms PBES2-HS256+A128KW, +// PBES2-HS384+A192KW, and PBES2-HS512+A256KW. If they are not provided a safe +// default of 100000 will be used for the count and a 128-bit random salt will +// be generated. +type Recipient struct { + Algorithm KeyAlgorithm + Key interface{} + KeyID string + PBES2Count int + PBES2Salt []byte +} + +// NewEncrypter creates an appropriate encrypter based on the key type +func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions) (Encrypter, error) { + encrypter := &genericEncrypter{ + contentAlg: enc, + recipients: []recipientKeyInfo{}, + cipher: getContentCipher(enc), + } + if opts != nil { + encrypter.compressionAlg = opts.Compression + encrypter.extraHeaders = opts.ExtraHeaders + } + + if encrypter.cipher == nil { + return nil, ErrUnsupportedAlgorithm + } + + var keyID string + var rawKey interface{} + switch encryptionKey := rcpt.Key.(type) { + case JSONWebKey: + keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key + case *JSONWebKey: + keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key + default: + rawKey = encryptionKey + } + + switch rcpt.Algorithm { + case DIRECT: + // Direct encryption mode must be treated differently + if reflect.TypeOf(rawKey) != reflect.TypeOf([]byte{}) { + return nil, ErrUnsupportedKeyType + } + if encrypter.cipher.keySize() != len(rawKey.([]byte)) { + return nil, ErrInvalidKeySize + } + encrypter.keyGenerator = staticKeyGenerator{ + key: rawKey.([]byte), + } + recipientInfo, _ := newSymmetricRecipient(rcpt.Algorithm, rawKey.([]byte)) + recipientInfo.keyID = keyID + if rcpt.KeyID != "" { + recipientInfo.keyID = rcpt.KeyID + } + encrypter.recipients = []recipientKeyInfo{recipientInfo} + return encrypter, nil + case ECDH_ES: + // ECDH-ES (w/o key wrapping) is similar to DIRECT mode + typeOf := reflect.TypeOf(rawKey) + if typeOf != reflect.TypeOf(&ecdsa.PublicKey{}) { + return nil, ErrUnsupportedKeyType + } + encrypter.keyGenerator = ecKeyGenerator{ + size: encrypter.cipher.keySize(), + algID: string(enc), + publicKey: rawKey.(*ecdsa.PublicKey), + } + recipientInfo, _ := newECDHRecipient(rcpt.Algorithm, rawKey.(*ecdsa.PublicKey)) + recipientInfo.keyID = keyID + if rcpt.KeyID != "" { + recipientInfo.keyID = rcpt.KeyID + } + encrypter.recipients = []recipientKeyInfo{recipientInfo} + return encrypter, nil + default: + // Can just add a standard recipient + encrypter.keyGenerator = randomKeyGenerator{ + size: encrypter.cipher.keySize(), + } + err := encrypter.addRecipient(rcpt) + return encrypter, err + } +} + +// NewMultiEncrypter creates a multi-encrypter based on the given parameters +func NewMultiEncrypter(enc ContentEncryption, rcpts []Recipient, opts *EncrypterOptions) (Encrypter, error) { + cipher := getContentCipher(enc) + + if cipher == nil { + return nil, ErrUnsupportedAlgorithm + } + if rcpts == nil || len(rcpts) == 0 { + return nil, fmt.Errorf("square/go-jose: recipients is nil or empty") + } + + encrypter := &genericEncrypter{ + contentAlg: enc, + recipients: []recipientKeyInfo{}, + cipher: cipher, + keyGenerator: randomKeyGenerator{ + size: cipher.keySize(), + }, + } + + if opts != nil { + encrypter.compressionAlg = opts.Compression + } + + for _, recipient := range rcpts { + err := encrypter.addRecipient(recipient) + if err != nil { + return nil, err + } + } + + return encrypter, nil +} + +func (ctx *genericEncrypter) addRecipient(recipient Recipient) (err error) { + var recipientInfo recipientKeyInfo + + switch recipient.Algorithm { + case DIRECT, ECDH_ES: + return fmt.Errorf("square/go-jose: key algorithm '%s' not supported in multi-recipient mode", recipient.Algorithm) + } + + recipientInfo, err = makeJWERecipient(recipient.Algorithm, recipient.Key) + if recipient.KeyID != "" { + recipientInfo.keyID = recipient.KeyID + } + + switch recipient.Algorithm { + case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW: + if sr, ok := recipientInfo.keyEncrypter.(*symmetricKeyCipher); ok { + sr.p2c = recipient.PBES2Count + sr.p2s = recipient.PBES2Salt + } + } + + if err == nil { + ctx.recipients = append(ctx.recipients, recipientInfo) + } + return err +} + +func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKeyInfo, error) { + switch encryptionKey := encryptionKey.(type) { + case *rsa.PublicKey: + return newRSARecipient(alg, encryptionKey) + case *ecdsa.PublicKey: + return newECDHRecipient(alg, encryptionKey) + case []byte: + return newSymmetricRecipient(alg, encryptionKey) + case string: + return newSymmetricRecipient(alg, []byte(encryptionKey)) + case *JSONWebKey: + recipient, err := makeJWERecipient(alg, encryptionKey.Key) + recipient.keyID = encryptionKey.KeyID + return recipient, err + default: + return recipientKeyInfo{}, ErrUnsupportedKeyType + } +} + +// newDecrypter creates an appropriate decrypter based on the key type +func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) { + switch decryptionKey := decryptionKey.(type) { + case *rsa.PrivateKey: + return &rsaDecrypterSigner{ + privateKey: decryptionKey, + }, nil + case *ecdsa.PrivateKey: + return &ecDecrypterSigner{ + privateKey: decryptionKey, + }, nil + case []byte: + return &symmetricKeyCipher{ + key: decryptionKey, + }, nil + case string: + return &symmetricKeyCipher{ + key: []byte(decryptionKey), + }, nil + case JSONWebKey: + return newDecrypter(decryptionKey.Key) + case *JSONWebKey: + return newDecrypter(decryptionKey.Key) + default: + return nil, ErrUnsupportedKeyType + } +} + +// Implementation of encrypt method producing a JWE object. +func (ctx *genericEncrypter) Encrypt(plaintext []byte) (*JSONWebEncryption, error) { + return ctx.EncryptWithAuthData(plaintext, nil) +} + +// Implementation of encrypt method producing a JWE object. +func (ctx *genericEncrypter) EncryptWithAuthData(plaintext, aad []byte) (*JSONWebEncryption, error) { + obj := &JSONWebEncryption{} + obj.aad = aad + + obj.protected = &rawHeader{} + err := obj.protected.set(headerEncryption, ctx.contentAlg) + if err != nil { + return nil, err + } + + obj.recipients = make([]recipientInfo, len(ctx.recipients)) + + if len(ctx.recipients) == 0 { + return nil, fmt.Errorf("square/go-jose: no recipients to encrypt to") + } + + cek, headers, err := ctx.keyGenerator.genKey() + if err != nil { + return nil, err + } + + obj.protected.merge(&headers) + + for i, info := range ctx.recipients { + recipient, err := info.keyEncrypter.encryptKey(cek, info.keyAlg) + if err != nil { + return nil, err + } + + err = recipient.header.set(headerAlgorithm, info.keyAlg) + if err != nil { + return nil, err + } + + if info.keyID != "" { + err = recipient.header.set(headerKeyID, info.keyID) + if err != nil { + return nil, err + } + } + obj.recipients[i] = recipient + } + + if len(ctx.recipients) == 1 { + // Move per-recipient headers into main protected header if there's + // only a single recipient. + obj.protected.merge(obj.recipients[0].header) + obj.recipients[0].header = nil + } + + if ctx.compressionAlg != NONE { + plaintext, err = compress(ctx.compressionAlg, plaintext) + if err != nil { + return nil, err + } + + err = obj.protected.set(headerCompression, ctx.compressionAlg) + if err != nil { + return nil, err + } + } + + for k, v := range ctx.extraHeaders { + b, err := json.Marshal(v) + if err != nil { + return nil, err + } + (*obj.protected)[k] = makeRawMessage(b) + } + + authData := obj.computeAuthData() + parts, err := ctx.cipher.encrypt(cek, authData, plaintext) + if err != nil { + return nil, err + } + + obj.iv = parts.iv + obj.ciphertext = parts.ciphertext + obj.tag = parts.tag + + return obj, nil +} + +func (ctx *genericEncrypter) Options() EncrypterOptions { + return EncrypterOptions{ + Compression: ctx.compressionAlg, + ExtraHeaders: ctx.extraHeaders, + } +} + +// Decrypt and validate the object and return the plaintext. Note that this +// function does not support multi-recipient, if you desire multi-recipient +// decryption use DecryptMulti instead. +func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) { + headers := obj.mergedHeaders(nil) + + if len(obj.recipients) > 1 { + return nil, errors.New("square/go-jose: too many recipients in payload; expecting only one") + } + + critical, err := headers.getCritical() + if err != nil { + return nil, fmt.Errorf("square/go-jose: invalid crit header") + } + + if len(critical) > 0 { + return nil, fmt.Errorf("square/go-jose: unsupported crit header") + } + + decrypter, err := newDecrypter(decryptionKey) + if err != nil { + return nil, err + } + + cipher := getContentCipher(headers.getEncryption()) + if cipher == nil { + return nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(headers.getEncryption())) + } + + generator := randomKeyGenerator{ + size: cipher.keySize(), + } + + parts := &aeadParts{ + iv: obj.iv, + ciphertext: obj.ciphertext, + tag: obj.tag, + } + + authData := obj.computeAuthData() + + var plaintext []byte + recipient := obj.recipients[0] + recipientHeaders := obj.mergedHeaders(&recipient) + + cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator) + if err == nil { + // Found a valid CEK -- let's try to decrypt. + plaintext, err = cipher.decrypt(cek, authData, parts) + } + + if plaintext == nil { + return nil, ErrCryptoFailure + } + + // The "zip" header parameter may only be present in the protected header. + if comp := obj.protected.getCompression(); comp != "" { + plaintext, err = decompress(comp, plaintext) + } + + return plaintext, err +} + +// DecryptMulti decrypts and validates the object and returns the plaintexts, +// with support for multiple recipients. It returns the index of the recipient +// for which the decryption was successful, the merged headers for that recipient, +// and the plaintext. +func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Header, []byte, error) { + globalHeaders := obj.mergedHeaders(nil) + + critical, err := globalHeaders.getCritical() + if err != nil { + return -1, Header{}, nil, fmt.Errorf("square/go-jose: invalid crit header") + } + + if len(critical) > 0 { + return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported crit header") + } + + decrypter, err := newDecrypter(decryptionKey) + if err != nil { + return -1, Header{}, nil, err + } + + encryption := globalHeaders.getEncryption() + cipher := getContentCipher(encryption) + if cipher == nil { + return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(encryption)) + } + + generator := randomKeyGenerator{ + size: cipher.keySize(), + } + + parts := &aeadParts{ + iv: obj.iv, + ciphertext: obj.ciphertext, + tag: obj.tag, + } + + authData := obj.computeAuthData() + + index := -1 + var plaintext []byte + var headers rawHeader + + for i, recipient := range obj.recipients { + recipientHeaders := obj.mergedHeaders(&recipient) + + cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator) + if err == nil { + // Found a valid CEK -- let's try to decrypt. + plaintext, err = cipher.decrypt(cek, authData, parts) + if err == nil { + index = i + headers = recipientHeaders + break + } + } + } + + if plaintext == nil || err != nil { + return -1, Header{}, nil, ErrCryptoFailure + } + + // The "zip" header parameter may only be present in the protected header. + if comp := obj.protected.getCompression(); comp != "" { + plaintext, err = decompress(comp, plaintext) + } + + sanitized, err := headers.sanitized() + if err != nil { + return -1, Header{}, nil, fmt.Errorf("square/go-jose: failed to sanitize header: %v", err) + } + + return index, sanitized, plaintext, err +} diff --git a/vendor/gopkg.in/square/go-jose.v2/doc.go b/vendor/gopkg.in/square/go-jose.v2/doc.go new file mode 100644 index 0000000000..dd1387f3f0 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/doc.go @@ -0,0 +1,27 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + +Package jose aims to provide an implementation of the Javascript Object Signing +and Encryption set of standards. It implements encryption and signing based on +the JSON Web Encryption and JSON Web Signature standards, with optional JSON +Web Token support available in a sub-package. The library supports both the +compact and full serialization formats, and has optional support for multiple +recipients. + +*/ +package jose diff --git a/vendor/gopkg.in/square/go-jose.v2/encoding.go b/vendor/gopkg.in/square/go-jose.v2/encoding.go new file mode 100644 index 0000000000..b9687c647d --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/encoding.go @@ -0,0 +1,179 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package jose + +import ( + "bytes" + "compress/flate" + "encoding/base64" + "encoding/binary" + "io" + "math/big" + "regexp" + + "gopkg.in/square/go-jose.v2/json" +) + +var stripWhitespaceRegex = regexp.MustCompile("\\s") + +// Helper function to serialize known-good objects. +// Precondition: value is not a nil pointer. +func mustSerializeJSON(value interface{}) []byte { + out, err := json.Marshal(value) + if err != nil { + panic(err) + } + // We never want to serialize the top-level value "null," since it's not a + // valid JOSE message. But if a caller passes in a nil pointer to this method, + // MarshalJSON will happily serialize it as the top-level value "null". If + // that value is then embedded in another operation, for instance by being + // base64-encoded and fed as input to a signing algorithm + // (https://github.com/square/go-jose/issues/22), the result will be + // incorrect. Because this method is intended for known-good objects, and a nil + // pointer is not a known-good object, we are free to panic in this case. + // Note: It's not possible to directly check whether the data pointed at by an + // interface is a nil pointer, so we do this hacky workaround. + // https://groups.google.com/forum/#!topic/golang-nuts/wnH302gBa4I + if string(out) == "null" { + panic("Tried to serialize a nil pointer.") + } + return out +} + +// Strip all newlines and whitespace +func stripWhitespace(data string) string { + return stripWhitespaceRegex.ReplaceAllString(data, "") +} + +// Perform compression based on algorithm +func compress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) { + switch algorithm { + case DEFLATE: + return deflate(input) + default: + return nil, ErrUnsupportedAlgorithm + } +} + +// Perform decompression based on algorithm +func decompress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) { + switch algorithm { + case DEFLATE: + return inflate(input) + default: + return nil, ErrUnsupportedAlgorithm + } +} + +// Compress with DEFLATE +func deflate(input []byte) ([]byte, error) { + output := new(bytes.Buffer) + + // Writing to byte buffer, err is always nil + writer, _ := flate.NewWriter(output, 1) + _, _ = io.Copy(writer, bytes.NewBuffer(input)) + + err := writer.Close() + return output.Bytes(), err +} + +// Decompress with DEFLATE +func inflate(input []byte) ([]byte, error) { + output := new(bytes.Buffer) + reader := flate.NewReader(bytes.NewBuffer(input)) + + _, err := io.Copy(output, reader) + if err != nil { + return nil, err + } + + err = reader.Close() + return output.Bytes(), err +} + +// byteBuffer represents a slice of bytes that can be serialized to url-safe base64. +type byteBuffer struct { + data []byte +} + +func newBuffer(data []byte) *byteBuffer { + if data == nil { + return nil + } + return &byteBuffer{ + data: data, + } +} + +func newFixedSizeBuffer(data []byte, length int) *byteBuffer { + if len(data) > length { + panic("square/go-jose: invalid call to newFixedSizeBuffer (len(data) > length)") + } + pad := make([]byte, length-len(data)) + return newBuffer(append(pad, data...)) +} + +func newBufferFromInt(num uint64) *byteBuffer { + data := make([]byte, 8) + binary.BigEndian.PutUint64(data, num) + return newBuffer(bytes.TrimLeft(data, "\x00")) +} + +func (b *byteBuffer) MarshalJSON() ([]byte, error) { + return json.Marshal(b.base64()) +} + +func (b *byteBuffer) UnmarshalJSON(data []byte) error { + var encoded string + err := json.Unmarshal(data, &encoded) + if err != nil { + return err + } + + if encoded == "" { + return nil + } + + decoded, err := base64.RawURLEncoding.DecodeString(encoded) + if err != nil { + return err + } + + *b = *newBuffer(decoded) + + return nil +} + +func (b *byteBuffer) base64() string { + return base64.RawURLEncoding.EncodeToString(b.data) +} + +func (b *byteBuffer) bytes() []byte { + // Handling nil here allows us to transparently handle nil slices when serializing. + if b == nil { + return nil + } + return b.data +} + +func (b byteBuffer) bigInt() *big.Int { + return new(big.Int).SetBytes(b.data) +} + +func (b byteBuffer) toInt() int { + return int(b.bigInt().Int64()) +} diff --git a/vendor/gopkg.in/square/go-jose.v2/json/LICENSE b/vendor/gopkg.in/square/go-jose.v2/json/LICENSE new file mode 100644 index 0000000000..7448756763 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/json/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2012 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/gopkg.in/square/go-jose.v2/json/README.md b/vendor/gopkg.in/square/go-jose.v2/json/README.md new file mode 100644 index 0000000000..86de5e5581 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/json/README.md @@ -0,0 +1,13 @@ +# Safe JSON + +This repository contains a fork of the `encoding/json` package from Go 1.6. + +The following changes were made: + +* Object deserialization uses case-sensitive member name matching instead of + [case-insensitive matching](https://www.ietf.org/mail-archive/web/json/current/msg03763.html). + This is to avoid differences in the interpretation of JOSE messages between + go-jose and libraries written in other languages. +* When deserializing a JSON object, we check for duplicate keys and reject the + input whenever we detect a duplicate. Rather than trying to work with malformed + data, we prefer to reject it right away. diff --git a/vendor/gopkg.in/square/go-jose.v2/json/decode.go b/vendor/gopkg.in/square/go-jose.v2/json/decode.go new file mode 100644 index 0000000000..37457e5a83 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/json/decode.go @@ -0,0 +1,1183 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Represents JSON data structure using native Go types: booleans, floats, +// strings, arrays, and maps. + +package json + +import ( + "bytes" + "encoding" + "encoding/base64" + "errors" + "fmt" + "reflect" + "runtime" + "strconv" + "unicode" + "unicode/utf16" + "unicode/utf8" +) + +// Unmarshal parses the JSON-encoded data and stores the result +// in the value pointed to by v. +// +// Unmarshal uses the inverse of the encodings that +// Marshal uses, allocating maps, slices, and pointers as necessary, +// with the following additional rules: +// +// To unmarshal JSON into a pointer, Unmarshal first handles the case of +// the JSON being the JSON literal null. In that case, Unmarshal sets +// the pointer to nil. Otherwise, Unmarshal unmarshals the JSON into +// the value pointed at by the pointer. If the pointer is nil, Unmarshal +// allocates a new value for it to point to. +// +// To unmarshal JSON into a struct, Unmarshal matches incoming object +// keys to the keys used by Marshal (either the struct field name or its tag), +// preferring an exact match but also accepting a case-insensitive match. +// Unmarshal will only set exported fields of the struct. +// +// To unmarshal JSON into an interface value, +// Unmarshal stores one of these in the interface value: +// +// bool, for JSON booleans +// float64, for JSON numbers +// string, for JSON strings +// []interface{}, for JSON arrays +// map[string]interface{}, for JSON objects +// nil for JSON null +// +// To unmarshal a JSON array into a slice, Unmarshal resets the slice length +// to zero and then appends each element to the slice. +// As a special case, to unmarshal an empty JSON array into a slice, +// Unmarshal replaces the slice with a new empty slice. +// +// To unmarshal a JSON array into a Go array, Unmarshal decodes +// JSON array elements into corresponding Go array elements. +// If the Go array is smaller than the JSON array, +// the additional JSON array elements are discarded. +// If the JSON array is smaller than the Go array, +// the additional Go array elements are set to zero values. +// +// To unmarshal a JSON object into a string-keyed map, Unmarshal first +// establishes a map to use, If the map is nil, Unmarshal allocates a new map. +// Otherwise Unmarshal reuses the existing map, keeping existing entries. +// Unmarshal then stores key-value pairs from the JSON object into the map. +// +// If a JSON value is not appropriate for a given target type, +// or if a JSON number overflows the target type, Unmarshal +// skips that field and completes the unmarshaling as best it can. +// If no more serious errors are encountered, Unmarshal returns +// an UnmarshalTypeError describing the earliest such error. +// +// The JSON null value unmarshals into an interface, map, pointer, or slice +// by setting that Go value to nil. Because null is often used in JSON to mean +// ``not present,'' unmarshaling a JSON null into any other Go type has no effect +// on the value and produces no error. +// +// When unmarshaling quoted strings, invalid UTF-8 or +// invalid UTF-16 surrogate pairs are not treated as an error. +// Instead, they are replaced by the Unicode replacement +// character U+FFFD. +// +func Unmarshal(data []byte, v interface{}) error { + // Check for well-formedness. + // Avoids filling out half a data structure + // before discovering a JSON syntax error. + var d decodeState + err := checkValid(data, &d.scan) + if err != nil { + return err + } + + d.init(data) + return d.unmarshal(v) +} + +// Unmarshaler is the interface implemented by objects +// that can unmarshal a JSON description of themselves. +// The input can be assumed to be a valid encoding of +// a JSON value. UnmarshalJSON must copy the JSON data +// if it wishes to retain the data after returning. +type Unmarshaler interface { + UnmarshalJSON([]byte) error +} + +// An UnmarshalTypeError describes a JSON value that was +// not appropriate for a value of a specific Go type. +type UnmarshalTypeError struct { + Value string // description of JSON value - "bool", "array", "number -5" + Type reflect.Type // type of Go value it could not be assigned to + Offset int64 // error occurred after reading Offset bytes +} + +func (e *UnmarshalTypeError) Error() string { + return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String() +} + +// An UnmarshalFieldError describes a JSON object key that +// led to an unexported (and therefore unwritable) struct field. +// (No longer used; kept for compatibility.) +type UnmarshalFieldError struct { + Key string + Type reflect.Type + Field reflect.StructField +} + +func (e *UnmarshalFieldError) Error() string { + return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String() +} + +// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal. +// (The argument to Unmarshal must be a non-nil pointer.) +type InvalidUnmarshalError struct { + Type reflect.Type +} + +func (e *InvalidUnmarshalError) Error() string { + if e.Type == nil { + return "json: Unmarshal(nil)" + } + + if e.Type.Kind() != reflect.Ptr { + return "json: Unmarshal(non-pointer " + e.Type.String() + ")" + } + return "json: Unmarshal(nil " + e.Type.String() + ")" +} + +func (d *decodeState) unmarshal(v interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + err = r.(error) + } + }() + + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return &InvalidUnmarshalError{reflect.TypeOf(v)} + } + + d.scan.reset() + // We decode rv not rv.Elem because the Unmarshaler interface + // test must be applied at the top level of the value. + d.value(rv) + return d.savedError +} + +// A Number represents a JSON number literal. +type Number string + +// String returns the literal text of the number. +func (n Number) String() string { return string(n) } + +// Float64 returns the number as a float64. +func (n Number) Float64() (float64, error) { + return strconv.ParseFloat(string(n), 64) +} + +// Int64 returns the number as an int64. +func (n Number) Int64() (int64, error) { + return strconv.ParseInt(string(n), 10, 64) +} + +// isValidNumber reports whether s is a valid JSON number literal. +func isValidNumber(s string) bool { + // This function implements the JSON numbers grammar. + // See https://tools.ietf.org/html/rfc7159#section-6 + // and http://json.org/number.gif + + if s == "" { + return false + } + + // Optional - + if s[0] == '-' { + s = s[1:] + if s == "" { + return false + } + } + + // Digits + switch { + default: + return false + + case s[0] == '0': + s = s[1:] + + case '1' <= s[0] && s[0] <= '9': + s = s[1:] + for len(s) > 0 && '0' <= s[0] && s[0] <= '9' { + s = s[1:] + } + } + + // . followed by 1 or more digits. + if len(s) >= 2 && s[0] == '.' && '0' <= s[1] && s[1] <= '9' { + s = s[2:] + for len(s) > 0 && '0' <= s[0] && s[0] <= '9' { + s = s[1:] + } + } + + // e or E followed by an optional - or + and + // 1 or more digits. + if len(s) >= 2 && (s[0] == 'e' || s[0] == 'E') { + s = s[1:] + if s[0] == '+' || s[0] == '-' { + s = s[1:] + if s == "" { + return false + } + } + for len(s) > 0 && '0' <= s[0] && s[0] <= '9' { + s = s[1:] + } + } + + // Make sure we are at the end. + return s == "" +} + +// decodeState represents the state while decoding a JSON value. +type decodeState struct { + data []byte + off int // read offset in data + scan scanner + nextscan scanner // for calls to nextValue + savedError error + useNumber bool +} + +// errPhase is used for errors that should not happen unless +// there is a bug in the JSON decoder or something is editing +// the data slice while the decoder executes. +var errPhase = errors.New("JSON decoder out of sync - data changing underfoot?") + +func (d *decodeState) init(data []byte) *decodeState { + d.data = data + d.off = 0 + d.savedError = nil + return d +} + +// error aborts the decoding by panicking with err. +func (d *decodeState) error(err error) { + panic(err) +} + +// saveError saves the first err it is called with, +// for reporting at the end of the unmarshal. +func (d *decodeState) saveError(err error) { + if d.savedError == nil { + d.savedError = err + } +} + +// next cuts off and returns the next full JSON value in d.data[d.off:]. +// The next value is known to be an object or array, not a literal. +func (d *decodeState) next() []byte { + c := d.data[d.off] + item, rest, err := nextValue(d.data[d.off:], &d.nextscan) + if err != nil { + d.error(err) + } + d.off = len(d.data) - len(rest) + + // Our scanner has seen the opening brace/bracket + // and thinks we're still in the middle of the object. + // invent a closing brace/bracket to get it out. + if c == '{' { + d.scan.step(&d.scan, '}') + } else { + d.scan.step(&d.scan, ']') + } + + return item +} + +// scanWhile processes bytes in d.data[d.off:] until it +// receives a scan code not equal to op. +// It updates d.off and returns the new scan code. +func (d *decodeState) scanWhile(op int) int { + var newOp int + for { + if d.off >= len(d.data) { + newOp = d.scan.eof() + d.off = len(d.data) + 1 // mark processed EOF with len+1 + } else { + c := d.data[d.off] + d.off++ + newOp = d.scan.step(&d.scan, c) + } + if newOp != op { + break + } + } + return newOp +} + +// value decodes a JSON value from d.data[d.off:] into the value. +// it updates d.off to point past the decoded value. +func (d *decodeState) value(v reflect.Value) { + if !v.IsValid() { + _, rest, err := nextValue(d.data[d.off:], &d.nextscan) + if err != nil { + d.error(err) + } + d.off = len(d.data) - len(rest) + + // d.scan thinks we're still at the beginning of the item. + // Feed in an empty string - the shortest, simplest value - + // so that it knows we got to the end of the value. + if d.scan.redo { + // rewind. + d.scan.redo = false + d.scan.step = stateBeginValue + } + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '"') + + n := len(d.scan.parseState) + if n > 0 && d.scan.parseState[n-1] == parseObjectKey { + // d.scan thinks we just read an object key; finish the object + d.scan.step(&d.scan, ':') + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '}') + } + + return + } + + switch op := d.scanWhile(scanSkipSpace); op { + default: + d.error(errPhase) + + case scanBeginArray: + d.array(v) + + case scanBeginObject: + d.object(v) + + case scanBeginLiteral: + d.literal(v) + } +} + +type unquotedValue struct{} + +// valueQuoted is like value but decodes a +// quoted string literal or literal null into an interface value. +// If it finds anything other than a quoted string literal or null, +// valueQuoted returns unquotedValue{}. +func (d *decodeState) valueQuoted() interface{} { + switch op := d.scanWhile(scanSkipSpace); op { + default: + d.error(errPhase) + + case scanBeginArray: + d.array(reflect.Value{}) + + case scanBeginObject: + d.object(reflect.Value{}) + + case scanBeginLiteral: + switch v := d.literalInterface().(type) { + case nil, string: + return v + } + } + return unquotedValue{} +} + +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +// if it encounters an Unmarshaler, indirect stops and returns that. +// if decodingNull is true, indirect stops at the last pointer so it can be set to nil. +func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. + if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + v = v.Addr() + } + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) { + v = e + continue + } + } + + if v.Kind() != reflect.Ptr { + break + } + + if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() { + break + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + if v.Type().NumMethod() > 0 { + if u, ok := v.Interface().(Unmarshaler); ok { + return u, nil, reflect.Value{} + } + if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { + return nil, u, reflect.Value{} + } + } + v = v.Elem() + } + return nil, nil, v +} + +// array consumes an array from d.data[d.off-1:], decoding into the value v. +// the first byte of the array ('[') has been read already. +func (d *decodeState) array(v reflect.Value) { + // Check for unmarshaler. + u, ut, pv := d.indirect(v, false) + if u != nil { + d.off-- + err := u.UnmarshalJSON(d.next()) + if err != nil { + d.error(err) + } + return + } + if ut != nil { + d.saveError(&UnmarshalTypeError{"array", v.Type(), int64(d.off)}) + d.off-- + d.next() + return + } + + v = pv + + // Check type of target. + switch v.Kind() { + case reflect.Interface: + if v.NumMethod() == 0 { + // Decoding into nil interface? Switch to non-reflect code. + v.Set(reflect.ValueOf(d.arrayInterface())) + return + } + // Otherwise it's invalid. + fallthrough + default: + d.saveError(&UnmarshalTypeError{"array", v.Type(), int64(d.off)}) + d.off-- + d.next() + return + case reflect.Array: + case reflect.Slice: + break + } + + i := 0 + for { + // Look ahead for ] - can only happen on first iteration. + op := d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + + // Back up so d.value can have the byte we just read. + d.off-- + d.scan.undo(op) + + // Get element of array, growing if necessary. + if v.Kind() == reflect.Slice { + // Grow slice if necessary + if i >= v.Cap() { + newcap := v.Cap() + v.Cap()/2 + if newcap < 4 { + newcap = 4 + } + newv := reflect.MakeSlice(v.Type(), v.Len(), newcap) + reflect.Copy(newv, v) + v.Set(newv) + } + if i >= v.Len() { + v.SetLen(i + 1) + } + } + + if i < v.Len() { + // Decode into element. + d.value(v.Index(i)) + } else { + // Ran out of fixed array: skip. + d.value(reflect.Value{}) + } + i++ + + // Next token must be , or ]. + op = d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + if op != scanArrayValue { + d.error(errPhase) + } + } + + if i < v.Len() { + if v.Kind() == reflect.Array { + // Array. Zero the rest. + z := reflect.Zero(v.Type().Elem()) + for ; i < v.Len(); i++ { + v.Index(i).Set(z) + } + } else { + v.SetLen(i) + } + } + if i == 0 && v.Kind() == reflect.Slice { + v.Set(reflect.MakeSlice(v.Type(), 0, 0)) + } +} + +var nullLiteral = []byte("null") + +// object consumes an object from d.data[d.off-1:], decoding into the value v. +// the first byte ('{') of the object has been read already. +func (d *decodeState) object(v reflect.Value) { + // Check for unmarshaler. + u, ut, pv := d.indirect(v, false) + if u != nil { + d.off-- + err := u.UnmarshalJSON(d.next()) + if err != nil { + d.error(err) + } + return + } + if ut != nil { + d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)}) + d.off-- + d.next() // skip over { } in input + return + } + v = pv + + // Decoding into nil interface? Switch to non-reflect code. + if v.Kind() == reflect.Interface && v.NumMethod() == 0 { + v.Set(reflect.ValueOf(d.objectInterface())) + return + } + + // Check type of target: struct or map[string]T + switch v.Kind() { + case reflect.Map: + // map must have string kind + t := v.Type() + if t.Key().Kind() != reflect.String { + d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)}) + d.off-- + d.next() // skip over { } in input + return + } + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } + case reflect.Struct: + + default: + d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)}) + d.off-- + d.next() // skip over { } in input + return + } + + var mapElem reflect.Value + keys := map[string]bool{} + + for { + // Read opening " of string key or closing }. + op := d.scanWhile(scanSkipSpace) + if op == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if op != scanBeginLiteral { + d.error(errPhase) + } + + // Read key. + start := d.off - 1 + op = d.scanWhile(scanContinue) + item := d.data[start : d.off-1] + key, ok := unquote(item) + if !ok { + d.error(errPhase) + } + + // Check for duplicate keys. + _, ok = keys[key] + if !ok { + keys[key] = true + } else { + d.error(fmt.Errorf("json: duplicate key '%s' in object", key)) + } + + // Figure out field corresponding to key. + var subv reflect.Value + destring := false // whether the value is wrapped in a string to be decoded first + + if v.Kind() == reflect.Map { + elemType := v.Type().Elem() + if !mapElem.IsValid() { + mapElem = reflect.New(elemType).Elem() + } else { + mapElem.Set(reflect.Zero(elemType)) + } + subv = mapElem + } else { + var f *field + fields := cachedTypeFields(v.Type()) + for i := range fields { + ff := &fields[i] + if bytes.Equal(ff.nameBytes, []byte(key)) { + f = ff + break + } + } + if f != nil { + subv = v + destring = f.quoted + for _, i := range f.index { + if subv.Kind() == reflect.Ptr { + if subv.IsNil() { + subv.Set(reflect.New(subv.Type().Elem())) + } + subv = subv.Elem() + } + subv = subv.Field(i) + } + } + } + + // Read : before value. + if op == scanSkipSpace { + op = d.scanWhile(scanSkipSpace) + } + if op != scanObjectKey { + d.error(errPhase) + } + + // Read value. + if destring { + switch qv := d.valueQuoted().(type) { + case nil: + d.literalStore(nullLiteral, subv, false) + case string: + d.literalStore([]byte(qv), subv, true) + default: + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", subv.Type())) + } + } else { + d.value(subv) + } + + // Write value back to map; + // if using struct, subv points into struct already. + if v.Kind() == reflect.Map { + kv := reflect.ValueOf(key).Convert(v.Type().Key()) + v.SetMapIndex(kv, subv) + } + + // Next token must be , or }. + op = d.scanWhile(scanSkipSpace) + if op == scanEndObject { + break + } + if op != scanObjectValue { + d.error(errPhase) + } + } +} + +// literal consumes a literal from d.data[d.off-1:], decoding into the value v. +// The first byte of the literal has been read already +// (that's how the caller knows it's a literal). +func (d *decodeState) literal(v reflect.Value) { + // All bytes inside literal return scanContinue op code. + start := d.off - 1 + op := d.scanWhile(scanContinue) + + // Scan read one byte too far; back up. + d.off-- + d.scan.undo(op) + + d.literalStore(d.data[start:d.off], v, false) +} + +// convertNumber converts the number literal s to a float64 or a Number +// depending on the setting of d.useNumber. +func (d *decodeState) convertNumber(s string) (interface{}, error) { + if d.useNumber { + return Number(s), nil + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return nil, &UnmarshalTypeError{"number " + s, reflect.TypeOf(0.0), int64(d.off)} + } + return f, nil +} + +var numberType = reflect.TypeOf(Number("")) + +// literalStore decodes a literal stored in item into v. +// +// fromQuoted indicates whether this literal came from unwrapping a +// string from the ",string" struct tag option. this is used only to +// produce more helpful error messages. +func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool) { + // Check for unmarshaler. + if len(item) == 0 { + //Empty string given + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + return + } + wantptr := item[0] == 'n' // null + u, ut, pv := d.indirect(v, wantptr) + if u != nil { + err := u.UnmarshalJSON(item) + if err != nil { + d.error(err) + } + return + } + if ut != nil { + if item[0] != '"' { + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.saveError(&UnmarshalTypeError{"string", v.Type(), int64(d.off)}) + } + return + } + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(errPhase) + } + } + err := ut.UnmarshalText(s) + if err != nil { + d.error(err) + } + return + } + + v = pv + + switch c := item[0]; c { + case 'n': // null + switch v.Kind() { + case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice: + v.Set(reflect.Zero(v.Type())) + // otherwise, ignore null for primitives/string + } + case 't', 'f': // true, false + value := c == 't' + switch v.Kind() { + default: + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.saveError(&UnmarshalTypeError{"bool", v.Type(), int64(d.off)}) + } + case reflect.Bool: + v.SetBool(value) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(value)) + } else { + d.saveError(&UnmarshalTypeError{"bool", v.Type(), int64(d.off)}) + } + } + + case '"': // string + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(errPhase) + } + } + switch v.Kind() { + default: + d.saveError(&UnmarshalTypeError{"string", v.Type(), int64(d.off)}) + case reflect.Slice: + if v.Type().Elem().Kind() != reflect.Uint8 { + d.saveError(&UnmarshalTypeError{"string", v.Type(), int64(d.off)}) + break + } + b := make([]byte, base64.StdEncoding.DecodedLen(len(s))) + n, err := base64.StdEncoding.Decode(b, s) + if err != nil { + d.saveError(err) + break + } + v.SetBytes(b[:n]) + case reflect.String: + v.SetString(string(s)) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(string(s))) + } else { + d.saveError(&UnmarshalTypeError{"string", v.Type(), int64(d.off)}) + } + } + + default: // number + if c != '-' && (c < '0' || c > '9') { + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(errPhase) + } + } + s := string(item) + switch v.Kind() { + default: + if v.Kind() == reflect.String && v.Type() == numberType { + v.SetString(s) + if !isValidNumber(s) { + d.error(fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number", item)) + } + break + } + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(&UnmarshalTypeError{"number", v.Type(), int64(d.off)}) + } + case reflect.Interface: + n, err := d.convertNumber(s) + if err != nil { + d.saveError(err) + break + } + if v.NumMethod() != 0 { + d.saveError(&UnmarshalTypeError{"number", v.Type(), int64(d.off)}) + break + } + v.Set(reflect.ValueOf(n)) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(s, 10, 64) + if err != nil || v.OverflowInt(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type(), int64(d.off)}) + break + } + v.SetInt(n) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + n, err := strconv.ParseUint(s, 10, 64) + if err != nil || v.OverflowUint(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type(), int64(d.off)}) + break + } + v.SetUint(n) + + case reflect.Float32, reflect.Float64: + n, err := strconv.ParseFloat(s, v.Type().Bits()) + if err != nil || v.OverflowFloat(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type(), int64(d.off)}) + break + } + v.SetFloat(n) + } + } +} + +// The xxxInterface routines build up a value to be stored +// in an empty interface. They are not strictly necessary, +// but they avoid the weight of reflection in this common case. + +// valueInterface is like value but returns interface{} +func (d *decodeState) valueInterface() interface{} { + switch d.scanWhile(scanSkipSpace) { + default: + d.error(errPhase) + panic("unreachable") + case scanBeginArray: + return d.arrayInterface() + case scanBeginObject: + return d.objectInterface() + case scanBeginLiteral: + return d.literalInterface() + } +} + +// arrayInterface is like array but returns []interface{}. +func (d *decodeState) arrayInterface() []interface{} { + var v = make([]interface{}, 0) + for { + // Look ahead for ] - can only happen on first iteration. + op := d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + + // Back up so d.value can have the byte we just read. + d.off-- + d.scan.undo(op) + + v = append(v, d.valueInterface()) + + // Next token must be , or ]. + op = d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + if op != scanArrayValue { + d.error(errPhase) + } + } + return v +} + +// objectInterface is like object but returns map[string]interface{}. +func (d *decodeState) objectInterface() map[string]interface{} { + m := make(map[string]interface{}) + keys := map[string]bool{} + + for { + // Read opening " of string key or closing }. + op := d.scanWhile(scanSkipSpace) + if op == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if op != scanBeginLiteral { + d.error(errPhase) + } + + // Read string key. + start := d.off - 1 + op = d.scanWhile(scanContinue) + item := d.data[start : d.off-1] + key, ok := unquote(item) + if !ok { + d.error(errPhase) + } + + // Check for duplicate keys. + _, ok = keys[key] + if !ok { + keys[key] = true + } else { + d.error(fmt.Errorf("json: duplicate key '%s' in object", key)) + } + + // Read : before value. + if op == scanSkipSpace { + op = d.scanWhile(scanSkipSpace) + } + if op != scanObjectKey { + d.error(errPhase) + } + + // Read value. + m[key] = d.valueInterface() + + // Next token must be , or }. + op = d.scanWhile(scanSkipSpace) + if op == scanEndObject { + break + } + if op != scanObjectValue { + d.error(errPhase) + } + } + return m +} + +// literalInterface is like literal but returns an interface value. +func (d *decodeState) literalInterface() interface{} { + // All bytes inside literal return scanContinue op code. + start := d.off - 1 + op := d.scanWhile(scanContinue) + + // Scan read one byte too far; back up. + d.off-- + d.scan.undo(op) + item := d.data[start:d.off] + + switch c := item[0]; c { + case 'n': // null + return nil + + case 't', 'f': // true, false + return c == 't' + + case '"': // string + s, ok := unquote(item) + if !ok { + d.error(errPhase) + } + return s + + default: // number + if c != '-' && (c < '0' || c > '9') { + d.error(errPhase) + } + n, err := d.convertNumber(string(item)) + if err != nil { + d.saveError(err) + } + return n + } +} + +// getu4 decodes \uXXXX from the beginning of s, returning the hex value, +// or it returns -1. +func getu4(s []byte) rune { + if len(s) < 6 || s[0] != '\\' || s[1] != 'u' { + return -1 + } + r, err := strconv.ParseUint(string(s[2:6]), 16, 64) + if err != nil { + return -1 + } + return rune(r) +} + +// unquote converts a quoted JSON string literal s into an actual string t. +// The rules are different than for Go, so cannot use strconv.Unquote. +func unquote(s []byte) (t string, ok bool) { + s, ok = unquoteBytes(s) + t = string(s) + return +} + +func unquoteBytes(s []byte) (t []byte, ok bool) { + if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' { + return + } + s = s[1 : len(s)-1] + + // Check for unusual characters. If there are none, + // then no unquoting is needed, so return a slice of the + // original bytes. + r := 0 + for r < len(s) { + c := s[r] + if c == '\\' || c == '"' || c < ' ' { + break + } + if c < utf8.RuneSelf { + r++ + continue + } + rr, size := utf8.DecodeRune(s[r:]) + if rr == utf8.RuneError && size == 1 { + break + } + r += size + } + if r == len(s) { + return s, true + } + + b := make([]byte, len(s)+2*utf8.UTFMax) + w := copy(b, s[0:r]) + for r < len(s) { + // Out of room? Can only happen if s is full of + // malformed UTF-8 and we're replacing each + // byte with RuneError. + if w >= len(b)-2*utf8.UTFMax { + nb := make([]byte, (len(b)+utf8.UTFMax)*2) + copy(nb, b[0:w]) + b = nb + } + switch c := s[r]; { + case c == '\\': + r++ + if r >= len(s) { + return + } + switch s[r] { + default: + return + case '"', '\\', '/', '\'': + b[w] = s[r] + r++ + w++ + case 'b': + b[w] = '\b' + r++ + w++ + case 'f': + b[w] = '\f' + r++ + w++ + case 'n': + b[w] = '\n' + r++ + w++ + case 'r': + b[w] = '\r' + r++ + w++ + case 't': + b[w] = '\t' + r++ + w++ + case 'u': + r-- + rr := getu4(s[r:]) + if rr < 0 { + return + } + r += 6 + if utf16.IsSurrogate(rr) { + rr1 := getu4(s[r:]) + if dec := utf16.DecodeRune(rr, rr1); dec != unicode.ReplacementChar { + // A valid pair; consume. + r += 6 + w += utf8.EncodeRune(b[w:], dec) + break + } + // Invalid surrogate; fall back to replacement rune. + rr = unicode.ReplacementChar + } + w += utf8.EncodeRune(b[w:], rr) + } + + // Quote, control characters are invalid. + case c == '"', c < ' ': + return + + // ASCII + case c < utf8.RuneSelf: + b[w] = c + r++ + w++ + + // Coerce to well-formed UTF-8. + default: + rr, size := utf8.DecodeRune(s[r:]) + r += size + w += utf8.EncodeRune(b[w:], rr) + } + } + return b[0:w], true +} diff --git a/vendor/gopkg.in/square/go-jose.v2/json/encode.go b/vendor/gopkg.in/square/go-jose.v2/json/encode.go new file mode 100644 index 0000000000..1dae8bb7cd --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/json/encode.go @@ -0,0 +1,1197 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package json implements encoding and decoding of JSON objects as defined in +// RFC 4627. The mapping between JSON objects and Go values is described +// in the documentation for the Marshal and Unmarshal functions. +// +// See "JSON and Go" for an introduction to this package: +// https://golang.org/doc/articles/json_and_go.html +package json + +import ( + "bytes" + "encoding" + "encoding/base64" + "fmt" + "math" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "unicode" + "unicode/utf8" +) + +// Marshal returns the JSON encoding of v. +// +// Marshal traverses the value v recursively. +// If an encountered value implements the Marshaler interface +// and is not a nil pointer, Marshal calls its MarshalJSON method +// to produce JSON. If no MarshalJSON method is present but the +// value implements encoding.TextMarshaler instead, Marshal calls +// its MarshalText method. +// The nil pointer exception is not strictly necessary +// but mimics a similar, necessary exception in the behavior of +// UnmarshalJSON. +// +// Otherwise, Marshal uses the following type-dependent default encodings: +// +// Boolean values encode as JSON booleans. +// +// Floating point, integer, and Number values encode as JSON numbers. +// +// String values encode as JSON strings coerced to valid UTF-8, +// replacing invalid bytes with the Unicode replacement rune. +// The angle brackets "<" and ">" are escaped to "\u003c" and "\u003e" +// to keep some browsers from misinterpreting JSON output as HTML. +// Ampersand "&" is also escaped to "\u0026" for the same reason. +// +// Array and slice values encode as JSON arrays, except that +// []byte encodes as a base64-encoded string, and a nil slice +// encodes as the null JSON object. +// +// Struct values encode as JSON objects. Each exported struct field +// becomes a member of the object unless +// - the field's tag is "-", or +// - the field is empty and its tag specifies the "omitempty" option. +// The empty values are false, 0, any +// nil pointer or interface value, and any array, slice, map, or string of +// length zero. The object's default key string is the struct field name +// but can be specified in the struct field's tag value. The "json" key in +// the struct field's tag value is the key name, followed by an optional comma +// and options. Examples: +// +// // Field is ignored by this package. +// Field int `json:"-"` +// +// // Field appears in JSON as key "myName". +// Field int `json:"myName"` +// +// // Field appears in JSON as key "myName" and +// // the field is omitted from the object if its value is empty, +// // as defined above. +// Field int `json:"myName,omitempty"` +// +// // Field appears in JSON as key "Field" (the default), but +// // the field is skipped if empty. +// // Note the leading comma. +// Field int `json:",omitempty"` +// +// The "string" option signals that a field is stored as JSON inside a +// JSON-encoded string. It applies only to fields of string, floating point, +// integer, or boolean types. This extra level of encoding is sometimes used +// when communicating with JavaScript programs: +// +// Int64String int64 `json:",string"` +// +// The key name will be used if it's a non-empty string consisting of +// only Unicode letters, digits, dollar signs, percent signs, hyphens, +// underscores and slashes. +// +// Anonymous struct fields are usually marshaled as if their inner exported fields +// were fields in the outer struct, subject to the usual Go visibility rules amended +// as described in the next paragraph. +// An anonymous struct field with a name given in its JSON tag is treated as +// having that name, rather than being anonymous. +// An anonymous struct field of interface type is treated the same as having +// that type as its name, rather than being anonymous. +// +// The Go visibility rules for struct fields are amended for JSON when +// deciding which field to marshal or unmarshal. If there are +// multiple fields at the same level, and that level is the least +// nested (and would therefore be the nesting level selected by the +// usual Go rules), the following extra rules apply: +// +// 1) Of those fields, if any are JSON-tagged, only tagged fields are considered, +// even if there are multiple untagged fields that would otherwise conflict. +// 2) If there is exactly one field (tagged or not according to the first rule), that is selected. +// 3) Otherwise there are multiple fields, and all are ignored; no error occurs. +// +// Handling of anonymous struct fields is new in Go 1.1. +// Prior to Go 1.1, anonymous struct fields were ignored. To force ignoring of +// an anonymous struct field in both current and earlier versions, give the field +// a JSON tag of "-". +// +// Map values encode as JSON objects. +// The map's key type must be string; the map keys are used as JSON object +// keys, subject to the UTF-8 coercion described for string values above. +// +// Pointer values encode as the value pointed to. +// A nil pointer encodes as the null JSON object. +// +// Interface values encode as the value contained in the interface. +// A nil interface value encodes as the null JSON object. +// +// Channel, complex, and function values cannot be encoded in JSON. +// Attempting to encode such a value causes Marshal to return +// an UnsupportedTypeError. +// +// JSON cannot represent cyclic data structures and Marshal does not +// handle them. Passing cyclic structures to Marshal will result in +// an infinite recursion. +// +func Marshal(v interface{}) ([]byte, error) { + e := &encodeState{} + err := e.marshal(v) + if err != nil { + return nil, err + } + return e.Bytes(), nil +} + +// MarshalIndent is like Marshal but applies Indent to format the output. +func MarshalIndent(v interface{}, prefix, indent string) ([]byte, error) { + b, err := Marshal(v) + if err != nil { + return nil, err + } + var buf bytes.Buffer + err = Indent(&buf, b, prefix, indent) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// HTMLEscape appends to dst the JSON-encoded src with <, >, &, U+2028 and U+2029 +// characters inside string literals changed to \u003c, \u003e, \u0026, \u2028, \u2029 +// so that the JSON will be safe to embed inside HTML