diff --git a/agent/acl_endpoint.go b/agent/acl_endpoint.go index 64eaad0de8..12c6b313a2 100644 --- a/agent/acl_endpoint.go +++ b/agent/acl_endpoint.go @@ -254,6 +254,7 @@ func (s *HTTPServer) ACLPolicyRead(resp http.ResponseWriter, req *http.Request, } if out.Policy == nil { + // TODO(rb): should this return a normal 404? return nil, acl.ErrNotFound } @@ -268,15 +269,35 @@ func (s *HTTPServer) ACLPolicyCreate(resp http.ResponseWriter, req *http.Request return s.ACLPolicyWrite(resp, req, "") } -// fixCreateTimeAndHash is used to help in decoding the CreateTime and Hash +// fixTimeAndHashFields is used to help in decoding the ExpirationTTL, ExpirationTime, CreateTime, and Hash // attributes from the ACL Token/Policy create/update requests. It is needed // to help mapstructure decode things properly when decodeBody is used. -func fixCreateTimeAndHash(raw interface{}) error { +func fixTimeAndHashFields(raw interface{}) error { rawMap, ok := raw.(map[string]interface{}) if !ok { return nil } + if val, ok := rawMap["ExpirationTTL"]; ok { + if sval, ok := val.(string); ok { + d, err := time.ParseDuration(sval) + if err != nil { + return err + } + rawMap["ExpirationTTL"] = d + } + } + + if val, ok := rawMap["ExpirationTime"]; ok { + if sval, ok := val.(string); ok { + t, err := time.Parse(time.RFC3339, sval) + if err != nil { + return err + } + rawMap["ExpirationTime"] = t + } + } + if val, ok := rawMap["CreateTime"]; ok { if sval, ok := val.(string); ok { t, err := time.Parse(time.RFC3339, sval) @@ -301,7 +322,7 @@ func (s *HTTPServer) ACLPolicyWrite(resp http.ResponseWriter, req *http.Request, } s.parseToken(req, &args.Token) - if err := decodeBody(req, &args.Policy, fixCreateTimeAndHash); err != nil { + if err := decodeBody(req, &args.Policy, fixTimeAndHashFields); err != nil { return nil, BadRequestError{Reason: fmt.Sprintf("Policy decoding failed: %v", err)} } @@ -354,6 +375,8 @@ func (s *HTTPServer) ACLTokenList(resp http.ResponseWriter, req *http.Request) ( } args.Policy = req.URL.Query().Get("policy") + args.Role = req.URL.Query().Get("role") + args.AuthMethod = req.URL.Query().Get("authmethod") var out structs.ACLTokenListResponse defer setMeta(resp, &out.QueryMeta) @@ -472,7 +495,7 @@ func (s *HTTPServer) ACLTokenSet(resp http.ResponseWriter, req *http.Request, to } s.parseToken(req, &args.Token) - if err := decodeBody(req, &args.ACLToken, fixCreateTimeAndHash); err != nil { + if err := decodeBody(req, &args.ACLToken, fixTimeAndHashFields); err != nil { return nil, BadRequestError{Reason: fmt.Sprintf("Token decoding failed: %v", err)} } @@ -513,7 +536,7 @@ func (s *HTTPServer) ACLTokenClone(resp http.ResponseWriter, req *http.Request, Datacenter: s.agent.config.Datacenter, } - if err := decodeBody(req, &args.ACLToken, fixCreateTimeAndHash); err != nil && err.Error() != "EOF" { + if err := decodeBody(req, &args.ACLToken, fixTimeAndHashFields); err != nil && err.Error() != "EOF" { return nil, BadRequestError{Reason: fmt.Sprintf("Token decoding failed: %v", err)} } s.parseToken(req, &args.Token) @@ -528,3 +551,480 @@ func (s *HTTPServer) ACLTokenClone(resp http.ResponseWriter, req *http.Request, return &out, nil } + +func (s *HTTPServer) ACLRoleList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var args structs.ACLRoleListRequest + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + args.Policy = req.URL.Query().Get("policy") + + var out structs.ACLRoleListResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.RoleList", &args, &out); err != nil { + return nil, err + } + + // make sure we return an array and not nil + if out.Roles == nil { + out.Roles = make(structs.ACLRoles, 0) + } + + return out.Roles, nil +} + +func (s *HTTPServer) ACLRoleCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var fn func(resp http.ResponseWriter, req *http.Request, roleID string) (interface{}, error) + + switch req.Method { + case "GET": + fn = s.ACLRoleReadByID + + case "PUT": + fn = s.ACLRoleWrite + + case "DELETE": + fn = s.ACLRoleDelete + + default: + return nil, MethodNotAllowedError{req.Method, []string{"GET", "PUT", "DELETE"}} + } + + roleID := strings.TrimPrefix(req.URL.Path, "/v1/acl/role/") + if roleID == "" && req.Method != "PUT" { + return nil, BadRequestError{Reason: "Missing role ID"} + } + + return fn(resp, req, roleID) +} + +func (s *HTTPServer) ACLRoleReadByName(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + roleName := strings.TrimPrefix(req.URL.Path, "/v1/acl/role/name/") + if roleName == "" { + return nil, BadRequestError{Reason: "Missing role Name"} + } + + return s.ACLRoleRead(resp, req, "", roleName) +} + +func (s *HTTPServer) ACLRoleReadByID(resp http.ResponseWriter, req *http.Request, roleID string) (interface{}, error) { + return s.ACLRoleRead(resp, req, roleID, "") +} + +func (s *HTTPServer) ACLRoleRead(resp http.ResponseWriter, req *http.Request, roleID, roleName string) (interface{}, error) { + args := structs.ACLRoleGetRequest{ + Datacenter: s.agent.config.Datacenter, + RoleID: roleID, + RoleName: roleName, + } + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + var out structs.ACLRoleResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.RoleRead", &args, &out); err != nil { + return nil, err + } + + if out.Role == nil { + resp.WriteHeader(http.StatusNotFound) + return nil, nil + } + + return out.Role, nil +} + +func (s *HTTPServer) ACLRoleCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + return s.ACLRoleWrite(resp, req, "") +} + +func (s *HTTPServer) ACLRoleWrite(resp http.ResponseWriter, req *http.Request, roleID string) (interface{}, error) { + args := structs.ACLRoleSetRequest{ + Datacenter: s.agent.config.Datacenter, + } + s.parseToken(req, &args.Token) + + if err := decodeBody(req, &args.Role, fixTimeAndHashFields); err != nil { + return nil, BadRequestError{Reason: fmt.Sprintf("Role decoding failed: %v", err)} + } + + if args.Role.ID != "" && args.Role.ID != roleID { + return nil, BadRequestError{Reason: "Role ID in URL and payload do not match"} + } else if args.Role.ID == "" { + args.Role.ID = roleID + } + + var out structs.ACLRole + if err := s.agent.RPC("ACL.RoleSet", args, &out); err != nil { + return nil, err + } + + return &out, nil +} + +func (s *HTTPServer) ACLRoleDelete(resp http.ResponseWriter, req *http.Request, roleID string) (interface{}, error) { + args := structs.ACLRoleDeleteRequest{ + Datacenter: s.agent.config.Datacenter, + RoleID: roleID, + } + s.parseToken(req, &args.Token) + + var ignored string + if err := s.agent.RPC("ACL.RoleDelete", args, &ignored); err != nil { + return nil, err + } + + return true, nil +} + +func (s *HTTPServer) ACLBindingRuleList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var args structs.ACLBindingRuleListRequest + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + args.AuthMethod = req.URL.Query().Get("authmethod") + + var out structs.ACLBindingRuleListResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.BindingRuleList", &args, &out); err != nil { + return nil, err + } + + // make sure we return an array and not nil + if out.BindingRules == nil { + out.BindingRules = make(structs.ACLBindingRules, 0) + } + + return out.BindingRules, nil +} + +func (s *HTTPServer) ACLBindingRuleCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var fn func(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) + + switch req.Method { + case "GET": + fn = s.ACLBindingRuleRead + + case "PUT": + fn = s.ACLBindingRuleWrite + + case "DELETE": + fn = s.ACLBindingRuleDelete + + default: + return nil, MethodNotAllowedError{req.Method, []string{"GET", "PUT", "DELETE"}} + } + + bindingRuleID := strings.TrimPrefix(req.URL.Path, "/v1/acl/binding-rule/") + if bindingRuleID == "" && req.Method != "PUT" { + return nil, BadRequestError{Reason: "Missing binding rule ID"} + } + + return fn(resp, req, bindingRuleID) +} + +func (s *HTTPServer) ACLBindingRuleRead(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) { + args := structs.ACLBindingRuleGetRequest{ + Datacenter: s.agent.config.Datacenter, + BindingRuleID: bindingRuleID, + } + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + var out structs.ACLBindingRuleResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.BindingRuleRead", &args, &out); err != nil { + return nil, err + } + + if out.BindingRule == nil { + resp.WriteHeader(http.StatusNotFound) + return nil, nil + } + + return out.BindingRule, nil +} + +func (s *HTTPServer) ACLBindingRuleCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + return s.ACLBindingRuleWrite(resp, req, "") +} + +func (s *HTTPServer) ACLBindingRuleWrite(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) { + args := structs.ACLBindingRuleSetRequest{ + Datacenter: s.agent.config.Datacenter, + } + s.parseToken(req, &args.Token) + + if err := decodeBody(req, &args.BindingRule, fixTimeAndHashFields); err != nil { + return nil, BadRequestError{Reason: fmt.Sprintf("BindingRule decoding failed: %v", err)} + } + + if args.BindingRule.ID != "" && args.BindingRule.ID != bindingRuleID { + return nil, BadRequestError{Reason: "BindingRule ID in URL and payload do not match"} + } else if args.BindingRule.ID == "" { + args.BindingRule.ID = bindingRuleID + } + + var out structs.ACLBindingRule + if err := s.agent.RPC("ACL.BindingRuleSet", args, &out); err != nil { + return nil, err + } + + return &out, nil +} + +func (s *HTTPServer) ACLBindingRuleDelete(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) { + args := structs.ACLBindingRuleDeleteRequest{ + Datacenter: s.agent.config.Datacenter, + BindingRuleID: bindingRuleID, + } + s.parseToken(req, &args.Token) + + var ignored bool + if err := s.agent.RPC("ACL.BindingRuleDelete", args, &ignored); err != nil { + return nil, err + } + + return true, nil +} + +func (s *HTTPServer) ACLAuthMethodList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var args structs.ACLAuthMethodListRequest + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + var out structs.ACLAuthMethodListResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.AuthMethodList", &args, &out); err != nil { + return nil, err + } + + // make sure we return an array and not nil + if out.AuthMethods == nil { + out.AuthMethods = make(structs.ACLAuthMethodListStubs, 0) + } + + return out.AuthMethods, nil +} + +func (s *HTTPServer) ACLAuthMethodCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var fn func(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) + + switch req.Method { + case "GET": + fn = s.ACLAuthMethodRead + + case "PUT": + fn = s.ACLAuthMethodWrite + + case "DELETE": + fn = s.ACLAuthMethodDelete + + default: + return nil, MethodNotAllowedError{req.Method, []string{"GET", "PUT", "DELETE"}} + } + + methodName := strings.TrimPrefix(req.URL.Path, "/v1/acl/auth-method/") + if methodName == "" && req.Method != "PUT" { + return nil, BadRequestError{Reason: "Missing auth method name"} + } + + return fn(resp, req, methodName) +} + +func (s *HTTPServer) ACLAuthMethodRead(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) { + args := structs.ACLAuthMethodGetRequest{ + Datacenter: s.agent.config.Datacenter, + AuthMethodName: methodName, + } + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + var out structs.ACLAuthMethodResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.AuthMethodRead", &args, &out); err != nil { + return nil, err + } + + if out.AuthMethod == nil { + resp.WriteHeader(http.StatusNotFound) + return nil, nil + } + + fixupAuthMethodConfig(out.AuthMethod) + return out.AuthMethod, nil +} + +func (s *HTTPServer) ACLAuthMethodCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + return s.ACLAuthMethodWrite(resp, req, "") +} + +func (s *HTTPServer) ACLAuthMethodWrite(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) { + args := structs.ACLAuthMethodSetRequest{ + Datacenter: s.agent.config.Datacenter, + } + s.parseToken(req, &args.Token) + + if err := decodeBody(req, &args.AuthMethod, fixTimeAndHashFields); err != nil { + return nil, BadRequestError{Reason: fmt.Sprintf("AuthMethod decoding failed: %v", err)} + } + + if methodName != "" { + if args.AuthMethod.Name != "" && args.AuthMethod.Name != methodName { + return nil, BadRequestError{Reason: "AuthMethod Name in URL and payload do not match"} + } else if args.AuthMethod.Name == "" { + args.AuthMethod.Name = methodName + } + } + + var out structs.ACLAuthMethod + if err := s.agent.RPC("ACL.AuthMethodSet", args, &out); err != nil { + return nil, err + } + + fixupAuthMethodConfig(&out) + return &out, nil +} + +func (s *HTTPServer) ACLAuthMethodDelete(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) { + args := structs.ACLAuthMethodDeleteRequest{ + Datacenter: s.agent.config.Datacenter, + AuthMethodName: methodName, + } + s.parseToken(req, &args.Token) + + var ignored bool + if err := s.agent.RPC("ACL.AuthMethodDelete", args, &ignored); err != nil { + return nil, err + } + + return true, nil +} + +func (s *HTTPServer) ACLLogin(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + args := &structs.ACLLoginRequest{ + Datacenter: s.agent.config.Datacenter, + } + s.parseDC(req, &args.Datacenter) + + if err := decodeBody(req, &args.Auth, nil); err != nil { + return nil, BadRequestError{Reason: fmt.Sprintf("Failed to decode request body:: %v", err)} + } + + var out structs.ACLToken + if err := s.agent.RPC("ACL.Login", args, &out); err != nil { + return nil, err + } + + return &out, nil +} + +func (s *HTTPServer) ACLLogout(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + args := structs.ACLLogoutRequest{ + Datacenter: s.agent.config.Datacenter, + } + s.parseDC(req, &args.Datacenter) + s.parseToken(req, &args.Token) + + if args.Token == "" { + return nil, acl.ErrNotFound + } + + var ignored bool + if err := s.agent.RPC("ACL.Logout", &args, &ignored); err != nil { + return nil, err + } + + return true, nil +} + +// A hack to fix up the config types inside of the map[string]interface{} +// so that they get formatted correctly during json.Marshal. Without this, +// string values that get converted to []uint8 end up getting output back +// to the user in base64-encoded form. +func fixupAuthMethodConfig(method *structs.ACLAuthMethod) { + for k, v := range method.Config { + if raw, ok := v.([]uint8); ok { + strVal := structs.Uint8ToString(raw) + method.Config[k] = strVal + } + } +} diff --git a/agent/acl_endpoint_test.go b/agent/acl_endpoint_test.go index 754b931cc6..ab92c1457a 100644 --- a/agent/acl_endpoint_test.go +++ b/agent/acl_endpoint_test.go @@ -8,6 +8,8 @@ import ( "net/http/httptest" "testing" + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/authmethod/testauth" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/testrpc" "github.com/stretchr/testify/require" @@ -40,6 +42,17 @@ func TestACL_Disabled_Response(t *testing.T) { {"ACLTokenCreate", a.srv.ACLTokenCreate}, {"ACLTokenSelf", a.srv.ACLTokenSelf}, {"ACLTokenCRUD", a.srv.ACLTokenCRUD}, + {"ACLRoleList", a.srv.ACLRoleList}, + {"ACLRoleCreate", a.srv.ACLRoleCreate}, + {"ACLRoleCRUD", a.srv.ACLRoleCRUD}, + {"ACLBindingRuleList", a.srv.ACLBindingRuleList}, + {"ACLBindingRuleCreate", a.srv.ACLBindingRuleCreate}, + {"ACLBindingRuleCRUD", a.srv.ACLBindingRuleCRUD}, + {"ACLAuthMethodList", a.srv.ACLAuthMethodList}, + {"ACLAuthMethodCreate", a.srv.ACLAuthMethodCreate}, + {"ACLAuthMethodCRUD", a.srv.ACLAuthMethodCRUD}, + {"ACLLogin", a.srv.ACLLogin}, + {"ACLLogout", a.srv.ACLLogout}, } testrpc.WaitForLeader(t, a.RPC, "dc1") for _, tt := range tests { @@ -119,6 +132,7 @@ func TestACL_HTTP(t *testing.T) { idMap := make(map[string]string) policyMap := make(map[string]*structs.ACLPolicy) + roleMap := make(map[string]*structs.ACLRole) tokenMap := make(map[string]*structs.ACLToken) // This is all done as a subtest for a couple reasons @@ -220,7 +234,7 @@ func TestACL_HTTP(t *testing.T) { policyMap[policy.ID] = policy }) - t.Run("Update Name ID Mistmatch", func(t *testing.T) { + t.Run("Update Name ID Mismatch", func(t *testing.T) { policyInput := &structs.ACLPolicy{ ID: "ac7560be-7f11-4d6d-bfcf-15633c2090fd", Name: "read-all-nodes", @@ -355,6 +369,222 @@ func TestACL_HTTP(t *testing.T) { }) }) + t.Run("Role", func(t *testing.T) { + t.Run("Create", func(t *testing.T) { + roleInput := &structs.ACLRole{ + Name: "test", + Description: "test", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: idMap["policy-test"], + Name: policyMap[idMap["policy-test"]].Name, + }, + structs.ACLRolePolicyLink{ + ID: idMap["policy-read-all-nodes"], + Name: policyMap[idMap["policy-read-all-nodes"]].Name, + }, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", jsonBody(roleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLRoleCreate(resp, req) + require.NoError(t, err) + + role, ok := obj.(*structs.ACLRole) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, role.ID, 36) + require.Equal(t, roleInput.Name, role.Name) + require.Equal(t, roleInput.Description, role.Description) + require.Equal(t, roleInput.Policies, role.Policies) + require.True(t, role.CreateIndex > 0) + require.Equal(t, role.CreateIndex, role.ModifyIndex) + require.NotNil(t, role.Hash) + require.NotEqual(t, role.Hash, []byte{}) + + idMap["role-test"] = role.ID + roleMap[role.ID] = role + }) + + t.Run("Name Chars", func(t *testing.T) { + roleInput := &structs.ACLRole{ + Name: "service-id-web", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "web", + }, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", jsonBody(roleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLRoleCreate(resp, req) + require.NoError(t, err) + + role, ok := obj.(*structs.ACLRole) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, role.ID, 36) + require.Equal(t, roleInput.Name, role.Name) + require.Equal(t, roleInput.Description, role.Description) + require.Equal(t, roleInput.ServiceIdentities, role.ServiceIdentities) + require.True(t, role.CreateIndex > 0) + require.Equal(t, role.CreateIndex, role.ModifyIndex) + require.NotNil(t, role.Hash) + require.NotEqual(t, role.Hash, []byte{}) + + idMap["role-service-id-web"] = role.ID + roleMap[role.ID] = role + }) + + t.Run("Update Name ID Mismatch", func(t *testing.T) { + roleInput := &structs.ACLRole{ + ID: "ac7560be-7f11-4d6d-bfcf-15633c2090fd", + Name: "test", + Description: "test", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "db", + }, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/role/"+idMap["role-test"]+"?token=root", jsonBody(roleInput)) + resp := httptest.NewRecorder() + _, err := a.srv.ACLRoleCRUD(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Role CRUD Missing ID in URL", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/role/?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLRoleCRUD(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Update", func(t *testing.T) { + roleInput := &structs.ACLRole{ + Name: "test", + Description: "test", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "web-indexer", + }, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/role/"+idMap["role-test"]+"?token=root", jsonBody(roleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLRoleCRUD(resp, req) + require.NoError(t, err) + + role, ok := obj.(*structs.ACLRole) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, role.ID, 36) + require.Equal(t, roleInput.Name, role.Name) + require.Equal(t, roleInput.Description, role.Description) + require.Equal(t, roleInput.Policies, role.Policies) + require.Equal(t, roleInput.ServiceIdentities, role.ServiceIdentities) + require.True(t, role.CreateIndex > 0) + require.True(t, role.CreateIndex < role.ModifyIndex) + require.NotNil(t, role.Hash) + require.NotEqual(t, role.Hash, []byte{}) + + idMap["role-test"] = role.ID + roleMap[role.ID] = role + }) + + t.Run("ID Supplied", func(t *testing.T) { + roleInput := &structs.ACLRole{ + ID: "12123d01-37f1-47e6-b55b-32328652bd38", + Name: "with-id", + Description: "test", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "foobar", + }, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", jsonBody(roleInput)) + resp := httptest.NewRecorder() + _, err := a.srv.ACLRoleCreate(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Invalid payload", func(t *testing.T) { + body := bytes.NewBuffer(nil) + body.Write([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + + req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", body) + resp := httptest.NewRecorder() + _, err := a.srv.ACLRoleCreate(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Delete", func(t *testing.T) { + req, _ := http.NewRequest("DELETE", "/v1/acl/role/"+idMap["role-service-id-web"]+"?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLRoleCRUD(resp, req) + require.NoError(t, err) + delete(roleMap, idMap["role-service-id-web"]) + delete(idMap, "role-service-id-web") + }) + + t.Run("List", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/roles?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLRoleList(resp, req) + require.NoError(t, err) + roles, ok := raw.(structs.ACLRoles) + require.True(t, ok) + + // 1 we just created + require.Len(t, roles, 1) + + for roleID, expected := range roleMap { + found := false + for _, actual := range roles { + if actual.ID == roleID { + require.Equal(t, expected.Name, actual.Name) + require.Equal(t, expected.Policies, actual.Policies) + require.Equal(t, expected.ServiceIdentities, actual.ServiceIdentities) + require.Equal(t, expected.Hash, actual.Hash) + require.Equal(t, expected.CreateIndex, actual.CreateIndex) + require.Equal(t, expected.ModifyIndex, actual.ModifyIndex) + found = true + break + } + } + + require.True(t, found) + } + }) + + t.Run("Read", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/role/"+idMap["role-test"]+"?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLRoleCRUD(resp, req) + require.NoError(t, err) + role, ok := raw.(*structs.ACLRole) + require.True(t, ok) + require.Equal(t, roleMap[idMap["role-test"]], role) + }) + }) + t.Run("Token", func(t *testing.T) { t.Run("Create", func(t *testing.T) { tokenInput := &structs.ACLToken{ @@ -594,3 +824,504 @@ func TestACL_HTTP(t *testing.T) { }) }) } + +func TestACL_LoginProcedure_HTTP(t *testing.T) { + // This tests AuthMethods, BindingRules, Login, and Logout. + t.Parallel() + a := NewTestAgent(t, t.Name(), TestACLConfig()) + defer a.Shutdown() + + testrpc.WaitForLeader(t, a.RPC, "dc1") + + idMap := make(map[string]string) + methodMap := make(map[string]*structs.ACLAuthMethod) + ruleMap := make(map[string]*structs.ACLBindingRule) + tokenMap := make(map[string]*structs.ACLToken) + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + + // This is all done as a subtest for a couple reasons + // 1. It uses only 1 test agent and these are + // somewhat expensive to bring up and tear down often + // 2. Instead of having to bring up a new agent and prime + // the ACL system with some data before running the test + // we can intelligently order these tests so we can still + // test everything with less actual operations and do + // so in a manner that is less prone to being flaky + // 3. While this test will be large it should + t.Run("AuthMethod", func(t *testing.T) { + t.Run("Create", func(t *testing.T) { + methodInput := &structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/auth-method?token=root", jsonBody(methodInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLAuthMethodCreate(resp, req) + require.NoError(t, err) + + method, ok := obj.(*structs.ACLAuthMethod) + require.True(t, ok) + + require.Equal(t, methodInput.Name, method.Name) + require.Equal(t, methodInput.Type, method.Type) + require.Equal(t, methodInput.Description, method.Description) + require.Equal(t, methodInput.Config, method.Config) + require.True(t, method.CreateIndex > 0) + require.Equal(t, method.CreateIndex, method.ModifyIndex) + + methodMap[method.Name] = method + }) + + t.Run("Create other", func(t *testing.T) { + methodInput := &structs.ACLAuthMethod{ + Name: "other", + Type: "testing", + Description: "test", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/auth-method?token=root", jsonBody(methodInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLAuthMethodCreate(resp, req) + require.NoError(t, err) + + method, ok := obj.(*structs.ACLAuthMethod) + require.True(t, ok) + + require.Equal(t, methodInput.Name, method.Name) + require.Equal(t, methodInput.Type, method.Type) + require.Equal(t, methodInput.Description, method.Description) + require.Equal(t, methodInput.Config, method.Config) + require.True(t, method.CreateIndex > 0) + require.Equal(t, method.CreateIndex, method.ModifyIndex) + + methodMap[method.Name] = method + }) + + t.Run("Update Name URL Mismatch", func(t *testing.T) { + methodInput := &structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/auth-method/not-test?token=root", jsonBody(methodInput)) + resp := httptest.NewRecorder() + _, err := a.srv.ACLAuthMethodCRUD(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Update", func(t *testing.T) { + methodInput := &structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "updated description", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/auth-method/test?token=root", jsonBody(methodInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLAuthMethodCRUD(resp, req) + require.NoError(t, err) + + method, ok := obj.(*structs.ACLAuthMethod) + require.True(t, ok) + + require.Equal(t, methodInput.Name, method.Name) + require.Equal(t, methodInput.Type, method.Type) + require.Equal(t, methodInput.Description, method.Description) + require.Equal(t, methodInput.Config, method.Config) + require.True(t, method.CreateIndex > 0) + require.True(t, method.CreateIndex < method.ModifyIndex) + + methodMap[method.Name] = method + }) + + t.Run("Invalid payload", func(t *testing.T) { + body := bytes.NewBuffer(nil) + body.Write([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + + req, _ := http.NewRequest("PUT", "/v1/acl/auth-method?token=root", body) + resp := httptest.NewRecorder() + _, err := a.srv.ACLAuthMethodCreate(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("List", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/auth-methods?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLAuthMethodList(resp, req) + require.NoError(t, err) + methods, ok := raw.(structs.ACLAuthMethodListStubs) + require.True(t, ok) + + // 2 we just created + require.Len(t, methods, 2) + + for methodName, expected := range methodMap { + found := false + for _, actual := range methods { + if actual.Name == methodName { + require.Equal(t, expected.Name, actual.Name) + require.Equal(t, expected.Type, actual.Type) + require.Equal(t, expected.Description, actual.Description) + require.Equal(t, expected.CreateIndex, actual.CreateIndex) + require.Equal(t, expected.ModifyIndex, actual.ModifyIndex) + found = true + break + } + } + + require.True(t, found) + } + }) + + t.Run("Delete", func(t *testing.T) { + req, _ := http.NewRequest("DELETE", "/v1/acl/auth-method/other?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLAuthMethodCRUD(resp, req) + require.NoError(t, err) + delete(methodMap, "other") + }) + + t.Run("Read", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/auth-method/test?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLAuthMethodCRUD(resp, req) + require.NoError(t, err) + method, ok := raw.(*structs.ACLAuthMethod) + require.True(t, ok) + require.Equal(t, methodMap["test"], method) + }) + }) + + t.Run("BindingRule", func(t *testing.T) { + t.Run("Create", func(t *testing.T) { + ruleInput := &structs.ACLBindingRule{ + Description: "test", + AuthMethod: "test", + Selector: "serviceaccount.namespace==default", + BindType: structs.BindingRuleBindTypeService, + BindName: "web", + } + + req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", jsonBody(ruleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLBindingRuleCreate(resp, req) + require.NoError(t, err) + + rule, ok := obj.(*structs.ACLBindingRule) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, rule.ID, 36) + require.Equal(t, ruleInput.Description, rule.Description) + require.Equal(t, ruleInput.AuthMethod, rule.AuthMethod) + require.Equal(t, ruleInput.Selector, rule.Selector) + require.Equal(t, ruleInput.BindType, rule.BindType) + require.Equal(t, ruleInput.BindName, rule.BindName) + require.True(t, rule.CreateIndex > 0) + require.Equal(t, rule.CreateIndex, rule.ModifyIndex) + + idMap["rule-test"] = rule.ID + ruleMap[rule.ID] = rule + }) + + t.Run("Create other", func(t *testing.T) { + ruleInput := &structs.ACLBindingRule{ + Description: "other", + AuthMethod: "test", + Selector: "serviceaccount.namespace==default", + BindType: structs.BindingRuleBindTypeRole, + BindName: "fancy-role", + } + + req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", jsonBody(ruleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLBindingRuleCreate(resp, req) + require.NoError(t, err) + + rule, ok := obj.(*structs.ACLBindingRule) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, rule.ID, 36) + require.Equal(t, ruleInput.Description, rule.Description) + require.Equal(t, ruleInput.AuthMethod, rule.AuthMethod) + require.Equal(t, ruleInput.Selector, rule.Selector) + require.Equal(t, ruleInput.BindType, rule.BindType) + require.Equal(t, ruleInput.BindName, rule.BindName) + require.True(t, rule.CreateIndex > 0) + require.Equal(t, rule.CreateIndex, rule.ModifyIndex) + + idMap["rule-other"] = rule.ID + ruleMap[rule.ID] = rule + }) + + t.Run("BindingRule CRUD Missing ID in URL", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/binding-rule/?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLBindingRuleCRUD(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Update", func(t *testing.T) { + ruleInput := &structs.ACLBindingRule{ + Description: "updated", + AuthMethod: "test", + Selector: "serviceaccount.namespace==default", + BindType: structs.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + } + + req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule/"+idMap["rule-test"]+"?token=root", jsonBody(ruleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLBindingRuleCRUD(resp, req) + require.NoError(t, err) + + rule, ok := obj.(*structs.ACLBindingRule) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, rule.ID, 36) + require.Equal(t, ruleInput.Description, rule.Description) + require.Equal(t, ruleInput.AuthMethod, rule.AuthMethod) + require.Equal(t, ruleInput.Selector, rule.Selector) + require.Equal(t, ruleInput.BindType, rule.BindType) + require.Equal(t, ruleInput.BindName, rule.BindName) + require.True(t, rule.CreateIndex > 0) + require.True(t, rule.CreateIndex < rule.ModifyIndex) + + idMap["rule-test"] = rule.ID + ruleMap[rule.ID] = rule + }) + + t.Run("ID Supplied", func(t *testing.T) { + ruleInput := &structs.ACLBindingRule{ + ID: "12123d01-37f1-47e6-b55b-32328652bd38", + Description: "with-id", + AuthMethod: "test", + Selector: "serviceaccount.namespace==default", + BindType: structs.BindingRuleBindTypeService, + BindName: "vault", + } + + req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", jsonBody(ruleInput)) + resp := httptest.NewRecorder() + _, err := a.srv.ACLBindingRuleCreate(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Invalid payload", func(t *testing.T) { + body := bytes.NewBuffer(nil) + body.Write([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + + req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", body) + resp := httptest.NewRecorder() + _, err := a.srv.ACLBindingRuleCreate(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("List", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/binding-rules?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLBindingRuleList(resp, req) + require.NoError(t, err) + rules, ok := raw.(structs.ACLBindingRules) + require.True(t, ok) + + // 2 we just created + require.Len(t, rules, 2) + + for ruleID, expected := range ruleMap { + found := false + for _, actual := range rules { + if actual.ID == ruleID { + require.Equal(t, expected.Description, actual.Description) + require.Equal(t, expected.AuthMethod, actual.AuthMethod) + require.Equal(t, expected.Selector, actual.Selector) + require.Equal(t, expected.BindType, actual.BindType) + require.Equal(t, expected.BindName, actual.BindName) + require.Equal(t, expected.CreateIndex, actual.CreateIndex) + require.Equal(t, expected.ModifyIndex, actual.ModifyIndex) + found = true + break + } + } + + require.True(t, found) + } + }) + + t.Run("Delete", func(t *testing.T) { + req, _ := http.NewRequest("DELETE", "/v1/acl/binding-rule/"+idMap["rule-other"]+"?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLBindingRuleCRUD(resp, req) + require.NoError(t, err) + delete(ruleMap, idMap["rule-other"]) + delete(idMap, "rule-other") + }) + + t.Run("Read", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/binding-rule/"+idMap["rule-test"]+"?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLBindingRuleCRUD(resp, req) + require.NoError(t, err) + rule, ok := raw.(*structs.ACLBindingRule) + require.True(t, ok) + require.Equal(t, ruleMap[idMap["rule-test"]], rule) + }) + }) + + testauth.InstallSessionToken(testSessionID, "token1", "default", "demo1", "abc123") + testauth.InstallSessionToken(testSessionID, "token2", "default", "demo2", "def456") + + t.Run("Login", func(t *testing.T) { + t.Run("Create Token 1", func(t *testing.T) { + loginInput := &structs.ACLLoginParams{ + AuthMethod: "test", + BearerToken: "token1", + Meta: map[string]string{"foo": "bar"}, + } + + req, _ := http.NewRequest("POST", "/v1/acl/login?token=root", jsonBody(loginInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLLogin(resp, req) + require.NoError(t, err) + + token, ok := obj.(*structs.ACLToken) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, token.AccessorID, 36) + require.Len(t, token.SecretID, 36) + require.Equal(t, `token created via login: {"foo":"bar"}`, token.Description) + require.True(t, token.Local) + require.Len(t, token.Policies, 0) + require.Len(t, token.Roles, 0) + require.Len(t, token.ServiceIdentities, 1) + require.Equal(t, "demo1", token.ServiceIdentities[0].ServiceName) + require.Len(t, token.ServiceIdentities[0].Datacenters, 0) + require.True(t, token.CreateIndex > 0) + require.Equal(t, token.CreateIndex, token.ModifyIndex) + require.NotNil(t, token.Hash) + require.NotEqual(t, token.Hash, []byte{}) + + idMap["token-test-1"] = token.AccessorID + tokenMap[token.AccessorID] = token + }) + t.Run("Create Token 2", func(t *testing.T) { + loginInput := &structs.ACLLoginParams{ + AuthMethod: "test", + BearerToken: "token2", + Meta: map[string]string{"blah": "woot"}, + } + + req, _ := http.NewRequest("POST", "/v1/acl/login?token=root", jsonBody(loginInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLLogin(resp, req) + require.NoError(t, err) + + token, ok := obj.(*structs.ACLToken) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, token.AccessorID, 36) + require.Len(t, token.SecretID, 36) + require.Equal(t, `token created via login: {"blah":"woot"}`, token.Description) + require.True(t, token.Local) + require.Len(t, token.Policies, 0) + require.Len(t, token.Roles, 0) + require.Len(t, token.ServiceIdentities, 1) + require.Equal(t, "demo2", token.ServiceIdentities[0].ServiceName) + require.Len(t, token.ServiceIdentities[0].Datacenters, 0) + require.True(t, token.CreateIndex > 0) + require.Equal(t, token.CreateIndex, token.ModifyIndex) + require.NotNil(t, token.Hash) + require.NotEqual(t, token.Hash, []byte{}) + + idMap["token-test-2"] = token.AccessorID + tokenMap[token.AccessorID] = token + }) + + t.Run("List Tokens by (incorrect) Method", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/tokens?token=root&authmethod=other", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLTokenList(resp, req) + require.NoError(t, err) + tokens, ok := raw.(structs.ACLTokenListStubs) + require.True(t, ok) + require.Len(t, tokens, 0) + }) + + t.Run("List Tokens by (correct) Method", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/tokens?token=root&authmethod=test", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLTokenList(resp, req) + require.NoError(t, err) + tokens, ok := raw.(structs.ACLTokenListStubs) + require.True(t, ok) + require.Len(t, tokens, 2) + + for tokenID, expected := range tokenMap { + found := false + for _, actual := range tokens { + if actual.AccessorID == tokenID { + require.Equal(t, expected.Description, actual.Description) + require.Equal(t, expected.Policies, actual.Policies) + require.Equal(t, expected.Roles, actual.Roles) + require.Equal(t, expected.ServiceIdentities, actual.ServiceIdentities) + require.Equal(t, expected.Local, actual.Local) + require.Equal(t, expected.CreateTime, actual.CreateTime) + require.Equal(t, expected.Hash, actual.Hash) + require.Equal(t, expected.CreateIndex, actual.CreateIndex) + require.Equal(t, expected.ModifyIndex, actual.ModifyIndex) + found = true + break + } + } + require.True(t, found) + } + }) + + t.Run("Logout", func(t *testing.T) { + tok := tokenMap[idMap["token-test-1"]] + req, _ := http.NewRequest("POST", "/v1/acl/logout?token="+tok.SecretID, nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLLogout(resp, req) + require.NoError(t, err) + }) + + t.Run("Token is gone after Logout", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/token/"+idMap["token-test-1"]+"?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLTokenCRUD(resp, req) + require.Error(t, err) + require.True(t, acl.IsErrNotFound(err), err.Error()) + }) + }) +} diff --git a/agent/agent.go b/agent/agent.go index 36ca1a2958..c1455a78d8 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1001,6 +1001,9 @@ func (a *Agent) consulConfig() (*consul.Config, error) { if a.config.ACLPolicyTTL != 0 { base.ACLPolicyTTL = a.config.ACLPolicyTTL } + if a.config.ACLRoleTTL != 0 { + base.ACLRoleTTL = a.config.ACLRoleTTL + } if a.config.ACLDefaultPolicy != "" { base.ACLDefaultPolicy = a.config.ACLDefaultPolicy } diff --git a/agent/config/builder.go b/agent/config/builder.go index 5ef7e35044..9cd5956ed0 100644 --- a/agent/config/builder.go +++ b/agent/config/builder.go @@ -702,6 +702,7 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) { ACLReplicationToken: b.stringValWithDefault(c.ACL.Tokens.Replication, b.stringVal(c.ACLReplicationToken)), ACLTokenTTL: b.durationValWithDefault("acl.token_ttl", c.ACL.TokenTTL, b.durationVal("acl_ttl", c.ACLTTL)), ACLPolicyTTL: b.durationVal("acl.policy_ttl", c.ACL.PolicyTTL), + ACLRoleTTL: b.durationVal("acl.role_ttl", c.ACL.RoleTTL), ACLToken: b.stringValWithDefault(c.ACL.Tokens.Default, b.stringVal(c.ACLToken)), ACLTokenReplication: b.boolValWithDefault(c.ACL.TokenReplication, b.boolValWithDefault(c.EnableACLReplication, enableTokenReplication)), ACLEnableTokenPersistence: b.boolValWithDefault(c.ACL.EnableTokenPersistence, false), diff --git a/agent/config/config.go b/agent/config/config.go index 958f53f26c..1e769118d6 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -635,6 +635,7 @@ type ACL struct { Enabled *bool `json:"enabled,omitempty" hcl:"enabled" mapstructure:"enabled"` TokenReplication *bool `json:"enable_token_replication,omitempty" hcl:"enable_token_replication" mapstructure:"enable_token_replication"` PolicyTTL *string `json:"policy_ttl,omitempty" hcl:"policy_ttl" mapstructure:"policy_ttl"` + RoleTTL *string `json:"role_ttl,omitempty" hcl:"role_ttl" mapstructure:"role_ttl"` TokenTTL *string `json:"token_ttl,omitempty" hcl:"token_ttl" mapstructure:"token_ttl"` DownPolicy *string `json:"down_policy,omitempty" hcl:"down_policy" mapstructure:"down_policy"` DefaultPolicy *string `json:"default_policy,omitempty" hcl:"default_policy" mapstructure:"default_policy"` diff --git a/agent/config/runtime.go b/agent/config/runtime.go index 96bbe1e3e4..dc9d567f81 100644 --- a/agent/config/runtime.go +++ b/agent/config/runtime.go @@ -155,6 +155,12 @@ type RuntimeConfig struct { // hcl: acl.token_ttl = "duration" ACLPolicyTTL time.Duration + // ACLRoleTTL is used to control the time-to-live of cached ACL roles. This has + // a major impact on performance. By default, it is set to 30 seconds. + // + // hcl: acl.role_ttl = "duration" + ACLRoleTTL time.Duration + // ACLToken is the default token used to make requests if a per-request // token is not provided. If not configured the 'anonymous' token is used. // diff --git a/agent/config/runtime_test.go b/agent/config/runtime_test.go index 0e22b95d73..18d9a7e354 100644 --- a/agent/config/runtime_test.go +++ b/agent/config/runtime_test.go @@ -2901,6 +2901,7 @@ func TestFullConfig(t *testing.T) { "enable_key_list_policy": false, "enable_token_persistence": true, "policy_ttl": "1123s", + "role_ttl": "9876s", "token_ttl": "3321s", "enable_token_replication" : true, "tokens" : { @@ -3464,6 +3465,7 @@ func TestFullConfig(t *testing.T) { enable_key_list_policy = false enable_token_persistence = true policy_ttl = "1123s" + role_ttl = "9876s" token_ttl = "3321s" enable_token_replication = true tokens = { @@ -4145,6 +4147,7 @@ func TestFullConfig(t *testing.T) { ACLReplicationToken: "5795983a", ACLTokenTTL: 3321 * time.Second, ACLPolicyTTL: 1123 * time.Second, + ACLRoleTTL: 9876 * time.Second, ACLToken: "418fdff1", ACLTokenReplication: true, AdvertiseAddrLAN: ipAddr("17.99.29.16"), @@ -4975,6 +4978,7 @@ func TestSanitize(t *testing.T) { "ACLMasterToken": "hidden", "ACLPolicyTTL": "0s", "ACLReplicationToken": "hidden", + "ACLRoleTTL": "0s", "ACLTokenReplication": false, "ACLTokenTTL": "0s", "ACLToken": "hidden", diff --git a/agent/consul/acl.go b/agent/consul/acl.go index 7553291fef..6e8130af27 100644 --- a/agent/consul/acl.go +++ b/agent/consul/acl.go @@ -4,10 +4,11 @@ import ( "fmt" "log" "os" + "sort" "sync" "time" - "github.com/armon/go-metrics" + metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" @@ -31,9 +32,16 @@ const ( // with all tokens in it. aclUpgradeBatchSize = 128 - // aclUpgradeRateLimit is the number of batch upgrade requests per second. + // aclUpgradeRateLimit is the number of batch upgrade requests per second allowed. aclUpgradeRateLimit rate.Limit = 1.0 + // aclTokenReapingRateLimit is the number of batch token reaping requests per second allowed. + aclTokenReapingRateLimit rate.Limit = 1.0 + + // aclTokenReapingBurst is the number of batch token reaping requests per second + // that can burst after a period of idleness. + aclTokenReapingBurst = 5 + // aclBatchDeleteSize is the number of deletions to send in a single batch operation. 4096 should produce a batch that is <150KB // in size but should be sufficiently large to handle 1 replication round in a single batch aclBatchDeleteSize = 4096 @@ -57,6 +65,10 @@ const ( // Maximum number of re-resolution requests to be made if the token is modified between // resolving the token and resolving its policies that would remove one of its policies. tokenPolicyResolutionMaxRetries = 5 + + // Maximum number of re-resolution requests to be made if the token is modified between + // resolving the token and resolving its roles that would remove one of its roles. + tokenRoleResolutionMaxRetries = 5 ) func minTTL(a time.Duration, b time.Duration) time.Duration { @@ -85,15 +97,16 @@ type ACLResolverDelegate interface { UseLegacyACLs() bool ResolveIdentityFromToken(token string) (bool, structs.ACLIdentity, error) ResolvePolicyFromID(policyID string) (bool, *structs.ACLPolicy, error) + ResolveRoleFromID(roleID string) (bool, *structs.ACLRole, error) RPC(method string, args interface{}, reply interface{}) error } -type policyTokenError struct { +type policyOrRoleTokenError struct { Err error token string } -func (e policyTokenError) Error() string { +func (e policyOrRoleTokenError) Error() string { return e.Err.Error() } @@ -121,19 +134,21 @@ type ACLResolverConfig struct { // Supports: // - Resolving tokens locally via the ACLResolverDelegate // - Resolving policies locally via the ACLResolverDelegate +// - Resolving roles locally via the ACLResolverDelegate // - Resolving legacy tokens remotely via a ACL.GetPolicy RPC // - Resolving tokens remotely via an ACL.TokenRead RPC // - Resolving policies remotely via an ACL.PolicyResolve RPC +// - Resolving roles remotely via an ACL.RoleResolve RPC // // Remote Resolution: -// Remote resolution can be done syncrhonously or asynchronously depending +// Remote resolution can be done synchronously or asynchronously depending // on the ACLDownPolicy in the Config passed to the resolver. // // When the down policy is set to async-cache and we have already cached values // then go routines will be spawned to perform the RPCs in the background // and then will update the cache with either the positive or negative result. // -// When the down policy is set to extend-cache or the token/policy is not already +// When the down policy is set to extend-cache or the token/policy/role is not already // cached then the same go routines are spawned to do the RPCs in the background. // However in this mode channels are created to receive the results of the RPC // and are registered with the resolver. Those channels are immediately read/blocked @@ -149,6 +164,7 @@ type ACLResolver struct { cache *structs.ACLCaches identityGroup singleflight.Group policyGroup singleflight.Group + roleGroup singleflight.Group legacyGroup singleflight.Group down acl.Authorizer @@ -431,25 +447,8 @@ func (r *ACLResolver) fetchAndCachePoliciesForIdentity(identity structs.ACLIdent return out, nil } - if acl.IsErrNotFound(err) { - // make sure to indicate that this identity is no longer valid within - // the cache - r.cache.PutIdentity(identity.SecretToken(), nil) - - // Do not touch the policy cache. Getting a top level ACL not found error - // only indicates that the secret token used in the request - // no longer exists - return nil, &policyTokenError{acl.ErrNotFound, identity.SecretToken()} - } - - if acl.IsErrPermissionDenied(err) { - // invalidate our ID cache so that identity resolution will take place - // again in the future - r.cache.RemoveIdentity(identity.SecretToken()) - - // Do not remove from the policy cache for permission denied - // what this does indicate is that our view of the token is out of date - return nil, &policyTokenError{acl.ErrPermissionDenied, identity.SecretToken()} + if handledErr := r.maybeHandleIdentityErrorDuringFetch(identity, err); handledErr != nil { + return nil, handledErr } // other RPC error - use cache if available @@ -475,6 +474,88 @@ func (r *ACLResolver) fetchAndCachePoliciesForIdentity(identity structs.ACLIdent return out, nil } +func (r *ACLResolver) fetchAndCacheRolesForIdentity(identity structs.ACLIdentity, roleIDs []string, cached map[string]*structs.RoleCacheEntry) (map[string]*structs.ACLRole, error) { + req := structs.ACLRoleBatchGetRequest{ + Datacenter: r.delegate.ACLDatacenter(false), + RoleIDs: roleIDs, + QueryOptions: structs.QueryOptions{ + Token: identity.SecretToken(), + AllowStale: true, + }, + } + + var resp structs.ACLRoleBatchResponse + err := r.delegate.RPC("ACL.RoleResolve", &req, &resp) + if err == nil { + out := make(map[string]*structs.ACLRole) + for _, role := range resp.Roles { + out[role.ID] = role + } + + for _, roleID := range roleIDs { + if role, ok := out[roleID]; ok { + r.cache.PutRole(roleID, role) + } else { + r.cache.PutRole(roleID, nil) + } + } + return out, nil + } + + if handledErr := r.maybeHandleIdentityErrorDuringFetch(identity, err); handledErr != nil { + return nil, handledErr + } + + // other RPC error - use cache if available + + extendCache := r.config.ACLDownPolicy == "extend-cache" || r.config.ACLDownPolicy == "async-cache" + + out := make(map[string]*structs.ACLRole) + insufficientCache := false + for _, roleID := range roleIDs { + if entry, ok := cached[roleID]; extendCache && ok { + r.cache.PutRole(roleID, entry.Role) + if entry.Role != nil { + out[roleID] = entry.Role + } + } else { + r.cache.PutRole(roleID, nil) + insufficientCache = true + } + } + + if insufficientCache { + return nil, ACLRemoteError{Err: err} + } + + return out, nil +} + +func (r *ACLResolver) maybeHandleIdentityErrorDuringFetch(identity structs.ACLIdentity, err error) error { + if acl.IsErrNotFound(err) { + // make sure to indicate that this identity is no longer valid within + // the cache + r.cache.PutIdentity(identity.SecretToken(), nil) + + // Do not touch the cache. Getting a top level ACL not found error + // only indicates that the secret token used in the request + // no longer exists + return &policyOrRoleTokenError{acl.ErrNotFound, identity.SecretToken()} + } + + if acl.IsErrPermissionDenied(err) { + // invalidate our ID cache so that identity resolution will take place + // again in the future + r.cache.RemoveIdentity(identity.SecretToken()) + + // Do not remove from the cache for permission denied + // what this does indicate is that our view of the token is out of date + return &policyOrRoleTokenError{acl.ErrPermissionDenied, identity.SecretToken()} + } + + return nil +} + func (r *ACLResolver) filterPoliciesByScope(policies structs.ACLPolicies) structs.ACLPolicies { var out structs.ACLPolicies for _, policy := range policies { @@ -496,7 +577,10 @@ func (r *ACLResolver) filterPoliciesByScope(policies structs.ACLPolicies) struct func (r *ACLResolver) resolvePoliciesForIdentity(identity structs.ACLIdentity) (structs.ACLPolicies, error) { policyIDs := identity.PolicyIDs() - if len(policyIDs) == 0 { + roleIDs := identity.RoleIDs() + serviceIdentities := identity.ServiceIdentityList() + + if len(policyIDs) == 0 && len(serviceIdentities) == 0 && len(roleIDs) == 0 { policy := identity.EmbeddedPolicy() if policy != nil { return []*structs.ACLPolicy{policy}, nil @@ -506,9 +590,116 @@ func (r *ACLResolver) resolvePoliciesForIdentity(identity structs.ACLIdentity) ( return nil, nil } + // Collect all of the roles tied to this token. + roles, err := r.collectRolesForIdentity(identity, roleIDs) + if err != nil { + return nil, err + } + + // Merge the policies and service identities across Token and Role fields. + for _, role := range roles { + for _, link := range role.Policies { + policyIDs = append(policyIDs, link.ID) + } + serviceIdentities = append(serviceIdentities, role.ServiceIdentities...) + } + + // Now deduplicate any policies or service identities that occur more than once. + policyIDs = dedupeStringSlice(policyIDs) + serviceIdentities = dedupeServiceIdentities(serviceIdentities) + + // Generate synthetic policies for all service identities in effect. + syntheticPolicies := r.synthesizePoliciesForServiceIdentities(serviceIdentities) + // For the new ACLs policy replication is mandatory for correct operation on servers. Therefore // we only attempt to resolve policies locally - policies := make([]*structs.ACLPolicy, 0, len(policyIDs)) + policies, err := r.collectPoliciesForIdentity(identity, policyIDs, len(syntheticPolicies)) + if err != nil { + return nil, err + } + + policies = append(policies, syntheticPolicies...) + filtered := r.filterPoliciesByScope(policies) + return filtered, nil +} + +func (r *ACLResolver) synthesizePoliciesForServiceIdentities(serviceIdentities []*structs.ACLServiceIdentity) []*structs.ACLPolicy { + if len(serviceIdentities) == 0 { + return nil + } + + syntheticPolicies := make([]*structs.ACLPolicy, 0, len(serviceIdentities)) + for _, s := range serviceIdentities { + syntheticPolicies = append(syntheticPolicies, s.SyntheticPolicy()) + } + + return syntheticPolicies +} + +func dedupeServiceIdentities(in []*structs.ACLServiceIdentity) []*structs.ACLServiceIdentity { + // From: https://github.com/golang/go/wiki/SliceTricks#in-place-deduplicate-comparable + + if len(in) <= 1 { + return in + } + + sort.Slice(in, func(i, j int) bool { + return in[i].ServiceName < in[j].ServiceName + }) + + j := 0 + for i := 1; i < len(in); i++ { + if in[j].ServiceName == in[i].ServiceName { + // Prefer increasing scope. + if len(in[j].Datacenters) == 0 || len(in[i].Datacenters) == 0 { + in[j].Datacenters = nil + } else { + in[j].Datacenters = mergeStringSlice(in[j].Datacenters, in[i].Datacenters) + } + continue + } + j++ + in[j] = in[i] + } + + // Discard the skipped items. + for i := j + 1; i < len(in); i++ { + in[i] = nil + } + + return in[:j+1] +} + +func mergeStringSlice(a, b []string) []string { + out := make([]string, 0, len(a)+len(b)) + out = append(out, a...) + out = append(out, b...) + return dedupeStringSlice(out) +} + +func dedupeStringSlice(in []string) []string { + // From: https://github.com/golang/go/wiki/SliceTricks#in-place-deduplicate-comparable + + if len(in) <= 1 { + return in + } + + sort.Strings(in) + + j := 0 + for i := 1; i < len(in); i++ { + if in[j] == in[i] { + continue + } + j++ + in[j] = in[i] + } + + return in[:j+1] +} + +func (r *ACLResolver) collectPoliciesForIdentity(identity structs.ACLIdentity, policyIDs []string, extraCap int) ([]*structs.ACLPolicy, error) { + policies := make([]*structs.ACLPolicy, 0, len(policyIDs)+extraCap) // Get all associated policies var missing []string @@ -538,7 +729,7 @@ func (r *ACLResolver) resolvePoliciesForIdentity(identity structs.ACLIdentity) ( } if entry.Policy == nil { - // this happens when we cache a negative response for the policies existence + // this happens when we cache a negative response for the policy's existence continue } @@ -552,7 +743,7 @@ func (r *ACLResolver) resolvePoliciesForIdentity(identity structs.ACLIdentity) ( // Hot-path if we have no missing or expired policies if len(missing)+len(expired) == 0 { - return r.filterPoliciesByScope(policies), nil + return policies, nil } hasMissing := len(missing) > 0 @@ -572,7 +763,7 @@ func (r *ACLResolver) resolvePoliciesForIdentity(identity structs.ACLIdentity) ( if !waitForResult { // waitForResult being false requires that all the policies were cached already policies = append(policies, expired...) - return r.filterPoliciesByScope(policies), nil + return policies, nil } res := <-waitChan @@ -589,7 +780,100 @@ func (r *ACLResolver) resolvePoliciesForIdentity(identity structs.ACLIdentity) ( } } - return r.filterPoliciesByScope(policies), nil + return policies, nil +} + +func (r *ACLResolver) resolveRolesForIdentity(identity structs.ACLIdentity) (structs.ACLRoles, error) { + return r.collectRolesForIdentity(identity, identity.RoleIDs()) +} + +func (r *ACLResolver) collectRolesForIdentity(identity structs.ACLIdentity, roleIDs []string) (structs.ACLRoles, error) { + if len(roleIDs) == 0 { + return nil, nil + } + + // For the new ACLs policy & role replication is mandatory for correct operation + // on servers. Therefore we only attempt to resolve roles locally + roles := make([]*structs.ACLRole, 0, len(roleIDs)) + + var missing []string + var expired []*structs.ACLRole + expCacheMap := make(map[string]*structs.RoleCacheEntry) + + for _, roleID := range roleIDs { + if done, role, err := r.delegate.ResolveRoleFromID(roleID); done { + if err != nil && !acl.IsErrNotFound(err) { + return nil, err + } + + if role != nil { + roles = append(roles, role) + } else { + r.logger.Printf("[WARN] acl: role %q not found for identity %q", roleID, identity.ID()) + } + + continue + } + + // create the missing list which we can execute an RPC to get all the missing roles at once + entry := r.cache.GetRole(roleID) + if entry == nil { + missing = append(missing, roleID) + continue + } + + if entry.Role == nil { + // this happens when we cache a negative response for the role's existence + continue + } + + if entry.Age() >= r.config.ACLRoleTTL { + expired = append(expired, entry.Role) + expCacheMap[roleID] = entry + } else { + roles = append(roles, entry.Role) + } + } + + // Hot-path if we have no missing or expired roles + if len(missing)+len(expired) == 0 { + return roles, nil + } + + hasMissing := len(missing) > 0 + + fetchIDs := missing + for _, role := range expired { + fetchIDs = append(fetchIDs, role.ID) + } + + waitChan := r.roleGroup.DoChan(identity.SecretToken(), func() (interface{}, error) { + roles, err := r.fetchAndCacheRolesForIdentity(identity, fetchIDs, expCacheMap) + return roles, err + }) + + waitForResult := hasMissing || r.config.ACLDownPolicy != "async-cache" + if !waitForResult { + // waitForResult being false requires that all the roles were cached already + roles = append(roles, expired...) + return roles, nil + } + + res := <-waitChan + + if res.Err != nil { + return nil, res.Err + } + + if res.Val != nil { + foundRoles := res.Val.(map[string]*structs.ACLRole) + + for _, role := range foundRoles { + roles = append(roles, role) + } + } + + return roles, nil } func (r *ACLResolver) resolveTokenToPolicies(token string) (structs.ACLPolicies, error) { @@ -608,6 +892,8 @@ func (r *ACLResolver) resolveTokenToIdentityAndPolicies(token string) (structs.A return nil, nil, err } else if identity == nil { return nil, nil, acl.ErrNotFound + } else if identity.IsExpired(time.Now()) { + return nil, nil, acl.ErrNotFound } lastIdentity = identity @@ -618,13 +904,52 @@ func (r *ACLResolver) resolveTokenToIdentityAndPolicies(token string) (structs.A } lastErr = err - if tokenErr, ok := err.(*policyTokenError); ok { + if tokenErr, ok := err.(*policyOrRoleTokenError); ok { if acl.IsErrNotFound(err) && tokenErr.token == identity.SecretToken() { // token was deleted while resolving policies return nil, nil, acl.ErrNotFound } - // other types of policyTokenErrors should cause retrying the whole token + // other types of policyOrRoleTokenErrors should cause retrying the whole token + // resolution process + } else { + return identity, nil, err + } + } + + return lastIdentity, nil, lastErr +} + +func (r *ACLResolver) resolveTokenToIdentityAndRoles(token string) (structs.ACLIdentity, structs.ACLRoles, error) { + var lastErr error + var lastIdentity structs.ACLIdentity + + for i := 0; i < tokenRoleResolutionMaxRetries; i++ { + // Resolve the token to an ACLIdentity + identity, err := r.resolveIdentityFromToken(token) + if err != nil { + return nil, nil, err + } else if identity == nil { + return nil, nil, acl.ErrNotFound + } else if identity.IsExpired(time.Now()) { + return nil, nil, acl.ErrNotFound + } + + lastIdentity = identity + + roles, err := r.resolveRolesForIdentity(identity) + if err == nil { + return identity, roles, nil + } + lastErr = err + + if tokenErr, ok := err.(*policyOrRoleTokenError); ok { + if acl.IsErrNotFound(err) && tokenErr.token == identity.SecretToken() { + // token was deleted while resolving roles + return nil, nil, acl.ErrNotFound + } + + // other types of policyOrRoleTokenErrors should cause retrying the whole token // resolution process } else { return identity, nil, err diff --git a/agent/consul/acl_authmethod.go b/agent/consul/acl_authmethod.go new file mode 100644 index 0000000000..ba3b3772da --- /dev/null +++ b/agent/consul/acl_authmethod.go @@ -0,0 +1,169 @@ +package consul + +import ( + "fmt" + + "github.com/hashicorp/consul/agent/consul/authmethod" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-bexpr" + + // register this as a builtin auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth" +) + +type authMethodValidatorEntry struct { + Validator authmethod.Validator + ModifyIndex uint64 // the raft index when this last changed +} + +// loadAuthMethodValidator returns an authmethod.Validator for the given auth +// method configuration. If the cache is up to date as-of the provided index +// then the cached version is returned, otherwise a new validator is created +// and cached. +func (s *Server) loadAuthMethodValidator(idx uint64, method *structs.ACLAuthMethod) (authmethod.Validator, error) { + if prevIdx, v, ok := s.getCachedAuthMethodValidator(method.Name); ok && idx <= prevIdx { + return v, nil + } + + v, err := authmethod.NewValidator(method) + if err != nil { + return nil, fmt.Errorf("auth method validator for %q could not be initialized: %v", method.Name, err) + } + + v = s.getOrReplaceAuthMethodValidator(method.Name, idx, v) + + return v, nil +} + +// getCachedAuthMethodValidator returns an AuthMethodValidator for +// the given name exclusively from the cache. If one is not found in the cache +// nil is returned. +func (s *Server) getCachedAuthMethodValidator(name string) (uint64, authmethod.Validator, bool) { + s.aclAuthMethodValidatorLock.RLock() + defer s.aclAuthMethodValidatorLock.RUnlock() + + if s.aclAuthMethodValidators != nil { + v, ok := s.aclAuthMethodValidators[name] + if ok { + return v.ModifyIndex, v.Validator, true + } + } + return 0, nil, false +} + +// getOrReplaceAuthMethodValidator updates the cached validator with the +// provided one UNLESS it has been updated by another goroutine in which case +// the updated one is returned. +func (s *Server) getOrReplaceAuthMethodValidator(name string, idx uint64, v authmethod.Validator) authmethod.Validator { + s.aclAuthMethodValidatorLock.Lock() + defer s.aclAuthMethodValidatorLock.Unlock() + + if s.aclAuthMethodValidators == nil { + s.aclAuthMethodValidators = make(map[string]*authMethodValidatorEntry) + } + + prev, ok := s.aclAuthMethodValidators[name] + if ok { + if prev.ModifyIndex >= idx { + return prev.Validator + } + } + + s.logger.Printf("[DEBUG] acl: updating cached auth method validator for %q", name) + + s.aclAuthMethodValidators[name] = &authMethodValidatorEntry{ + Validator: v, + ModifyIndex: idx, + } + return v +} + +// purgeAuthMethodValidators resets the cache of validators. +func (s *Server) purgeAuthMethodValidators() { + s.aclAuthMethodValidatorLock.Lock() + s.aclAuthMethodValidators = make(map[string]*authMethodValidatorEntry) + s.aclAuthMethodValidatorLock.Unlock() +} + +// evaluateRoleBindings evaluates all current binding rules associated with the +// given auth method against the verified data returned from the authentication +// process. +// +// A list of role links and service identities are returned. +func (s *Server) evaluateRoleBindings( + validator authmethod.Validator, + verifiedFields map[string]string, +) ([]*structs.ACLServiceIdentity, []structs.ACLTokenRoleLink, error) { + // Only fetch rules that are relevant for this method. + _, rules, err := s.fsm.State().ACLBindingRuleList(nil, validator.Name()) + if err != nil { + return nil, nil, err + } else if len(rules) == 0 { + return nil, nil, nil + } + + // Convert the fields into something suitable for go-bexpr. + selectableVars := validator.MakeFieldMapSelectable(verifiedFields) + + // Find all binding rules that match the provided fields. + var matchingRules []*structs.ACLBindingRule + for _, rule := range rules { + if doesBindingRuleMatch(rule, selectableVars) { + matchingRules = append(matchingRules, rule) + } + } + if len(matchingRules) == 0 { + return nil, nil, nil + } + + // For all matching rules compute the attributes of a token. + var ( + roleLinks []structs.ACLTokenRoleLink + serviceIdentities []*structs.ACLServiceIdentity + ) + for _, rule := range matchingRules { + bindName, valid, err := computeBindingRuleBindName(rule.BindType, rule.BindName, verifiedFields) + if err != nil { + return nil, nil, fmt.Errorf("cannot compute %q bind name for bind target: %v", rule.BindType, err) + } else if !valid { + return nil, nil, fmt.Errorf("computed %q bind name for bind target is invalid: %q", rule.BindType, bindName) + } + + switch rule.BindType { + case structs.BindingRuleBindTypeService: + serviceIdentities = append(serviceIdentities, &structs.ACLServiceIdentity{ + ServiceName: bindName, + }) + + case structs.BindingRuleBindTypeRole: + roleLinks = append(roleLinks, structs.ACLTokenRoleLink{ + Name: bindName, + }) + + default: + // skip unknown bind type; don't grant privileges + } + } + + return serviceIdentities, roleLinks, nil +} + +// doesBindingRuleMatch checks that a single binding rule matches the provided +// vars. +func doesBindingRuleMatch(rule *structs.ACLBindingRule, selectableVars interface{}) bool { + if rule.Selector == "" { + return true // catch-all + } + + eval, err := bexpr.CreateEvaluatorForType(rule.Selector, nil, selectableVars) + if err != nil { + return false // fails to match if selector is invalid + } + + result, err := eval.Evaluate(selectableVars) + if err != nil { + return false // fails to match if evaluation fails + } + + return result +} diff --git a/agent/consul/acl_authmethod_test.go b/agent/consul/acl_authmethod_test.go new file mode 100644 index 0000000000..45e3021e44 --- /dev/null +++ b/agent/consul/acl_authmethod_test.go @@ -0,0 +1,48 @@ +package consul + +import ( + "testing" + + "github.com/hashicorp/consul/agent/structs" + "github.com/stretchr/testify/require" +) + +func TestDoesBindingRuleMatch(t *testing.T) { + type matchable struct { + A string `bexpr:"a"` + C string `bexpr:"c"` + } + + for _, test := range []struct { + name string + selector string + details interface{} + ok bool + }{ + {"no fields", + "a==b", nil, false}, + {"1 term ok", + "a==b", &matchable{A: "b"}, true}, + {"1 term no field", + "a==b", &matchable{C: "d"}, false}, + {"1 term wrong value", + "a==b", &matchable{A: "z"}, false}, + {"2 terms ok", + "a==b and c==d", &matchable{A: "b", C: "d"}, true}, + {"2 terms one missing field", + "a==b and c==d", &matchable{A: "b"}, false}, + {"2 terms one wrong value", + "a==b and c==d", &matchable{A: "z", C: "d"}, false}, + /////////////////////////////// + {"no fields (no selectors)", + "", nil, true}, + {"1 term ok (no selectors)", + "", &matchable{A: "b"}, true}, + } { + t.Run(test.name, func(t *testing.T) { + rule := structs.ACLBindingRule{Selector: test.selector} + ok := doesBindingRuleMatch(&rule, test.details) + require.Equal(t, test.ok, ok) + }) + } +} diff --git a/agent/consul/acl_client.go b/agent/consul/acl_client.go index 06b9d78b5e..1951d43412 100644 --- a/agent/consul/acl_client.go +++ b/agent/consul/acl_client.go @@ -25,6 +25,8 @@ var clientACLCacheConfig *structs.ACLCachesConfig = &structs.ACLCachesConfig{ ParsedPolicies: 128, // Authorizers - number of compiled multi-policy effective policies that can be cached Authorizers: 256, + // Roles - number of ACL roles that can be cached + Roles: 128, } func (c *Client) UseLegacyACLs() bool { @@ -96,6 +98,11 @@ func (c *Client) ResolvePolicyFromID(policyID string) (bool, *structs.ACLPolicy, return false, nil, nil } +func (c *Client) ResolveRoleFromID(roleID string) (bool, *structs.ACLRole, error) { + // clients do no local role resolution at the moment + return false, nil, nil +} + func (c *Client) ResolveToken(token string) (acl.Authorizer, error) { return c.acls.ResolveToken(token) } diff --git a/agent/consul/acl_endpoint.go b/agent/consul/acl_endpoint.go index 5f9738c9cc..cc3b5da2e7 100644 --- a/agent/consul/acl_endpoint.go +++ b/agent/consul/acl_endpoint.go @@ -1,6 +1,8 @@ package consul import ( + "encoding/json" + "errors" "fmt" "io/ioutil" "os" @@ -8,13 +10,15 @@ import ( "regexp" "time" - "github.com/armon/go-metrics" + metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/authmethod" "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/lib" - "github.com/hashicorp/go-memdb" - "github.com/hashicorp/go-uuid" + "github.com/hashicorp/go-bexpr" + memdb "github.com/hashicorp/go-memdb" + uuid "github.com/hashicorp/go-uuid" ) const ( @@ -24,7 +28,13 @@ const ( ) // Regex for matching -var validPolicyName = regexp.MustCompile(`^[A-Za-z0-9\-_]{1,128}$`) +var ( + validPolicyName = regexp.MustCompile(`^[A-Za-z0-9\-_]{1,128}$`) + validServiceIdentityName = regexp.MustCompile(`^[a-z0-9]([a-z0-9\-_]*[a-z0-9])?$`) + serviceIdentityNameMaxLength = 256 + validRoleName = regexp.MustCompile(`^[A-Za-z0-9\-_]{1,256}$`) + validAuthMethod = regexp.MustCompile(`^[A-Za-z0-9\-_]{1,128}$`) +) // ACL endpoint is used to manipulate ACLs type ACL struct { @@ -221,6 +231,10 @@ func (a *ACL) TokenRead(args *structs.ACLTokenGetRequest, reply *structs.ACLToke index, token, err = state.ACLTokenGetBySecret(ws, args.TokenID) } + if token != nil && token.IsExpired(time.Now()) { + token = nil + } + if err != nil { return err } @@ -256,7 +270,7 @@ func (a *ACL) TokenClone(args *structs.ACLTokenSetRequest, reply *structs.ACLTok _, token, err := a.srv.fsm.State().ACLTokenGetByAccessor(nil, args.ACLToken.AccessorID) if err != nil { return err - } else if token == nil { + } else if token == nil || token.IsExpired(time.Now()) { return acl.ErrNotFound } else if !a.srv.InACLDatacenter() && !token.Local { // global token writes must be forwarded to the primary DC @@ -264,6 +278,10 @@ func (a *ACL) TokenClone(args *structs.ACLTokenSetRequest, reply *structs.ACLTok return a.srv.forwardDC("ACL.TokenClone", a.srv.config.ACLDatacenter, args, reply) } + if token.AuthMethod != "" { + return fmt.Errorf("Cannot clone a token created from an auth method") + } + if token.Rules != "" { return fmt.Errorf("Cannot clone a legacy ACL with this endpoint") } @@ -271,9 +289,11 @@ func (a *ACL) TokenClone(args *structs.ACLTokenSetRequest, reply *structs.ACLTok cloneReq := structs.ACLTokenSetRequest{ Datacenter: args.Datacenter, ACLToken: structs.ACLToken{ - Policies: token.Policies, - Local: token.Local, - Description: token.Description, + Policies: token.Policies, + ServiceIdentities: token.ServiceIdentities, + Local: token.Local, + Description: token.Description, + ExpirationTime: token.ExpirationTime, }, WriteRequest: args.WriteRequest, } @@ -313,7 +333,7 @@ func (a *ACL) TokenSet(args *structs.ACLTokenSetRequest, reply *structs.ACLToken return a.tokenSetInternal(args, reply, false) } -func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs.ACLToken, upgrade bool) error { +func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs.ACLToken, fromLogin bool) error { token := &args.ACLToken if !a.srv.LocalTokensEnabled() { @@ -342,6 +362,47 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. } token.CreateTime = time.Now() + + if fromLogin { + if token.AuthMethod == "" { + return fmt.Errorf("AuthMethod field is required during Login") + } + if !token.Local { + return fmt.Errorf("Cannot create Global token via Login") + } + } else { + if token.AuthMethod != "" { + return fmt.Errorf("AuthMethod field is disallowed outside of Login") + } + } + + // Ensure an ExpirationTTL is valid if provided. + if token.ExpirationTTL != 0 { + if token.ExpirationTTL < 0 { + return fmt.Errorf("Token Expiration TTL '%s' should be > 0", token.ExpirationTTL) + } + if token.HasExpirationTime() { + return fmt.Errorf("Token Expiration TTL and Expiration Time cannot both be set") + } + + token.ExpirationTime = timePointer(token.CreateTime.Add(token.ExpirationTTL)) + token.ExpirationTTL = 0 + } + + if token.HasExpirationTime() { + if token.CreateTime.After(*token.ExpirationTime) { + return fmt.Errorf("ExpirationTime cannot be before CreateTime") + } + + expiresIn := token.ExpirationTime.Sub(token.CreateTime) + if expiresIn > a.srv.config.ACLTokenMaxExpirationTTL { + return fmt.Errorf("ExpirationTime cannot be more than %s in the future (was %s)", + a.srv.config.ACLTokenMaxExpirationTTL, expiresIn) + } else if expiresIn < a.srv.config.ACLTokenMinExpirationTTL { + return fmt.Errorf("ExpirationTime cannot be less than %s in the future (was %s)", + a.srv.config.ACLTokenMinExpirationTTL, expiresIn) + } + } } else { // Token Update if _, err := uuid.ParseUUID(token.AccessorID); err != nil { @@ -365,7 +426,7 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. if err != nil { return fmt.Errorf("Failed to lookup the acl token %q: %v", token.AccessorID, err) } - if existing == nil { + if existing == nil || existing.IsExpired(time.Now()) { return fmt.Errorf("Cannot find token %q", token.AccessorID) } if token.SecretID == "" { @@ -379,11 +440,25 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. return fmt.Errorf("cannot toggle local mode of %s", token.AccessorID) } - if upgrade { - token.CreateTime = time.Now() - } else { - token.CreateTime = existing.CreateTime + if token.AuthMethod == "" { + token.AuthMethod = existing.AuthMethod + } else if token.AuthMethod != existing.AuthMethod { + return fmt.Errorf("Cannot change AuthMethod of %s", token.AccessorID) } + + if token.ExpirationTTL != 0 { + return fmt.Errorf("Cannot change expiration time of %s", token.AccessorID) + } + + if !token.HasExpirationTime() { + token.ExpirationTime = existing.ExpirationTime + } else if !existing.HasExpirationTime() { + return fmt.Errorf("Cannot change expiration time of %s", token.AccessorID) + } else if !token.ExpirationTime.Equal(*existing.ExpirationTime) { + return fmt.Errorf("Cannot change expiration time of %s", token.AccessorID) + } + + token.CreateTime = existing.CreateTime } policyIDs := make(map[string]struct{}) @@ -413,6 +488,46 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. } token.Policies = policies + roleIDs := make(map[string]struct{}) + var roles []structs.ACLTokenRoleLink + + // Validate all the role names and convert them to role IDs. + for _, link := range token.Roles { + if link.ID == "" { + _, role, err := state.ACLRoleGetByName(nil, link.Name) + if err != nil { + return fmt.Errorf("Error looking up role for name %q: %v", link.Name, err) + } + if role == nil { + return fmt.Errorf("No such ACL role with name %q", link.Name) + } + link.ID = role.ID + } + + // Do not store the role name within raft/memdb as the role could be renamed in the future. + link.Name = "" + + // dedup role links by id + if _, ok := roleIDs[link.ID]; !ok { + roles = append(roles, link) + roleIDs[link.ID] = struct{}{} + } + } + token.Roles = roles + + for _, svcid := range token.ServiceIdentities { + if svcid.ServiceName == "" { + return fmt.Errorf("Service identity is missing the service name field on this token") + } + if token.Local && len(svcid.Datacenters) > 0 { + return fmt.Errorf("Service identity %q cannot specify a list of datacenters on a local token", svcid.ServiceName) + } + if !isValidServiceIdentityName(svcid.ServiceName) { + return fmt.Errorf("Service identity %q has an invalid name. Only alphanumeric characters, '-' and '_' are allowed", svcid.ServiceName) + } + } + token.ServiceIdentities = dedupeServiceIdentities(token.ServiceIdentities) + if token.Rules != "" { return fmt.Errorf("Rules cannot be specified for this token") } @@ -440,6 +555,7 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. return respErr } + // Don't check expiration times here as it doesn't really matter. if _, updatedToken, err := a.srv.fsm.State().ACLTokenGetByAccessor(nil, token.AccessorID); err == nil && token != nil { *reply = *updatedToken } else { @@ -449,6 +565,62 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. return nil } +func validateBindingRuleBindName(bindType, bindName string, availableFields []string) (bool, error) { + if bindType == "" || bindName == "" { + return false, nil + } + + fakeVarMap := make(map[string]string) + for _, v := range availableFields { + fakeVarMap[v] = "fake" + } + + _, valid, err := computeBindingRuleBindName(bindType, bindName, fakeVarMap) + if err != nil { + return false, err + } + return valid, nil +} + +// computeBindingRuleBindName processes the HIL for the provided bind type+name +// using the verified fields. +// +// - If the HIL is invalid ("", false, AN_ERROR) is returned. +// - If the computed name is not valid for the type ("INVALID_NAME", false, nil) is returned. +// - If the computed name is valid for the type ("VALID_NAME", true, nil) is returned. +func computeBindingRuleBindName(bindType, bindName string, verifiedFields map[string]string) (string, bool, error) { + bindName, err := InterpolateHIL(bindName, verifiedFields) + if err != nil { + return "", false, err + } + + valid := false + + switch bindType { + case structs.BindingRuleBindTypeService: + valid = isValidServiceIdentityName(bindName) + + case structs.BindingRuleBindTypeRole: + valid = validRoleName.MatchString(bindName) + + default: + return "", false, fmt.Errorf("unknown binding rule bind type: %s", bindType) + } + + return bindName, valid, nil +} + +// isValidServiceIdentityName returns true if the provided name can be used as +// an ACLServiceIdentity ServiceName. This is more restrictive than standard +// catalog registration, which basically takes the view that "everything is +// valid". +func isValidServiceIdentityName(name string) bool { + if len(name) < 1 || len(name) > serviceIdentityNameMaxLength { + return false + } + return validServiceIdentityName.MatchString(name) +} + func (a *ACL) TokenDelete(args *structs.ACLTokenDeleteRequest, reply *string) error { if err := a.aclPreCheck(); err != nil { return err @@ -490,6 +662,8 @@ func (a *ACL) TokenDelete(args *structs.ACLTokenDeleteRequest, reply *string) er return fmt.Errorf("Deletion of the request's authorization token is not permitted") } + // No need to check expiration time because it's being deleted. + if !a.srv.InACLDatacenter() && !token.Local { args.Datacenter = a.srv.config.ACLDatacenter return a.srv.forwardDC("ACL.TokenDelete", a.srv.config.ACLDatacenter, args, reply) @@ -548,13 +722,18 @@ func (a *ACL) TokenList(args *structs.ACLTokenListRequest, reply *structs.ACLTok return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, tokens, err := state.ACLTokenList(ws, args.IncludeLocal, args.IncludeGlobal, args.Policy) + index, tokens, err := state.ACLTokenList(ws, args.IncludeLocal, args.IncludeGlobal, args.Policy, args.Role, args.AuthMethod) if err != nil { return err } + now := time.Now() + stubs := make([]*structs.ACLTokenListStub, 0, len(tokens)) for _, token := range tokens { + if token.IsExpired(now) { + continue + } stubs = append(stubs, token.Stub()) } reply.Index, reply.Tokens = index, stubs @@ -589,6 +768,8 @@ func (a *ACL) TokenBatchRead(args *structs.ACLTokenBatchGetRequest, reply *struc return err } + // This RPC is used for replication, so don't filter out expired tokens here. + a.srv.filterACLWithAuthorizer(rule, &tokens) reply.Index, reply.Tokens = index, tokens @@ -966,3 +1147,906 @@ func (a *ACL) ReplicationStatus(args *structs.DCSpecificRequest, a.srv.aclReplicationStatusLock.RUnlock() return nil } + +func timePointer(t time.Time) *time.Time { + return &t +} + +func (a *ACL) RoleRead(args *structs.ACLRoleGetRequest, reply *structs.ACLRoleResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if done, err := a.srv.forward("ACL.RoleRead", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + var ( + index uint64 + role *structs.ACLRole + err error + ) + if args.RoleID != "" { + index, role, err = state.ACLRoleGetByID(ws, args.RoleID) + } else { + index, role, err = state.ACLRoleGetByName(ws, args.RoleName) + } + + if err != nil { + return err + } + + reply.Index, reply.Role = index, role + return nil + }) +} + +func (a *ACL) RoleBatchRead(args *structs.ACLRoleBatchGetRequest, reply *structs.ACLRoleBatchResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if done, err := a.srv.forward("ACL.RoleBatchRead", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, roles, err := state.ACLRoleBatchGet(ws, args.RoleIDs) + if err != nil { + return err + } + + reply.Index, reply.Roles = index, roles + return nil + }) +} + +func (a *ACL) RoleSet(args *structs.ACLRoleSetRequest, reply *structs.ACLRole) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.InACLDatacenter() { + args.Datacenter = a.srv.config.ACLDatacenter + } + + if done, err := a.srv.forward("ACL.RoleSet", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "role", "upsert"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + role := &args.Role + state := a.srv.fsm.State() + + // Almost all of the checks here are also done in the state store. However, + // we want to prevent the raft operations when we know they are going to fail + // so we still do them here. + + // ensure a name is set + if role.Name == "" { + return fmt.Errorf("Invalid Role: no Name is set") + } + + if !validRoleName.MatchString(role.Name) { + return fmt.Errorf("Invalid Role: invalid Name. Only alphanumeric characters, '-' and '_' are allowed") + } + + if role.ID == "" { + // with no role ID one will be generated + var err error + + role.ID, err = lib.GenerateUUID(a.srv.checkRoleUUID) + if err != nil { + return err + } + + // validate the name is unique + if _, existing, err := state.ACLRoleGetByName(nil, role.Name); err != nil { + return fmt.Errorf("acl role lookup by name failed: %v", err) + } else if existing != nil { + return fmt.Errorf("Invalid Role: A Role with Name %q already exists", role.Name) + } + } else { + if _, err := uuid.ParseUUID(role.ID); err != nil { + return fmt.Errorf("Role ID invalid UUID") + } + + // Verify the role exists + _, existing, err := state.ACLRoleGetByID(nil, role.ID) + if err != nil { + return fmt.Errorf("acl role lookup failed: %v", err) + } else if existing == nil { + return fmt.Errorf("cannot find role %s", role.ID) + } + + if existing.Name != role.Name { + if _, nameMatch, err := state.ACLRoleGetByName(nil, role.Name); err != nil { + return fmt.Errorf("acl role lookup by name failed: %v", err) + } else if nameMatch != nil { + return fmt.Errorf("Invalid Role: A role with name %q already exists", role.Name) + } + } + } + + policyIDs := make(map[string]struct{}) + var policies []structs.ACLRolePolicyLink + + // Validate all the policy names and convert them to policy IDs + for _, link := range role.Policies { + if link.ID == "" { + _, policy, err := state.ACLPolicyGetByName(nil, link.Name) + if err != nil { + return fmt.Errorf("Error looking up policy for name %q: %v", link.Name, err) + } + if policy == nil { + return fmt.Errorf("No such ACL policy with name %q", link.Name) + } + link.ID = policy.ID + } + + // Do not store the policy name within raft/memdb as the policy could be renamed in the future. + link.Name = "" + + // dedup policy links by id + if _, ok := policyIDs[link.ID]; !ok { + policies = append(policies, link) + policyIDs[link.ID] = struct{}{} + } + } + role.Policies = policies + + for _, svcid := range role.ServiceIdentities { + if svcid.ServiceName == "" { + return fmt.Errorf("Service identity is missing the service name field on this role") + } + if !isValidServiceIdentityName(svcid.ServiceName) { + return fmt.Errorf("Service identity %q has an invalid name. Only alphanumeric characters, '-' and '_' are allowed", svcid.ServiceName) + } + } + role.ServiceIdentities = dedupeServiceIdentities(role.ServiceIdentities) + + // calculate the hash for this role + role.SetHash(true) + + req := &structs.ACLRoleBatchSetRequest{ + Roles: structs.ACLRoles{role}, + } + + resp, err := a.srv.raftApply(structs.ACLRoleSetRequestType, req) + if err != nil { + return fmt.Errorf("Failed to apply role upsert request: %v", err) + } + + // Remove from the cache to prevent stale cache usage + a.srv.acls.cache.RemoveRole(role.ID) + + if respErr, ok := resp.(error); ok { + return respErr + } + + if _, role, err := a.srv.fsm.State().ACLRoleGetByID(nil, role.ID); err == nil && role != nil { + *reply = *role + } + + return nil +} + +func (a *ACL) RoleDelete(args *structs.ACLRoleDeleteRequest, reply *string) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.InACLDatacenter() { + args.Datacenter = a.srv.config.ACLDatacenter + } + + if done, err := a.srv.forward("ACL.RoleDelete", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "role", "delete"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + _, role, err := a.srv.fsm.State().ACLRoleGetByID(nil, args.RoleID) + if err != nil { + return err + } + + if role == nil { + return nil + } + + req := structs.ACLRoleBatchDeleteRequest{ + RoleIDs: []string{args.RoleID}, + } + + resp, err := a.srv.raftApply(structs.ACLRoleDeleteRequestType, &req) + if err != nil { + return fmt.Errorf("Failed to apply role delete request: %v", err) + } + + a.srv.acls.cache.RemoveRole(role.ID) + + if respErr, ok := resp.(error); ok { + return respErr + } + + if role != nil { + *reply = role.Name + } + + return nil +} + +func (a *ACL) RoleList(args *structs.ACLRoleListRequest, reply *structs.ACLRoleListResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if done, err := a.srv.forward("ACL.RoleList", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, roles, err := state.ACLRoleList(ws, args.Policy) + if err != nil { + return err + } + + reply.Index, reply.Roles = index, roles + return nil + }) +} + +// RoleResolve is used to retrieve a subset of the roles associated with a given token +// The role ids in the args simply act as a filter on the role set assigned to the token +func (a *ACL) RoleResolve(args *structs.ACLRoleBatchGetRequest, reply *structs.ACLRoleBatchResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if done, err := a.srv.forward("ACL.RoleResolve", args, args, reply); done { + return err + } + + // get full list of roles for this token + identity, roles, err := a.srv.acls.resolveTokenToIdentityAndRoles(args.Token) + if err != nil { + return err + } + + idMap := make(map[string]*structs.ACLRole) + for _, roleID := range identity.RoleIDs() { + idMap[roleID] = nil + } + for _, role := range roles { + idMap[role.ID] = role + } + + for _, roleID := range args.RoleIDs { + if role, ok := idMap[roleID]; ok { + // only add non-deleted roles + if role != nil { + reply.Roles = append(reply.Roles, role) + } + } else { + // send a permission denied to indicate that the request included + // role ids not associated with this token + return acl.ErrPermissionDenied + } + } + + a.srv.setQueryMeta(&reply.QueryMeta) + + return nil +} + +var errAuthMethodsRequireTokenReplication = errors.New("Token replication is required for auth methods to function") + +func (a *ACL) BindingRuleRead(args *structs.ACLBindingRuleGetRequest, reply *structs.ACLBindingRuleResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.BindingRuleRead", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, rule, err := state.ACLBindingRuleGetByID(ws, args.BindingRuleID) + + if err != nil { + return err + } + + reply.Index, reply.BindingRule = index, rule + return nil + }) +} + +func (a *ACL) BindingRuleSet(args *structs.ACLBindingRuleSetRequest, reply *structs.ACLBindingRule) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.BindingRuleSet", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "bindingrule", "upsert"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + rule := &args.BindingRule + state := a.srv.fsm.State() + + if rule.ID == "" { + // with no binding rule ID one will be generated + var err error + + rule.ID, err = lib.GenerateUUID(a.srv.checkBindingRuleUUID) + if err != nil { + return err + } + } else { + if _, err := uuid.ParseUUID(rule.ID); err != nil { + return fmt.Errorf("Binding Rule ID invalid UUID") + } + + // Verify the role exists + _, existing, err := state.ACLBindingRuleGetByID(nil, rule.ID) + if err != nil { + return fmt.Errorf("acl binding rule lookup failed: %v", err) + } else if existing == nil { + return fmt.Errorf("cannot find binding rule %s", rule.ID) + } + + if rule.AuthMethod == "" { + rule.AuthMethod = existing.AuthMethod + } else if existing.AuthMethod != rule.AuthMethod { + return fmt.Errorf("the AuthMethod field of an Binding Rule is immutable") + } + } + + if rule.AuthMethod == "" { + return fmt.Errorf("Invalid Binding Rule: no AuthMethod is set") + } + + methodIdx, method, err := state.ACLAuthMethodGetByName(nil, rule.AuthMethod) + if err != nil { + return fmt.Errorf("acl auth method lookup failed: %v", err) + } else if method == nil { + return fmt.Errorf("cannot find auth method with name %q", rule.AuthMethod) + } + validator, err := a.srv.loadAuthMethodValidator(methodIdx, method) + if err != nil { + return err + } + + if rule.Selector != "" { + selectableVars := validator.MakeFieldMapSelectable(map[string]string{}) + _, err := bexpr.CreateEvaluatorForType(rule.Selector, nil, selectableVars) + if err != nil { + return fmt.Errorf("invalid Binding Rule: Selector is invalid: %v", err) + } + } + + if rule.BindType == "" { + return fmt.Errorf("Invalid Binding Rule: no BindType is set") + } + + if rule.BindName == "" { + return fmt.Errorf("Invalid Binding Rule: no BindName is set") + } + + switch rule.BindType { + case structs.BindingRuleBindTypeService: + case structs.BindingRuleBindTypeRole: + default: + return fmt.Errorf("Invalid Binding Rule: unknown BindType %q", rule.BindType) + } + + if valid, err := validateBindingRuleBindName(rule.BindType, rule.BindName, validator.AvailableFields()); err != nil { + return fmt.Errorf("Invalid Binding Rule: invalid BindName: %v", err) + } else if !valid { + return fmt.Errorf("Invalid Binding Rule: invalid BindName") + } + + req := &structs.ACLBindingRuleBatchSetRequest{ + BindingRules: structs.ACLBindingRules{rule}, + } + + resp, err := a.srv.raftApply(structs.ACLBindingRuleSetRequestType, req) + if err != nil { + return fmt.Errorf("Failed to apply binding rule upsert request: %v", err) + } + + if respErr, ok := resp.(error); ok { + return respErr + } + + if _, rule, err := a.srv.fsm.State().ACLBindingRuleGetByID(nil, rule.ID); err == nil && rule != nil { + *reply = *rule + } + + return nil +} + +func (a *ACL) BindingRuleDelete(args *structs.ACLBindingRuleDeleteRequest, reply *bool) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.BindingRuleDelete", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "bindingrule", "delete"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + _, rule, err := a.srv.fsm.State().ACLBindingRuleGetByID(nil, args.BindingRuleID) + if err != nil { + return err + } + + if rule == nil { + return nil + } + + req := structs.ACLBindingRuleBatchDeleteRequest{ + BindingRuleIDs: []string{args.BindingRuleID}, + } + + resp, err := a.srv.raftApply(structs.ACLBindingRuleDeleteRequestType, &req) + if err != nil { + return fmt.Errorf("Failed to apply binding rule delete request: %v", err) + } + + if respErr, ok := resp.(error); ok { + return respErr + } + + *reply = true + + return nil +} + +func (a *ACL) BindingRuleList(args *structs.ACLBindingRuleListRequest, reply *structs.ACLBindingRuleListResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.BindingRuleList", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, rules, err := state.ACLBindingRuleList(ws, args.AuthMethod) + if err != nil { + return err + } + + reply.Index, reply.BindingRules = index, rules + return nil + }) +} + +func (a *ACL) AuthMethodRead(args *structs.ACLAuthMethodGetRequest, reply *structs.ACLAuthMethodResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.AuthMethodRead", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, method, err := state.ACLAuthMethodGetByName(ws, args.AuthMethodName) + + if err != nil { + return err + } + + reply.Index, reply.AuthMethod = index, method + return nil + }) +} + +func (a *ACL) AuthMethodSet(args *structs.ACLAuthMethodSetRequest, reply *structs.ACLAuthMethod) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.AuthMethodSet", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "authmethod", "upsert"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + method := &args.AuthMethod + state := a.srv.fsm.State() + + // ensure a name is set + if method.Name == "" { + return fmt.Errorf("Invalid Auth Method: no Name is set") + } + if !validAuthMethod.MatchString(method.Name) { + return fmt.Errorf("Invalid Auth Method: invalid Name. Only alphanumeric characters, '-' and '_' are allowed") + } + + // Check to see if the method exists first. + _, existing, err := state.ACLAuthMethodGetByName(nil, method.Name) + if err != nil { + return fmt.Errorf("acl auth method lookup failed: %v", err) + } + + if existing != nil { + if method.Type == "" { + method.Type = existing.Type + } else if existing.Type != method.Type { + return fmt.Errorf("the Type field of an Auth Method is immutable") + } + } + + if !authmethod.IsRegisteredType(method.Type) { + return fmt.Errorf("Invalid Auth Method: Type should be one of: %v", authmethod.Types()) + } + + // Instantiate a validator but do not cache it yet. This will validate the + // configuration. + if _, err := authmethod.NewValidator(method); err != nil { + return fmt.Errorf("Invalid Auth Method: %v", err) + } + + req := &structs.ACLAuthMethodBatchSetRequest{ + AuthMethods: structs.ACLAuthMethods{method}, + } + + resp, err := a.srv.raftApply(structs.ACLAuthMethodSetRequestType, req) + if err != nil { + return fmt.Errorf("Failed to apply auth method upsert request: %v", err) + } + + if respErr, ok := resp.(error); ok { + return respErr + } + + if _, method, err := a.srv.fsm.State().ACLAuthMethodGetByName(nil, method.Name); err == nil && method != nil { + *reply = *method + } + + return nil +} + +func (a *ACL) AuthMethodDelete(args *structs.ACLAuthMethodDeleteRequest, reply *bool) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.AuthMethodDelete", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "authmethod", "delete"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + _, method, err := a.srv.fsm.State().ACLAuthMethodGetByName(nil, args.AuthMethodName) + if err != nil { + return err + } + + if method == nil { + return nil + } + + req := structs.ACLAuthMethodBatchDeleteRequest{ + AuthMethodNames: []string{args.AuthMethodName}, + } + + resp, err := a.srv.raftApply(structs.ACLAuthMethodDeleteRequestType, &req) + if err != nil { + return fmt.Errorf("Failed to apply auth method delete request: %v", err) + } + + if respErr, ok := resp.(error); ok { + return respErr + } + + *reply = true + + return nil +} + +func (a *ACL) AuthMethodList(args *structs.ACLAuthMethodListRequest, reply *structs.ACLAuthMethodListResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.AuthMethodList", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, methods, err := state.ACLAuthMethodList(ws) + if err != nil { + return err + } + + var stubs structs.ACLAuthMethodListStubs + for _, method := range methods { + stubs = append(stubs, method.Stub()) + } + + reply.Index, reply.AuthMethods = index, stubs + return nil + }) +} + +func (a *ACL) Login(args *structs.ACLLoginRequest, reply *structs.ACLToken) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if args.Token != "" { // This shouldn't happen. + return errors.New("do not provide a token when logging in") + } + + if done, err := a.srv.forward("ACL.Login", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "login"}, time.Now()) + + auth := args.Auth + + // 1. take args.Data.AuthMethod to get an AuthMethod Validator + idx, method, err := a.srv.fsm.State().ACLAuthMethodGetByName(nil, auth.AuthMethod) + if err != nil { + return err + } else if method == nil { + return acl.ErrNotFound + } + + validator, err := a.srv.loadAuthMethodValidator(idx, method) + if err != nil { + return err + } + + // 2. Send args.Data.BearerToken to method validator and get back a fields map + verifiedFields, err := validator.ValidateLogin(auth.BearerToken) + if err != nil { + return err + } + + // 3. send map through role bindings + serviceIdentities, roleLinks, err := a.srv.evaluateRoleBindings(validator, verifiedFields) + if err != nil { + return err + } + + if len(serviceIdentities) == 0 && len(roleLinks) == 0 { + return acl.ErrPermissionDenied + } + + description := "token created via login" + loginMeta, err := encodeLoginMeta(auth.Meta) + if err != nil { + return err + } + if loginMeta != "" { + description += ": " + loginMeta + } + + // 4. create token + createReq := structs.ACLTokenSetRequest{ + Datacenter: args.Datacenter, + ACLToken: structs.ACLToken{ + Description: description, + Local: true, + AuthMethod: auth.AuthMethod, + ServiceIdentities: serviceIdentities, + Roles: roleLinks, + }, + WriteRequest: args.WriteRequest, + } + + // 5. return token information like a TokenCreate would + return a.tokenSetInternal(&createReq, reply, true) +} + +func encodeLoginMeta(meta map[string]string) (string, error) { + if len(meta) == 0 { + return "", nil + } + + d, err := json.Marshal(meta) + if err != nil { + return "", err + } + return string(d), nil +} + +func (a *ACL) Logout(args *structs.ACLLogoutRequest, reply *bool) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if args.Token == "" { + return acl.ErrNotFound + } + + if done, err := a.srv.forward("ACL.Logout", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "logout"}, time.Now()) + + _, token, err := a.srv.fsm.State().ACLTokenGetBySecret(nil, args.Token) + if err != nil { + return err + + } else if token == nil { + return acl.ErrNotFound + + } else if token.AuthMethod == "" { + // Can't "logout" of a token that wasn't a result of login. + return acl.ErrPermissionDenied + + } else if !a.srv.InACLDatacenter() && !token.Local { + // global token writes must be forwarded to the primary DC + args.Datacenter = a.srv.config.ACLDatacenter + return a.srv.forwardDC("ACL.Logout", a.srv.config.ACLDatacenter, args, reply) + } + + // No need to check expiration time because it's being deleted. + + req := &structs.ACLTokenBatchDeleteRequest{ + TokenIDs: []string{token.AccessorID}, + } + + resp, err := a.srv.raftApply(structs.ACLTokenDeleteRequestType, req) + if err != nil { + return fmt.Errorf("Failed to apply token delete request: %v", err) + } + + // Purge the identity from the cache to prevent using the previous definition of the identity + if token != nil { + a.srv.acls.cache.RemoveIdentity(token.SecretID) + } + + if respErr, ok := resp.(error); ok { + return respErr + } + + *reply = true + + return nil +} diff --git a/agent/consul/acl_endpoint_legacy.go b/agent/consul/acl_endpoint_legacy.go index 48867fdb3a..16379faa2e 100644 --- a/agent/consul/acl_endpoint_legacy.go +++ b/agent/consul/acl_endpoint_legacy.go @@ -93,8 +93,9 @@ func aclApplyInternal(srv *Server, args *structs.ACLRequest, reply *string) erro return fmt.Errorf("Invalid ACL Type") } + // No need to check expiration times as those did not exist in legacy tokens. _, existing, _ := srv.fsm.State().ACLTokenGetBySecret(nil, args.ACL.ID) - if existing != nil && len(existing.Policies) > 0 { + if existing != nil && existing.UsesNonLegacyFields() { return fmt.Errorf("Cannot use legacy endpoint to modify a non-legacy token") } @@ -210,8 +211,13 @@ func (a *ACL) Get(args *structs.ACLSpecificRequest, return err } - // converting an ACLToken to an ACL will return nil and an error + // Converting an ACLToken to an ACL will return nil and an error // (which we ignore) when it is unconvertible. + // + // This also means we won't have to check expiration times since + // any legacy tokens never had expiration times and no non-legacy + // tokens can be converted. + var acl *structs.ACL if token != nil { acl, _ = token.Convert() @@ -249,13 +255,18 @@ func (a *ACL) List(args *structs.DCSpecificRequest, return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, tokens, err := state.ACLTokenList(ws, false, true, "") + index, tokens, err := state.ACLTokenList(ws, false, true, "", "", "") if err != nil { return err } + now := time.Now() + var acls structs.ACLs for _, token := range tokens { + if token.IsExpired(now) { + continue + } if acl, err := token.Convert(); err == nil && acl != nil { acls = append(acls, acl) } diff --git a/agent/consul/acl_endpoint_test.go b/agent/consul/acl_endpoint_test.go index c88ba1563f..bfdecafea9 100644 --- a/agent/consul/acl_endpoint_test.go +++ b/agent/consul/acl_endpoint_test.go @@ -12,6 +12,8 @@ import ( "time" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth" + "github.com/hashicorp/consul/agent/consul/authmethod/testauth" "github.com/hashicorp/consul/agent/structs" tokenStore "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/lib" @@ -630,6 +632,8 @@ func TestACLEndpoint_TokenRead(t *testing.T) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second }) defer os.RemoveAll(dir1) defer s1.Shutdown() @@ -638,14 +642,12 @@ func TestACLEndpoint_TokenRead(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") - token, err := upsertTestToken(codec, "root", "dc1") - if err != nil { - t.Fatalf("err: %v", err) - } - acl := ACL{srv: s1} t.Run("exists and matches what we created", func(t *testing.T) { + token, err := upsertTestToken(codec, "root", "dc1", nil) + require.NoError(t, err) + req := structs.ACLTokenGetRequest{ Datacenter: "dc1", TokenID: token.AccessorID, @@ -655,7 +657,7 @@ func TestACLEndpoint_TokenRead(t *testing.T) { resp := structs.ACLTokenResponse{} - err := acl.TokenRead(&req, &resp) + err = acl.TokenRead(&req, &resp) require.NoError(t, err) if !reflect.DeepEqual(resp.Token, token) { @@ -663,6 +665,44 @@ func TestACLEndpoint_TokenRead(t *testing.T) { } }) + t.Run("expired tokens are filtered", func(t *testing.T) { + // insert a token that will expire + token, err := upsertTestToken(codec, "root", "dc1", func(t *structs.ACLToken) { + t.ExpirationTTL = 20 * time.Millisecond + }) + require.NoError(t, err) + + t.Run("readable until expiration", func(t *testing.T) { + req := structs.ACLTokenGetRequest{ + Datacenter: "dc1", + TokenID: token.AccessorID, + TokenIDType: structs.ACLTokenAccessor, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLTokenResponse{} + + require.NoError(t, acl.TokenRead(&req, &resp)) + require.Equal(t, token, resp.Token) + }) + + time.Sleep(50 * time.Millisecond) + + t.Run("not returned when expired", func(t *testing.T) { + req := structs.ACLTokenGetRequest{ + Datacenter: "dc1", + TokenID: token.AccessorID, + TokenIDType: structs.ACLTokenAccessor, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLTokenResponse{} + + require.NoError(t, acl.TokenRead(&req, &resp)) + require.Nil(t, resp.Token) + }) + }) + t.Run("nil when token does not exist", func(t *testing.T) { fakeID, err := uuid.GenerateUUID() require.NoError(t, err) @@ -704,6 +744,8 @@ func TestACLEndpoint_TokenClone(t *testing.T) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second }) defer os.RemoveAll(dir1) defer s1.Shutdown() @@ -712,28 +754,52 @@ func TestACLEndpoint_TokenClone(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") - t1, err := upsertTestToken(codec, "root", "dc1") + t1, err := upsertTestToken(codec, "root", "dc1", nil) require.NoError(t, err) - acl := ACL{srv: s1} + endpoint := ACL{srv: s1} - req := structs.ACLTokenSetRequest{ - Datacenter: "dc1", - ACLToken: structs.ACLToken{AccessorID: t1.AccessorID}, - WriteRequest: structs.WriteRequest{Token: "root"}, - } + t.Run("normal", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{AccessorID: t1.AccessorID}, + WriteRequest: structs.WriteRequest{Token: "root"}, + } - t2 := structs.ACLToken{} + t2 := structs.ACLToken{} - err = acl.TokenClone(&req, &t2) - require.NoError(t, err) + err = endpoint.TokenClone(&req, &t2) + require.NoError(t, err) - require.Equal(t, t1.Description, t2.Description) - require.Equal(t, t1.Policies, t2.Policies) - require.Equal(t, t1.Rules, t2.Rules) - require.Equal(t, t1.Local, t2.Local) - require.NotEqual(t, t1.AccessorID, t2.AccessorID) - require.NotEqual(t, t1.SecretID, t2.SecretID) + require.Equal(t, t1.Description, t2.Description) + require.Equal(t, t1.Policies, t2.Policies) + require.Equal(t, t1.Rules, t2.Rules) + require.Equal(t, t1.Local, t2.Local) + require.NotEqual(t, t1.AccessorID, t2.AccessorID) + require.NotEqual(t, t1.SecretID, t2.SecretID) + }) + + t.Run("can't clone expired token", func(t *testing.T) { + // insert a token that will expire + t1, err := upsertTestToken(codec, "root", "dc1", func(t *structs.ACLToken) { + t.ExpirationTTL = 11 * time.Millisecond + }) + require.NoError(t, err) + + time.Sleep(30 * time.Millisecond) + + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{AccessorID: t1.AccessorID}, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + t2 := structs.ACLToken{} + + err = endpoint.TokenClone(&req, &t2) + require.Error(t, err) + require.Equal(t, acl.ErrNotFound, err) + }) } func TestACLEndpoint_TokenSet(t *testing.T) { @@ -743,6 +809,8 @@ func TestACLEndpoint_TokenSet(t *testing.T) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second }) defer os.RemoveAll(dir1) defer s1.Shutdown() @@ -752,6 +820,7 @@ func TestACLEndpoint_TokenSet(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") acl := ACL{srv: s1} + var tokenID string t.Run("Create it", func(t *testing.T) { @@ -806,6 +875,625 @@ func TestACLEndpoint_TokenSet(t *testing.T) { require.Equal(t, token.Description, "new-description") require.Equal(t, token.AccessorID, resp.AccessorID) }) + + t.Run("Create it using Policies linked by id and name", func(t *testing.T) { + policy1, err := upsertTestPolicy(codec, "root", "dc1") + require.NoError(t, err) + policy2, err := upsertTestPolicy(codec, "root", "dc1") + require.NoError(t, err) + + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: []structs.ACLTokenPolicyLink{ + structs.ACLTokenPolicyLink{ + ID: policy1.ID, + }, + structs.ACLTokenPolicyLink{ + Name: policy2.Name, + }, + }, + Local: false, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err = acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Delete both policies to ensure that we skip resolving ID->Name + // in the returned data. + require.NoError(t, deleteTestPolicy(codec, "root", "dc1", policy1.ID)) + require.NoError(t, deleteTestPolicy(codec, "root", "dc1", policy2.ID)) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + require.NotNil(t, token.AccessorID) + require.Equal(t, token.Description, "foobar") + require.Equal(t, token.AccessorID, resp.AccessorID) + + require.Len(t, token.Policies, 0) + }) + + t.Run("Create it using Roles linked by id and name", func(t *testing.T) { + role1, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + role2, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: role1.ID, + }, + structs.ACLTokenRoleLink{ + Name: role2.Name, + }, + }, + Local: false, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err = acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Delete both roles to ensure that we skip resolving ID->Name + // in the returned data. + require.NoError(t, deleteTestRole(codec, "root", "dc1", role1.ID)) + require.NoError(t, deleteTestRole(codec, "root", "dc1", role2.ID)) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + require.NotNil(t, token.AccessorID) + require.Equal(t, token.Description, "foobar") + require.Equal(t, token.AccessorID, resp.AccessorID) + + require.Len(t, token.Roles, 0) + }) + + t.Run("Create it with AuthMethod set outside of login", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + AuthMethod: "fakemethod", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "AuthMethod field is disallowed outside of Login") + }) + + t.Run("Update auth method linked token and try to change auth method", func(t *testing.T) { + acl := ACL{srv: s1} + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + testauth.InstallSessionToken(testSessionID, "fake-token", "default", "demo", "abc123") + + method1, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID) + require.NoError(t, err) + + _, err = upsertTestBindingRule(codec, "root", "dc1", method1.Name, "", structs.BindingRuleBindTypeService, "demo") + require.NoError(t, err) + + // create a token in one method + methodToken := structs.ACLToken{} + require.NoError(t, acl.Login(&structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method1.Name, + BearerToken: "fake-token", + }, + Datacenter: "dc1", + }, &methodToken)) + + method2, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + // try to update the token and change the method + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + AccessorID: methodToken.AccessorID, + SecretID: methodToken.SecretID, + AuthMethod: method2.Name, + Description: "updated token", + Local: true, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err = acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "Cannot change AuthMethod") + }) + + t.Run("Update auth method linked token and let the SecretID and AuthMethod be defaulted", func(t *testing.T) { + acl := ACL{srv: s1} + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + testauth.InstallSessionToken(testSessionID, "fake-token", "default", "demo", "abc123") + + method, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID) + require.NoError(t, err) + + _, err = upsertTestBindingRule(codec, "root", "dc1", method.Name, "", structs.BindingRuleBindTypeService, "demo") + require.NoError(t, err) + + methodToken := structs.ACLToken{} + require.NoError(t, acl.Login(&structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-token", + }, + Datacenter: "dc1", + }, &methodToken)) + + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + AccessorID: methodToken.AccessorID, + // SecretID: methodToken.SecretID, + // AuthMethod: method.Name, + Description: "updated token", + Local: true, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + require.NoError(t, acl.TokenSet(&req, &resp)) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + require.Len(t, token.Roles, 0) + require.Equal(t, "updated token", token.Description) + require.True(t, token.Local) + require.Equal(t, methodToken.SecretID, token.SecretID) + require.Equal(t, methodToken.AuthMethod, token.AuthMethod) + }) + + t.Run("Create it with invalid service identity (empty)", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: ""}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "Service identity is missing the service name field") + }) + + t.Run("Create it with invalid service identity (too large)", func(t *testing.T) { + long := strings.Repeat("x", serviceIdentityNameMaxLength+1) + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: long}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NotNil(t, err) + }) + + for _, test := range []struct { + name string + ok bool + }{ + {"-abc", false}, + {"abc-", false}, + {"a-bc", true}, + {"_abc", false}, + {"abc_", false}, + {"a_bc", true}, + {":abc", false}, + {"abc:", false}, + {"a:bc", false}, + {"Abc", false}, + {"aBc", false}, + {"abC", false}, + {"0abc", true}, + {"abc0", true}, + {"a0bc", true}, + } { + var testName string + if test.ok { + testName = "Create it with valid service identity (by regex): " + test.name + } else { + testName = "Create it with invalid service identity (by regex): " + test.name + } + t.Run(testName, func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: test.name}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + if test.ok { + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + require.ElementsMatch(t, req.ACLToken.ServiceIdentities, token.ServiceIdentities) + } else { + require.NotNil(t, err) + } + }) + } + + t.Run("Create it with two of the same service identities", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: "example"}, + &structs.ACLServiceIdentity{ServiceName: "example"}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + require.Len(t, token.ServiceIdentities, 1) + }) + + t.Run("Create it with two of the same service identities and different DCs", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "example", + Datacenters: []string{"dc2", "dc3"}, + }, + &structs.ACLServiceIdentity{ + ServiceName: "example", + Datacenters: []string{"dc1", "dc2"}, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + require.Len(t, token.ServiceIdentities, 1) + svcid := token.ServiceIdentities[0] + require.Equal(t, "example", svcid.ServiceName) + require.ElementsMatch(t, []string{"dc1", "dc2", "dc3"}, svcid.Datacenters) + }) + + t.Run("Create it with invalid service identity (datacenters set on local token)", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: true, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: "foo", Datacenters: []string{"dc2"}}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "cannot specify a list of datacenters on a local token") + }) + + for _, test := range []struct { + name string + offset time.Duration + errString string + errStringTTL string + }{ + {"before create time", -5 * time.Minute, "ExpirationTime cannot be before CreateTime", ""}, + {"too soon", 1 * time.Millisecond, "ExpirationTime cannot be less than", "ExpirationTime cannot be less than"}, + {"too distant", 25 * time.Hour, "ExpirationTime cannot be more than", "ExpirationTime cannot be more than"}, + } { + t.Run("Create it with an expiration time that is "+test.name, func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ExpirationTime: timePointer(time.Now().Add(test.offset)), + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + if test.errString != "" { + requireErrorContains(t, err, test.errString) + } else { + require.NotNil(t, err) + } + }) + + t.Run("Create it with an expiration TTL that is "+test.name, func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ExpirationTTL: test.offset, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + if test.errString != "" { + requireErrorContains(t, err, test.errStringTTL) + } else { + require.NotNil(t, err) + } + }) + } + + t.Run("Create it with expiration time AND expiration TTL set (error)", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ExpirationTime: timePointer(time.Now().Add(4 * time.Second)), + ExpirationTTL: 4 * time.Second, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "Expiration TTL and Expiration Time cannot both be set") + }) + + t.Run("Create it with expiration time using TTLs", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ExpirationTTL: 4 * time.Second, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + expectExpTime := resp.CreateTime.Add(4 * time.Second) + + require.NotNil(t, token.AccessorID) + require.Equal(t, token.Description, "foobar") + require.Equal(t, token.AccessorID, resp.AccessorID) + requireTimeEquals(t, &expectExpTime, resp.ExpirationTime) + + tokenID = token.AccessorID + }) + + var expTime time.Time + t.Run("Create it with expiration time", func(t *testing.T) { + expTime = time.Now().Add(4 * time.Second) + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ExpirationTime: &expTime, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + require.NotNil(t, token.AccessorID) + require.Equal(t, token.Description, "foobar") + require.Equal(t, token.AccessorID, resp.AccessorID) + requireTimeEquals(t, &expTime, resp.ExpirationTime) + + tokenID = token.AccessorID + }) + + // do not insert another test at this point: these tests need to be serial + + t.Run("Update expiration time is not allowed", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "new-description", + AccessorID: tokenID, + ExpirationTime: timePointer(expTime.Add(-1 * time.Second)), + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "Cannot change expiration time") + }) + + // do not insert another test at this point: these tests need to be serial + + t.Run("Update anything except expiration time is ok - omit expiration time and let it default", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "new-description-1", + AccessorID: tokenID, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + require.NotNil(t, token.AccessorID) + require.Equal(t, token.Description, "new-description-1") + require.Equal(t, token.AccessorID, resp.AccessorID) + requireTimeEquals(t, &expTime, resp.ExpirationTime) + }) + + t.Run("Update anything except expiration time is ok", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "new-description-2", + AccessorID: tokenID, + ExpirationTime: &expTime, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + require.NotNil(t, token.AccessorID) + require.Equal(t, token.Description, "new-description-2") + require.Equal(t, token.AccessorID, resp.AccessorID) + requireTimeEquals(t, &expTime, resp.ExpirationTime) + }) + + t.Run("cannot update a token that is past its expiration time", func(t *testing.T) { + // create a token that will expire + expiringToken, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.ExpirationTTL = 11 * time.Millisecond + }) + require.NoError(t, err) + + time.Sleep(20 * time.Millisecond) // now 'expiringToken' is expired + + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "new-description", + AccessorID: expiringToken.AccessorID, + ExpirationTTL: 4 * time.Second, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err = acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "Cannot find token") + }) } func TestACLEndpoint_TokenSet_anon(t *testing.T) { @@ -857,6 +1545,8 @@ func TestACLEndpoint_TokenDelete(t *testing.T) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second }) defer os.RemoveAll(dir1) defer s1.Shutdown() @@ -869,6 +1559,8 @@ func TestACLEndpoint_TokenDelete(t *testing.T) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true c.Datacenter = "dc2" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second // token replication is required to test deleting non-local tokens in secondary dc c.ACLTokenReplication = true }) @@ -885,12 +1577,74 @@ func TestACLEndpoint_TokenDelete(t *testing.T) { // Try to join joinWAN(t, s2, s1) - existingToken, err := upsertTestToken(codec, "root", "dc1") - require.NoError(t, err) - acl := ACL{srv: s1} acl2 := ACL{srv: s2} + existingToken, err := upsertTestToken(codec, "root", "dc1", nil) + require.NoError(t, err) + + t.Run("deletes a token that has an expiration time in the future", func(t *testing.T) { + // create a token that will expire + testToken, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.ExpirationTTL = 4 * time.Second + }) + require.NoError(t, err) + + // Make sure the token is listable + tokenResp, err := retrieveTestToken(codec, "root", "dc1", testToken.AccessorID) + require.NoError(t, err) + require.NotNil(t, tokenResp.Token) + + // Now try to delete it (this should work). + req := structs.ACLTokenDeleteRequest{ + Datacenter: "dc1", + TokenID: testToken.AccessorID, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var resp string + + err = acl.TokenDelete(&req, &resp) + require.NoError(t, err) + + // Make sure the token is gone + tokenResp, err = retrieveTestToken(codec, "root", "dc1", testToken.AccessorID) + require.NoError(t, err) + require.Nil(t, tokenResp.Token) + }) + + t.Run("deletes a token that is past its expiration time", func(t *testing.T) { + // create a token that will expire + expiringToken, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.ExpirationTTL = 11 * time.Millisecond + }) + require.NoError(t, err) + + time.Sleep(20 * time.Millisecond) // now 'expiringToken' is expired + + // Make sure the token is not listable (filtered due to expiry) + tokenResp, err := retrieveTestToken(codec, "root", "dc1", expiringToken.AccessorID) + require.NoError(t, err) + require.Nil(t, tokenResp.Token) + + // Now try to delete it (this should work). + req := structs.ACLTokenDeleteRequest{ + Datacenter: "dc1", + TokenID: expiringToken.AccessorID, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var resp string + + err = acl.TokenDelete(&req, &resp) + require.NoError(t, err) + + // Make sure the token is still gone (this time it's actually gone) + tokenResp, err = retrieveTestToken(codec, "root", "dc1", expiringToken.AccessorID) + require.NoError(t, err) + require.Nil(t, tokenResp.Token) + }) + t.Run("deletes a token", func(t *testing.T) { req := structs.ACLTokenDeleteRequest{ Datacenter: "dc1", @@ -919,7 +1673,7 @@ func TestACLEndpoint_TokenDelete(t *testing.T) { var out structs.ACLTokenResponse - err := msgpackrpc.CallWithCodec(codec, "ACL.TokenRead", &readReq, &out) + err := acl.TokenRead(&readReq, &out) require.NoError(t, err) @@ -1019,6 +1773,8 @@ func TestACLEndpoint_TokenList(t *testing.T) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second }) defer os.RemoveAll(dir1) defer s1.Shutdown() @@ -1027,31 +1783,64 @@ func TestACLEndpoint_TokenList(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") - t1, err := upsertTestToken(codec, "root", "dc1") - require.NoError(t, err) - - t2, err := upsertTestToken(codec, "root", "dc1") - require.NoError(t, err) - acl := ACL{srv: s1} - req := structs.ACLTokenListRequest{ - Datacenter: "dc1", - QueryOptions: structs.QueryOptions{Token: "root"}, - } - - resp := structs.ACLTokenListResponse{} - - err = acl.TokenList(&req, &resp) + t1, err := upsertTestToken(codec, "root", "dc1", nil) require.NoError(t, err) - tokens := []string{t1.AccessorID, t2.AccessorID} - var retrievedTokens []string + t2, err := upsertTestToken(codec, "root", "dc1", nil) + require.NoError(t, err) - for _, v := range resp.Tokens { - retrievedTokens = append(retrievedTokens, v.AccessorID) - } - require.Subset(t, retrievedTokens, tokens) + t3, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.ExpirationTTL = 11 * time.Millisecond + }) + require.NoError(t, err) + + masterTokenAccessorID, err := retrieveTestTokenAccessorForSecret(codec, "root", "dc1", "root") + require.NoError(t, err) + + t.Run("normal", func(t *testing.T) { + req := structs.ACLTokenListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLTokenListResponse{} + + err = acl.TokenList(&req, &resp) + require.NoError(t, err) + + tokens := []string{ + masterTokenAccessorID, + structs.ACLTokenAnonymousID, + t1.AccessorID, + t2.AccessorID, + t3.AccessorID, + } + require.ElementsMatch(t, gatherIDs(t, resp.Tokens), tokens) + }) + + time.Sleep(20 * time.Millisecond) // now 't3' is expired + + t.Run("filter expired", func(t *testing.T) { + req := structs.ACLTokenListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLTokenListResponse{} + + err = acl.TokenList(&req, &resp) + require.NoError(t, err) + + tokens := []string{ + masterTokenAccessorID, + structs.ACLTokenAnonymousID, + t1.AccessorID, + t2.AccessorID, + } + require.ElementsMatch(t, gatherIDs(t, resp.Tokens), tokens) + }) } func TestACLEndpoint_TokenBatchRead(t *testing.T) { @@ -1061,6 +1850,8 @@ func TestACLEndpoint_TokenBatchRead(t *testing.T) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second }) defer os.RemoveAll(dir1) defer s1.Shutdown() @@ -1069,32 +1860,52 @@ func TestACLEndpoint_TokenBatchRead(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") - t1, err := upsertTestToken(codec, "root", "dc1") - require.NoError(t, err) - - t2, err := upsertTestToken(codec, "root", "dc1") - require.NoError(t, err) - acl := ACL{srv: s1} - tokens := []string{t1.AccessorID, t2.AccessorID} - req := structs.ACLTokenBatchGetRequest{ - Datacenter: "dc1", - AccessorIDs: tokens, - QueryOptions: structs.QueryOptions{Token: "root"}, - } - - resp := structs.ACLTokenBatchResponse{} - - err = acl.TokenBatchRead(&req, &resp) + t1, err := upsertTestToken(codec, "root", "dc1", nil) require.NoError(t, err) - var retrievedTokens []string + t2, err := upsertTestToken(codec, "root", "dc1", nil) + require.NoError(t, err) - for _, v := range resp.Tokens { - retrievedTokens = append(retrievedTokens, v.AccessorID) - } - require.EqualValues(t, retrievedTokens, tokens) + t3, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.ExpirationTTL = 4 * time.Second + }) + require.NoError(t, err) + + t.Run("normal", func(t *testing.T) { + tokens := []string{t1.AccessorID, t2.AccessorID, t3.AccessorID} + + req := structs.ACLTokenBatchGetRequest{ + Datacenter: "dc1", + AccessorIDs: tokens, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLTokenBatchResponse{} + + err = acl.TokenBatchRead(&req, &resp) + require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.Tokens), tokens) + }) + + time.Sleep(20 * time.Millisecond) // now 't3' is expired + + t.Run("returns expired tokens", func(t *testing.T) { + tokens := []string{t1.AccessorID, t2.AccessorID, t3.AccessorID} + + req := structs.ACLTokenBatchGetRequest{ + Datacenter: "dc1", + AccessorIDs: tokens, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLTokenBatchResponse{} + + err = acl.TokenBatchRead(&req, &resp) + require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.Tokens), tokens) + }) } func TestACLEndpoint_PolicyRead(t *testing.T) { @@ -1170,13 +1981,7 @@ func TestACLEndpoint_PolicyBatchRead(t *testing.T) { err = acl.PolicyBatchRead(&req, &resp) require.NoError(t, err) - - var retrievedPolicies []string - - for _, v := range resp.Policies { - retrievedPolicies = append(retrievedPolicies, v.ID) - } - require.EqualValues(t, retrievedPolicies, policies) + require.ElementsMatch(t, gatherIDs(t, resp.Policies), []string{p1.ID, p2.ID}) } func TestACLEndpoint_PolicySet(t *testing.T) { @@ -1197,8 +2002,7 @@ func TestACLEndpoint_PolicySet(t *testing.T) { acl := ACL{srv: s1} var policyID string - // Create it - { + t.Run("Create it", func(t *testing.T) { req := structs.ACLPolicySetRequest{ Datacenter: "dc1", Policy: structs.ACLPolicy{ @@ -1225,10 +2029,9 @@ func TestACLEndpoint_PolicySet(t *testing.T) { require.Equal(t, policy.Rules, "service \"\" { policy = \"read\" }") policyID = policy.ID - } + }) - // Update it - { + t.Run("Update it", func(t *testing.T) { req := structs.ACLPolicySetRequest{ Datacenter: "dc1", Policy: structs.ACLPolicy{ @@ -1254,7 +2057,7 @@ func TestACLEndpoint_PolicySet(t *testing.T) { require.Equal(t, policy.Description, "bat") require.Equal(t, policy.Name, "bar") require.Equal(t, policy.Rules, "service \"\" { policy = \"write\" }") - } + }) } func TestACLEndpoint_PolicySet_globalManagement(t *testing.T) { @@ -1419,13 +2222,12 @@ func TestACLEndpoint_PolicyList(t *testing.T) { err = acl.PolicyList(&req, &resp) require.NoError(t, err) - policies := []string{p1.ID, p2.ID} - var retrievedPolicies []string - - for _, v := range resp.Policies { - retrievedPolicies = append(retrievedPolicies, v.ID) + policies := []string{ + structs.ACLPolicyGlobalManagementID, + p1.ID, + p2.ID, } - require.Subset(t, retrievedPolicies, policies) + require.ElementsMatch(t, gatherIDs(t, resp.Policies), policies) } func TestACLEndpoint_PolicyResolve(t *testing.T) { @@ -1481,17 +2283,2538 @@ func TestACLEndpoint_PolicyResolve(t *testing.T) { } err = acl.PolicyResolve(&req, &resp) require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.Policies), policies) +} - var retrievedPolicies []string +func TestACLEndpoint_RoleRead(t *testing.T) { + t.Parallel() + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() - for _, v := range resp.Policies { - retrievedPolicies = append(retrievedPolicies, v.ID) + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + role, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + acl := ACL{srv: s1} + + req := structs.ACLRoleGetRequest{ + Datacenter: "dc1", + RoleID: role.ID, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLRoleResponse{} + + err = acl.RoleRead(&req, &resp) + require.NoError(t, err) + require.Equal(t, role, resp.Role) +} + +func TestACLEndpoint_RoleBatchRead(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + r1, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + r2, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + acl := ACL{srv: s1} + roles := []string{r1.ID, r2.ID} + + req := structs.ACLRoleBatchGetRequest{ + Datacenter: "dc1", + RoleIDs: roles, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLRoleBatchResponse{} + + err = acl.RoleBatchRead(&req, &resp) + require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.Roles), roles) +} + +func TestACLEndpoint_RoleSet(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + var roleID string + + testPolicy1, err := upsertTestPolicy(codec, "root", "dc1") + require.NoError(t, err) + testPolicy2, err := upsertTestPolicy(codec, "root", "dc1") + require.NoError(t, err) + + t.Run("Create it", func(t *testing.T) { + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: "baz", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicy1.ID, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + require.NoError(t, err) + require.NotNil(t, resp.ID) + + // Get the role directly to validate that it exists + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + + require.NotNil(t, role.ID) + require.Equal(t, role.Description, "foobar") + require.Equal(t, role.Name, "baz") + require.Len(t, role.Policies, 1) + require.Equal(t, testPolicy1.ID, role.Policies[0].ID) + + roleID = role.ID + }) + + t.Run("Update it", func(t *testing.T) { + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + ID: roleID, + Description: "bat", + Name: "bar", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicy2.ID, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + require.NoError(t, err) + require.NotNil(t, resp.ID) + + // Get the role directly to validate that it exists + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + + require.NotNil(t, role.ID) + require.Equal(t, role.Description, "bat") + require.Equal(t, role.Name, "bar") + require.Len(t, role.Policies, 1) + require.Equal(t, testPolicy2.ID, role.Policies[0].ID) + }) + + t.Run("Create it using Policies linked by id and name", func(t *testing.T) { + policy1, err := upsertTestPolicy(codec, "root", "dc1") + require.NoError(t, err) + policy2, err := upsertTestPolicy(codec, "root", "dc1") + require.NoError(t, err) + + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: "baz", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: policy1.ID, + }, + structs.ACLRolePolicyLink{ + Name: policy2.Name, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLRole{} + + err = acl.RoleSet(&req, &resp) + require.NoError(t, err) + require.NotNil(t, resp.ID) + + // Delete both policies to ensure that we skip resolving ID->Name + // in the returned data. + require.NoError(t, deleteTestPolicy(codec, "root", "dc1", policy1.ID)) + require.NoError(t, deleteTestPolicy(codec, "root", "dc1", policy2.ID)) + + // Get the role directly to validate that it exists + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + + require.NotNil(t, role.ID) + require.Equal(t, role.Description, "foobar") + require.Equal(t, role.Name, "baz") + + require.Len(t, role.Policies, 0) + }) + + roleNameGen := func(t *testing.T) string { + t.Helper() + name, err := uuid.GenerateUUID() + require.NoError(t, err) + return name + } + + t.Run("Create it with invalid service identity (empty)", func(t *testing.T) { + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: roleNameGen(t), + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: ""}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + requireErrorContains(t, err, "Service identity is missing the service name field") + }) + + t.Run("Create it with invalid service identity (too large)", func(t *testing.T) { + long := strings.Repeat("x", serviceIdentityNameMaxLength+1) + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: roleNameGen(t), + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: long}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + require.NotNil(t, err) + }) + + for _, test := range []struct { + name string + ok bool + }{ + {"-abc", false}, + {"abc-", false}, + {"a-bc", true}, + {"_abc", false}, + {"abc_", false}, + {"a_bc", true}, + {":abc", false}, + {"abc:", false}, + {"a:bc", false}, + {"Abc", false}, + {"aBc", false}, + {"abC", false}, + {"0abc", true}, + {"abc0", true}, + {"a0bc", true}, + } { + var testName string + if test.ok { + testName = "Create it with valid service identity (by regex): " + test.name + } else { + testName = "Create it with invalid service identity (by regex): " + test.name + } + t.Run(testName, func(t *testing.T) { + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: roleNameGen(t), + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: test.name}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + if test.ok { + require.NoError(t, err) + + // Get the token directly to validate that it exists + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + require.ElementsMatch(t, req.Role.ServiceIdentities, role.ServiceIdentities) + } else { + require.NotNil(t, err) + } + }) + } + + t.Run("Create it with two of the same service identities", func(t *testing.T) { + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: roleNameGen(t), + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: "example"}, + &structs.ACLServiceIdentity{ServiceName: "example"}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + require.NoError(t, err) + + // Get the role directly to validate that it exists + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + require.Len(t, role.ServiceIdentities, 1) + }) + + t.Run("Create it with two of the same service identities and different DCs", func(t *testing.T) { + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: roleNameGen(t), + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "example", + Datacenters: []string{"dc2", "dc3"}, + }, + &structs.ACLServiceIdentity{ + ServiceName: "example", + Datacenters: []string{"dc1", "dc2"}, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + require.NoError(t, err) + + // Get the role directly to validate that it exists + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + require.Len(t, role.ServiceIdentities, 1) + svcid := role.ServiceIdentities[0] + require.Equal(t, "example", svcid.ServiceName) + require.ElementsMatch(t, []string{"dc1", "dc2", "dc3"}, svcid.Datacenters) + }) +} + +func TestACLEndpoint_RoleSet_names(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + + testPolicy1, err := upsertTestPolicy(codec, "root", "dc1") + require.NoError(t, err) + + for _, test := range []struct { + name string + ok bool + }{ + {"", false}, + {"-bad", true}, + {"bad-", true}, + {"bad?bad", false}, + {strings.Repeat("x", 257), false}, + {strings.Repeat("x", 256), true}, + {"-abc", true}, + {"abc-", true}, + {"a-bc", true}, + {"_abc", true}, + {"abc_", true}, + {"a_bc", true}, + {":abc", false}, + {"abc:", false}, + {"a:bc", false}, + {"Abc", true}, + {"aBc", true}, + {"abC", true}, + {"0abc", true}, + {"abc0", true}, + {"a0bc", true}, + } { + var testName string + if test.ok { + testName = "create with valid name: " + test.name + } else { + testName = "create with invalid name: " + test.name + } + + t.Run(testName, func(t *testing.T) { + // cleanup from a prior insertion that may have succeeded + require.NoError(t, deleteTestRoleByName(codec, "root", "dc1", test.name)) + + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Name: test.name, + Description: "foobar", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicy1.ID, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + if test.ok { + require.NoError(t, err) + + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + require.Equal(t, test.name, role.Name) + } else { + require.Error(t, err) + } + }) + } +} + +func TestACLEndpoint_RoleDelete(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + existingRole, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + acl := ACL{srv: s1} + + req := structs.ACLRoleDeleteRequest{ + Datacenter: "dc1", + RoleID: existingRole.ID, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var resp string + + err = acl.RoleDelete(&req, &resp) + require.NoError(t, err) + + // Make sure the role is gone + roleResp, err := retrieveTestRole(codec, "root", "dc1", existingRole.ID) + require.Nil(t, roleResp.Role) +} + +func TestACLEndpoint_RoleList(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + r1, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + r2, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + acl := ACL{srv: s1} + + req := structs.ACLRoleListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLRoleListResponse{} + + err = acl.RoleList(&req, &resp) + require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.Roles), []string{r1.ID, r2.ID}) +} + +func TestACLEndpoint_RoleResolve(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + t.Run("Normal", func(t *testing.T) { + r1, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + r2, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + acl := ACL{srv: s1} + + // Assign the roles to a token + tokenUpsertReq := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: r1.ID, + }, + structs.ACLTokenRoleLink{ + ID: r2.ID, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + token := structs.ACLToken{} + err = acl.TokenSet(&tokenUpsertReq, &token) + require.NoError(t, err) + require.NotEmpty(t, token.SecretID) + + resp := structs.ACLRoleBatchResponse{} + req := structs.ACLRoleBatchGetRequest{ + Datacenter: "dc1", + RoleIDs: []string{r1.ID, r2.ID}, + QueryOptions: structs.QueryOptions{Token: token.SecretID}, + } + err = acl.RoleResolve(&req, &resp) + require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.Roles), []string{r1.ID, r2.ID}) + }) +} + +func TestACLEndpoint_AuthMethodSet(t *testing.T) { + t.Parallel() + + tempDir, err := ioutil.TempDir("", "consul") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + + newAuthMethod := func(name string) structs.ACLAuthMethod { + return structs.ACLAuthMethod{ + Name: name, + Description: "test", + Type: "testing", + } + } + + t.Run("Create", func(t *testing.T) { + reqMethod := newAuthMethod("test") + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: reqMethod, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.NoError(t, err) + + // Get the method directly to validate that it exists + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", resp.Name) + require.NoError(t, err) + method := methodResp.AuthMethod + + require.Equal(t, method.Name, "test") + require.Equal(t, method.Description, "test") + require.Equal(t, method.Type, "testing") + }) + + t.Run("Update fails; not allowed to change types", func(t *testing.T) { + reqMethod := newAuthMethod("test") + reqMethod.Type = "invalid" + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: reqMethod, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.Error(t, err) + }) + + t.Run("Update - allow type to default", func(t *testing.T) { + reqMethod := newAuthMethod("test") + reqMethod.Description = "test modified 1" + reqMethod.Type = "" // unset + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: reqMethod, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.NoError(t, err) + + // Get the method directly to validate that it exists + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", resp.Name) + require.NoError(t, err) + method := methodResp.AuthMethod + + require.Equal(t, method.Name, "test") + require.Equal(t, method.Description, "test modified 1") + require.Equal(t, method.Type, "testing") + }) + + t.Run("Update - specify type", func(t *testing.T) { + reqMethod := newAuthMethod("test") + reqMethod.Description = "test modified 2" + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: reqMethod, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.NoError(t, err) + + // Get the method directly to validate that it exists + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", resp.Name) + require.NoError(t, err) + method := methodResp.AuthMethod + + require.Equal(t, method.Name, "test") + require.Equal(t, method.Description, "test modified 2") + require.Equal(t, method.Type, "testing") + }) + + t.Run("Create with no name", func(t *testing.T) { + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: newAuthMethod(""), + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.Error(t, err) + }) + + t.Run("Create with invalid type", func(t *testing.T) { + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: structs.ACLAuthMethod{ + Name: "invalid", + Description: "invalid test", + Type: "invalid", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.Error(t, err) + }) + + for _, test := range []struct { + name string + ok bool + }{ + {strings.Repeat("x", 129), false}, + {strings.Repeat("x", 128), true}, + {"-abc", true}, + {"abc-", true}, + {"a-bc", true}, + {"_abc", true}, + {"abc_", true}, + {"a_bc", true}, + {":abc", false}, + {"abc:", false}, + {"a:bc", false}, + {"Abc", true}, + {"aBc", true}, + {"abC", true}, + {"0abc", true}, + {"abc0", true}, + {"a0bc", true}, + } { + var testName string + if test.ok { + testName = "Create with valid name (by regex): " + test.name + } else { + testName = "Create with invalid name (by regex): " + test.name + } + t.Run(testName, func(t *testing.T) { + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: newAuthMethod(test.name), + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + + if test.ok { + require.NoError(t, err) + + // Get the method directly to validate that it exists + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", resp.Name) + require.NoError(t, err) + method := methodResp.AuthMethod + + require.Equal(t, method.Name, test.name) + require.Equal(t, method.Type, "testing") + } else { + require.Error(t, err) + } + }) + } +} + +func TestACLEndpoint_AuthMethodDelete(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + + existingMethod, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID) + require.NoError(t, err) + + acl := ACL{srv: s1} + + t.Run("normal", func(t *testing.T) { + req := structs.ACLAuthMethodDeleteRequest{ + Datacenter: "dc1", + AuthMethodName: existingMethod.Name, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + err = acl.AuthMethodDelete(&req, &ignored) + require.NoError(t, err) + + // Make sure the method is gone + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", existingMethod.Name) + require.NoError(t, err) + require.Nil(t, methodResp.AuthMethod) + }) + + t.Run("delete something that doesn't exist", func(t *testing.T) { + req := structs.ACLAuthMethodDeleteRequest{ + Datacenter: "dc1", + AuthMethodName: "missing", + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + err = acl.AuthMethodDelete(&req, &ignored) + require.NoError(t, err) + }) +} + +// Deleting an auth method atomically deletes all rules and tokens as well. +func TestACLEndpoint_AuthMethodDelete_RuleAndTokenCascade(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + testSessionID1 := testauth.StartSession() + defer testauth.ResetSession(testSessionID1) + testauth.InstallSessionToken(testSessionID1, "fake-token1", "default", "abc", "abc123") + + testSessionID2 := testauth.StartSession() + defer testauth.ResetSession(testSessionID2) + testauth.InstallSessionToken(testSessionID2, "fake-token2", "default", "abc", "abc123") + + createToken := func(methodName, bearerToken string) *structs.ACLToken { + acl := ACL{srv: s1} + + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: methodName, + BearerToken: bearerToken, + }, + Datacenter: "dc1", + }, &resp)) + + return &resp + } + + method1, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID1) + require.NoError(t, err) + i1_r1, err := upsertTestBindingRule( + codec, "root", "dc1", + method1.Name, + "serviceaccount.name==abc", + structs.BindingRuleBindTypeService, + "abc", + ) + require.NoError(t, err) + i1_r2, err := upsertTestBindingRule( + codec, "root", "dc1", + method1.Name, + "serviceaccount.name==def", + structs.BindingRuleBindTypeService, + "def", + ) + require.NoError(t, err) + i1_t1 := createToken(method1.Name, "fake-token1") + i1_t2 := createToken(method1.Name, "fake-token1") + + method2, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID2) + require.NoError(t, err) + i2_r1, err := upsertTestBindingRule( + codec, "root", "dc1", + method2.Name, + "serviceaccount.name==abc", + structs.BindingRuleBindTypeService, + "abc", + ) + require.NoError(t, err) + i2_r2, err := upsertTestBindingRule( + codec, "root", "dc1", + method2.Name, + "serviceaccount.name==def", + structs.BindingRuleBindTypeService, + "def", + ) + require.NoError(t, err) + i2_t1 := createToken(method2.Name, "fake-token2") + i2_t2 := createToken(method2.Name, "fake-token2") + + acl := ACL{srv: s1} + + req := structs.ACLAuthMethodDeleteRequest{ + Datacenter: "dc1", + AuthMethodName: method1.Name, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + err = acl.AuthMethodDelete(&req, &ignored) + require.NoError(t, err) + + // Make sure the method is gone. + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", method1.Name) + require.NoError(t, err) + require.Nil(t, methodResp.AuthMethod) + + // Make sure the rules and tokens are gone. + for _, id := range []string{i1_r1.ID, i1_r2.ID} { + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", id) + require.NoError(t, err) + require.Nil(t, ruleResp.BindingRule) + } + for _, id := range []string{i1_t1.AccessorID, i1_t2.AccessorID} { + tokResp, err := retrieveTestToken(codec, "root", "dc1", id) + require.NoError(t, err) + require.Nil(t, tokResp.Token) + } + + // Make sure the rules and tokens for the untouched auth method are still there. + for _, id := range []string{i2_r1.ID, i2_r2.ID} { + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", id) + require.NoError(t, err) + require.NotNil(t, ruleResp.BindingRule) + } + for _, id := range []string{i2_t1.AccessorID, i2_t2.AccessorID} { + tokResp, err := retrieveTestToken(codec, "root", "dc1", id) + require.NoError(t, err) + require.NotNil(t, tokResp.Token) + } +} + +func TestACLEndpoint_AuthMethodList(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + i1, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + i2, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + acl := ACL{srv: s1} + + req := structs.ACLAuthMethodListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLAuthMethodListResponse{} + + err = acl.AuthMethodList(&req, &resp) + require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.AuthMethods), []string{i1.Name, i2.Name}) +} + +func TestACLEndpoint_BindingRuleSet(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + var ruleID string + + testAuthMethod, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + otherTestAuthMethod, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + newRule := func() structs.ACLBindingRule { + return structs.ACLBindingRule{ + Description: "foobar", + AuthMethod: testAuthMethod.Name, + Selector: "serviceaccount.name==abc", + BindType: structs.BindingRuleBindTypeService, + BindName: "abc", + } + } + + requireSetErrors := func(t *testing.T, reqRule structs.ACLBindingRule) { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: reqRule, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLBindingRule{} + + err := acl.BindingRuleSet(&req, &resp) + require.Error(t, err) + } + + requireOK := func(t *testing.T, reqRule structs.ACLBindingRule) *structs.ACLBindingRule { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: reqRule, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLBindingRule{} + + err := acl.BindingRuleSet(&req, &resp) + require.NoError(t, err) + require.NotEmpty(t, resp.ID) + return &resp + } + + t.Run("Create it", func(t *testing.T) { + reqRule := newRule() + + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: reqRule, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLBindingRule{} + + err := acl.BindingRuleSet(&req, &resp) + require.NoError(t, err) + require.NotNil(t, resp.ID) + + // Get the rule directly to validate that it exists + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + rule := ruleResp.BindingRule + + require.NotEmpty(t, rule.ID) + require.Equal(t, rule.Description, "foobar") + require.Equal(t, rule.AuthMethod, testAuthMethod.Name) + require.Equal(t, "serviceaccount.name==abc", rule.Selector) + require.Equal(t, structs.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "abc", rule.BindName) + + ruleID = rule.ID + }) + + t.Run("Update fails; cannot change method name", func(t *testing.T) { + reqRule := newRule() + reqRule.ID = ruleID + reqRule.AuthMethod = otherTestAuthMethod.Name + requireSetErrors(t, reqRule) + }) + + t.Run("Update it - omit method name", func(t *testing.T) { + reqRule := newRule() + reqRule.ID = ruleID + reqRule.Description = "foobar modified 1" + reqRule.Selector = "serviceaccount.namespace==def" + reqRule.BindType = structs.BindingRuleBindTypeRole + reqRule.BindName = "def" + reqRule.AuthMethod = "" // clear + + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: reqRule, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLBindingRule{} + + err := acl.BindingRuleSet(&req, &resp) + require.NoError(t, err) + require.NotNil(t, resp.ID) + + // Get the rule directly to validate that it exists + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + rule := ruleResp.BindingRule + + require.NotEmpty(t, rule.ID) + require.Equal(t, rule.Description, "foobar modified 1") + require.Equal(t, rule.AuthMethod, testAuthMethod.Name) + require.Equal(t, "serviceaccount.namespace==def", rule.Selector) + require.Equal(t, structs.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "def", rule.BindName) + }) + + t.Run("Update it - specify method name", func(t *testing.T) { + reqRule := newRule() + reqRule.ID = ruleID + reqRule.Description = "foobar modified 2" + reqRule.Selector = "serviceaccount.namespace==def" + reqRule.BindType = structs.BindingRuleBindTypeRole + reqRule.BindName = "def" + + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: reqRule, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLBindingRule{} + + err := acl.BindingRuleSet(&req, &resp) + require.NoError(t, err) + require.NotNil(t, resp.ID) + + // Get the rule directly to validate that it exists + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + rule := ruleResp.BindingRule + + require.NotEmpty(t, rule.ID) + require.Equal(t, rule.Description, "foobar modified 2") + require.Equal(t, rule.AuthMethod, testAuthMethod.Name) + require.Equal(t, "serviceaccount.namespace==def", rule.Selector) + require.Equal(t, structs.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "def", rule.BindName) + }) + + t.Run("Create fails; empty method name", func(t *testing.T) { + reqRule := newRule() + reqRule.AuthMethod = "" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; unknown method name", func(t *testing.T) { + reqRule := newRule() + reqRule.AuthMethod = "unknown" + requireSetErrors(t, reqRule) + }) + + t.Run("Create with no explicit selector", func(t *testing.T) { + reqRule := newRule() + reqRule.Selector = "" + + rule := requireOK(t, reqRule) + require.Empty(t, rule.Selector, 0) + }) + + t.Run("Create fails; match selector with unknown vars", func(t *testing.T) { + reqRule := newRule() + reqRule.Selector = "serviceaccount.name==a and serviceaccount.bizarroname==b" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; match selector invalid", func(t *testing.T) { + reqRule := newRule() + reqRule.Selector = "serviceaccount.name" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; empty bind type", func(t *testing.T) { + reqRule := newRule() + reqRule.BindType = "" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; empty bind name", func(t *testing.T) { + reqRule := newRule() + reqRule.BindName = "" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; invalid bind type", func(t *testing.T) { + reqRule := newRule() + reqRule.BindType = "invalid" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; bind name with unknown vars", func(t *testing.T) { + reqRule := newRule() + reqRule.BindName = "method-${serviceaccount.bizarroname}" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; invalid bind name no template", func(t *testing.T) { + reqRule := newRule() + reqRule.BindName = "-abc:" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; invalid bind name with template", func(t *testing.T) { + reqRule := newRule() + reqRule.BindName = "method-${serviceaccount.name" + requireSetErrors(t, reqRule) + }) + t.Run("Create fails; invalid bind name after template computed", func(t *testing.T) { + reqRule := newRule() + reqRule.BindName = "method-${serviceaccount.name}:blah-" + requireSetErrors(t, reqRule) + }) +} + +func TestACLEndpoint_BindingRuleDelete(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + testAuthMethod, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + existingRule, err := upsertTestBindingRule( + codec, "root", "dc1", + testAuthMethod.Name, + "serviceaccount.name==abc", + structs.BindingRuleBindTypeService, + "abc", + ) + require.NoError(t, err) + + acl := ACL{srv: s1} + + t.Run("normal", func(t *testing.T) { + req := structs.ACLBindingRuleDeleteRequest{ + Datacenter: "dc1", + BindingRuleID: existingRule.ID, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + err = acl.BindingRuleDelete(&req, &ignored) + require.NoError(t, err) + + // Make sure the rule is gone + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", existingRule.ID) + require.NoError(t, err) + require.Nil(t, ruleResp.BindingRule) + }) + + t.Run("delete something that doesn't exist", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + req := structs.ACLBindingRuleDeleteRequest{ + Datacenter: "dc1", + BindingRuleID: fakeID, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + err = acl.BindingRuleDelete(&req, &ignored) + require.NoError(t, err) + }) +} + +func TestACLEndpoint_BindingRuleList(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + testAuthMethod, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + r1, err := upsertTestBindingRule( + codec, "root", "dc1", + testAuthMethod.Name, + "serviceaccount.name==abc", + structs.BindingRuleBindTypeService, + "abc", + ) + require.NoError(t, err) + + r2, err := upsertTestBindingRule( + codec, "root", "dc1", + testAuthMethod.Name, + "serviceaccount.name==def", + structs.BindingRuleBindTypeService, + "def", + ) + require.NoError(t, err) + + acl := ACL{srv: s1} + + req := structs.ACLBindingRuleListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLBindingRuleListResponse{} + + err = acl.BindingRuleList(&req, &resp) + require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.BindingRules), []string{r1.ID, r2.ID}) +} + +func TestACLEndpoint_SecureIntroEndpoints_LocalTokensDisabled(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + dir2, s2 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.Datacenter = "dc2" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second + // disable local tokens + c.ACLTokenReplication = false + }) + defer os.RemoveAll(dir2) + defer s2.Shutdown() + codec2 := rpcClient(t, s2) + defer codec2.Close() + + s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig) + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForLeader(t, s2.RPC, "dc2") + + // Try to join + joinWAN(t, s2, s1) + + waitForNewACLs(t, s1) + waitForNewACLs(t, s2) + + acl2 := ACL{srv: s2} + var ignored bool + + errString := errAuthMethodsRequireTokenReplication.Error() + + t.Run("AuthMethodRead", func(t *testing.T) { + requireErrorContains(t, + acl2.AuthMethodRead(&structs.ACLAuthMethodGetRequest{Datacenter: "dc2"}, + &structs.ACLAuthMethodResponse{}), + errString, + ) + }) + t.Run("AuthMethodSet", func(t *testing.T) { + requireErrorContains(t, + acl2.AuthMethodSet(&structs.ACLAuthMethodSetRequest{Datacenter: "dc2"}, + &structs.ACLAuthMethod{}), + errString, + ) + }) + t.Run("AuthMethodDelete", func(t *testing.T) { + requireErrorContains(t, + acl2.AuthMethodDelete(&structs.ACLAuthMethodDeleteRequest{Datacenter: "dc2"}, &ignored), + errString, + ) + }) + t.Run("AuthMethodList", func(t *testing.T) { + requireErrorContains(t, + acl2.AuthMethodList(&structs.ACLAuthMethodListRequest{Datacenter: "dc2"}, + &structs.ACLAuthMethodListResponse{}), + errString, + ) + }) + + t.Run("BindingRuleRead", func(t *testing.T) { + requireErrorContains(t, + acl2.BindingRuleRead(&structs.ACLBindingRuleGetRequest{Datacenter: "dc2"}, + &structs.ACLBindingRuleResponse{}), + errString, + ) + }) + t.Run("BindingRuleSet", func(t *testing.T) { + requireErrorContains(t, + acl2.BindingRuleSet(&structs.ACLBindingRuleSetRequest{Datacenter: "dc2"}, + &structs.ACLBindingRule{}), + errString, + ) + }) + t.Run("BindingRuleDelete", func(t *testing.T) { + requireErrorContains(t, + acl2.BindingRuleDelete(&structs.ACLBindingRuleDeleteRequest{Datacenter: "dc2"}, &ignored), + errString, + ) + }) + t.Run("BindingRuleList", func(t *testing.T) { + requireErrorContains(t, + acl2.BindingRuleList(&structs.ACLBindingRuleListRequest{Datacenter: "dc2"}, + &structs.ACLBindingRuleListResponse{}), + errString, + ) + }) + + t.Run("Login", func(t *testing.T) { + requireErrorContains(t, + acl2.Login(&structs.ACLLoginRequest{Datacenter: "dc2"}, + &structs.ACLToken{}), + errString, + ) + }) + t.Run("Logout", func(t *testing.T) { + requireErrorContains(t, + acl2.Logout(&structs.ACLLogoutRequest{Datacenter: "dc2"}, &ignored), + errString, + ) + }) +} + +func TestACLEndpoint_SecureIntroEndpoints_OnlyCreateLocalData(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec1 := rpcClient(t, s1) + defer codec1.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + dir2, s2 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.Datacenter = "dc2" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second + // enable token replication so secure intro works + c.ACLTokenReplication = true + }) + defer os.RemoveAll(dir2) + defer s2.Shutdown() + codec2 := rpcClient(t, s2) + defer codec2.Close() + + s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig) + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForLeader(t, s2.RPC, "dc2") + + // Try to join + joinWAN(t, s2, s1) + + waitForNewACLs(t, s1) + waitForNewACLs(t, s2) + + acl := ACL{srv: s1} + acl2 := ACL{srv: s2} + + // + // this order is specific so that we can do it in one pass + // + + testSessionID_1 := testauth.StartSession() + defer testauth.ResetSession(testSessionID_1) + + testSessionID_2 := testauth.StartSession() + defer testauth.ResetSession(testSessionID_2) + + testauth.InstallSessionToken( + testSessionID_1, + "fake-web1-token", + "default", "web1", "abc123", + ) + testauth.InstallSessionToken( + testSessionID_2, + "fake-web2-token", + "default", "web2", "def456", + ) + + t.Run("create auth method", func(t *testing.T) { + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc2", + AuthMethod: structs.ACLAuthMethod{ + Name: "testmethod", + Description: "test original", + Type: "testing", + Config: map[string]interface{}{ + "SessionID": testSessionID_2, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + require.NoError(t, acl2.AuthMethodSet(&req, &resp)) + + // present in dc2 + resp2, err := retrieveTestAuthMethod(codec2, "root", "dc2", "testmethod") + require.NoError(t, err) + require.NotNil(t, resp2.AuthMethod) + require.Equal(t, "test original", resp2.AuthMethod.Description) + // absent in dc1 + resp2, err = retrieveTestAuthMethod(codec1, "root", "dc1", "testmethod") + require.NoError(t, err) + require.Nil(t, resp2.AuthMethod) + }) + + t.Run("update auth method", func(t *testing.T) { + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc2", + AuthMethod: structs.ACLAuthMethod{ + Name: "testmethod", + Description: "test updated", + Config: map[string]interface{}{ + "SessionID": testSessionID_2, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + require.NoError(t, acl2.AuthMethodSet(&req, &resp)) + + // present in dc2 + resp2, err := retrieveTestAuthMethod(codec2, "root", "dc2", "testmethod") + require.NoError(t, err) + require.NotNil(t, resp2.AuthMethod) + require.Equal(t, "test updated", resp2.AuthMethod.Description) + // absent in dc1 + resp2, err = retrieveTestAuthMethod(codec1, "root", "dc1", "testmethod") + require.NoError(t, err) + require.Nil(t, resp2.AuthMethod) + }) + + t.Run("read auth method", func(t *testing.T) { + // present in dc2 + req := structs.ACLAuthMethodGetRequest{ + Datacenter: "dc2", + AuthMethodName: "testmethod", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp := structs.ACLAuthMethodResponse{} + require.NoError(t, acl2.AuthMethodRead(&req, &resp)) + require.NotNil(t, resp.AuthMethod) + require.Equal(t, "test updated", resp.AuthMethod.Description) + + // absent in dc1 + req = structs.ACLAuthMethodGetRequest{ + Datacenter: "dc1", + AuthMethodName: "testmethod", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp = structs.ACLAuthMethodResponse{} + require.NoError(t, acl.AuthMethodRead(&req, &resp)) + require.Nil(t, resp.AuthMethod) + }) + + t.Run("list auth method", func(t *testing.T) { + // present in dc2 + req := structs.ACLAuthMethodListRequest{ + Datacenter: "dc2", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp := structs.ACLAuthMethodListResponse{} + require.NoError(t, acl2.AuthMethodList(&req, &resp)) + require.Len(t, resp.AuthMethods, 1) + + // absent in dc1 + req = structs.ACLAuthMethodListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp = structs.ACLAuthMethodListResponse{} + require.NoError(t, acl.AuthMethodList(&req, &resp)) + require.Len(t, resp.AuthMethods, 0) + }) + + var ruleID string + t.Run("create binding rule", func(t *testing.T) { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc2", + BindingRule: structs.ACLBindingRule{ + Description: "test original", + AuthMethod: "testmethod", + BindType: structs.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLBindingRule{} + + require.NoError(t, acl2.BindingRuleSet(&req, &resp)) + ruleID = resp.ID + + // present in dc2 + resp2, err := retrieveTestBindingRule(codec2, "root", "dc2", ruleID) + require.NoError(t, err) + require.NotNil(t, resp2.BindingRule) + require.Equal(t, "test original", resp2.BindingRule.Description) + // absent in dc1 + resp2, err = retrieveTestBindingRule(codec1, "root", "dc1", ruleID) + require.NoError(t, err) + require.Nil(t, resp2.BindingRule) + }) + + t.Run("update binding rule", func(t *testing.T) { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc2", + BindingRule: structs.ACLBindingRule{ + ID: ruleID, + Description: "test updated", + AuthMethod: "testmethod", + BindType: structs.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLBindingRule{} + + require.NoError(t, acl2.BindingRuleSet(&req, &resp)) + ruleID = resp.ID + + // present in dc2 + resp2, err := retrieveTestBindingRule(codec2, "root", "dc2", ruleID) + require.NoError(t, err) + require.NotNil(t, resp2.BindingRule) + require.Equal(t, "test updated", resp2.BindingRule.Description) + // absent in dc1 + resp2, err = retrieveTestBindingRule(codec1, "root", "dc1", ruleID) + require.NoError(t, err) + require.Nil(t, resp2.BindingRule) + }) + + t.Run("read binding rule", func(t *testing.T) { + // present in dc2 + req := structs.ACLBindingRuleGetRequest{ + Datacenter: "dc2", + BindingRuleID: ruleID, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp := structs.ACLBindingRuleResponse{} + require.NoError(t, acl2.BindingRuleRead(&req, &resp)) + require.NotNil(t, resp.BindingRule) + require.Equal(t, "test updated", resp.BindingRule.Description) + + // absent in dc1 + req = structs.ACLBindingRuleGetRequest{ + Datacenter: "dc1", + BindingRuleID: ruleID, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp = structs.ACLBindingRuleResponse{} + require.NoError(t, acl.BindingRuleRead(&req, &resp)) + require.Nil(t, resp.BindingRule) + }) + + t.Run("list binding rule", func(t *testing.T) { + // present in dc2 + req := structs.ACLBindingRuleListRequest{ + Datacenter: "dc2", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp := structs.ACLBindingRuleListResponse{} + require.NoError(t, acl2.BindingRuleList(&req, &resp)) + require.Len(t, resp.BindingRules, 1) + + // absent in dc1 + req = structs.ACLBindingRuleListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp = structs.ACLBindingRuleListResponse{} + require.NoError(t, acl.BindingRuleList(&req, &resp)) + require.Len(t, resp.BindingRules, 0) + }) + + var remoteToken *structs.ACLToken + t.Run("login in remote", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Datacenter: "dc2", + Auth: &structs.ACLLoginParams{ + AuthMethod: "testmethod", + BearerToken: "fake-web2-token", + }, + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + remoteToken = &resp + + // present in dc2 + resp2, err := retrieveTestToken(codec2, "root", "dc2", remoteToken.AccessorID) + require.NoError(t, err) + require.NotNil(t, resp2.Token) + require.Len(t, resp2.Token.ServiceIdentities, 1) + require.Equal(t, "web2", resp2.Token.ServiceIdentities[0].ServiceName) + // absent in dc1 + resp2, err = retrieveTestToken(codec1, "root", "dc1", remoteToken.AccessorID) + require.NoError(t, err) + require.Nil(t, resp2.Token) + }) + + // We delay until now to setup an auth method and binding rule in the + // primary so our earlier listing tests were sane. We need to be able to + // use auth methods in both datacenters in order to verify Logout is + // properly scoped. + t.Run("initialize primary so we can test logout", func(t *testing.T) { + reqAM := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: structs.ACLAuthMethod{ + Name: "primarymethod", + Type: "testing", + Config: map[string]interface{}{ + "SessionID": testSessionID_1, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + respAM := structs.ACLAuthMethod{} + require.NoError(t, acl.AuthMethodSet(&reqAM, &respAM)) + + reqBR := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: structs.ACLBindingRule{ + AuthMethod: "primarymethod", + BindType: structs.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + respBR := structs.ACLBindingRule{} + require.NoError(t, acl.BindingRuleSet(&reqBR, &respBR)) + }) + + var primaryToken *structs.ACLToken + t.Run("login in primary", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Datacenter: "dc1", + Auth: &structs.ACLLoginParams{ + AuthMethod: "primarymethod", + BearerToken: "fake-web1-token", + }, + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + primaryToken = &resp + + // present in dc1 + resp2, err := retrieveTestToken(codec1, "root", "dc1", primaryToken.AccessorID) + require.NoError(t, err) + require.NotNil(t, resp2.Token) + require.Len(t, resp2.Token.ServiceIdentities, 1) + require.Equal(t, "web1", resp2.Token.ServiceIdentities[0].ServiceName) + // absent in dc2 + resp2, err = retrieveTestToken(codec2, "root", "dc2", primaryToken.AccessorID) + require.NoError(t, err) + require.Nil(t, resp2.Token) + }) + + t.Run("logout of remote token in remote dc", func(t *testing.T) { + req := structs.ACLLogoutRequest{ + Datacenter: "dc2", + WriteRequest: structs.WriteRequest{Token: remoteToken.SecretID}, + } + + var ignored bool + require.NoError(t, acl.Logout(&req, &ignored)) + + // absent in dc2 + resp2, err := retrieveTestToken(codec2, "root", "dc2", remoteToken.AccessorID) + require.NoError(t, err) + require.Nil(t, resp2.Token) + // absent in dc1 + resp2, err = retrieveTestToken(codec1, "root", "dc1", remoteToken.AccessorID) + require.NoError(t, err) + require.Nil(t, resp2.Token) + }) + + t.Run("logout of primary token in remote dc should not work", func(t *testing.T) { + req := structs.ACLLogoutRequest{ + Datacenter: "dc2", + WriteRequest: structs.WriteRequest{Token: primaryToken.SecretID}, + } + + var ignored bool + requireErrorContains(t, acl.Logout(&req, &ignored), "ACL not found") + + // present in dc1 + resp2, err := retrieveTestToken(codec1, "root", "dc1", primaryToken.AccessorID) + require.NoError(t, err) + require.NotNil(t, resp2.Token) + require.Len(t, resp2.Token.ServiceIdentities, 1) + require.Equal(t, "web1", resp2.Token.ServiceIdentities[0].ServiceName) + // absent in dc2 + resp2, err = retrieveTestToken(codec2, "root", "dc2", primaryToken.AccessorID) + require.NoError(t, err) + require.Nil(t, resp2.Token) + }) + + // Don't trigger the auth method delete cascade so we know the individual + // endpoints follow the rules. + + t.Run("delete binding rule", func(t *testing.T) { + req := structs.ACLBindingRuleDeleteRequest{ + Datacenter: "dc2", + BindingRuleID: ruleID, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + require.NoError(t, acl2.BindingRuleDelete(&req, &ignored)) + + // absent in dc2 + resp2, err := retrieveTestBindingRule(codec2, "root", "dc2", ruleID) + require.NoError(t, err) + require.Nil(t, resp2.BindingRule) + // absent in dc1 + resp2, err = retrieveTestBindingRule(codec1, "root", "dc1", ruleID) + require.NoError(t, err) + require.Nil(t, resp2.BindingRule) + }) + + t.Run("delete auth method", func(t *testing.T) { + req := structs.ACLAuthMethodDeleteRequest{ + Datacenter: "dc2", + AuthMethodName: "testmethod", + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + require.NoError(t, acl2.AuthMethodDelete(&req, &ignored)) + + // absent in dc2 + resp2, err := retrieveTestAuthMethod(codec2, "root", "dc2", "testmethod") + require.NoError(t, err) + require.Nil(t, resp2.AuthMethod) + // absent in dc1 + resp2, err = retrieveTestAuthMethod(codec1, "root", "dc1", "testmethod") + require.NoError(t, err) + require.Nil(t, resp2.AuthMethod) + }) +} + +func TestACLEndpoint_Login(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + + testauth.InstallSessionToken( + testSessionID, + "fake-web", // no rules + "default", "web", "abc123", + ) + testauth.InstallSessionToken( + testSessionID, + "fake-db", // 1 rule + "default", "db", "def456", + ) + testauth.InstallSessionToken( + testSessionID, + "fake-monolith", // 1 rule, must exist + "default", "monolith", "ghi789", + ) + + method, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID) + require.NoError(t, err) + + ruleDB, err := upsertTestBindingRule( + codec, "root", "dc1", method.Name, + "serviceaccount.namespace==default and serviceaccount.name==db", + structs.BindingRuleBindTypeService, + "method-${serviceaccount.name}", + ) + _, err = upsertTestBindingRule( + codec, "root", "dc1", method.Name, + "serviceaccount.namespace==default and serviceaccount.name==monolith", + structs.BindingRuleBindTypeRole, + "method-${serviceaccount.name}", + ) + require.NoError(t, err) + + t.Run("do not provide a token", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-web", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + req.Token = "nope" + resp := structs.ACLToken{} + + requireErrorContains(t, acl.Login(&req, &resp), "do not provide a token") + }) + + t.Run("unknown method", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name + "-notexist", + BearerToken: "fake-web", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + requireErrorContains(t, acl.Login(&req, &resp), "ACL not found") + }) + + t.Run("invalid method token", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "invalid", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.Error(t, acl.Login(&req, &resp)) + }) + + t.Run("valid method token no bindings", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-web", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + requireErrorContains(t, acl.Login(&req, &resp), "Permission denied") + }) + + t.Run("valid method token 1 role binding must exist and does not exist", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-monolith", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.Error(t, acl.Login(&req, &resp)) + }) + + // create the role so that the bindtype=existing login works + var monolithRoleID string + { + arg := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Name: "method-monolith", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var out structs.ACLRole + require.NoError(t, acl.RoleSet(&arg, &out)) + + monolithRoleID = out.ID + } + s1.purgeAuthMethodValidators() + + t.Run("valid bearer token 1 role binding must exist and now exists", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-monolith", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + + require.Equal(t, method.Name, resp.AuthMethod) + require.Equal(t, `token created via login: {"pod":"pod1"}`, resp.Description) + require.True(t, resp.Local) + require.Len(t, resp.ServiceIdentities, 0) + require.Len(t, resp.Roles, 1) + role := resp.Roles[0] + require.Equal(t, monolithRoleID, role.ID) + require.Equal(t, "method-monolith", role.Name) + }) + + t.Run("valid bearer token 1 service binding", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-db", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + + require.Equal(t, method.Name, resp.AuthMethod) + require.Equal(t, `token created via login: {"pod":"pod1"}`, resp.Description) + require.True(t, resp.Local) + require.Len(t, resp.Roles, 0) + require.Len(t, resp.ServiceIdentities, 1) + svcid := resp.ServiceIdentities[0] + require.Len(t, svcid.Datacenters, 0) + require.Equal(t, "method-db", svcid.ServiceName) + }) + + { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: structs.ACLBindingRule{ + AuthMethod: ruleDB.AuthMethod, + BindType: structs.BindingRuleBindTypeService, + BindName: ruleDB.BindName, + Selector: "", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var out structs.ACLBindingRule + require.NoError(t, acl.BindingRuleSet(&req, &out)) + } + + t.Run("valid bearer token 1 binding (no selectors this time)", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-db", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + + require.Equal(t, method.Name, resp.AuthMethod) + require.Equal(t, `token created via login: {"pod":"pod1"}`, resp.Description) + require.True(t, resp.Local) + require.Len(t, resp.Roles, 0) + require.Len(t, resp.ServiceIdentities, 1) + svcid := resp.ServiceIdentities[0] + require.Len(t, svcid.Datacenters, 0) + require.Equal(t, "method-db", svcid.ServiceName) + }) + + testSessionID_2 := testauth.StartSession() + defer testauth.ResetSession(testSessionID_2) + { + // Update the method to force the cache to invalidate for the next + // subtest. + updated := *method + updated.Description = "updated for the test" + updated.Config = map[string]interface{}{ + "SessionID": testSessionID_2, + } + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: updated, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored structs.ACLAuthMethod + require.NoError(t, acl.AuthMethodSet(&req, &ignored)) + } + + t.Run("updating the method invalidates the cache", func(t *testing.T) { + // We'll try to login with the 'fake-db' cred which DOES exist in the + // old fake validator, but no longer exists in the new fake validator. + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-db", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + requireErrorContains(t, acl.Login(&req, &resp), "ACL not found") + }) +} + +func TestACLEndpoint_Login_k8s(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + + // spin up a fake api server + testSrv := kubeauth.StartTestAPIServer(t) + defer testSrv.Stop() + + testSrv.AuthorizeJWT(goodJWT_A) + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "", + goodJWT_B, + ) + + method, err := upsertTestKubernetesAuthMethod( + codec, "root", "dc1", + testSrv.CACert(), + testSrv.Addr(), + goodJWT_A, + ) + require.NoError(t, err) + + t.Run("invalid bearer token", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "invalid", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.Error(t, acl.Login(&req, &resp)) + }) + + t.Run("valid bearer token no bindings", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: goodJWT_B, + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + requireErrorContains(t, acl.Login(&req, &resp), "Permission denied") + }) + + _, err = upsertTestBindingRule( + codec, "root", "dc1", method.Name, + "serviceaccount.namespace==default", + structs.BindingRuleBindTypeService, + "${serviceaccount.name}", + ) + require.NoError(t, err) + + t.Run("valid bearer token 1 service binding", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: goodJWT_B, + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + + require.Equal(t, method.Name, resp.AuthMethod) + require.Equal(t, `token created via login: {"pod":"pod1"}`, resp.Description) + require.True(t, resp.Local) + require.Len(t, resp.Roles, 0) + require.Len(t, resp.ServiceIdentities, 1) + svcid := resp.ServiceIdentities[0] + require.Len(t, svcid.Datacenters, 0) + require.Equal(t, "demo", svcid.ServiceName) + }) + + // annotate the account + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "alternate-name", + goodJWT_B, + ) + + t.Run("valid bearer token 1 service binding - with annotation", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: goodJWT_B, + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + + require.Equal(t, method.Name, resp.AuthMethod) + require.Equal(t, `token created via login: {"pod":"pod1"}`, resp.Description) + require.True(t, resp.Local) + require.Len(t, resp.Roles, 0) + require.Len(t, resp.ServiceIdentities, 1) + svcid := resp.ServiceIdentities[0] + require.Len(t, svcid.Datacenters, 0) + require.Equal(t, "alternate-name", svcid.ServiceName) + }) +} + +func TestACLEndpoint_Logout(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + testauth.InstallSessionToken( + testSessionID, + "fake-db", // 1 rule + "default", "db", "def456", + ) + + method, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID) + require.NoError(t, err) + + _, err = upsertTestBindingRule( + codec, "root", "dc1", method.Name, + "", + structs.BindingRuleBindTypeService, + "method-${serviceaccount.name}", + ) + require.NoError(t, err) + + t.Run("you must provide a token", func(t *testing.T) { + req := structs.ACLLogoutRequest{ + Datacenter: "dc1", + // WriteRequest: structs.WriteRequest{Token: "root"}, + } + req.Token = "" + var ignored bool + + requireErrorContains(t, acl.Logout(&req, &ignored), "ACL not found") + }) + + t.Run("logout from deleted token", func(t *testing.T) { + req := structs.ACLLogoutRequest{ + Datacenter: "dc1", + WriteRequest: structs.WriteRequest{Token: "not-found"}, + } + var ignored bool + requireErrorContains(t, acl.Logout(&req, &ignored), "ACL not found") + }) + + t.Run("logout from non-auth method-linked token should fail", func(t *testing.T) { + req := structs.ACLLogoutRequest{ + Datacenter: "dc1", + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var ignored bool + requireErrorContains(t, acl.Logout(&req, &ignored), "Permission denied") + }) + + t.Run("login then logout", func(t *testing.T) { + // Create a totally legit Login token. + loginReq := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-db", + }, + Datacenter: "dc1", + } + loginToken := structs.ACLToken{} + + require.NoError(t, acl.Login(&loginReq, &loginToken)) + require.NotEmpty(t, loginToken.SecretID) + + // Now turn around and nuke it. + req := structs.ACLLogoutRequest{ + Datacenter: "dc1", + WriteRequest: structs.WriteRequest{Token: loginToken.SecretID}, + } + + var ignored bool + require.NoError(t, acl.Logout(&req, &ignored)) + }) +} + +func gatherIDs(t *testing.T, v interface{}) []string { + t.Helper() + + var out []string + switch x := v.(type) { + case []*structs.ACLRole: + for _, r := range x { + out = append(out, r.ID) + } + case structs.ACLRoles: + for _, r := range x { + out = append(out, r.ID) + } + case []*structs.ACLPolicy: + for _, p := range x { + out = append(out, p.ID) + } + case structs.ACLPolicyListStubs: + for _, p := range x { + out = append(out, p.ID) + } + case []*structs.ACLToken: + for _, p := range x { + out = append(out, p.AccessorID) + } + case structs.ACLTokenListStubs: + for _, p := range x { + out = append(out, p.AccessorID) + } + case []*structs.ACLAuthMethod: + for _, p := range x { + out = append(out, p.Name) + } + case structs.ACLAuthMethodListStubs: + for _, p := range x { + out = append(out, p.Name) + } + case []*structs.ACLBindingRule: + for _, p := range x { + out = append(out, p.ID) + } + case structs.ACLBindingRules: + for _, p := range x { + out = append(out, p.ID) + } + default: + t.Fatalf("unknown type: %T", x) + } + return out +} + +func TestValidateBindingRuleBindName(t *testing.T) { + t.Parallel() + + type testcase struct { + name string + bindType string + bindName string + fields string + valid bool // valid HIL, invalid contents + err bool // invalid HIL + } + + for _, test := range []testcase{ + {"no bind type", + "", "", "", false, false}, + {"bad bind type", + "invalid", "blah", "", false, true}, + // valid HIL, invalid name + {"empty", + "both", "", "", false, false}, + {"just end", + "both", "}", "", false, false}, + {"var without start", + "both", " item }", "item", false, false}, + {"two vars missing second start", + "both", "before-${ item }after--more }", "item,more", false, false}, + // names for the two types are validated differently + {"@ is disallowed", + "both", "bad@name", "", false, false}, + {"leading dash", + "role", "-name", "", true, false}, + {"leading dash", + "service", "-name", "", false, false}, + {"trailing dash", + "role", "name-", "", true, false}, + {"trailing dash", + "service", "name-", "", false, false}, + {"inner dash", + "both", "name-end", "", true, false}, + {"upper case", + "role", "NAME", "", true, false}, + {"upper case", + "service", "NAME", "", false, false}, + // valid HIL, valid name + {"no vars", + "both", "nothing", "", true, false}, + {"just var", + "both", "${item}", "item", true, false}, + {"var in middle", + "both", "before-${item}after", "item", true, false}, + {"two vars", + "both", "before-${item}after-${more}", "item,more", true, false}, + // bad + {"no bind name", + "both", "", "", false, false}, + {"just start", + "both", "${", "", false, true}, + {"backwards", + "both", "}${", "", false, true}, + {"no varname", + "both", "${}", "", false, true}, + {"missing map key", + "both", "${item}", "", false, true}, + {"var without end", + "both", "${ item ", "item", false, true}, + {"two vars missing first end", + "both", "before-${ item after-${ more }", "item,more", false, true}, + } { + var cases []testcase + if test.bindType == "both" { + test1 := test + test1.bindType = "role" + test2 := test + test2.bindType = "service" + cases = []testcase{test1, test2} + } else { + cases = []testcase{test} + } + + for _, test := range cases { + test := test + t.Run(test.bindType+"--"+test.name, func(t *testing.T) { + t.Parallel() + valid, err := validateBindingRuleBindName( + test.bindType, + test.bindName, + strings.Split(test.fields, ","), + ) + if test.err { + require.NotNil(t, err) + require.False(t, valid) + } else { + require.NoError(t, err) + require.Equal(t, test.valid, valid) + } + }) + } } - require.EqualValues(t, retrievedPolicies, policies) } // upsertTestToken creates a token for testing purposes -func upsertTestToken(codec rpc.ClientCodec, masterToken string, datacenter string) (*structs.ACLToken, error) { +func upsertTestToken(codec rpc.ClientCodec, masterToken string, datacenter string, + tokenModificationFn func(token *structs.ACLToken)) (*structs.ACLToken, error) { arg := structs.ACLTokenSetRequest{ Datacenter: datacenter, ACLToken: structs.ACLToken{ @@ -1502,6 +4825,10 @@ func upsertTestToken(codec rpc.ClientCodec, masterToken string, datacenter strin WriteRequest: structs.WriteRequest{Token: masterToken}, } + if tokenModificationFn != nil { + tokenModificationFn(&arg.ACLToken) + } + var out structs.ACLToken err := msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg, &out) @@ -1517,6 +4844,29 @@ func upsertTestToken(codec rpc.ClientCodec, masterToken string, datacenter strin return &out, nil } +func retrieveTestTokenAccessorForSecret(codec rpc.ClientCodec, masterToken string, datacenter string, id string) (string, error) { + arg := structs.ACLTokenGetRequest{ + TokenID: "root", + TokenIDType: structs.ACLTokenSecret, + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + var out structs.ACLTokenResponse + + err := msgpackrpc.CallWithCodec(codec, "ACL.TokenRead", &arg, &out) + + if err != nil { + return "", err + } + + if out.Token == nil { + return "", nil + } + + return out.Token.AccessorID, nil +} + // retrieveTestToken returns a policy for testing purposes func retrieveTestToken(codec rpc.ClientCodec, masterToken string, datacenter string, id string) (*structs.ACLTokenResponse, error) { arg := structs.ACLTokenGetRequest{ @@ -1537,6 +4887,18 @@ func retrieveTestToken(codec rpc.ClientCodec, masterToken string, datacenter str return &out, nil } +func deleteTestPolicy(codec rpc.ClientCodec, masterToken string, datacenter string, policyID string) error { + arg := structs.ACLPolicyDeleteRequest{ + Datacenter: datacenter, + PolicyID: policyID, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var ignored string + err := msgpackrpc.CallWithCodec(codec, "ACL.PolicyDelete", &arg, &ignored) + return err +} + // upsertTestPolicy creates a policy for testing purposes func upsertTestPolicy(codec rpc.ClientCodec, masterToken string, datacenter string) (*structs.ACLPolicy, error) { // Make sure test policies can't collide @@ -1586,3 +4948,292 @@ func retrieveTestPolicy(codec rpc.ClientCodec, masterToken string, datacenter st return &out, nil } + +func deleteTestRole(codec rpc.ClientCodec, masterToken string, datacenter string, roleID string) error { + arg := structs.ACLRoleDeleteRequest{ + Datacenter: datacenter, + RoleID: roleID, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var ignored string + err := msgpackrpc.CallWithCodec(codec, "ACL.RoleDelete", &arg, &ignored) + return err +} + +func deleteTestRoleByName(codec rpc.ClientCodec, masterToken string, datacenter string, roleName string) error { + resp, err := retrieveTestRoleByName(codec, masterToken, datacenter, roleName) + if err != nil { + return err + } + if resp.Role == nil { + return nil + } + + return deleteTestRole(codec, masterToken, datacenter, resp.Role.ID) +} + +// upsertTestRole creates a role for testing purposes +func upsertTestRole(codec rpc.ClientCodec, masterToken string, datacenter string) (*structs.ACLRole, error) { + // Make sure test roles can't collide + roleUnq, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + policyID, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + + arg := structs.ACLRoleSetRequest{ + Datacenter: datacenter, + Role: structs.ACLRole{ + Name: fmt.Sprintf("test-role-%s", roleUnq), + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: policyID, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var out structs.ACLRole + + err = msgpackrpc.CallWithCodec(codec, "ACL.RoleSet", &arg, &out) + + if err != nil { + return nil, err + } + + if out.ID == "" { + return nil, fmt.Errorf("ID is nil: %v", out) + } + + return &out, nil +} + +func retrieveTestRole(codec rpc.ClientCodec, masterToken string, datacenter string, id string) (*structs.ACLRoleResponse, error) { + arg := structs.ACLRoleGetRequest{ + Datacenter: datacenter, + RoleID: id, + QueryOptions: structs.QueryOptions{Token: masterToken}, + } + + var out structs.ACLRoleResponse + + err := msgpackrpc.CallWithCodec(codec, "ACL.RoleRead", &arg, &out) + + if err != nil { + return nil, err + } + + return &out, nil +} + +func retrieveTestRoleByName(codec rpc.ClientCodec, masterToken string, datacenter string, name string) (*structs.ACLRoleResponse, error) { + arg := structs.ACLRoleGetRequest{ + Datacenter: datacenter, + RoleName: name, + QueryOptions: structs.QueryOptions{Token: masterToken}, + } + + var out structs.ACLRoleResponse + + err := msgpackrpc.CallWithCodec(codec, "ACL.RoleRead", &arg, &out) + + if err != nil { + return nil, err + } + + return &out, nil +} + +func deleteTestAuthMethod(codec rpc.ClientCodec, masterToken string, datacenter string, methodName string) error { + arg := structs.ACLAuthMethodDeleteRequest{ + Datacenter: datacenter, + AuthMethodName: methodName, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var ignored string + err := msgpackrpc.CallWithCodec(codec, "ACL.AuthMethodDelete", &arg, &ignored) + return err +} +func upsertTestAuthMethod( + codec rpc.ClientCodec, masterToken string, datacenter string, + sessionID string, +) (*structs.ACLAuthMethod, error) { + name, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: datacenter, + AuthMethod: structs.ACLAuthMethod{ + Name: "test-method-" + name, + Type: "testing", + Config: map[string]interface{}{ + "SessionID": sessionID, + }, + }, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var out structs.ACLAuthMethod + + err = msgpackrpc.CallWithCodec(codec, "ACL.AuthMethodSet", &req, &out) + if err != nil { + return nil, err + } + + return &out, nil +} + +func upsertTestKubernetesAuthMethod( + codec rpc.ClientCodec, masterToken string, datacenter string, + caCert, kubeHost, kubeJWT string, +) (*structs.ACLAuthMethod, error) { + name, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + + if kubeHost == "" { + kubeHost = "https://abc:8443" + } + if kubeJWT == "" { + kubeJWT = goodJWT_A + } + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: datacenter, + AuthMethod: structs.ACLAuthMethod{ + Name: "test-method-" + name, + Type: "kubernetes", + Config: map[string]interface{}{ + "Host": kubeHost, + "CACert": caCert, + "ServiceAccountJWT": kubeJWT, + }, + }, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var out structs.ACLAuthMethod + + err = msgpackrpc.CallWithCodec(codec, "ACL.AuthMethodSet", &req, &out) + if err != nil { + return nil, err + } + + return &out, nil +} + +func retrieveTestAuthMethod(codec rpc.ClientCodec, masterToken string, datacenter string, name string) (*structs.ACLAuthMethodResponse, error) { + arg := structs.ACLAuthMethodGetRequest{ + Datacenter: datacenter, + AuthMethodName: name, + QueryOptions: structs.QueryOptions{Token: masterToken}, + } + + var out structs.ACLAuthMethodResponse + + err := msgpackrpc.CallWithCodec(codec, "ACL.AuthMethodRead", &arg, &out) + + if err != nil { + return nil, err + } + + return &out, nil +} + +func deleteTestBindingRule(codec rpc.ClientCodec, masterToken string, datacenter string, ruleID string) error { + arg := structs.ACLBindingRuleDeleteRequest{ + Datacenter: datacenter, + BindingRuleID: ruleID, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var ignored string + err := msgpackrpc.CallWithCodec(codec, "ACL.BindingRuleDelete", &arg, &ignored) + return err +} + +func upsertTestBindingRule( + codec rpc.ClientCodec, + masterToken string, + datacenter string, + methodName string, + selector string, + bindType string, + bindName string, +) (*structs.ACLBindingRule, error) { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: datacenter, + BindingRule: structs.ACLBindingRule{ + AuthMethod: methodName, + BindType: bindType, + BindName: bindName, + Selector: selector, + }, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var out structs.ACLBindingRule + + err := msgpackrpc.CallWithCodec(codec, "ACL.BindingRuleSet", &req, &out) + if err != nil { + return nil, err + } + + return &out, nil +} + +func retrieveTestBindingRule(codec rpc.ClientCodec, masterToken string, datacenter string, ruleID string) (*structs.ACLBindingRuleResponse, error) { + arg := structs.ACLBindingRuleGetRequest{ + Datacenter: datacenter, + BindingRuleID: ruleID, + QueryOptions: structs.QueryOptions{Token: masterToken}, + } + + var out structs.ACLBindingRuleResponse + + err := msgpackrpc.CallWithCodec(codec, "ACL.BindingRuleRead", &arg, &out) + + if err != nil { + return nil, err + } + + return &out, nil +} + +func requireTimeEquals(t *testing.T, expect, got *time.Time) { + t.Helper() + if expect == nil && got == nil { + return + } else if expect == nil && got != nil { + t.Fatalf("expected=NIL != got=%q", *got) + } else if expect != nil && got == nil { + t.Fatalf("expected=%q != got=NIL", *expect) + } else if !expect.Equal(*got) { + t.Fatalf("expected=%q != got=%q", *expect, *got) + } +} + +func requireErrorContains(t *testing.T, err error, expectedErrorMessage string) { + t.Helper() + if err == nil { + t.Fatal("An error is expected but got nil.") + } + if !strings.Contains(err.Error(), expectedErrorMessage) { + t.Fatalf("unexpected error: %v", err) + } +} + +// 'default/admin' +const goodJWT_A = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImFkbWluLXRva2VuLXFsejQyIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQubmFtZSI6ImFkbWluIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQudWlkIjoiNzM4YmMyNTEtNjUzMi0xMWU5LWI2N2YtNDhlNmM4YjhlY2I1Iiwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6YWRtaW4ifQ.ixMlnWrAG7NVuTTKu8cdcYfM7gweS3jlKaEsIBNGOVEjPE7rtXtgMkAwjQTdYR08_0QBjkgzy5fQC5ZNyglSwONJ-bPaXGvhoH1cTnRi1dz9H_63CfqOCvQP1sbdkMeRxNTGVAyWZT76rXoCUIfHP4LY2I8aab0KN9FTIcgZRF0XPTtT70UwGIrSmRpxW38zjiy2ymWL01cc5VWGhJqVysmWmYk3wNp0h5N57H_MOrz4apQR4pKaamzskzjLxO55gpbmZFC76qWuUdexAR7DT2fpbHLOw90atN_NlLMY-VrXyW3-Ei5EhYaVreMB9PSpKwkrA4jULITohV-sxpa1LA" + +// 'default/demo' +const goodJWT_B = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImRlbW8tdG9rZW4ta21iOW4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZGVtbyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6Ijc2MDkxYWY0LTRiNTYtMTFlOS1hYzRiLTcwOGIxMTgwMWNiZSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmRlbW8ifQ.ZiAHjijBAOsKdum0Aix6lgtkLkGo9_Tu87dWQ5Zfwnn3r2FejEWDAnftTft1MqqnMzivZ9Wyyki5ZjQRmTAtnMPJuHC-iivqY4Wh4S6QWCJ1SivBv5tMZR79t5t8mE7R1-OHwst46spru1pps9wt9jsA04d3LpV0eeKYgdPTVaQKklxTm397kIMUugA6yINIBQ3Rh8eQqBgNwEmL4iqyYubzHLVkGkoP9MJikFI05vfRiHtYr-piXz6JFDzXMQj9rW6xtMmrBSn79ChbyvC5nz-Nj2rJPnHsb_0rDUbmXY5PpnMhBpdSH-CbZ4j8jsiib6DtaGJhVZeEQ1GjsFAZwQ" diff --git a/agent/consul/acl_replication.go b/agent/consul/acl_replication.go index d691895b67..4cec1d81a3 100644 --- a/agent/consul/acl_replication.go +++ b/agent/consul/acl_replication.go @@ -3,10 +3,11 @@ package consul import ( "bytes" "context" + "errors" "fmt" "time" - "github.com/armon/go-metrics" + metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/agent/structs" ) @@ -15,123 +16,108 @@ const ( aclReplicationMaxRetryBackoff = 64 ) -func diffACLPolicies(local structs.ACLPolicies, remote structs.ACLPolicyListStubs, lastRemoteIndex uint64) ([]string, []string) { - local.Sort() - remote.Sort() +// aclTypeReplicator allows the machinery of acl replication to be shared between +// types with minimal code duplication (barring generics magically popping into +// existence). +// +// Concrete implementations of this interface should internally contain a +// pointer to the server so that data lookups can occur, and they should +// maintain the smallest quantity of type-specific state they can. +// +// Implementations of this interface are short-lived and recreated on every +// iteration. +type aclTypeReplicator interface { + // Type is variant of replication in use. Used for updating the replication + // status tracker. + Type() structs.ACLReplicationType - var deletions []string - var updates []string - var localIdx int - var remoteIdx int - for localIdx, remoteIdx = 0, 0; localIdx < len(local) && remoteIdx < len(remote); { - if local[localIdx].ID == remote[remoteIdx].ID { - // policy is in both the local and remote state - need to check raft indices and the Hash - if remote[remoteIdx].ModifyIndex > lastRemoteIndex && !bytes.Equal(remote[remoteIdx].Hash, local[localIdx].Hash) { - updates = append(updates, remote[remoteIdx].ID) - } - // increment both indices when equal - localIdx += 1 - remoteIdx += 1 - } else if local[localIdx].ID < remote[remoteIdx].ID { - // policy no longer in remoted state - needs deleting - deletions = append(deletions, local[localIdx].ID) + // SingularNoun is the singular form of the item being replicated. + SingularNoun() string - // increment just the local index - localIdx += 1 - } else { - // local state doesn't have this policy - needs updating - updates = append(updates, remote[remoteIdx].ID) + // PluralNoun is the plural form of the item being replicated. + PluralNoun() string - // increment just the remote index - remoteIdx += 1 - } - } + // FetchRemote retrieves items newer than the provided index from the + // remote datacenter (for diffing purposes). + FetchRemote(srv *Server, lastRemoteIndex uint64) (int, uint64, error) - for ; localIdx < len(local); localIdx += 1 { - deletions = append(deletions, local[localIdx].ID) - } + // FetchLocal retrieves items from the current datacenter (for diffing + // purposes). + FetchLocal(srv *Server) (int, uint64, error) - for ; remoteIdx < len(remote); remoteIdx += 1 { - updates = append(updates, remote[remoteIdx].ID) - } + // SortState sorts the internal working state output of FetchRemote and + // FetchLocal so that a sane diff can be performed. + SortState() (lenLocal, lenRemote int) - return deletions, updates + // LocalMeta allows for type-agnostic metadata from the sorted local state + // can be retrieved for the purposes of diffing. + LocalMeta(i int) (id string, modIndex uint64, hash []byte) + + // RemoteMeta allows for type-agnostic metadata from the sorted remote + // state can be retrieved for the purposes of diffing. + RemoteMeta(i int) (id string, modIndex uint64, hash []byte) + + // FetchUpdated retrieves the specific items from the remote (during the + // correction phase). + FetchUpdated(srv *Server, updates []string) (int, error) + + // LenPendingUpdates should be the size of the data retrieved in + // FetchUpdated. + LenPendingUpdates() int + + // PendingUpdateIsRedacted returns true if the update contains redacted + // data. Really only valid for tokens. + PendingUpdateIsRedacted(i int) bool + + // PendingUpdateEstimatedSize is the item's EstimatedSize in the state + // populated by FetchUpdated. + PendingUpdateEstimatedSize(i int) int + + // UpdateLocalBatch applies a portion of the state populated by + // FetchUpdated to the current datacenter. + UpdateLocalBatch(ctx context.Context, srv *Server, start, end int) error + + // DeleteLocalBatch removes items from the current datacenter. + DeleteLocalBatch(srv *Server, batch []string) error } -func (s *Server) deleteLocalACLPolicies(deletions []string, ctx context.Context) (bool, error) { - ticker := time.NewTicker(time.Second / time.Duration(s.config.ACLReplicationApplyLimit)) - defer ticker.Stop() +var errContainsRedactedData = errors.New("replication results contain redacted data") - for i := 0; i < len(deletions); i += aclBatchDeleteSize { - req := structs.ACLPolicyBatchDeleteRequest{} - - if i+aclBatchDeleteSize > len(deletions) { - req.PolicyIDs = deletions[i:] - } else { - req.PolicyIDs = deletions[i : i+aclBatchDeleteSize] - } - - resp, err := s.raftApply(structs.ACLPolicyDeleteRequestType, &req) - if err != nil { - return false, fmt.Errorf("Failed to apply policy deletions: %v", err) - } - if respErr, ok := resp.(error); ok && err != nil { - return false, fmt.Errorf("Failed to apply policy deletions: %v", respErr) - } - - if i+aclBatchDeleteSize < len(deletions) { - select { - case <-ctx.Done(): - return true, nil - case <-ticker.C: - // do nothing - ready for the next batch - } - } +func (s *Server) fetchACLRolesBatch(roleIDs []string) (*structs.ACLRoleBatchResponse, error) { + req := structs.ACLRoleBatchGetRequest{ + Datacenter: s.config.ACLDatacenter, + RoleIDs: roleIDs, + QueryOptions: structs.QueryOptions{ + AllowStale: true, + Token: s.tokens.ReplicationToken(), + }, } - return false, nil + var response structs.ACLRoleBatchResponse + if err := s.RPC("ACL.RoleBatchRead", &req, &response); err != nil { + return nil, err + } + + return &response, nil } -func (s *Server) updateLocalACLPolicies(policies structs.ACLPolicies, ctx context.Context) (bool, error) { - ticker := time.NewTicker(time.Second / time.Duration(s.config.ACLReplicationApplyLimit)) - defer ticker.Stop() +func (s *Server) fetchACLRoles(lastRemoteIndex uint64) (*structs.ACLRoleListResponse, error) { + defer metrics.MeasureSince([]string{"leader", "replication", "acl", "role", "fetch"}, time.Now()) - // outer loop handles submitting a batch - for batchStart := 0; batchStart < len(policies); { - // inner loop finds the last element to include in this batch. - batchSize := 0 - batchEnd := batchStart - for ; batchEnd < len(policies) && batchSize < aclBatchUpsertSize; batchEnd += 1 { - batchSize += policies[batchEnd].EstimateSize() - } - - req := structs.ACLPolicyBatchSetRequest{ - Policies: policies[batchStart:batchEnd], - } - - resp, err := s.raftApply(structs.ACLPolicySetRequestType, &req) - if err != nil { - return false, fmt.Errorf("Failed to apply policy upserts: %v", err) - } - if respErr, ok := resp.(error); ok && respErr != nil { - return false, fmt.Errorf("Failed to apply policy upsert: %v", respErr) - } - s.logger.Printf("[DEBUG] acl: policy replication - upserted 1 batch with %d policies of size %d", batchEnd-batchStart, batchSize) - - // policies[batchEnd] wasn't include as the slicing doesn't include the element at the stop index - batchStart = batchEnd - - // prevent waiting if we are done - if batchEnd < len(policies) { - select { - case <-ctx.Done(): - return true, nil - case <-ticker.C: - // nothing to do - just rate limiting - } - } + req := structs.ACLRoleListRequest{ + Datacenter: s.config.ACLDatacenter, + QueryOptions: structs.QueryOptions{ + AllowStale: true, + MinQueryIndex: lastRemoteIndex, + Token: s.tokens.ReplicationToken(), + }, } - return false, nil + + var response structs.ACLRoleListResponse + if err := s.RPC("ACL.RoleList", &req, &response); err != nil { + return nil, err + } + return &response, nil } func (s *Server) fetchACLPoliciesBatch(policyIDs []string) (*structs.ACLPolicyBatchResponse, error) { @@ -171,66 +157,72 @@ func (s *Server) fetchACLPolicies(lastRemoteIndex uint64) (*structs.ACLPolicyLis return &response, nil } -type tokenDiffResults struct { +type itemDiffResults struct { LocalDeletes []string LocalUpserts []string LocalSkipped int RemoteSkipped int } -func diffACLTokens(local structs.ACLTokens, remote structs.ACLTokenListStubs, lastRemoteIndex uint64) tokenDiffResults { - // Note: items with empty AccessorIDs will bubble up to the top. - local.Sort() - remote.Sort() +func diffACLType(tr aclTypeReplicator, lastRemoteIndex uint64) itemDiffResults { + // Note: items with empty IDs will bubble up to the top (like legacy, unmigrated Tokens) - var res tokenDiffResults + lenLocal, lenRemote := tr.SortState() + + var res itemDiffResults var localIdx int var remoteIdx int - for localIdx, remoteIdx = 0, 0; localIdx < len(local) && remoteIdx < len(remote); { - if local[localIdx].AccessorID == "" { + for localIdx, remoteIdx = 0, 0; localIdx < lenLocal && remoteIdx < lenRemote; { + localID, _, localHash := tr.LocalMeta(localIdx) + remoteID, remoteMod, remoteHash := tr.RemoteMeta(remoteIdx) + + if localID == "" { res.LocalSkipped++ localIdx += 1 continue } - if remote[remoteIdx].AccessorID == "" { + if remoteID == "" { res.RemoteSkipped++ remoteIdx += 1 continue } - if local[localIdx].AccessorID == remote[remoteIdx].AccessorID { - // policy is in both the local and remote state - need to check raft indices and Hash - if remote[remoteIdx].ModifyIndex > lastRemoteIndex && !bytes.Equal(remote[remoteIdx].Hash, local[localIdx].Hash) { - res.LocalUpserts = append(res.LocalUpserts, remote[remoteIdx].AccessorID) + + if localID == remoteID { + // item is in both the local and remote state - need to check raft indices and the Hash + if remoteMod > lastRemoteIndex && !bytes.Equal(remoteHash, localHash) { + res.LocalUpserts = append(res.LocalUpserts, remoteID) } // increment both indices when equal localIdx += 1 remoteIdx += 1 - } else if local[localIdx].AccessorID < remote[remoteIdx].AccessorID { - // policy no longer in remoted state - needs deleting - res.LocalDeletes = append(res.LocalDeletes, local[localIdx].AccessorID) + } else if localID < remoteID { + // item no longer in remote state - needs deleting + res.LocalDeletes = append(res.LocalDeletes, localID) // increment just the local index localIdx += 1 } else { - // local state doesn't have this policy - needs updating - res.LocalUpserts = append(res.LocalUpserts, remote[remoteIdx].AccessorID) + // local state doesn't have this item - needs updating + res.LocalUpserts = append(res.LocalUpserts, remoteID) // increment just the remote index remoteIdx += 1 } } - for ; localIdx < len(local); localIdx += 1 { - if local[localIdx].AccessorID != "" { - res.LocalDeletes = append(res.LocalDeletes, local[localIdx].AccessorID) + for ; localIdx < lenLocal; localIdx += 1 { + localID, _, _ := tr.LocalMeta(localIdx) + if localID != "" { + res.LocalDeletes = append(res.LocalDeletes, localID) } else { res.LocalSkipped++ } } - for ; remoteIdx < len(remote); remoteIdx += 1 { - if remote[remoteIdx].AccessorID != "" { - res.LocalUpserts = append(res.LocalUpserts, remote[remoteIdx].AccessorID) + for ; remoteIdx < lenRemote; remoteIdx += 1 { + remoteID, _, _ := tr.RemoteMeta(remoteIdx) + if remoteID != "" { + res.LocalUpserts = append(res.LocalUpserts, remoteID) } else { res.RemoteSkipped++ } @@ -239,25 +231,21 @@ func diffACLTokens(local structs.ACLTokens, remote structs.ACLTokenListStubs, la return res } -func (s *Server) deleteLocalACLTokens(deletions []string, ctx context.Context) (bool, error) { +func (s *Server) deleteLocalACLType(ctx context.Context, tr aclTypeReplicator, deletions []string) (bool, error) { ticker := time.NewTicker(time.Second / time.Duration(s.config.ACLReplicationApplyLimit)) defer ticker.Stop() for i := 0; i < len(deletions); i += aclBatchDeleteSize { - req := structs.ACLTokenBatchDeleteRequest{} + var batch []string if i+aclBatchDeleteSize > len(deletions) { - req.TokenIDs = deletions[i:] + batch = deletions[i:] } else { - req.TokenIDs = deletions[i : i+aclBatchDeleteSize] + batch = deletions[i : i+aclBatchDeleteSize] } - resp, err := s.raftApply(structs.ACLTokenDeleteRequestType, &req) - if err != nil { - return false, fmt.Errorf("Failed to apply token deletions: %v", err) - } - if respErr, ok := resp.(error); ok && err != nil { - return false, fmt.Errorf("Failed to apply token deletions: %v", respErr) + if err := tr.DeleteLocalBatch(s, batch); err != nil { + return false, fmt.Errorf("Failed to apply %s deletions: %v", tr.SingularNoun(), err) } if i+aclBatchDeleteSize < len(deletions) { @@ -273,47 +261,50 @@ func (s *Server) deleteLocalACLTokens(deletions []string, ctx context.Context) ( return false, nil } -func (s *Server) updateLocalACLTokens(tokens structs.ACLTokens, ctx context.Context) (bool, error) { +func (s *Server) updateLocalACLType(ctx context.Context, tr aclTypeReplicator) (bool, error) { ticker := time.NewTicker(time.Second / time.Duration(s.config.ACLReplicationApplyLimit)) defer ticker.Stop() + lenPending := tr.LenPendingUpdates() + // outer loop handles submitting a batch - for batchStart := 0; batchStart < len(tokens); { + for batchStart := 0; batchStart < lenPending; { // inner loop finds the last element to include in this batch. batchSize := 0 batchEnd := batchStart - for ; batchEnd < len(tokens) && batchSize < aclBatchUpsertSize; batchEnd += 1 { - if tokens[batchEnd].SecretID == redactedToken { - return false, fmt.Errorf("Detected redacted token secrets: stopping token update round - verify that the replication token in use has acl:write permissions.") + for ; batchEnd < lenPending && batchSize < aclBatchUpsertSize; batchEnd += 1 { + if tr.PendingUpdateIsRedacted(batchEnd) { + return false, fmt.Errorf( + "Detected redacted %s secrets: stopping %s update round - verify that the replication token in use has acl:write permissions.", + tr.SingularNoun(), + tr.SingularNoun(), + ) } - batchSize += tokens[batchEnd].EstimateSize() + batchSize += tr.PendingUpdateEstimatedSize(batchEnd) } - req := structs.ACLTokenBatchSetRequest{ - Tokens: tokens[batchStart:batchEnd], - CAS: false, - } - - resp, err := s.raftApply(structs.ACLTokenSetRequestType, &req) + err := tr.UpdateLocalBatch(ctx, s, batchStart, batchEnd) if err != nil { - return false, fmt.Errorf("Failed to apply token upserts: %v", err) - } - if respErr, ok := resp.(error); ok && respErr != nil { - return false, fmt.Errorf("Failed to apply token upserts: %v", respErr) + return false, fmt.Errorf("Failed to apply %s upserts: %v", tr.SingularNoun(), err) } + s.logger.Printf( + "[DEBUG] acl: %s replication - upserted 1 batch with %d %s of size %d", + tr.SingularNoun(), + batchEnd-batchStart, + tr.PluralNoun(), + batchSize, + ) - s.logger.Printf("[DEBUG] acl: token replication - upserted 1 batch with %d tokens of size %d", batchEnd-batchStart, batchSize) - - // tokens[batchEnd] wasn't include as the slicing doesn't include the element at the stop index + // items[batchEnd] wasn't include as the slicing doesn't include the element at the stop index batchStart = batchEnd // prevent waiting if we are done - if batchEnd < len(tokens) { + if batchEnd < lenPending { select { case <-ctx.Done(): return true, nil case <-ticker.C: - // nothing to do - just rate limiting here + // nothing to do - just rate limiting } } } @@ -359,95 +350,28 @@ func (s *Server) fetchACLTokens(lastRemoteIndex uint64) (*structs.ACLTokenListRe return &response, nil } -func (s *Server) replicateACLPolicies(lastRemoteIndex uint64, ctx context.Context) (uint64, bool, error) { - remote, err := s.fetchACLPolicies(lastRemoteIndex) - if err != nil { - return 0, false, fmt.Errorf("failed to retrieve remote ACL policies: %v", err) - } - - s.logger.Printf("[DEBUG] acl: finished fetching policies tokens: %d", len(remote.Policies)) - - // Need to check if we should be stopping. This will be common as the fetching process is a blocking - // RPC which could have been hanging around for a long time and during that time leadership could - // have been lost. - select { - case <-ctx.Done(): - return 0, true, nil - default: - // do nothing - } - - // Measure everything after the remote query, which can block for long - // periods of time. This metric is a good measure of how expensive the - // replication process is. - defer metrics.MeasureSince([]string{"leader", "replication", "acl", "policy", "apply"}, time.Now()) - - _, local, err := s.fsm.State().ACLPolicyList(nil) - if err != nil { - return 0, false, fmt.Errorf("failed to retrieve local ACL policies: %v", err) - } - - // If the remote index ever goes backwards, it's a good indication that - // the remote side was rebuilt and we should do a full sync since we - // can't make any assumptions about what's going on. - if remote.QueryMeta.Index < lastRemoteIndex { - s.logger.Printf("[WARN] consul: ACL policy replication remote index moved backwards (%d to %d), forcing a full ACL policy sync", lastRemoteIndex, remote.QueryMeta.Index) - lastRemoteIndex = 0 - } - - s.logger.Printf("[DEBUG] acl: policy replication - local: %d, remote: %d", len(local), len(remote.Policies)) - // Calculate the changes required to bring the state into sync and then - // apply them. - deletions, updates := diffACLPolicies(local, remote.Policies, lastRemoteIndex) - - s.logger.Printf("[DEBUG] acl: policy replication - deletions: %d, updates: %d", len(deletions), len(updates)) - - var policies *structs.ACLPolicyBatchResponse - if len(updates) > 0 { - policies, err = s.fetchACLPoliciesBatch(updates) - if err != nil { - return 0, false, fmt.Errorf("failed to retrieve ACL policy updates: %v", err) - } - s.logger.Printf("[DEBUG] acl: policy replication - downloaded %d policies", len(policies.Policies)) - } - - if len(deletions) > 0 { - s.logger.Printf("[DEBUG] acl: policy replication - performing deletions") - - exit, err := s.deleteLocalACLPolicies(deletions, ctx) - if exit { - return 0, true, nil - } - if err != nil { - return 0, false, fmt.Errorf("failed to delete local ACL policies: %v", err) - } - s.logger.Printf("[DEBUG] acl: policy replication - finished deletions") - } - - if len(updates) > 0 { - s.logger.Printf("[DEBUG] acl: policy replication - performing updates") - exit, err := s.updateLocalACLPolicies(policies.Policies, ctx) - if exit { - return 0, true, nil - } - if err != nil { - return 0, false, fmt.Errorf("failed to update local ACL policies: %v", err) - } - s.logger.Printf("[DEBUG] acl: policy replication - finished updates") - } - - // Return the index we got back from the remote side, since we've synced - // up with the remote state as of that index. - return remote.QueryMeta.Index, false, nil +func (s *Server) replicateACLTokens(ctx context.Context, lastRemoteIndex uint64) (uint64, bool, error) { + tr := &aclTokenReplicator{} + return s.replicateACLType(ctx, tr, lastRemoteIndex) } -func (s *Server) replicateACLTokens(lastRemoteIndex uint64, ctx context.Context) (uint64, bool, error) { - remote, err := s.fetchACLTokens(lastRemoteIndex) +func (s *Server) replicateACLPolicies(ctx context.Context, lastRemoteIndex uint64) (uint64, bool, error) { + tr := &aclPolicyReplicator{} + return s.replicateACLType(ctx, tr, lastRemoteIndex) +} + +func (s *Server) replicateACLRoles(ctx context.Context, lastRemoteIndex uint64) (uint64, bool, error) { + tr := &aclRoleReplicator{} + return s.replicateACLType(ctx, tr, lastRemoteIndex) +} + +func (s *Server) replicateACLType(ctx context.Context, tr aclTypeReplicator, lastRemoteIndex uint64) (uint64, bool, error) { + lenRemote, remoteIndex, err := tr.FetchRemote(s, lastRemoteIndex) if err != nil { - return 0, false, fmt.Errorf("failed to retrieve remote ACL tokens: %v", err) + return 0, false, fmt.Errorf("failed to retrieve remote ACL %s: %v", tr.PluralNoun(), err) } - s.logger.Printf("[DEBUG] acl: finished fetching remote tokens: %d", len(remote.Tokens)) + s.logger.Printf("[DEBUG] acl: finished fetching %s: %d", tr.PluralNoun(), lenRemote) // Need to check if we should be stopping. This will be common as the fetching process is a blocking // RPC which could have been hanging around for a long time and during that time leadership could @@ -462,73 +386,99 @@ func (s *Server) replicateACLTokens(lastRemoteIndex uint64, ctx context.Context) // Measure everything after the remote query, which can block for long // periods of time. This metric is a good measure of how expensive the // replication process is. - defer metrics.MeasureSince([]string{"leader", "replication", "acl", "token", "apply"}, time.Now()) + defer metrics.MeasureSince([]string{"leader", "replication", "acl", tr.SingularNoun(), "apply"}, time.Now()) - _, local, err := s.fsm.State().ACLTokenList(nil, false, true, "") + lenLocal, _, err := tr.FetchLocal(s) if err != nil { - return 0, false, fmt.Errorf("failed to retrieve local ACL tokens: %v", err) + return 0, false, fmt.Errorf("failed to retrieve local ACL %s: %v", tr.PluralNoun(), err) } // If the remote index ever goes backwards, it's a good indication that // the remote side was rebuilt and we should do a full sync since we // can't make any assumptions about what's going on. - if remote.QueryMeta.Index < lastRemoteIndex { - s.logger.Printf("[WARN] consul: ACL token replication remote index moved backwards (%d to %d), forcing a full ACL token sync", lastRemoteIndex, remote.QueryMeta.Index) + if remoteIndex < lastRemoteIndex { + s.logger.Printf( + "[WARN] consul: ACL %s replication remote index moved backwards (%d to %d), forcing a full ACL %s sync", + tr.SingularNoun(), + lastRemoteIndex, + remoteIndex, + tr.SingularNoun(), + ) lastRemoteIndex = 0 } - s.logger.Printf("[DEBUG] acl: token replication - local: %d, remote: %d", len(local), len(remote.Tokens)) - - // Calculate the changes required to bring the state into sync and then - // apply them. - res := diffACLTokens(local, remote.Tokens, lastRemoteIndex) + s.logger.Printf( + "[DEBUG] acl: %s replication - local: %d, remote: %d", + tr.SingularNoun(), + lenLocal, + lenRemote, + ) + // Calculate the changes required to bring the state into sync and then apply them. + res := diffACLType(tr, lastRemoteIndex) if res.LocalSkipped > 0 || res.RemoteSkipped > 0 { - s.logger.Printf("[DEBUG] acl: token replication - deletions: %d, updates: %d, skipped: %d, skippedRemote: %d", - len(res.LocalDeletes), len(res.LocalUpserts), res.LocalSkipped, res.RemoteSkipped) + s.logger.Printf( + "[DEBUG] acl: %s replication - deletions: %d, updates: %d, skipped: %d, skippedRemote: %d", + tr.SingularNoun(), + len(res.LocalDeletes), + len(res.LocalUpserts), + res.LocalSkipped, + res.RemoteSkipped, + ) } else { - s.logger.Printf("[DEBUG] acl: token replication - deletions: %d, updates: %d", len(res.LocalDeletes), len(res.LocalUpserts)) + s.logger.Printf( + "[DEBUG] acl: %s replication - deletions: %d, updates: %d", + tr.SingularNoun(), + len(res.LocalDeletes), + len(res.LocalUpserts), + ) } - var tokens *structs.ACLTokenBatchResponse if len(res.LocalUpserts) > 0 { - tokens, err = s.fetchACLTokensBatch(res.LocalUpserts) - if err != nil { - return 0, false, fmt.Errorf("failed to retrieve ACL token updates: %v", err) - } else if tokens.Redacted { - return 0, false, fmt.Errorf("failed to retrieve unredacted tokens - replication token in use does not grant acl:write") + lenUpdated, err := tr.FetchUpdated(s, res.LocalUpserts) + if err == errContainsRedactedData { + return 0, false, fmt.Errorf("failed to retrieve unredacted %s - replication token in use does not grant acl:write", tr.PluralNoun()) + } else if err != nil { + return 0, false, fmt.Errorf("failed to retrieve ACL %s updates: %v", tr.SingularNoun(), err) } - - s.logger.Printf("[DEBUG] acl: token replication - downloaded %d tokens", len(tokens.Tokens)) + s.logger.Printf( + "[DEBUG] acl: %s replication - downloaded %d %s", + tr.SingularNoun(), + lenUpdated, + tr.PluralNoun(), + ) } if len(res.LocalDeletes) > 0 { - s.logger.Printf("[DEBUG] acl: token replication - performing deletions") + s.logger.Printf( + "[DEBUG] acl: %s replication - performing deletions", + tr.SingularNoun(), + ) - exit, err := s.deleteLocalACLTokens(res.LocalDeletes, ctx) + exit, err := s.deleteLocalACLType(ctx, tr, res.LocalDeletes) if exit { return 0, true, nil } if err != nil { - return 0, false, fmt.Errorf("failed to delete local ACL tokens: %v", err) + return 0, false, fmt.Errorf("failed to delete local ACL %s: %v", tr.PluralNoun(), err) } - s.logger.Printf("[DEBUG] acl: token replication - finished deletions") + s.logger.Printf("[DEBUG] acl: %s replication - finished deletions", tr.SingularNoun()) } if len(res.LocalUpserts) > 0 { - s.logger.Printf("[DEBUG] acl: token replication - performing updates") - exit, err := s.updateLocalACLTokens(tokens.Tokens, ctx) + s.logger.Printf("[DEBUG] acl: %s replication - performing updates", tr.SingularNoun()) + exit, err := s.updateLocalACLType(ctx, tr) if exit { return 0, true, nil } if err != nil { - return 0, false, fmt.Errorf("failed to update local ACL tokens: %v", err) + return 0, false, fmt.Errorf("failed to update local ACL %s: %v", tr.PluralNoun(), err) } - s.logger.Printf("[DEBUG] acl: token replication - finished updates") + s.logger.Printf("[DEBUG] acl: %s replication - finished updates", tr.SingularNoun()) } // Return the index we got back from the remote side, since we've synced // up with the remote state as of that index. - return remote.QueryMeta.Index, false, nil + return remoteIndex, false, nil } // IsACLReplicationEnabled returns true if ACL replication is enabled. @@ -546,20 +496,23 @@ func (s *Server) updateACLReplicationStatusError() { s.aclReplicationStatus.LastError = time.Now().Round(time.Second).UTC() } -func (s *Server) updateACLReplicationStatusIndex(index uint64) { +func (s *Server) updateACLReplicationStatusIndex(replicationType structs.ACLReplicationType, index uint64) { s.aclReplicationStatusLock.Lock() defer s.aclReplicationStatusLock.Unlock() s.aclReplicationStatus.LastSuccess = time.Now().Round(time.Second).UTC() - s.aclReplicationStatus.ReplicatedIndex = index -} - -func (s *Server) updateACLReplicationStatusTokenIndex(index uint64) { - s.aclReplicationStatusLock.Lock() - defer s.aclReplicationStatusLock.Unlock() - - s.aclReplicationStatus.LastSuccess = time.Now().Round(time.Second).UTC() - s.aclReplicationStatus.ReplicatedTokenIndex = index + switch replicationType { + case structs.ACLReplicateLegacy: + s.aclReplicationStatus.ReplicatedIndex = index + case structs.ACLReplicateTokens: + s.aclReplicationStatus.ReplicatedTokenIndex = index + case structs.ACLReplicatePolicies: + s.aclReplicationStatus.ReplicatedIndex = index + case structs.ACLReplicateRoles: + s.aclReplicationStatus.ReplicatedRoleIndex = index + default: + panic("unknown replication type: " + replicationType.SingularNoun()) + } } func (s *Server) initReplicationStatus() { @@ -582,6 +535,21 @@ func (s *Server) updateACLReplicationStatusRunning(replicationType structs.ACLRe s.aclReplicationStatusLock.Lock() defer s.aclReplicationStatusLock.Unlock() + // The running state represents which type of overall replication has been + // configured. Though there are various types of internal plumbing for acl + // replication, to the end user there are only 3 distinctly configurable + // variants: legacy, policy, token. Roles replicate with policies so we + // round that up here. + if replicationType == structs.ACLReplicateRoles { + replicationType = structs.ACLReplicatePolicies + } + s.aclReplicationStatus.Running = true s.aclReplicationStatus.ReplicationType = replicationType } + +func (s *Server) getACLReplicationStatusRunningType() (structs.ACLReplicationType, bool) { + s.aclReplicationStatusLock.RLock() + defer s.aclReplicationStatusLock.RUnlock() + return s.aclReplicationStatus.ReplicationType, s.aclReplicationStatus.Running +} diff --git a/agent/consul/acl_replication_legacy.go b/agent/consul/acl_replication_legacy.go index 182e206208..b933f714e9 100644 --- a/agent/consul/acl_replication_legacy.go +++ b/agent/consul/acl_replication_legacy.go @@ -6,7 +6,7 @@ import ( "sort" "time" - "github.com/armon/go-metrics" + metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/agent/structs" ) @@ -138,13 +138,18 @@ func reconcileLegacyACLs(local, remote structs.ACLs, lastRemoteIndex uint64) str // FetchLocalACLs returns the ACLs in the local state store. func (s *Server) fetchLocalLegacyACLs() (structs.ACLs, error) { - _, local, err := s.fsm.State().ACLTokenList(nil, false, true, "") + _, local, err := s.fsm.State().ACLTokenList(nil, false, true, "", "", "") if err != nil { return nil, err } + now := time.Now() + var acls structs.ACLs for _, token := range local { + if token.IsExpired(now) { + continue + } if acl, err := token.Convert(); err == nil && acl != nil { acls = append(acls, acl) } diff --git a/agent/consul/acl_replication_legacy_test.go b/agent/consul/acl_replication_legacy_test.go index a1eea646f2..171f71c359 100644 --- a/agent/consul/acl_replication_legacy_test.go +++ b/agent/consul/acl_replication_legacy_test.go @@ -335,6 +335,10 @@ func TestACLReplication_IsACLReplicationEnabled(t *testing.T) { } } +// Note that this test is testing that legacy token data is replicated, NOT +// directly testing the legacy acl replication goroutine code. +// +// Actually testing legacy replication is difficult to do without old binaries. func TestACLReplication_LegacyTokens(t *testing.T) { t.Parallel() dir1, s1 := testServerWithConfig(t, func(c *Config) { @@ -367,6 +371,12 @@ func TestACLReplication_LegacyTokens(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") testrpc.WaitForLeader(t, s1.RPC, "dc2") + // Wait for legacy acls to be disabled so we are clear that + // legacy replication isn't meddling. + waitForNewACLs(t, s1) + waitForNewACLs(t, s2) + waitForNewACLReplication(t, s2, structs.ACLReplicateTokens) + // Create a bunch of new tokens. var id string for i := 0; i < 50; i++ { @@ -386,14 +396,15 @@ func TestACLReplication_LegacyTokens(t *testing.T) { } checkSame := func() error { - index, remote, err := s1.fsm.State().ACLTokenList(nil, true, true, "") + index, remote, err := s1.fsm.State().ACLTokenList(nil, true, true, "", "", "") if err != nil { return err } - _, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "") + _, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "", "") if err != nil { return err } + if got, want := len(remote), len(local); got != want { return fmt.Errorf("got %d remote ACLs want %d", got, want) } diff --git a/agent/consul/acl_replication_test.go b/agent/consul/acl_replication_test.go index 86c939c6ae..e8a6a7d693 100644 --- a/agent/consul/acl_replication_test.go +++ b/agent/consul/acl_replication_test.go @@ -3,6 +3,7 @@ package consul import ( "fmt" "os" + "strconv" "testing" "time" @@ -15,6 +16,11 @@ import ( ) func TestACLReplication_diffACLPolicies(t *testing.T) { + diffACLPolicies := func(local structs.ACLPolicies, remote structs.ACLPolicyListStubs, lastRemoteIndex uint64) ([]string, []string) { + tr := &aclPolicyReplicator{local: local, remote: remote} + res := diffACLType(tr, lastRemoteIndex) + return res.LocalDeletes, res.LocalUpserts + } local := structs.ACLPolicies{ &structs.ACLPolicy{ ID: "44ef9aec-7654-4401-901b-4d4a8b3c80fc", @@ -127,6 +133,15 @@ func TestACLReplication_diffACLPolicies(t *testing.T) { } func TestACLReplication_diffACLTokens(t *testing.T) { + diffACLTokens := func( + local structs.ACLTokens, + remote structs.ACLTokenListStubs, + lastRemoteIndex uint64, + ) itemDiffResults { + tr := &aclTokenReplicator{local: local, remote: remote} + return diffACLType(tr, lastRemoteIndex) + } + local := structs.ACLTokens{ // When a just-upgraded (1.3->1.4+) secondary DC is replicating from an // upgraded primary DC (1.4+), the local state for tokens predating the @@ -307,6 +322,12 @@ func TestACLReplication_Tokens(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") testrpc.WaitForLeader(t, s1.RPC, "dc2") + // Wait for legacy acls to be disabled so we are clear that + // legacy replication isn't meddling. + waitForNewACLs(t, s1) + waitForNewACLs(t, s2) + waitForNewACLReplication(t, s2, structs.ACLReplicateTokens) + // Create a bunch of new tokens and policies var tokens structs.ACLTokens for i := 0; i < 50; i++ { @@ -328,11 +349,11 @@ func TestACLReplication_Tokens(t *testing.T) { tokens = append(tokens, &token) } - checkSame := func(t *retry.R) error { + checkSame := func(t *retry.R) { // only account for global tokens - local tokens shouldn't be replicated - index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "") + index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "", "") require.NoError(t, err) - _, local, err := s2.fsm.State().ACLTokenList(nil, false, true, "") + _, local, err := s2.fsm.State().ACLTokenList(nil, false, true, "", "", "") require.NoError(t, err) require.Len(t, local, len(remote)) @@ -340,18 +361,15 @@ func TestACLReplication_Tokens(t *testing.T) { require.Equal(t, token.Hash, local[i].Hash) } - var status structs.ACLReplicationStatus s2.aclReplicationStatusLock.RLock() - status = s2.aclReplicationStatus + status := s2.aclReplicationStatus s2.aclReplicationStatusLock.RUnlock() - if !status.Enabled || !status.Running || - status.ReplicationType != structs.ACLReplicateTokens || - status.ReplicatedTokenIndex != index || - status.SourceDatacenter != "dc1" { - return fmt.Errorf("ACL replication status differs") - } - return nil + require.True(t, status.Enabled) + require.True(t, status.Running) + require.Equal(t, status.ReplicationType, structs.ACLReplicateTokens) + require.Equal(t, status.ReplicatedTokenIndex, index) + require.Equal(t, status.SourceDatacenter, "dc1") } // Wait for the replica to converge. retry.Run(t, func(r *retry.R) { @@ -426,7 +444,7 @@ func TestACLReplication_Tokens(t *testing.T) { }) // verify dc2 local tokens didn't get blown away - _, local, err := s2.fsm.State().ACLTokenList(nil, true, false, "") + _, local, err := s2.fsm.State().ACLTokenList(nil, true, false, "", "", "") require.NoError(t, err) require.Len(t, local, 50) @@ -479,6 +497,12 @@ func TestACLReplication_Policies(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") testrpc.WaitForLeader(t, s1.RPC, "dc2") + // Wait for legacy acls to be disabled so we are clear that + // legacy replication isn't meddling. + waitForNewACLs(t, s1) + waitForNewACLs(t, s2) + waitForNewACLReplication(t, s2, structs.ACLReplicatePolicies) + // Create a bunch of new policies var policies structs.ACLPolicies for i := 0; i < 50; i++ { @@ -496,7 +520,7 @@ func TestACLReplication_Policies(t *testing.T) { policies = append(policies, &policy) } - checkSame := func(t *retry.R) error { + checkSame := func(t *retry.R) { // only account for global tokens - local tokens shouldn't be replicated index, remote, err := s1.fsm.State().ACLPolicyList(nil) require.NoError(t, err) @@ -508,18 +532,15 @@ func TestACLReplication_Policies(t *testing.T) { require.Equal(t, policy.Hash, local[i].Hash) } - var status structs.ACLReplicationStatus s2.aclReplicationStatusLock.RLock() - status = s2.aclReplicationStatus + status := s2.aclReplicationStatus s2.aclReplicationStatusLock.RUnlock() - if !status.Enabled || !status.Running || - status.ReplicationType != structs.ACLReplicatePolicies || - status.ReplicatedIndex != index || - status.SourceDatacenter != "dc1" { - return fmt.Errorf("ACL replication status differs") - } - return nil + require.True(t, status.Enabled) + require.True(t, status.Running) + require.Equal(t, status.ReplicationType, structs.ACLReplicatePolicies) + require.Equal(t, status.ReplicatedIndex, index) + require.Equal(t, status.SourceDatacenter, "dc1") } // Wait for the replica to converge. retry.Run(t, func(r *retry.R) { @@ -709,3 +730,249 @@ func TestACLReplication_TokensRedacted(t *testing.T) { require.True(r, status.LastError.After(minErrorTime), "Replication LastError not after the minErrorTime") }) } + +func TestACLReplication_AllTypes(t *testing.T) { + t.Parallel() + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + testrpc.WaitForLeader(t, s1.RPC, "dc1") + client := rpcClient(t, s1) + defer client.Close() + + dir2, s2 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc2" + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLTokenReplication = true + c.ACLReplicationRate = 100 + c.ACLReplicationBurst = 100 + c.ACLReplicationApplyLimit = 1000000 + }) + s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig) + testrpc.WaitForLeader(t, s2.RPC, "dc2") + defer os.RemoveAll(dir2) + defer s2.Shutdown() + + // Try to join. + joinWAN(t, s2, s1) + testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForLeader(t, s1.RPC, "dc2") + + // Wait for legacy acls to be disabled so we are clear that + // legacy replication isn't meddling. + waitForNewACLs(t, s1) + waitForNewACLs(t, s2) + waitForNewACLReplication(t, s2, structs.ACLReplicateTokens) + + const ( + numItems = 50 + numItemsThatAreLocal = 10 + ) + + // Create some data. + policyIDs, roleIDs, tokenIDs := createACLTestData(t, s1, "b1", numItems, numItemsThatAreLocal) + + checkSameTokens := func(t *retry.R) { + // only account for global tokens - local tokens shouldn't be replicated + index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "", "") + require.NoError(t, err) + // Query for all of them, so that we can prove that no globals snuck in. + _, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "", "") + require.NoError(t, err) + + require.Len(t, remote, len(local)) + for i, token := range remote { + require.Equal(t, token.Hash, local[i].Hash) + } + + s2.aclReplicationStatusLock.RLock() + status := s2.aclReplicationStatus + s2.aclReplicationStatusLock.RUnlock() + + require.True(t, status.Enabled) + require.True(t, status.Running) + require.Equal(t, status.ReplicationType, structs.ACLReplicateTokens) + require.Equal(t, status.ReplicatedTokenIndex, index) + require.Equal(t, status.SourceDatacenter, "dc1") + } + checkSamePolicies := func(t *retry.R) { + index, remote, err := s1.fsm.State().ACLPolicyList(nil) + require.NoError(t, err) + _, local, err := s2.fsm.State().ACLPolicyList(nil) + require.NoError(t, err) + + require.Len(t, remote, len(local)) + for i, policy := range remote { + require.Equal(t, policy.Hash, local[i].Hash) + } + + s2.aclReplicationStatusLock.RLock() + status := s2.aclReplicationStatus + s2.aclReplicationStatusLock.RUnlock() + + require.True(t, status.Enabled) + require.True(t, status.Running) + require.Equal(t, status.ReplicationType, structs.ACLReplicateTokens) + require.Equal(t, status.ReplicatedIndex, index) + require.Equal(t, status.SourceDatacenter, "dc1") + } + checkSameRoles := func(t *retry.R) { + index, remote, err := s1.fsm.State().ACLRoleList(nil, "") + require.NoError(t, err) + _, local, err := s2.fsm.State().ACLRoleList(nil, "") + require.NoError(t, err) + + require.Len(t, remote, len(local)) + for i, role := range remote { + require.Equal(t, role.Hash, local[i].Hash) + } + + s2.aclReplicationStatusLock.RLock() + status := s2.aclReplicationStatus + s2.aclReplicationStatusLock.RUnlock() + + require.True(t, status.Enabled) + require.True(t, status.Running) + require.Equal(t, status.ReplicationType, structs.ACLReplicateTokens) + require.Equal(t, status.ReplicatedRoleIndex, index) + require.Equal(t, status.SourceDatacenter, "dc1") + } + checkSame := func(t *retry.R) { + checkSameTokens(t) + checkSamePolicies(t) + checkSameRoles(t) + } + // Wait for the replica to converge. + retry.Run(t, func(r *retry.R) { + checkSame(r) + }) + + // Create additional data to replicate. + _, _, _ = createACLTestData(t, s1, "b2", numItems, numItemsThatAreLocal) + + // Wait for the replica to converge. + retry.Run(t, func(r *retry.R) { + checkSame(r) + }) + + // Delete one piece of each type of data from batch 1. + const itemToDelete = numItems - 1 + { + id := tokenIDs[itemToDelete] + + arg := structs.ACLTokenDeleteRequest{ + Datacenter: "dc1", + TokenID: id, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var dontCare string + if err := s1.RPC("ACL.TokenDelete", &arg, &dontCare); err != nil { + t.Fatalf("err: %v", err) + } + } + { + id := roleIDs[itemToDelete] + + arg := structs.ACLRoleDeleteRequest{ + Datacenter: "dc1", + RoleID: id, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var dontCare string + if err := s1.RPC("ACL.RoleDelete", &arg, &dontCare); err != nil { + t.Fatalf("err: %v", err) + } + } + { + id := policyIDs[itemToDelete] + + arg := structs.ACLPolicyDeleteRequest{ + Datacenter: "dc1", + PolicyID: id, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var dontCare string + if err := s1.RPC("ACL.PolicyDelete", &arg, &dontCare); err != nil { + t.Fatalf("err: %v", err) + } + } + // Wait for the replica to converge. + retry.Run(t, func(r *retry.R) { + checkSame(r) + }) +} + +func createACLTestData(t *testing.T, srv *Server, namePrefix string, numObjects, numItemsThatAreLocal int) (policyIDs, roleIDs, tokenIDs []string) { + require.True(t, numItemsThatAreLocal <= numObjects, 0, "numItemsThatAreLocal <= numObjects") + + // Create some policies. + for i := 0; i < numObjects; i++ { + str := strconv.Itoa(i) + arg := structs.ACLPolicySetRequest{ + Datacenter: "dc1", + Policy: structs.ACLPolicy{ + Name: namePrefix + "-policy-" + str, + Description: namePrefix + "-policy " + str, + Rules: testACLPolicyNew, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var out structs.ACLPolicy + if err := srv.RPC("ACL.PolicySet", &arg, &out); err != nil { + t.Fatalf("err: %v", err) + } + policyIDs = append(policyIDs, out.ID) + } + + // Create some roles. + for i := 0; i < numObjects; i++ { + str := strconv.Itoa(i) + arg := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Name: namePrefix + "-role-" + str, + Description: namePrefix + "-role " + str, + Policies: []structs.ACLRolePolicyLink{ + {ID: policyIDs[i]}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var out structs.ACLRole + if err := srv.RPC("ACL.RoleSet", &arg, &out); err != nil { + t.Fatalf("err: %v", err) + } + roleIDs = append(roleIDs, out.ID) + } + + // Create a bunch of new tokens. + for i := 0; i < numObjects; i++ { + str := strconv.Itoa(i) + arg := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: namePrefix + "-token " + str, + Policies: []structs.ACLTokenPolicyLink{ + {ID: policyIDs[i]}, + }, + Roles: []structs.ACLTokenRoleLink{ + {ID: roleIDs[i]}, + }, + Local: (i < numItemsThatAreLocal), + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var out structs.ACLToken + if err := srv.RPC("ACL.TokenSet", &arg, &out); err != nil { + t.Fatalf("err: %v", err) + } + tokenIDs = append(tokenIDs, out.AccessorID) + } + + return policyIDs, roleIDs, tokenIDs +} diff --git a/agent/consul/acl_replication_types.go b/agent/consul/acl_replication_types.go new file mode 100644 index 0000000000..8efc229632 --- /dev/null +++ b/agent/consul/acl_replication_types.go @@ -0,0 +1,370 @@ +package consul + +import ( + "context" + "fmt" + + "github.com/hashicorp/consul/agent/structs" +) + +type aclTokenReplicator struct { + local structs.ACLTokens + remote structs.ACLTokenListStubs + updated []*structs.ACLToken +} + +var _ aclTypeReplicator = (*aclTokenReplicator)(nil) + +func (r *aclTokenReplicator) Type() structs.ACLReplicationType { return structs.ACLReplicateTokens } +func (r *aclTokenReplicator) SingularNoun() string { return "token" } +func (r *aclTokenReplicator) PluralNoun() string { return "tokens" } + +func (r *aclTokenReplicator) FetchRemote(srv *Server, lastRemoteIndex uint64) (int, uint64, error) { + r.remote = nil + + remote, err := srv.fetchACLTokens(lastRemoteIndex) + if err != nil { + return 0, 0, err + } + + r.remote = remote.Tokens + return len(remote.Tokens), remote.QueryMeta.Index, nil +} + +func (r *aclTokenReplicator) FetchLocal(srv *Server) (int, uint64, error) { + r.local = nil + + idx, local, err := srv.fsm.State().ACLTokenList(nil, false, true, "", "", "") + if err != nil { + return 0, 0, err + } + + // Do not filter by expiration times. Wait until the tokens are explicitly + // deleted. + + r.local = local + return len(local), idx, nil +} + +func (r *aclTokenReplicator) SortState() (int, int) { + r.local.Sort() + r.remote.Sort() + + return len(r.local), len(r.remote) +} +func (r *aclTokenReplicator) LocalMeta(i int) (id string, modIndex uint64, hash []byte) { + v := r.local[i] + return v.AccessorID, v.ModifyIndex, v.Hash +} +func (r *aclTokenReplicator) RemoteMeta(i int) (id string, modIndex uint64, hash []byte) { + v := r.remote[i] + return v.AccessorID, v.ModifyIndex, v.Hash +} + +func (r *aclTokenReplicator) FetchUpdated(srv *Server, updates []string) (int, error) { + r.updated = nil + + if len(updates) > 0 { + tokens, err := srv.fetchACLTokensBatch(updates) + if err != nil { + return 0, err + } else if tokens.Redacted { + return 0, errContainsRedactedData + } + + // Do not filter by expiration times. Wait until the tokens are + // explicitly deleted. + + r.updated = tokens.Tokens + } + + return len(r.updated), nil +} + +func (r *aclTokenReplicator) DeleteLocalBatch(srv *Server, batch []string) error { + req := structs.ACLTokenBatchDeleteRequest{ + TokenIDs: batch, + } + + resp, err := srv.raftApply(structs.ACLTokenDeleteRequestType, &req) + if err != nil { + return err + } + if respErr, ok := resp.(error); ok && err != nil { + return respErr + } + return nil +} + +func (r *aclTokenReplicator) LenPendingUpdates() int { + return len(r.updated) +} + +func (r *aclTokenReplicator) PendingUpdateEstimatedSize(i int) int { + return r.updated[i].EstimateSize() +} + +func (r *aclTokenReplicator) PendingUpdateIsRedacted(i int) bool { + return r.updated[i].SecretID == redactedToken +} + +func (r *aclTokenReplicator) UpdateLocalBatch(ctx context.Context, srv *Server, start, end int) error { + req := structs.ACLTokenBatchSetRequest{ + Tokens: r.updated[start:end], + CAS: false, + } + + resp, err := srv.raftApply(structs.ACLTokenSetRequestType, &req) + if err != nil { + return err + } + if respErr, ok := resp.(error); ok && err != nil { + return respErr + } + + return nil +} + +/////////////////////// + +type aclPolicyReplicator struct { + local structs.ACLPolicies + remote structs.ACLPolicyListStubs + updated []*structs.ACLPolicy +} + +var _ aclTypeReplicator = (*aclPolicyReplicator)(nil) + +func (r *aclPolicyReplicator) Type() structs.ACLReplicationType { return structs.ACLReplicatePolicies } +func (r *aclPolicyReplicator) SingularNoun() string { return "policy" } +func (r *aclPolicyReplicator) PluralNoun() string { return "policies" } + +func (r *aclPolicyReplicator) FetchRemote(srv *Server, lastRemoteIndex uint64) (int, uint64, error) { + r.remote = nil + + remote, err := srv.fetchACLPolicies(lastRemoteIndex) + if err != nil { + return 0, 0, err + } + + r.remote = remote.Policies + return len(remote.Policies), remote.QueryMeta.Index, nil +} + +func (r *aclPolicyReplicator) FetchLocal(srv *Server) (int, uint64, error) { + r.local = nil + + idx, local, err := srv.fsm.State().ACLPolicyList(nil) + if err != nil { + return 0, 0, err + } + + r.local = local + return len(local), idx, nil +} + +func (r *aclPolicyReplicator) SortState() (int, int) { + r.local.Sort() + r.remote.Sort() + + return len(r.local), len(r.remote) +} +func (r *aclPolicyReplicator) LocalMeta(i int) (id string, modIndex uint64, hash []byte) { + v := r.local[i] + return v.ID, v.ModifyIndex, v.Hash +} +func (r *aclPolicyReplicator) RemoteMeta(i int) (id string, modIndex uint64, hash []byte) { + v := r.remote[i] + return v.ID, v.ModifyIndex, v.Hash +} + +func (r *aclPolicyReplicator) FetchUpdated(srv *Server, updates []string) (int, error) { + r.updated = nil + + if len(updates) > 0 { + policies, err := srv.fetchACLPoliciesBatch(updates) + if err != nil { + return 0, err + } + r.updated = policies.Policies + } + + return len(r.updated), nil +} + +func (r *aclPolicyReplicator) DeleteLocalBatch(srv *Server, batch []string) error { + req := structs.ACLPolicyBatchDeleteRequest{ + PolicyIDs: batch, + } + + resp, err := srv.raftApply(structs.ACLPolicyDeleteRequestType, &req) + if err != nil { + return err + } + if respErr, ok := resp.(error); ok && err != nil { + return respErr + } + return nil +} + +func (r *aclPolicyReplicator) LenPendingUpdates() int { + return len(r.updated) +} + +func (r *aclPolicyReplicator) PendingUpdateEstimatedSize(i int) int { + return r.updated[i].EstimateSize() +} + +func (r *aclPolicyReplicator) PendingUpdateIsRedacted(i int) bool { + return false +} + +func (r *aclPolicyReplicator) UpdateLocalBatch(ctx context.Context, srv *Server, start, end int) error { + req := structs.ACLPolicyBatchSetRequest{ + Policies: r.updated[start:end], + } + + resp, err := srv.raftApply(structs.ACLPolicySetRequestType, &req) + if err != nil { + return err + } + if respErr, ok := resp.(error); ok && err != nil { + return respErr + } + + return nil +} + +//////////////////////////////// + +type aclRoleReplicator struct { + local structs.ACLRoles + remote structs.ACLRoles + updated []*structs.ACLRole +} + +var _ aclTypeReplicator = (*aclRoleReplicator)(nil) + +func (r *aclRoleReplicator) Type() structs.ACLReplicationType { return structs.ACLReplicateRoles } +func (r *aclRoleReplicator) SingularNoun() string { return "role" } +func (r *aclRoleReplicator) PluralNoun() string { return "roles" } + +func (r *aclRoleReplicator) FetchRemote(srv *Server, lastRemoteIndex uint64) (int, uint64, error) { + r.remote = nil + + remote, err := srv.fetchACLRoles(lastRemoteIndex) + if err != nil { + return 0, 0, err + } + + r.remote = remote.Roles + return len(remote.Roles), remote.QueryMeta.Index, nil +} + +func (r *aclRoleReplicator) FetchLocal(srv *Server) (int, uint64, error) { + r.local = nil + + idx, local, err := srv.fsm.State().ACLRoleList(nil, "") + if err != nil { + return 0, 0, err + } + + r.local = local + return len(local), idx, nil +} + +func (r *aclRoleReplicator) SortState() (int, int) { + r.local.Sort() + r.remote.Sort() + + return len(r.local), len(r.remote) +} +func (r *aclRoleReplicator) LocalMeta(i int) (id string, modIndex uint64, hash []byte) { + v := r.local[i] + return v.ID, v.ModifyIndex, v.Hash +} +func (r *aclRoleReplicator) RemoteMeta(i int) (id string, modIndex uint64, hash []byte) { + v := r.remote[i] + return v.ID, v.ModifyIndex, v.Hash +} + +func (r *aclRoleReplicator) FetchUpdated(srv *Server, updates []string) (int, error) { + r.updated = nil + + if len(updates) > 0 { + // Since ACLRoles do not have a "list entry" variation, all of the data + // to replicate a role is already present in the "r.remote" list. + // + // We avoid a second query by just repurposing the data we already have + // access to in a way that is compatible with the generic ACL type + // replicator. + keep := make(map[string]struct{}) + for _, id := range updates { + keep[id] = struct{}{} + } + + subset := make([]*structs.ACLRole, 0, len(updates)) + for _, role := range r.remote { + if _, ok := keep[role.ID]; ok { + subset = append(subset, role) + } + } + + if len(subset) != len(keep) { // only possible via programming bug + for _, role := range subset { + delete(keep, role.ID) + } + missing := make([]string, 0, len(keep)) + for id, _ := range keep { + missing = append(missing, id) + } + return 0, fmt.Errorf("role replication trying to replicated uncached roles with IDs: %v", missing) + } + r.updated = subset + } + + return len(r.updated), nil +} + +func (r *aclRoleReplicator) DeleteLocalBatch(srv *Server, batch []string) error { + req := structs.ACLRoleBatchDeleteRequest{ + RoleIDs: batch, + } + + resp, err := srv.raftApply(structs.ACLRoleDeleteRequestType, &req) + if err != nil { + return err + } + if respErr, ok := resp.(error); ok && err != nil { + return respErr + } + return nil +} + +func (r *aclRoleReplicator) LenPendingUpdates() int { + return len(r.updated) +} + +func (r *aclRoleReplicator) PendingUpdateEstimatedSize(i int) int { + return r.updated[i].EstimateSize() +} + +func (r *aclRoleReplicator) PendingUpdateIsRedacted(i int) bool { + return false +} + +func (r *aclRoleReplicator) UpdateLocalBatch(ctx context.Context, srv *Server, start, end int) error { + req := structs.ACLRoleBatchSetRequest{ + Roles: r.updated[start:end], + } + + resp, err := srv.raftApply(structs.ACLRoleSetRequestType, &req) + if err != nil { + return err + } + if respErr, ok := resp.(error); ok && err != nil { + return respErr + } + + return nil +} diff --git a/agent/consul/acl_server.go b/agent/consul/acl_server.go index 1eaf474c2b..34ca09584b 100644 --- a/agent/consul/acl_server.go +++ b/agent/consul/acl_server.go @@ -2,6 +2,7 @@ package consul import ( "sync/atomic" + "time" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/structs" @@ -12,7 +13,7 @@ var serverACLCacheConfig *structs.ACLCachesConfig = &structs.ACLCachesConfig{ // The server's ACL caching has a few underlying assumptions: // // 1 - All policies can be resolved locally. Hence we do not cache any - // unparsed policies as we have memdb for that. + // unparsed policies/roles as we have memdb for that. // 2 - While there could be many identities being used within a DC the // number of distinct policies and combined multi-policy authorizers // will be much less. @@ -25,10 +26,16 @@ var serverACLCacheConfig *structs.ACLCachesConfig = &structs.ACLCachesConfig{ Policies: 0, ParsedPolicies: 512, Authorizers: 1024, + Roles: 0, } func (s *Server) checkTokenUUID(id string) (bool, error) { state := s.fsm.State() + + // We won't check expiration times here. If we generate a UUID that matches + // a token that hasn't been reaped yet, then we won't be able to insert the + // new token due to a collision. + if _, token, err := state.ACLTokenGetByAccessor(nil, id); err != nil { return false, err } else if token != nil { @@ -55,6 +62,28 @@ func (s *Server) checkPolicyUUID(id string) (bool, error) { return !structs.ACLIDReserved(id), nil } +func (s *Server) checkRoleUUID(id string) (bool, error) { + state := s.fsm.State() + if _, role, err := state.ACLRoleGetByID(nil, id); err != nil { + return false, err + } else if role != nil { + return false, nil + } + + return !structs.ACLIDReserved(id), nil +} + +func (s *Server) checkBindingRuleUUID(id string) (bool, error) { + state := s.fsm.State() + if _, rule, err := state.ACLBindingRuleGetByID(nil, id); err != nil { + return false, err + } else if rule != nil { + return false, nil + } + + return !structs.ACLIDReserved(id), nil +} + func (s *Server) updateACLAdvertisement() { // One thing to note is that once in new ACL mode the server will // never transition to legacy ACL mode. This is not currently a @@ -145,7 +174,7 @@ func (s *Server) ResolveIdentityFromToken(token string) (bool, structs.ACLIdenti index, aclToken, err := s.fsm.State().ACLTokenGetBySecret(nil, token) if err != nil { return true, nil, err - } else if aclToken != nil { + } else if aclToken != nil && !aclToken.IsExpired(time.Now()) { return true, aclToken, nil } @@ -166,6 +195,20 @@ func (s *Server) ResolvePolicyFromID(policyID string) (bool, *structs.ACLPolicy, return s.InACLDatacenter() || index > 0, policy, acl.ErrNotFound } +func (s *Server) ResolveRoleFromID(roleID string) (bool, *structs.ACLRole, error) { + index, role, err := s.fsm.State().ACLRoleGetByID(nil, roleID) + if err != nil { + return true, nil, err + } else if role != nil { + return true, role, nil + } + + // If the max index of the roles table is non-zero then we have acls, until then + // we may need to allow remote resolution. This is particularly useful to allow updating + // the replication token via the API in a non-primary dc. + return s.InACLDatacenter() || index > 0, role, acl.ErrNotFound +} + func (s *Server) ResolveToken(token string) (acl.Authorizer, error) { return s.acls.ResolveToken(token) } diff --git a/agent/consul/acl_test.go b/agent/consul/acl_test.go index 38b723cbdb..e2c84afe21 100644 --- a/agent/consul/acl_test.go +++ b/agent/consul/acl_test.go @@ -2,7 +2,6 @@ package consul import ( "fmt" - "log" "os" "reflect" "strings" @@ -12,6 +11,7 @@ import ( "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -60,6 +60,29 @@ func testIdentityForToken(token string) (bool, structs.ACLIdentity, error) { }, }, }, nil + case "missing-role": + return true, &structs.ACLToken{ + AccessorID: "435a75af-1763-4980-89f4-f0951dda53b4", + SecretID: "b1b6be70-ed2e-4c80-8495-bdb3db110b1e", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: "not-found", + }, + structs.ACLTokenRoleLink{ + ID: "acl-ro", + }, + }, + }, nil + case "missing-policy-on-role": + return true, &structs.ACLToken{ + AccessorID: "435a75af-1763-4980-89f4-f0951dda53b4", + SecretID: "b1b6be70-ed2e-4c80-8495-bdb3db110b1e", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: "missing-policy", + }, + }, + }, nil case "legacy-management": return true, &structs.ACLToken{ AccessorID: "d109a033-99d1-47e2-a711-d6593373a973", @@ -86,6 +109,56 @@ func testIdentityForToken(token string) (bool, structs.ACLIdentity, error) { }, }, }, nil + case "found-role": + // This should be permission-wise identical to "found", except it + // gets it's policies indirectly by way of a Role. + return true, &structs.ACLToken{ + AccessorID: "5f57c1f6-6a89-4186-9445-531b316e01df", + SecretID: "a1a54629-5050-4d17-8a4e-560d2423f835", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: "found", + }, + }, + }, nil + case "found-policy-and-role": + return true, &structs.ACLToken{ + AccessorID: "5f57c1f6-6a89-4186-9445-531b316e01df", + SecretID: "a1a54629-5050-4d17-8a4e-560d2423f835", + Policies: []structs.ACLTokenPolicyLink{ + structs.ACLTokenPolicyLink{ + ID: "node-wr", + }, + structs.ACLTokenPolicyLink{ + ID: "dc2-key-wr", + }, + }, + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: "service-ro", + }, + }, + }, nil + case "found-synthetic-policy-1": + return true, &structs.ACLToken{ + AccessorID: "f6c5a5fb-4da4-422b-9abf-2c942813fc71", + SecretID: "55cb7d69-2bea-42c3-a68f-2a1443d2abbc", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "service1", + }, + }, + }, nil + case "found-synthetic-policy-2": + return true, &structs.ACLToken{ + AccessorID: "7c87dfad-be37-446e-8305-299585677cb5", + SecretID: "dfca9676-ac80-453a-837b-4c0cf923473c", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "service2", + }, + }, + }, nil case "acl-ro": return true, &structs.ACLToken{ AccessorID: "435a75af-1763-4980-89f4-f0951dda53b4", @@ -177,6 +250,24 @@ func testPolicyForID(policyID string) (bool, *structs.ACLPolicy, error) { Syntax: acl.SyntaxCurrent, RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, }, nil + case "service-ro": + return true, &structs.ACLPolicy{ + ID: "service-ro", + Name: "service-ro", + Description: "service-ro", + Rules: `service_prefix "" { policy = "read" }`, + Syntax: acl.SyntaxCurrent, + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, nil + case "service-wr": + return true, &structs.ACLPolicy{ + ID: "service-wr", + Name: "service-wr", + Description: "service-wr", + Rules: `service_prefix "" { policy = "write" }`, + Syntax: acl.SyntaxCurrent, + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, nil case "node-wr": return true, &structs.ACLPolicy{ ID: "node-wr", @@ -202,6 +293,141 @@ func testPolicyForID(policyID string) (bool, *structs.ACLPolicy, error) { } } +func testRoleForID(roleID string) (bool, *structs.ACLRole, error) { + switch roleID { + case "service-ro": + return true, &structs.ACLRole{ + ID: "service-ro", + Name: "service-ro", + Description: "service-ro", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "service-ro", + }, + }, + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, nil + case "service-wr": + return true, &structs.ACLRole{ + ID: "service-wr", + Name: "service-wr", + Description: "service-wr", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "service-wr", + }, + }, + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, nil + case "missing-policy": + return true, &structs.ACLRole{ + ID: "missing-policy", + Name: "missing-policy", + Description: "missing-policy", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "not-found", + }, + structs.ACLRolePolicyLink{ + ID: "acl-ro", + }, + }, + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, nil + case "found": + return true, &structs.ACLRole{ + ID: "found", + Name: "found", + Description: "found", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "node-wr", + }, + structs.ACLRolePolicyLink{ + ID: "dc2-key-wr", + }, + }, + }, nil + case "acl-ro": + return true, &structs.ACLRole{ + ID: "acl-ro", + Name: "acl-ro", + Description: "acl-ro", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "acl-ro", + }, + }, + }, nil + case "acl-wr": + return true, &structs.ACLRole{ + ID: "acl-rw", + Name: "acl-rw", + Description: "acl-rw", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "acl-wr", + }, + }, + }, nil + case "racey-unmodified": + return true, &structs.ACLRole{ + ID: "racey-unmodified", + Name: "racey-unmodified", + Description: "racey-unmodified", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "node-wr", + }, + structs.ACLRolePolicyLink{ + ID: "acl-wr", + }, + }, + }, nil + case "racey-modified": + return true, &structs.ACLRole{ + ID: "racey-modified", + Name: "racey-modified", + Description: "racey-modified", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "node-wr", + }, + }, + }, nil + case "concurrent-resolve-1": + return true, &structs.ACLRole{ + ID: "concurrent-resolve-1", + Name: "concurrent-resolve-1", + Description: "concurrent-resolve-1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "node-wr", + }, + structs.ACLRolePolicyLink{ + ID: "acl-wr", + }, + }, + }, nil + case "concurrent-resolve-2": + return true, &structs.ACLRole{ + ID: "concurrent-resolve-2", + Name: "concurrent-resolve-2", + Description: "concurrent-resolve-2", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "node-wr", + }, + structs.ACLRolePolicyLink{ + ID: "acl-wr", + }, + }, + }, nil + default: + return true, nil, acl.ErrNotFound + } +} + // ACLResolverTestDelegate is used to test // the ACLResolver without running Agents type ACLResolverTestDelegate struct { @@ -210,9 +436,99 @@ type ACLResolverTestDelegate struct { legacy bool localTokens bool localPolicies bool + localRoles bool getPolicyFn func(*structs.ACLPolicyResolveLegacyRequest, *structs.ACLPolicyResolveLegacyResponse) error tokenReadFn func(*structs.ACLTokenGetRequest, *structs.ACLTokenResponse) error policyResolveFn func(*structs.ACLPolicyBatchGetRequest, *structs.ACLPolicyBatchResponse) error + roleResolveFn func(*structs.ACLRoleBatchGetRequest, *structs.ACLRoleBatchResponse) error + + // state for the optional default resolver function defaultTokenReadFn + tokenCached bool + // state for the optional default resolver function defaultPolicyResolveFn + policyCached bool + // state for the optional default resolver function defaultRoleResolveFn + roleCached bool +} + +func (d *ACLResolverTestDelegate) Reset() { + d.tokenCached = false + d.policyCached = false + d.roleCached = false +} + +var errRPC = fmt.Errorf("Induced RPC Error") + +func (d *ACLResolverTestDelegate) defaultTokenReadFn(errAfterCached error) func(*structs.ACLTokenGetRequest, *structs.ACLTokenResponse) error { + return func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { + if !d.tokenCached { + err := d.plainTokenReadFn(args, reply) + d.tokenCached = true + return err + } + return errAfterCached + } +} + +func (d *ACLResolverTestDelegate) plainTokenReadFn(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { + _, token, err := testIdentityForToken(args.TokenID) + if token != nil { + reply.Token = token.(*structs.ACLToken) + } + return err +} + +func (d *ACLResolverTestDelegate) defaultPolicyResolveFn(errAfterCached error) func(*structs.ACLPolicyBatchGetRequest, *structs.ACLPolicyBatchResponse) error { + return func(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error { + if !d.policyCached { + err := d.plainPolicyResolveFn(args, reply) + d.policyCached = true + return err + } + + return errAfterCached + } +} + +func (d *ACLResolverTestDelegate) plainPolicyResolveFn(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error { + // TODO: if we were being super correct about it, we'd verify the token first + // TODO: and possibly return a not-found or permission-denied here + + for _, policyID := range args.PolicyIDs { + _, policy, _ := testPolicyForID(policyID) + if policy != nil { + reply.Policies = append(reply.Policies, policy) + } + } + + return nil +} + +func (d *ACLResolverTestDelegate) defaultRoleResolveFn(errAfterCached error) func(*structs.ACLRoleBatchGetRequest, *structs.ACLRoleBatchResponse) error { + return func(args *structs.ACLRoleBatchGetRequest, reply *structs.ACLRoleBatchResponse) error { + if !d.roleCached { + err := d.plainRoleResolveFn(args, reply) + d.roleCached = true + return err + } + + return errAfterCached + } +} + +// plainRoleResolveFn tries to follow the normal logic of ACL.RoleResolve using +// the test fixtures. +func (d *ACLResolverTestDelegate) plainRoleResolveFn(args *structs.ACLRoleBatchGetRequest, reply *structs.ACLRoleBatchResponse) error { + // TODO: if we were being super correct about it, we'd verify the token first + // TODO: and possibly return a not-found or permission-denied here + + for _, roleID := range args.RoleIDs { + _, role, _ := testRoleForID(roleID) + if role != nil { + reply.Roles = append(reply.Roles, role) + } + } + + return nil } func (d *ACLResolverTestDelegate) ACLsEnabled() bool { @@ -243,23 +559,36 @@ func (d *ACLResolverTestDelegate) ResolvePolicyFromID(policyID string) (bool, *s return testPolicyForID(policyID) } +func (d *ACLResolverTestDelegate) ResolveRoleFromID(roleID string) (bool, *structs.ACLRole, error) { + if !d.localRoles { + return false, nil, nil + } + + return testRoleForID(roleID) +} + func (d *ACLResolverTestDelegate) RPC(method string, args interface{}, reply interface{}) error { switch method { case "ACL.GetPolicy": if d.getPolicyFn != nil { return d.getPolicyFn(args.(*structs.ACLPolicyResolveLegacyRequest), reply.(*structs.ACLPolicyResolveLegacyResponse)) } - panic("Bad Test Implmentation: should provide a getPolicyFn to the ACLResolverTestDelegate") + panic("Bad Test Implementation: should provide a getPolicyFn to the ACLResolverTestDelegate") case "ACL.TokenRead": if d.tokenReadFn != nil { return d.tokenReadFn(args.(*structs.ACLTokenGetRequest), reply.(*structs.ACLTokenResponse)) } - panic("Bad Test Implmentation: should provide a tokenReadFn to the ACLResolverTestDelegate") + panic("Bad Test Implementation: should provide a tokenReadFn to the ACLResolverTestDelegate") case "ACL.PolicyResolve": if d.policyResolveFn != nil { return d.policyResolveFn(args.(*structs.ACLPolicyBatchGetRequest), reply.(*structs.ACLPolicyBatchResponse)) } - panic("Bad Test Implmentation: should provide a policyResolveFn to the ACLResolverTestDelegate") + panic("Bad Test Implementation: should provide a policyResolveFn to the ACLResolverTestDelegate") + case "ACL.RoleResolve": + if d.roleResolveFn != nil { + return d.roleResolveFn(args.(*structs.ACLRoleBatchGetRequest), reply.(*structs.ACLRoleBatchResponse)) + } + panic("Bad Test Implementation: should provide a roleResolveFn to the ACLResolverTestDelegate") } panic("Bad Test Implementation: Was the ACLResolver updated to use new RPC methods") } @@ -270,12 +599,13 @@ func newTestACLResolver(t *testing.T, delegate ACLResolverDelegate, cb func(*ACL config.ACLDownPolicy = "extend-cache" rconf := &ACLResolverConfig{ Config: config, - Logger: log.New(os.Stdout, t.Name()+" - ", log.LstdFlags|log.Lmicroseconds), + Logger: testutil.TestLoggerWithName(t, t.Name()), CacheConfig: &structs.ACLCachesConfig{ Identities: 4, Policies: 4, ParsedPolicies: 4, Authorizers: 4, + Roles: 4, }, AutoDisable: true, Delegate: delegate, @@ -371,8 +701,9 @@ func TestACLResolver_DownPolicy(t *testing.T) { legacy: false, localTokens: false, localPolicies: true, + localRoles: true, tokenReadFn: func(*structs.ACLTokenGetRequest, *structs.ACLTokenResponse) error { - return fmt.Errorf("Induced RPC Error") + return errRPC }, } r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { @@ -395,8 +726,9 @@ func TestACLResolver_DownPolicy(t *testing.T) { legacy: false, localTokens: false, localPolicies: true, + localRoles: true, tokenReadFn: func(*structs.ACLTokenGetRequest, *structs.ACLTokenResponse) error { - return fmt.Errorf("Induced RPC Error") + return errRPC }, } r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { @@ -413,32 +745,20 @@ func TestACLResolver_DownPolicy(t *testing.T) { t.Run("Expired-Policy", func(t *testing.T) { t.Parallel() - policyCached := false delegate := &ACLResolverTestDelegate{ enabled: true, datacenter: "dc1", legacy: false, localTokens: true, localPolicies: false, - policyResolveFn: func(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error { - if !policyCached { - for _, policyID := range args.PolicyIDs { - _, policy, _ := testPolicyForID(policyID) - if policy != nil { - reply.Policies = append(reply.Policies, policy) - } - } - - policyCached = true - return nil - } - - return fmt.Errorf("Induced RPC Error") - }, + localRoles: false, } + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(errRPC) + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { config.Config.ACLDownPolicy = "deny" config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 }) authz, err := r.ResolveToken("found") @@ -460,75 +780,118 @@ func TestACLResolver_DownPolicy(t *testing.T) { requirePolicyCached(t, r, "dc2-key-wr", false, "expired") // from "found" token }) - t.Run("Extend-Cache", func(t *testing.T) { + t.Run("Expired-Role", func(t *testing.T) { t.Parallel() - cached := false - delegate := &ACLResolverTestDelegate{ - enabled: true, - datacenter: "dc1", - legacy: false, - localTokens: false, - localPolicies: true, - tokenReadFn: func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { - if !cached { - _, token, _ := testIdentityForToken("found") - reply.Token = token.(*structs.ACLToken) - cached = true - return nil - } - return fmt.Errorf("Induced RPC Error") - }, - } - r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { - config.Config.ACLDownPolicy = "extend-cache" - config.Config.ACLTokenTTL = 0 - }) - - authz, err := r.ResolveToken("foo") - require.NoError(t, err) - require.NotNil(t, authz) - require.True(t, authz.NodeWrite("foo", nil)) - - requireIdentityCached(t, r, "foo", true, "cached") - - authz2, err := r.ResolveToken("foo") - require.NoError(t, err) - require.NotNil(t, authz2) - // testing pointer equality - these will be the same object because it is cached. - require.True(t, authz == authz2) - require.True(t, authz.NodeWrite("foo", nil)) - - requireIdentityCached(t, r, "foo", true, "still cached") - }) - - t.Run("Extend-Cache-Expired-Policy", func(t *testing.T) { - t.Parallel() - policyCached := false delegate := &ACLResolverTestDelegate{ enabled: true, datacenter: "dc1", legacy: false, localTokens: true, localPolicies: false, - policyResolveFn: func(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error { - if !policyCached { - for _, policyID := range args.PolicyIDs { - _, policy, _ := testPolicyForID(policyID) - if policy != nil { - reply.Policies = append(reply.Policies, policy) - } - } - - policyCached = true - return nil - } - - return fmt.Errorf("Induced RPC Error") - }, + localRoles: false, } + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(errRPC) + delegate.roleResolveFn = delegate.defaultRoleResolveFn(errRPC) + + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { + config.Config.ACLDownPolicy = "deny" + config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 + }) + + authz, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz) + require.True(t, authz.NodeWrite("foo", nil)) + + // role cache expired - so we will fail to resolve that role and use the default policy only + authz2, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz2) + require.False(t, authz == authz2) + require.False(t, authz2.NodeWrite("foo", nil)) + }) + + t.Run("Extend-Cache-Policy", func(t *testing.T) { + t.Parallel() + delegate := &ACLResolverTestDelegate{ + enabled: true, + datacenter: "dc1", + legacy: false, + localTokens: false, + localPolicies: true, + localRoles: true, + } + delegate.tokenReadFn = delegate.defaultTokenReadFn(errRPC) + + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { + config.Config.ACLDownPolicy = "extend-cache" + config.Config.ACLTokenTTL = 0 + }) + + authz, err := r.ResolveToken("found") + require.NoError(t, err) + require.NotNil(t, authz) + require.True(t, authz.NodeWrite("foo", nil)) + + requireIdentityCached(t, r, "found", true, "cached") + + authz2, err := r.ResolveToken("found") + require.NoError(t, err) + require.NotNil(t, authz2) + // testing pointer equality - these will be the same object because it is cached. + require.True(t, authz == authz2) + require.True(t, authz2.NodeWrite("foo", nil)) + }) + + t.Run("Extend-Cache-Role", func(t *testing.T) { + t.Parallel() + delegate := &ACLResolverTestDelegate{ + enabled: true, + datacenter: "dc1", + legacy: false, + localTokens: false, + localPolicies: true, + localRoles: true, + } + delegate.tokenReadFn = delegate.defaultTokenReadFn(errRPC) + + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { + config.Config.ACLDownPolicy = "extend-cache" + config.Config.ACLTokenTTL = 0 + }) + + authz, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz) + require.True(t, authz.NodeWrite("foo", nil)) + + requireIdentityCached(t, r, "found-role", true, "still cached") + + authz2, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz2) + // testing pointer equality - these will be the same object because it is cached. + require.True(t, authz == authz2) + require.True(t, authz2.NodeWrite("foo", nil)) + }) + + t.Run("Extend-Cache-Expired-Policy", func(t *testing.T) { + t.Parallel() + delegate := &ACLResolverTestDelegate{ + enabled: true, + datacenter: "dc1", + legacy: false, + localTokens: true, + localPolicies: false, + localRoles: false, + } + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(errRPC) + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { config.Config.ACLDownPolicy = "extend-cache" config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 }) authz, err := r.ResolveToken("found") @@ -550,36 +913,56 @@ func TestACLResolver_DownPolicy(t *testing.T) { requirePolicyCached(t, r, "dc2-key-wr", true, "still cached") // from "found" token }) - t.Run("Async-Cache-Expired-Policy", func(t *testing.T) { + t.Run("Extend-Cache-Expired-Role", func(t *testing.T) { t.Parallel() - policyCached := false delegate := &ACLResolverTestDelegate{ enabled: true, datacenter: "dc1", legacy: false, localTokens: true, localPolicies: false, - policyResolveFn: func(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error { - if !policyCached { - for _, policyID := range args.PolicyIDs { - _, policy, _ := testPolicyForID(policyID) - if policy != nil { - reply.Policies = append(reply.Policies, policy) - } - } - - policyCached = true - return nil - } - - // We don't need to return acl.ErrNotFound here but we could. The ACLResolver will search for any - // policies not in the response and emit an ACL not found for any not-found within the result set. - return nil - }, + localRoles: false, } + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(errRPC) + delegate.roleResolveFn = delegate.defaultRoleResolveFn(errRPC) + + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { + config.Config.ACLDownPolicy = "extend-cache" + config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 + }) + + authz, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz) + require.True(t, authz.NodeWrite("foo", nil)) + + // Will just use the policy cache + authz2, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz2) + require.True(t, authz == authz2) + require.True(t, authz.NodeWrite("foo", nil)) + }) + + t.Run("Async-Cache-Expired-Policy", func(t *testing.T) { + t.Parallel() + delegate := &ACLResolverTestDelegate{ + enabled: true, + datacenter: "dc1", + legacy: false, + localTokens: true, + localPolicies: false, + localRoles: false, + } + // We don't need to return acl.ErrNotFound here but we could. The ACLResolver will search for any + // policies not in the response and emit an ACL not found for any not-found within the result set. + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(nil) + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { config.Config.ACLDownPolicy = "async-cache" config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 }) authz, err := r.ResolveToken("found") @@ -613,45 +996,67 @@ func TestACLResolver_DownPolicy(t *testing.T) { requirePolicyCached(t, r, "dc2-key-wr", false, "no longer cached") // from "found" token }) - t.Run("Extend-Cache-Client", func(t *testing.T) { + t.Run("Async-Cache-Expired-Role", func(t *testing.T) { + t.Parallel() + delegate := &ACLResolverTestDelegate{ + enabled: true, + datacenter: "dc1", + legacy: false, + localTokens: true, + localPolicies: false, + localRoles: false, + } + // We don't need to return acl.ErrNotFound here but we could. The ACLResolver will search for any + // policies not in the response and emit an ACL not found for any not-found within the result set. + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(nil) + delegate.roleResolveFn = delegate.defaultRoleResolveFn(nil) + + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { + config.Config.ACLDownPolicy = "async-cache" + config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 + }) + + authz, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz) + require.True(t, authz.NodeWrite("foo", nil)) + + // The identity should have been cached so this should still be valid + authz2, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz2) + // testing pointer equality - these will be the same object because it is cached. + require.True(t, authz == authz2) + require.True(t, authz.NodeWrite("foo", nil)) + + // the go routine spawned will eventually return with a authz that doesn't have the policy + retry.Run(t, func(t *retry.R) { + authz3, err := r.ResolveToken("found-role") + assert.NoError(t, err) + assert.NotNil(t, authz3) + assert.False(t, authz3.NodeWrite("foo", nil)) + }) + }) + + t.Run("Extend-Cache-Client-Policy", func(t *testing.T) { t.Parallel() - tokenCached := false - policyCached := false delegate := &ACLResolverTestDelegate{ enabled: true, datacenter: "dc1", legacy: false, localTokens: false, localPolicies: false, - tokenReadFn: func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { - if !tokenCached { - _, token, _ := testIdentityForToken("found") - reply.Token = token.(*structs.ACLToken) - tokenCached = true - return nil - } - return fmt.Errorf("Induced RPC Error") - }, - policyResolveFn: func(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error { - if !policyCached { - for _, policyID := range args.PolicyIDs { - _, policy, _ := testPolicyForID(policyID) - if policy != nil { - reply.Policies = append(reply.Policies, policy) - } - } - - policyCached = true - return nil - } - - return fmt.Errorf("Induced RPC Error") - }, + localRoles: false, } + delegate.tokenReadFn = delegate.defaultTokenReadFn(errRPC) + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(errRPC) + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { config.Config.ACLDownPolicy = "extend-cache" config.Config.ACLTokenTTL = 0 config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 }) authz, err := r.ResolveToken("found") @@ -667,62 +1072,89 @@ func TestACLResolver_DownPolicy(t *testing.T) { require.NotNil(t, authz2) // testing pointer equality - these will be the same object because it is cached. require.True(t, authz == authz2) + require.True(t, authz2.NodeWrite("foo", nil)) + }) + + t.Run("Extend-Cache-Client-Role", func(t *testing.T) { + t.Parallel() + delegate := &ACLResolverTestDelegate{ + enabled: true, + datacenter: "dc1", + legacy: false, + localTokens: false, + localPolicies: false, + localRoles: false, + } + delegate.tokenReadFn = delegate.defaultTokenReadFn(errRPC) + delegate.policyResolveFn = delegate.defaultPolicyResolveFn(errRPC) + delegate.roleResolveFn = delegate.defaultRoleResolveFn(errRPC) + + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { + config.Config.ACLDownPolicy = "extend-cache" + config.Config.ACLTokenTTL = 0 + config.Config.ACLPolicyTTL = 0 + config.Config.ACLRoleTTL = 0 + }) + + authz, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz) require.True(t, authz.NodeWrite("foo", nil)) requirePolicyCached(t, r, "node-wr", true, "still cached") // from "found" token requirePolicyCached(t, r, "dc2-key-wr", true, "still cached") // from "found" token + + authz2, err := r.ResolveToken("found-role") + require.NoError(t, err) + require.NotNil(t, authz2) + // testing pointer equality - these will be the same object because it is cached. + require.True(t, authz == authz2, "\n[1]={%+v} != \n[2]={%+v}", authz, authz2) + require.True(t, authz2.NodeWrite("foo", nil)) }) t.Run("Async-Cache", func(t *testing.T) { t.Parallel() - cached := false delegate := &ACLResolverTestDelegate{ enabled: true, datacenter: "dc1", legacy: false, localTokens: false, localPolicies: true, - tokenReadFn: func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { - if !cached { - _, token, _ := testIdentityForToken("found") - reply.Token = token.(*structs.ACLToken) - cached = true - return nil - } - return acl.ErrNotFound - }, + localRoles: true, } + delegate.tokenReadFn = delegate.defaultTokenReadFn(acl.ErrNotFound) + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { config.Config.ACLDownPolicy = "async-cache" config.Config.ACLTokenTTL = 0 }) - authz, err := r.ResolveToken("foo") + authz, err := r.ResolveToken("found") require.NoError(t, err) require.NotNil(t, authz) require.True(t, authz.NodeWrite("foo", nil)) - requireIdentityCached(t, r, "foo", true, "cached") + requireIdentityCached(t, r, "found", true, "cached") // The identity should have been cached so this should still be valid - authz2, err := r.ResolveToken("foo") + authz2, err := r.ResolveToken("found") require.NoError(t, err) require.NotNil(t, authz2) // testing pointer equality - these will be the same object because it is cached. require.True(t, authz == authz2) - require.True(t, authz.NodeWrite("foo", nil)) + require.True(t, authz2.NodeWrite("foo", nil)) - requireIdentityCached(t, r, "foo", true, "cached") + requireIdentityCached(t, r, "found", true, "cached") // the go routine spawned will eventually return and this will be a not found error retry.Run(t, func(t *retry.R) { - authz3, err := r.ResolveToken("foo") + authz3, err := r.ResolveToken("found") assert.Error(t, err) assert.True(t, acl.IsErrNotFound(err)) assert.Nil(t, authz3) }) - requireIdentityCached(t, r, "foo", false, "no longer cached") + requireIdentityCached(t, r, "found", false, "no longer cached") }) t.Run("PolicyResolve-TokenNotFound", func(t *testing.T) { @@ -864,6 +1296,7 @@ func TestACLResolver_DatacenterScoping(t *testing.T) { legacy: false, localTokens: true, localPolicies: true, + localRoles: true, // No need to provide any of the RPC callbacks } r := newTestACLResolver(t, delegate, nil) @@ -883,6 +1316,7 @@ func TestACLResolver_DatacenterScoping(t *testing.T) { legacy: false, localTokens: true, localPolicies: true, + localRoles: true, // No need to provide any of the RPC callbacks } r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { @@ -898,6 +1332,7 @@ func TestACLResolver_DatacenterScoping(t *testing.T) { }) } +// TODO(rb): replicate this sort of test but for roles func TestACLResolver_Client(t *testing.T) { t.Parallel() @@ -951,6 +1386,7 @@ func TestACLResolver_Client(t *testing.T) { r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { config.Config.ACLTokenTTL = 600 * time.Second config.Config.ACLPolicyTTL = 30 * time.Millisecond + config.Config.ACLRoleTTL = 30 * time.Millisecond config.Config.ACLDownPolicy = "extend-cache" }) @@ -1039,6 +1475,7 @@ func TestACLResolver_Client(t *testing.T) { // being resolved concurrently config.Config.ACLTokenTTL = 0 * time.Second config.Config.ACLPolicyTTL = 30 * time.Millisecond + config.Config.ACLRoleTTL = 30 * time.Millisecond config.Config.ACLDownPolicy = "extend-cache" }) @@ -1058,7 +1495,24 @@ func TestACLResolver_Client(t *testing.T) { }) } -func TestACLResolver_LocalTokensAndPolicies(t *testing.T) { +func TestACLResolver_Client_TokensPoliciesAndRoles(t *testing.T) { + t.Parallel() + delegate := &ACLResolverTestDelegate{ + enabled: true, + datacenter: "dc1", + legacy: false, + localTokens: false, + localPolicies: false, + localRoles: false, + } + delegate.tokenReadFn = delegate.plainTokenReadFn + delegate.policyResolveFn = delegate.plainPolicyResolveFn + delegate.roleResolveFn = delegate.plainRoleResolveFn + + testACLResolver_variousTokens(t, delegate) +} + +func TestACLResolver_LocalTokensPoliciesAndRoles(t *testing.T) { t.Parallel() delegate := &ACLResolverTestDelegate{ enabled: true, @@ -1066,85 +1520,60 @@ func TestACLResolver_LocalTokensAndPolicies(t *testing.T) { legacy: false, localTokens: true, localPolicies: true, + localRoles: true, // No need to provide any of the RPC callbacks } - r := newTestACLResolver(t, delegate, nil) - t.Run("Missing Identity", func(t *testing.T) { - authz, err := r.ResolveToken("doesn't exist") - require.Nil(t, authz) - require.Error(t, err) - require.True(t, acl.IsErrNotFound(err)) - }) - - t.Run("Missing Policy", func(t *testing.T) { - authz, err := r.ResolveToken("missing-policy") - require.NoError(t, err) - require.NotNil(t, authz) - require.True(t, authz.ACLRead()) - require.False(t, authz.NodeWrite("foo", nil)) - }) - - t.Run("Normal", func(t *testing.T) { - authz, err := r.ResolveToken("found") - require.NotNil(t, authz) - require.NoError(t, err) - require.False(t, authz.ACLRead()) - require.True(t, authz.NodeWrite("foo", nil)) - }) - - t.Run("Anonymous", func(t *testing.T) { - authz, err := r.ResolveToken("") - require.NotNil(t, authz) - require.NoError(t, err) - require.False(t, authz.ACLRead()) - require.True(t, authz.NodeWrite("foo", nil)) - }) - - t.Run("legacy-management", func(t *testing.T) { - authz, err := r.ResolveToken("legacy-management") - require.NotNil(t, authz) - require.NoError(t, err) - require.True(t, authz.ACLWrite()) - require.True(t, authz.KeyRead("foo")) - }) - - t.Run("legacy-client", func(t *testing.T) { - authz, err := r.ResolveToken("legacy-client") - require.NoError(t, err) - require.NotNil(t, authz) - require.False(t, authz.OperatorRead()) - require.True(t, authz.ServiceRead("foo")) - }) + testACLResolver_variousTokens(t, delegate) } -func TestACLResolver_LocalPolicies(t *testing.T) { +func TestACLResolver_LocalPoliciesAndRoles(t *testing.T) { t.Parallel() + delegate := &ACLResolverTestDelegate{ enabled: true, datacenter: "dc1", legacy: false, localTokens: false, localPolicies: true, - tokenReadFn: func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { - _, token, err := testIdentityForToken(args.TokenID) - - if token != nil { - reply.Token = token.(*structs.ACLToken) - } - return err - }, + localRoles: true, } - r := newTestACLResolver(t, delegate, nil) + delegate.tokenReadFn = delegate.plainTokenReadFn - t.Run("Missing Identity", func(t *testing.T) { + testACLResolver_variousTokens(t, delegate) +} + +func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelegate) { + t.Helper() + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { + config.Config.ACLTokenTTL = 600 * time.Second + config.Config.ACLPolicyTTL = 30 * time.Millisecond + config.Config.ACLRoleTTL = 30 * time.Millisecond + config.Config.ACLDownPolicy = "extend-cache" + }) + reset := func() { + // prevent subtest bleedover + r.cache.Purge() + delegate.Reset() + } + + runTwiceAndReset := func(name string, f func(t *testing.T)) { + t.Helper() + defer reset() // reset the stateful resolve AND blow away the cache + + t.Run(name+" (no-cache)", f) + delegate.Reset() // allow the stateful resolve functions to reset + t.Run(name+" (cached)", f) + } + + runTwiceAndReset("Missing Identity", func(t *testing.T) { authz, err := r.ResolveToken("doesn't exist") require.Nil(t, authz) require.Error(t, err) require.True(t, acl.IsErrNotFound(err)) }) - t.Run("Missing Policy", func(t *testing.T) { + runTwiceAndReset("Missing Policy", func(t *testing.T) { authz, err := r.ResolveToken("missing-policy") require.NoError(t, err) require.NotNil(t, authz) @@ -1152,7 +1581,23 @@ func TestACLResolver_LocalPolicies(t *testing.T) { require.False(t, authz.NodeWrite("foo", nil)) }) - t.Run("Normal", func(t *testing.T) { + runTwiceAndReset("Missing Role", func(t *testing.T) { + authz, err := r.ResolveToken("missing-role") + require.NoError(t, err) + require.NotNil(t, authz) + require.True(t, authz.ACLRead()) + require.False(t, authz.NodeWrite("foo", nil)) + }) + + runTwiceAndReset("Missing Policy on Role", func(t *testing.T) { + authz, err := r.ResolveToken("missing-policy-on-role") + require.NoError(t, err) + require.NotNil(t, authz) + require.True(t, authz.ACLRead()) + require.False(t, authz.NodeWrite("foo", nil)) + }) + + runTwiceAndReset("Normal with Policy", func(t *testing.T) { authz, err := r.ResolveToken("found") require.NotNil(t, authz) require.NoError(t, err) @@ -1160,7 +1605,58 @@ func TestACLResolver_LocalPolicies(t *testing.T) { require.True(t, authz.NodeWrite("foo", nil)) }) - t.Run("Anonymous", func(t *testing.T) { + runTwiceAndReset("Normal with Role", func(t *testing.T) { + authz, err := r.ResolveToken("found-role") + require.NotNil(t, authz) + require.NoError(t, err) + require.False(t, authz.ACLRead()) + require.True(t, authz.NodeWrite("foo", nil)) + }) + + runTwiceAndReset("Normal with Policy and Role", func(t *testing.T) { + authz, err := r.ResolveToken("found-policy-and-role") + require.NotNil(t, authz) + require.NoError(t, err) + require.False(t, authz.ACLRead()) + require.True(t, authz.NodeWrite("foo", nil)) + require.True(t, authz.ServiceRead("bar")) + }) + + runTwiceAndReset("Synthetic Policies Independently Cache", func(t *testing.T) { + // We resolve both of these tokens in the same cache session + // to verify that the keys for caching synthetic policies don't bleed + // over between each other. + { + authz, err := r.ResolveToken("found-synthetic-policy-1") + require.NotNil(t, authz) + require.NoError(t, err) + // spot check some random perms + require.False(t, authz.ACLRead()) + require.False(t, authz.NodeWrite("foo", nil)) + // ensure we didn't bleed over to the other synthetic policy + require.False(t, authz.ServiceWrite("service2", nil)) + // check our own synthetic policy + require.True(t, authz.ServiceWrite("service1", nil)) + require.True(t, authz.ServiceRead("literally-anything")) + require.True(t, authz.NodeRead("any-node")) + } + { + authz, err := r.ResolveToken("found-synthetic-policy-2") + require.NotNil(t, authz) + require.NoError(t, err) + // spot check some random perms + require.False(t, authz.ACLRead()) + require.False(t, authz.NodeWrite("foo", nil)) + // ensure we didn't bleed over to the other synthetic policy + require.False(t, authz.ServiceWrite("service1", nil)) + // check our own synthetic policy + require.True(t, authz.ServiceWrite("service2", nil)) + require.True(t, authz.ServiceRead("literally-anything")) + require.True(t, authz.NodeRead("any-node")) + } + }) + + runTwiceAndReset("Anonymous", func(t *testing.T) { authz, err := r.ResolveToken("") require.NotNil(t, authz) require.NoError(t, err) @@ -1168,7 +1664,7 @@ func TestACLResolver_LocalPolicies(t *testing.T) { require.True(t, authz.NodeWrite("foo", nil)) }) - t.Run("legacy-management", func(t *testing.T) { + runTwiceAndReset("legacy-management", func(t *testing.T) { authz, err := r.ResolveToken("legacy-management") require.NotNil(t, authz) require.NoError(t, err) @@ -1176,7 +1672,7 @@ func TestACLResolver_LocalPolicies(t *testing.T) { require.True(t, authz.KeyRead("foo")) }) - t.Run("legacy-client", func(t *testing.T) { + runTwiceAndReset("legacy-client", func(t *testing.T) { authz, err := r.ResolveToken("legacy-client") require.NoError(t, err) require.NotNil(t, authz) @@ -1214,7 +1710,7 @@ func TestACLResolver_Legacy(t *testing.T) { cached = true return nil } - return fmt.Errorf("Induced RPC Error") + return errRPC }, } r := newTestACLResolver(t, delegate, nil) @@ -1263,7 +1759,7 @@ func TestACLResolver_Legacy(t *testing.T) { cached = true return nil } - return fmt.Errorf("Induced RPC Error") + return errRPC }, } r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { @@ -1314,7 +1810,7 @@ func TestACLResolver_Legacy(t *testing.T) { cached = true return nil } - return fmt.Errorf("Induced RPC Error") + return errRPC }, } r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { @@ -1366,7 +1862,7 @@ func TestACLResolver_Legacy(t *testing.T) { cached = true return nil } - return fmt.Errorf("Induced RPC Error") + return errRPC }, } r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { @@ -2861,3 +3357,92 @@ service "service" { t.Fatalf("err: %v", err) } } + +func TestDedupeServiceIdentities(t *testing.T) { + srvid := func(name string, datacenters ...string) *structs.ACLServiceIdentity { + return &structs.ACLServiceIdentity{ + ServiceName: name, + Datacenters: datacenters, + } + } + + tests := []struct { + name string + in []*structs.ACLServiceIdentity + expect []*structs.ACLServiceIdentity + }{ + { + name: "empty", + in: nil, + expect: nil, + }, + { + name: "one", + in: []*structs.ACLServiceIdentity{ + srvid("foo"), + }, + expect: []*structs.ACLServiceIdentity{ + srvid("foo"), + }, + }, + { + name: "just names", + in: []*structs.ACLServiceIdentity{ + srvid("fooZ"), + srvid("fooA"), + srvid("fooY"), + srvid("fooB"), + }, + expect: []*structs.ACLServiceIdentity{ + srvid("fooA"), + srvid("fooB"), + srvid("fooY"), + srvid("fooZ"), + }, + }, + { + name: "just names with dupes", + in: []*structs.ACLServiceIdentity{ + srvid("fooZ"), + srvid("fooA"), + srvid("fooY"), + srvid("fooB"), + srvid("fooA"), + srvid("fooB"), + srvid("fooY"), + srvid("fooZ"), + }, + expect: []*structs.ACLServiceIdentity{ + srvid("fooA"), + srvid("fooB"), + srvid("fooY"), + srvid("fooZ"), + }, + }, + { + name: "names with dupes and datacenters", + in: []*structs.ACLServiceIdentity{ + srvid("fooZ", "dc2", "dc4"), + srvid("fooA"), + srvid("fooY", "dc1"), + srvid("fooB"), + srvid("fooA", "dc9", "dc8"), + srvid("fooB"), + srvid("fooY", "dc1"), + srvid("fooZ", "dc3", "dc4"), + }, + expect: []*structs.ACLServiceIdentity{ + srvid("fooA"), + srvid("fooB"), + srvid("fooY", "dc1"), + srvid("fooZ", "dc2", "dc3", "dc4"), + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := dedupeServiceIdentities(test.in) + require.ElementsMatch(t, test.expect, got) + }) + } +} diff --git a/agent/consul/acl_token_exp.go b/agent/consul/acl_token_exp.go new file mode 100644 index 0000000000..f1d5cb52d5 --- /dev/null +++ b/agent/consul/acl_token_exp.go @@ -0,0 +1,144 @@ +package consul + +import ( + "context" + "fmt" + "time" + + "github.com/hashicorp/consul/agent/structs" + "golang.org/x/time/rate" +) + +func (s *Server) startACLTokenReaping() { + s.aclTokenReapLock.Lock() + defer s.aclTokenReapLock.Unlock() + + if s.aclTokenReapEnabled { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + s.aclTokenReapCancel = cancel + + // Do a quick check for config settings that would imply the goroutine + // below will just spin forever. + // + // We can only check the config settings here that cannot change without a + // restart, so we omit the check for a non-empty replication token as that + // can be changed at runtime. + if !s.InACLDatacenter() && !s.config.ACLTokenReplication { + return + } + + go func() { + limiter := rate.NewLimiter(aclTokenReapingRateLimit, aclTokenReapingBurst) + + for { + if err := limiter.Wait(ctx); err != nil { + return + } + + if s.LocalTokensEnabled() { + if _, err := s.reapExpiredLocalACLTokens(); err != nil { + s.logger.Printf("[ERR] acl: error reaping expired local ACL tokens: %v", err) + } + } + if s.InACLDatacenter() { + if _, err := s.reapExpiredGlobalACLTokens(); err != nil { + s.logger.Printf("[ERR] acl: error reaping expired global ACL tokens: %v", err) + } + } + } + }() + + s.aclTokenReapEnabled = true +} + +func (s *Server) stopACLTokenReaping() { + s.aclTokenReapLock.Lock() + defer s.aclTokenReapLock.Unlock() + + if !s.aclTokenReapEnabled { + return + } + + s.aclTokenReapCancel() + s.aclTokenReapCancel = nil + s.aclTokenReapEnabled = false +} + +func (s *Server) reapExpiredGlobalACLTokens() (int, error) { + return s.reapExpiredACLTokens(false, true) +} +func (s *Server) reapExpiredLocalACLTokens() (int, error) { + return s.reapExpiredACLTokens(true, false) +} +func (s *Server) reapExpiredACLTokens(local, global bool) (int, error) { + if !s.ACLsEnabled() { + return 0, nil + } + if s.UseLegacyACLs() { + return 0, nil + } + if local == global { + return 0, fmt.Errorf("cannot reap both local and global tokens in the same request") + } + + locality := localityName(local) + + minExpiredTime, err := s.fsm.State().ACLTokenMinExpirationTime(local) + if err != nil { + return 0, err + } + + now := time.Now() + + if minExpiredTime.After(now) { + return 0, nil // nothing to do + } + + tokens, _, err := s.fsm.State().ACLTokenListExpired(local, now, aclBatchDeleteSize) + if err != nil { + return 0, err + } + + if len(tokens) == 0 { + return 0, nil + } + + var ( + secretIDs []string + req structs.ACLTokenBatchDeleteRequest + ) + for _, token := range tokens { + if token.Local != local { + return 0, fmt.Errorf("expired index for local=%v returned a mismatched token with local=%v: %s", local, token.Local, token.AccessorID) + } + req.TokenIDs = append(req.TokenIDs, token.AccessorID) + secretIDs = append(secretIDs, token.SecretID) + } + + s.logger.Printf("[INFO] acl: deleting %d expired %s tokens", len(req.TokenIDs), locality) + resp, err := s.raftApply(structs.ACLTokenDeleteRequestType, &req) + if err != nil { + return 0, fmt.Errorf("Failed to apply token expiration deletions: %v", err) + } + + // Purge the identities from the cache + for _, secretID := range secretIDs { + s.acls.cache.RemoveIdentity(secretID) + } + + if respErr, ok := resp.(error); ok { + return 0, respErr + } + + return len(req.TokenIDs), nil +} + +func localityName(local bool) string { + if local { + return "local" + } + return "global" +} diff --git a/agent/consul/acl_token_exp_test.go b/agent/consul/acl_token_exp_test.go new file mode 100644 index 0000000000..a851b4dc33 --- /dev/null +++ b/agent/consul/acl_token_exp_test.go @@ -0,0 +1,219 @@ +package consul + +import ( + "os" + "testing" + "time" + + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/testrpc" + "github.com/stretchr/testify/require" +) + +func TestACLTokenReap_Primary(t *testing.T) { + t.Parallel() + + t.Run("global", func(t *testing.T) { + t.Parallel() + testACLTokenReap_Primary(t, false, true) + }) + t.Run("local", func(t *testing.T) { + t.Parallel() + testACLTokenReap_Primary(t, true, false) + }) +} + +func testACLTokenReap_Primary(t *testing.T, local, global bool) { + // ------------------------------------------- + // A word of caution when testing reapExpiredACLTokens(): + // + // The underlying memdb index used for reaping has a minimum granularity of + // 1 second as it delegates to `time.Unix()`. This test will have to be + // deliberately slow to allow for necessary sleeps. If you try to make it + // operate faster (using expiration ttls of milliseconds) it will be flaky. + // ------------------------------------------- + + t.Helper() + require.NotEqual(t, local, global) + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 8 * time.Second + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + codec := rpcClient(t, s1) + defer codec.Close() + + acl := ACL{srv: s1} + + masterTokenAccessorID, err := retrieveTestTokenAccessorForSecret(codec, "root", "dc1", "root") + require.NoError(t, err) + + listTokens := func() (localTokens, globalTokens []string, err error) { + req := structs.ACLTokenListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + var res structs.ACLTokenListResponse + err = acl.TokenList(&req, &res) + if err != nil { + return nil, nil, err + } + + for _, tok := range res.Tokens { + if tok.Local { + localTokens = append(localTokens, tok.AccessorID) + } else { + globalTokens = append(globalTokens, tok.AccessorID) + } + } + + return localTokens, globalTokens, nil + } + + requireTokenMatch := func(t *testing.T, expect []string) { + t.Helper() + + var expectLocal, expectGlobal []string + // The master token and the anonymous token are always going to be + // present and global. + expectGlobal = append(expectGlobal, masterTokenAccessorID) + expectGlobal = append(expectGlobal, structs.ACLTokenAnonymousID) + + if local { + expectLocal = append(expectLocal, expect...) + } else { + expectGlobal = append(expectGlobal, expect...) + } + + localTokens, globalTokens, err := listTokens() + require.NoError(t, err) + require.ElementsMatch(t, expectLocal, localTokens) + require.ElementsMatch(t, expectGlobal, globalTokens) + } + + // initial sanity check + requireTokenMatch(t, []string{}) + + t.Run("no tokens", func(t *testing.T) { + n, err := s1.reapExpiredACLTokens(local, global) + require.NoError(t, err) + require.Equal(t, 0, n) + + requireTokenMatch(t, []string{}) + }) + + // 2 normal + token1, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.Local = local + }) + require.NoError(t, err) + token2, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.Local = local + }) + require.NoError(t, err) + + requireTokenMatch(t, []string{ + token1.AccessorID, + token2.AccessorID, + }) + + t.Run("only normal tokens", func(t *testing.T) { + n, err := s1.reapExpiredACLTokens(local, global) + require.NoError(t, err) + require.Equal(t, 0, n) + + requireTokenMatch(t, []string{ + token1.AccessorID, + token2.AccessorID, + }) + }) + + // 2 expiring + token3, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.ExpirationTTL = 1 * time.Second + token.Local = local + }) + require.NoError(t, err) + token4, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.ExpirationTTL = 5 * time.Second + token.Local = local + }) + require.NoError(t, err) + + // 2 more normal + token5, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.Local = local + }) + require.NoError(t, err) + token6, err := upsertTestToken(codec, "root", "dc1", func(token *structs.ACLToken) { + token.Local = local + }) + require.NoError(t, err) + + requireTokenMatch(t, []string{ + token1.AccessorID, + token2.AccessorID, + token3.AccessorID, + token4.AccessorID, + token5.AccessorID, + token6.AccessorID, + }) + + t.Run("mixed but nothing expired yet", func(t *testing.T) { + n, err := s1.reapExpiredACLTokens(local, global) + require.NoError(t, err) + require.Equal(t, 0, n) + + requireTokenMatch(t, []string{ + token1.AccessorID, + token2.AccessorID, + token3.AccessorID, + token4.AccessorID, + token5.AccessorID, + token6.AccessorID, + }) + }) + + time.Sleep(token3.ExpirationTime.Sub(time.Now()) + 10*time.Millisecond) + + t.Run("one should be reaped", func(t *testing.T) { + n, err := s1.reapExpiredACLTokens(local, global) + require.NoError(t, err) + require.Equal(t, 1, n) + + requireTokenMatch(t, []string{ + token1.AccessorID, + token2.AccessorID, + // token3.AccessorID, + token4.AccessorID, + token5.AccessorID, + token6.AccessorID, + }) + }) + + time.Sleep(token4.ExpirationTime.Sub(time.Now()) + 10*time.Millisecond) + + t.Run("two should be reaped", func(t *testing.T) { + n, err := s1.reapExpiredACLTokens(local, global) + require.NoError(t, err) + require.Equal(t, 1, n) + + requireTokenMatch(t, []string{ + token1.AccessorID, + token2.AccessorID, + // token3.AccessorID, + // token4.AccessorID, + token5.AccessorID, + token6.AccessorID, + }) + }) +} diff --git a/agent/consul/authmethod/authmethods.go b/agent/consul/authmethod/authmethods.go new file mode 100644 index 0000000000..8fd477d0f2 --- /dev/null +++ b/agent/consul/authmethod/authmethods.go @@ -0,0 +1,112 @@ +package authmethod + +import ( + "fmt" + "sort" + "sync" + + "github.com/hashicorp/consul/agent/structs" + "github.com/mitchellh/mapstructure" +) + +type ValidatorFactory func(method *structs.ACLAuthMethod) (Validator, error) + +type Validator interface { + // Name returns the name of the auth method backing this validator. + Name() string + + // ValidateLogin takes raw user-provided auth method metadata and ensures + // it is sane, provably correct, and currently valid. Relevant identifying + // data is extracted and returned for immediate use by the role binding + // process. + // + // Depending upon the method, it may make sense to use these calls to + // continue to extend the life of the underlying token. + // + // Returns auth method specific metadata suitable for the Role Binding + // process. + ValidateLogin(loginToken string) (map[string]string, error) + + // AvailableFields returns a slice of all fields that are returned as a + // result of ValidateLogin. These are valid fields for use in any + // BindingRule tied to this auth method. + AvailableFields() []string + + // MakeFieldMapSelectable converts a field map as returned by ValidateLogin + // into a structure suitable for selection with a binding rule. + MakeFieldMapSelectable(fieldMap map[string]string) interface{} +} + +var ( + typesMu sync.RWMutex + types = make(map[string]ValidatorFactory) +) + +// Register makes an auth method with the given type available for use. If +// Register is called twice with the same name or if validator is nil, it +// panics. +func Register(name string, factory ValidatorFactory) { + typesMu.Lock() + defer typesMu.Unlock() + if factory == nil { + panic("authmethod: Register factory is nil for type " + name) + } + if _, dup := types[name]; dup { + panic("authmethod: Register called twice for type " + name) + } + types[name] = factory +} + +func IsRegisteredType(typeName string) bool { + typesMu.RLock() + _, ok := types[typeName] + typesMu.RUnlock() + return ok +} + +// NewValidator instantiates a new Validator for the given auth method +// configuration. If no auth method is registered with the provided type an +// error is returned. +func NewValidator(method *structs.ACLAuthMethod) (Validator, error) { + typesMu.RLock() + factory, ok := types[method.Type] + typesMu.RUnlock() + + if !ok { + return nil, fmt.Errorf("no auth method registered with type: %s", method.Type) + } + + return factory(method) +} + +// Types returns a sorted list of the names of the registered types. +func Types() []string { + typesMu.RLock() + defer typesMu.RUnlock() + var list []string + for name := range types { + list = append(list, name) + } + sort.Strings(list) + return list +} + +// ParseConfig parses the config block for a auth method. +func ParseConfig(rawConfig map[string]interface{}, out interface{}) error { + decodeConf := &mapstructure.DecoderConfig{ + Result: out, + WeaklyTypedInput: true, + ErrorUnused: true, + } + + decoder, err := mapstructure.NewDecoder(decodeConf) + if err != nil { + return err + } + + if err := decoder.Decode(rawConfig); err != nil { + return fmt.Errorf("error decoding config: %s", err) + } + + return nil +} diff --git a/agent/consul/authmethod/kubeauth/k8s.go b/agent/consul/authmethod/kubeauth/k8s.go new file mode 100644 index 0000000000..88c4b32e3d --- /dev/null +++ b/agent/consul/authmethod/kubeauth/k8s.go @@ -0,0 +1,202 @@ +package kubeauth + +import ( + "errors" + "fmt" + "strings" + + "github.com/hashicorp/consul/agent/consul/authmethod" + "github.com/hashicorp/consul/agent/structs" + cleanhttp "github.com/hashicorp/go-cleanhttp" + "gopkg.in/square/go-jose.v2/jwt" + authv1 "k8s.io/api/authentication/v1" + client_metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + k8s "k8s.io/client-go/kubernetes" + client_authv1 "k8s.io/client-go/kubernetes/typed/authentication/v1" + client_corev1 "k8s.io/client-go/kubernetes/typed/core/v1" + client_rest "k8s.io/client-go/rest" + cert "k8s.io/client-go/util/cert" +) + +func init() { + // register this as an available auth method type + authmethod.Register("kubernetes", func(method *structs.ACLAuthMethod) (authmethod.Validator, error) { + v, err := NewValidator(method) + if err != nil { + return nil, err + } + return v, nil + }) +} + +const ( + serviceAccountNamespaceField = "serviceaccount.namespace" + serviceAccountNameField = "serviceaccount.name" + serviceAccountUIDField = "serviceaccount.uid" + + serviceAccountServiceNameAnnotation = "consul.hashicorp.com/service-name" +) + +type Config struct { + // Host must be a host string, a host:port pair, or a URL to the base of + // the Kubernetes API server. + Host string `json:",omitempty"` + + // PEM encoded CA cert for use by the TLS client used to talk with the + // Kubernetes API. Every line must end with a newline: \n + CACert string `json:",omitempty"` + + // A service account JWT used to access the TokenReview API to validate + // other JWTs during login. It also must be able to read ServiceAccount + // annotations. + ServiceAccountJWT string `json:",omitempty"` +} + +// Validator is the wrapper around the relevant portions of the Kubernetes API +// that also conforms to the authmethod.Validator interface. +type Validator struct { + name string + config *Config + saGetter client_corev1.ServiceAccountsGetter + trGetter client_authv1.TokenReviewsGetter +} + +func NewValidator(method *structs.ACLAuthMethod) (*Validator, error) { + if method.Type != "kubernetes" { + return nil, fmt.Errorf("%q is not a kubernetes auth method", method.Name) + } + + var config Config + if err := authmethod.ParseConfig(method.Config, &config); err != nil { + return nil, err + } + + if config.Host == "" { + return nil, fmt.Errorf("Config.Host is required") + } + + if config.CACert == "" { + return nil, fmt.Errorf("Config.CACert is required") + } + if _, err := cert.ParseCertsPEM([]byte(config.CACert)); err != nil { + return nil, fmt.Errorf("error parsing kubernetes ca cert: %v", err) + } + + // This is the bearer token we give the apiserver to use the API. + if config.ServiceAccountJWT == "" { + return nil, fmt.Errorf("Config.ServiceAccountJWT is required") + } + if _, err := jwt.ParseSigned(config.ServiceAccountJWT); err != nil { + return nil, fmt.Errorf("Config.ServiceAccountJWT is not a valid JWT: %v", err) + } + + transport := cleanhttp.DefaultTransport() + client, err := k8s.NewForConfig(&client_rest.Config{ + Host: config.Host, + BearerToken: config.ServiceAccountJWT, + Dial: transport.DialContext, + TLSClientConfig: client_rest.TLSClientConfig{ + CAData: []byte(config.CACert), + }, + ContentConfig: client_rest.ContentConfig{ + ContentType: "application/json", + }, + }) + if err != nil { + return nil, err + } + + return &Validator{ + name: method.Name, + config: &config, + saGetter: client.CoreV1(), + trGetter: client.AuthenticationV1(), + }, nil +} + +func (v *Validator) Name() string { return v.name } + +func (v *Validator) ValidateLogin(loginToken string) (map[string]string, error) { + if _, err := jwt.ParseSigned(loginToken); err != nil { + return nil, fmt.Errorf("failed to parse and validate JWT: %v", err) + } + + // Check TokenReview for the bulk of the work. + trResp, err := v.trGetter.TokenReviews().Create(&authv1.TokenReview{ + Spec: authv1.TokenReviewSpec{ + Token: loginToken, + }, + }) + + if err != nil { + return nil, err + } else if trResp.Status.Error != "" { + return nil, fmt.Errorf("lookup failed: %s", trResp.Status.Error) + } + + if !trResp.Status.Authenticated { + return nil, errors.New("lookup failed: service account jwt not valid") + } + + // The username is of format: system:serviceaccount:(NAMESPACE):(SERVICEACCOUNT) + parts := strings.Split(trResp.Status.User.Username, ":") + if len(parts) != 4 { + return nil, errors.New("lookup failed: unexpected username format") + } + + // Validate the user that comes back from token review is a service account + if parts[0] != "system" || parts[1] != "serviceaccount" { + return nil, errors.New("lookup failed: username returned is not a service account") + } + + var ( + saNamespace = parts[2] + saName = parts[3] + saUID = string(trResp.Status.User.UID) + ) + + // Check to see if there is an override name on the ServiceAccount object. + sa, err := v.saGetter.ServiceAccounts(saNamespace).Get(saName, client_metav1.GetOptions{}) + if err != nil { + return nil, fmt.Errorf("annotation lookup failed: %v", err) + } + + annotations := sa.GetObjectMeta().GetAnnotations() + if serviceNameOverride, ok := annotations[serviceAccountServiceNameAnnotation]; ok { + saName = serviceNameOverride + } + + return map[string]string{ + serviceAccountNamespaceField: saNamespace, + serviceAccountNameField: saName, + serviceAccountUIDField: saUID, + }, nil +} + +func (p *Validator) AvailableFields() []string { + return []string{ + serviceAccountNamespaceField, + serviceAccountNameField, + serviceAccountUIDField, + } +} + +func (v *Validator) MakeFieldMapSelectable(fieldMap map[string]string) interface{} { + return &k8sFieldDetails{ + ServiceAccount: k8sFieldDetailsServiceAccount{ + Namespace: fieldMap[serviceAccountNamespaceField], + Name: fieldMap[serviceAccountNameField], + UID: fieldMap[serviceAccountUIDField], + }, + } +} + +type k8sFieldDetails struct { + ServiceAccount k8sFieldDetailsServiceAccount `bexpr:"serviceaccount"` +} + +type k8sFieldDetailsServiceAccount struct { + Namespace string `bexpr:"namespace"` + Name string `bexpr:"name"` + UID string `bexpr:"uid"` +} diff --git a/agent/consul/authmethod/kubeauth/k8s_test.go b/agent/consul/authmethod/kubeauth/k8s_test.go new file mode 100644 index 0000000000..614538c40e --- /dev/null +++ b/agent/consul/authmethod/kubeauth/k8s_test.go @@ -0,0 +1,144 @@ +package kubeauth + +import ( + "testing" + + "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/agent/structs" + "github.com/stretchr/testify/require" +) + +func TestValidateLogin(t *testing.T) { + testSrv := StartTestAPIServer(t) + defer testSrv.Stop() + + testSrv.AuthorizeJWT(goodJWT_A) + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "", + goodJWT_B, + ) + + method := &structs.ACLAuthMethod{ + Name: "test-k8s", + Description: "k8s test", + Type: "kubernetes", + Config: map[string]interface{}{ + "Host": testSrv.Addr(), + "CACert": testSrv.CACert(), + "ServiceAccountJWT": goodJWT_A, + }, + } + validator, err := NewValidator(method) + require.NoError(t, err) + + t.Run("invalid bearer token", func(t *testing.T) { + _, err := validator.ValidateLogin("invalid") + require.Error(t, err) + }) + + t.Run("valid bearer token", func(t *testing.T) { + fields, err := validator.ValidateLogin(goodJWT_B) + require.NoError(t, err) + require.Equal(t, map[string]string{ + "serviceaccount.namespace": "default", + "serviceaccount.name": "demo", + "serviceaccount.uid": "76091af4-4b56-11e9-ac4b-708b11801cbe", + }, fields) + }) + + // annotate the account + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "alternate-name", + goodJWT_B, + ) + + t.Run("valid bearer token with annotation", func(t *testing.T) { + fields, err := validator.ValidateLogin(goodJWT_B) + require.NoError(t, err) + require.Equal(t, map[string]string{ + "serviceaccount.namespace": "default", + "serviceaccount.name": "alternate-name", + "serviceaccount.uid": "76091af4-4b56-11e9-ac4b-708b11801cbe", + }, fields) + }) +} + +func TestNewValidator(t *testing.T) { + ca := connect.TestCA(t, nil) + + type AM = *structs.ACLAuthMethod + + makeAuthMethod := func(f func(method AM)) *structs.ACLAuthMethod { + method := &structs.ACLAuthMethod{ + Name: "test-k8s", + Description: "k8s test", + Type: "kubernetes", + Config: map[string]interface{}{ + "Host": "https://abc:8443", + "CACert": ca.RootCert, + "ServiceAccountJWT": goodJWT_A, + }, + } + if f != nil { + f(method) + } + return method + } + + for _, test := range []struct { + name string + method *structs.ACLAuthMethod + ok bool + }{ + // bad + {"wrong type", makeAuthMethod(func(method AM) { + method.Type = "invalid" + }), false}, + {"extra config", makeAuthMethod(func(method AM) { + method.Config["extra"] = "config" + }), false}, + {"wrong type of config", makeAuthMethod(func(method AM) { + method.Config["Host"] = []int{12345} + }), false}, + {"missing host", makeAuthMethod(func(method AM) { + delete(method.Config, "Host") + }), false}, + {"missing ca cert", makeAuthMethod(func(method AM) { + delete(method.Config, "CACert") + }), false}, + {"invalid ca cert", makeAuthMethod(func(method AM) { + method.Config["CACert"] = "invalid" + }), false}, + {"invalid jwt", makeAuthMethod(func(method AM) { + method.Config["ServiceAccountJWT"] = "invalid" + }), false}, + {"garbage host", makeAuthMethod(func(method AM) { + method.Config["Host"] = "://:12345" + }), false}, + // good + {"normal", makeAuthMethod(nil), true}, + } { + t.Run(test.name, func(t *testing.T) { + v, err := NewValidator(test.method) + if test.ok { + require.NoError(t, err) + require.NotNil(t, v) + } else { + require.NotNil(t, err) + require.Nil(t, v) + } + }) + } +} + +// 'default/admin' +const goodJWT_A = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImFkbWluLXRva2VuLXFsejQyIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQubmFtZSI6ImFkbWluIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQudWlkIjoiNzM4YmMyNTEtNjUzMi0xMWU5LWI2N2YtNDhlNmM4YjhlY2I1Iiwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6YWRtaW4ifQ.ixMlnWrAG7NVuTTKu8cdcYfM7gweS3jlKaEsIBNGOVEjPE7rtXtgMkAwjQTdYR08_0QBjkgzy5fQC5ZNyglSwONJ-bPaXGvhoH1cTnRi1dz9H_63CfqOCvQP1sbdkMeRxNTGVAyWZT76rXoCUIfHP4LY2I8aab0KN9FTIcgZRF0XPTtT70UwGIrSmRpxW38zjiy2ymWL01cc5VWGhJqVysmWmYk3wNp0h5N57H_MOrz4apQR4pKaamzskzjLxO55gpbmZFC76qWuUdexAR7DT2fpbHLOw90atN_NlLMY-VrXyW3-Ei5EhYaVreMB9PSpKwkrA4jULITohV-sxpa1LA" + +// 'default/demo' +const goodJWT_B = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImRlbW8tdG9rZW4ta21iOW4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZGVtbyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6Ijc2MDkxYWY0LTRiNTYtMTFlOS1hYzRiLTcwOGIxMTgwMWNiZSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmRlbW8ifQ.ZiAHjijBAOsKdum0Aix6lgtkLkGo9_Tu87dWQ5Zfwnn3r2FejEWDAnftTft1MqqnMzivZ9Wyyki5ZjQRmTAtnMPJuHC-iivqY4Wh4S6QWCJ1SivBv5tMZR79t5t8mE7R1-OHwst46spru1pps9wt9jsA04d3LpV0eeKYgdPTVaQKklxTm397kIMUugA6yINIBQ3Rh8eQqBgNwEmL4iqyYubzHLVkGkoP9MJikFI05vfRiHtYr-piXz6JFDzXMQj9rW6xtMmrBSn79ChbyvC5nz-Nj2rJPnHsb_0rDUbmXY5PpnMhBpdSH-CbZ4j8jsiib6DtaGJhVZeEQ1GjsFAZwQ" diff --git a/agent/consul/authmethod/kubeauth/testing.go b/agent/consul/authmethod/kubeauth/testing.go new file mode 100644 index 0000000000..7e6340dd9d --- /dev/null +++ b/agent/consul/authmethod/kubeauth/testing.go @@ -0,0 +1,532 @@ +package kubeauth + +import ( + "bytes" + "encoding/json" + "encoding/pem" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "regexp" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + authv1 "k8s.io/api/authentication/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" +) + +// TestAPIServer is a way to mock the Kubernetes API server as it is used by +// the consul kubernetes auth method. +// +// - POST /apis/authentication.k8s.io/v1/tokenreviews +// - GET /api/v1/namespaces//serviceaccounts/ +// +type TestAPIServer struct { + t *testing.T + srv *httptest.Server + caCert string + + mu sync.Mutex + authorizedJWT string // token review and sa read + allowedServiceAccountJWT string // general service account + replyStatus *authv1.TokenReview // general service account + replyRead *corev1.ServiceAccount // general service account +} + +// StartTestAPIServer creates a disposable TestAPIServer and binds it to a +// random free port. +func StartTestAPIServer(t *testing.T) *TestAPIServer { + s := &TestAPIServer{t: t} + + s.srv = httptest.NewTLSServer(s) + + bs := s.srv.TLS.Certificates[0].Certificate[0] + + var buf bytes.Buffer + require.NoError(t, pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})) + s.caCert = buf.String() + + return s +} + +// AuthorizeJWT whitelists the given JWT as able to use the API server. +func (s *TestAPIServer) AuthorizeJWT(jwt string) { + s.mu.Lock() + defer s.mu.Unlock() + + s.authorizedJWT = jwt +} + +// SetAllowedServiceAccount configures the singular known Service Account +// installed in this API server. If any of namespace/name/uid/jwt are empty +// it removes anything previously configured. +// +// It is up to the caller to ensure that the provided JWT matches the other +// data. +func (s *TestAPIServer) SetAllowedServiceAccount( + namespace, name, uid, overrideAnnotation, jwt string, +) { + s.mu.Lock() + defer s.mu.Unlock() + + if namespace == "" || name == "" || uid == "" || jwt == "" { + s.allowedServiceAccountJWT = "" + s.replyStatus = nil + s.replyRead = nil + return + } + + s.allowedServiceAccountJWT = jwt + s.replyRead = createReadServiceAccountFound(namespace, name, uid, overrideAnnotation, jwt) + s.replyStatus = createTokenReviewFound(namespace, name, uid, jwt) +} + +// Stop stops the running TestAPIServer. +func (s *TestAPIServer) Stop() { + s.srv.Close() +} + +// Addr returns the current base URL for the running webserver. +func (s *TestAPIServer) Addr() string { return s.srv.URL } + +// CACert returns the pem-encoded CA certificate used by the HTTPS server. +func (s *TestAPIServer) CACert() string { return s.caCert } + +var readServiceAccountPathRE = regexp.MustCompile("^/api/v1/namespaces/([^/]+)/serviceaccounts/([^/]+)$") + +func (s *TestAPIServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + s.mu.Lock() + defer s.mu.Unlock() + + w.Header().Set("content-type", "application/json") + + if req.URL.Path == "/apis/authentication.k8s.io/v1/tokenreviews" { + s.handleTokenReview(w, req) + return + } + + if m := readServiceAccountPathRE.FindStringSubmatch(req.URL.Path); m != nil { + namespace, err := url.QueryUnescape(m[1]) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + name, err := url.QueryUnescape(m[2]) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + s.handleReadServiceAccount(namespace, name, w, req) + return + } + + w.WriteHeader(http.StatusNotFound) +} + +func writeJSON(w http.ResponseWriter, out interface{}) error { + enc := json.NewEncoder(w) + return enc.Encode(out) +} + +func (s *TestAPIServer) handleTokenReview(w http.ResponseWriter, req *http.Request) { + if req.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + if auth, anon := s.isAuthenticated(req); !auth { + var out interface{} + if anon { + out = createTokenReviewForbidden_NoAuthz() + } else { + out = createTokenReviewForbidden("default", "fake-account") + } + + w.WriteHeader(http.StatusForbidden) + if err := writeJSON(w, out); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + return + } + + if req.Body == nil { + w.WriteHeader(http.StatusBadRequest) + return + } + defer req.Body.Close() + + b, err := ioutil.ReadAll(req.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + var trReq authv1.TokenReview + if err := json.Unmarshal(b, &trReq); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + reviewingJWT := trReq.Spec.Token + + var out interface{} + if s.replyStatus == nil || reviewingJWT != s.allowedServiceAccountJWT { + out = createTokenReviewNotFound(reviewingJWT) + } else { + out = s.replyStatus + } + w.WriteHeader(http.StatusCreated) + + if err := writeJSON(w, out); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +func (s *TestAPIServer) handleReadServiceAccount( + namespace, name string, + w http.ResponseWriter, + req *http.Request, +) { + if req.Method != "GET" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + var out interface{} + if auth, anon := s.isAuthenticated(req); !auth { + if anon { + out = createReadServiceAccountForbidden_NoAuthz() + } else { + out = createReadServiceAccountForbidden(namespace, name) + } + w.WriteHeader(http.StatusForbidden) + } else if s.replyRead == nil { + out = createReadServiceAccountNotFound(namespace, name) + w.WriteHeader(http.StatusNotFound) + } else if s.replyRead.Namespace != namespace || s.replyRead.Name != name { + out = createReadServiceAccountNotFound(namespace, name) + w.WriteHeader(http.StatusNotFound) + } else { + out = s.replyRead + w.WriteHeader(http.StatusOK) + } + + if err := writeJSON(w, out); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +func (s *TestAPIServer) isAuthenticated(req *http.Request) (auth, anonymous bool) { + authz := req.Header.Get("Authorization") + if !strings.HasPrefix(authz, "Bearer ") { + return false, true + } + jwt := strings.TrimPrefix(authz, "Bearer ") + + return s.authorizedJWT == jwt, false +} + +func createTokenReviewForbidden_NoAuthz() *metav1.Status { + /* + STATUS: 403 + { + "kind": "Status", + "apiVersion": "v1", + "metadata": {}, + "status": "Failure", + "message": "tokenreviews.authentication.k8s.io is forbidden: User \"system:anonymous\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" at the cluster scope", + "reason": "Forbidden", + "details": { + "group": "authentication.k8s.io", + "kind": "tokenreviews" + }, + "code": 403 + } + */ + return createStatus( + metav1.StatusFailure, + "tokenreviews.authentication.k8s.io is forbidden: User \"system:anonymous\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" in the cluster scope", + metav1.StatusReasonForbidden, + &metav1.StatusDetails{ + Group: "authentication.k8s.io", + Kind: "tokenreviews", + }, + 403, + ) +} + +func createTokenReviewForbidden(namespace, name string) *metav1.Status { + /* + STATUS: 403 + { + "kind": "Status", + "apiVersion": "v1", + "metadata": {}, + "status": "Failure", + "message": "tokenreviews.authentication.k8s.io is forbidden: User \"system:serviceaccount:default:admin\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" at the cluster scope", + "reason": "Forbidden", + "details": { + "group": "authentication.k8s.io", + "kind": "tokenreviews" + }, + "code": 403 + } + */ + return createStatus( + metav1.StatusFailure, + "tokenreviews.authentication.k8s.io is forbidden: User \"system:serviceaccount:"+namespace+":"+name+"\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" in the cluster scope", + metav1.StatusReasonForbidden, + &metav1.StatusDetails{ + Group: "authentication.k8s.io", + Kind: "tokenreviews", + }, + 403, + ) +} + +func createTokenReviewNotFound(jwt string) *authv1.TokenReview { + /* + STATUS: 201 + { + "kind": "TokenReview", + "apiVersion": "authentication.k8s.io/v1", + "metadata": { + "creationTimestamp": null + }, + "spec": { + "token": "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImZha2UtdG9rZW4tano2YnYiLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZmFrZSIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6IjgxYTY1Mjg2LTU3YzEtMTFlOS1iYzJhLTQ4ZTZjOGI4ZWNiNSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmZha2UifQ.DqjUXe34SzCP4NCwbhqV9EuksfzmTSLhJzkE_URyufeGJDn-Gw0_JS-_KmxZSdAO0XXNzB1tJNM1NCVW-V6YbThnPUw5WY4V2J6U1W72c2dzNBx_ipBxGBZ632ZnpViIRu6tL2guT36lWa8YnMDF_OY8sHhl_3kJ6MRxNxY41vAuf45mohi3gri46Kpzc3pf1g6PJ-0oogvUsZ2nBFv1mIdciGBV0zejMKc5Bnxur1L-hEQ9EgZrJ7o0yQRCWYgam_yo_M38EsB8b-suTzQJMA-pRgApOb9dHIV6YAE_b3g_pGkJjrPYzV4IJC1CiPfdz1SAjm7e0ARXtZmaoPltjQ" + }, + "status": { + "user": {}, + "error": "[invalid bearer token, Token has been invalidated]" + } + } + */ + return &authv1.TokenReview{ + TypeMeta: metav1.TypeMeta{ + Kind: "TokenReview", + APIVersion: "authentication.k8s.io/v1", + }, + ObjectMeta: metav1.ObjectMeta{}, + Spec: authv1.TokenReviewSpec{ + Token: jwt, + }, + Status: authv1.TokenReviewStatus{ + User: authv1.UserInfo{}, + Error: "[invalid bearer token, Token has been invalidated]", + }, + } +} + +func createTokenReviewFound(namespace, name, uid, jwt string) *authv1.TokenReview { + /* + STATUS: 201 + { + "kind": "TokenReview", + "apiVersion": "authentication.k8s.io/v1", + "metadata": { + "creationTimestamp": null + }, + "spec": { + "token": "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImRlbW8tdG9rZW4tbTljdm4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZGVtbyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6IjlmZjUxZmY0LTU1N2UtMTFlOS05Njg3LTQ4ZTZjOGI4ZWNiNSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmRlbW8ifQ.UJEphtrN261gy9WCl4ZKjm2PRDLDkc3Xg9VcDGfzyroOqFQ6sog5dVAb9voc5Nc0-H5b1yGwxDViEMucwKvZpA5pi7VEx_OskK-KTWXSmafM0Xg_AvzpU9Ed5TSRno-OhXaAraxdjXoC4myh1ay2DMeHUusJg_ibqcYJrWx-6MO1bH_ObORtAKhoST_8fzkqNAlZmsQ87FinQvYN5mzDXYukl-eeRdBgQUBkWvEb-Ju6cc0-QE4sUQ4IH_fs0fUyX_xc0om0SZGWLP909FTz4V8LxV8kr6L7irxROiS1jn3Fvyc9ur1PamVf3JOPPrOyfmKbaGRiWJM32b3buQw7cg" + }, + "status": { + "authenticated": true, + "user": { + "username": "system:serviceaccount:default:demo", + "uid": "9ff51ff4-557e-11e9-9687-48e6c8b8ecb5", + "groups": [ + "system:serviceaccounts", + "system:serviceaccounts:default", + "system:authenticated" + ] + } + } + } + */ + return &authv1.TokenReview{ + TypeMeta: metav1.TypeMeta{ + Kind: "TokenReview", + APIVersion: "authentication.k8s.io/v1", + }, + ObjectMeta: metav1.ObjectMeta{}, + Spec: authv1.TokenReviewSpec{ + Token: jwt, + }, + Status: authv1.TokenReviewStatus{ + Authenticated: true, + User: authv1.UserInfo{ + Username: "system:serviceaccount:" + namespace + ":" + name, + UID: uid, + Groups: []string{ + "system:serviceaccounts", + "system:serviceaccounts:default", + "system:authenticated", + }, + }, + }, + } +} + +func createReadServiceAccountForbidden(namespace, name string) *metav1.Status { + /* + STATUS: 403 + { + "kind": "Status", + "apiVersion": "v1", + "metadata": {}, + "status": "Failure", + "message": "serviceaccounts \"demo\" is forbidden: User \"system:serviceaccount:default:admin\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \"default\"", + "reason": "Forbidden", + "details": { + "name": "demo", + "kind": "serviceaccounts" + }, + "code": 403 + } + */ + return createStatus( + metav1.StatusFailure, + "serviceaccounts \""+name+"\" is forbidden: User \"system:serviceaccount:"+namespace+":"+name+"\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \""+namespace+"\"", + metav1.StatusReasonForbidden, + &metav1.StatusDetails{ + Kind: "serviceaccounts", + Name: name, + }, + 403, + ) +} + +func createReadServiceAccountForbidden_NoAuthz() *metav1.Status { + // missing bearer token header 403 + /* + { + "kind": "Status", + "apiVersion": "v1", + "metadata": {}, + "status": "Failure", + "message": "serviceaccounts \"demo\" is forbidden: User \"system:anonymous\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \"default\"", + "reason": "Forbidden", + "details": { + "name": "demo", + "kind": "serviceaccounts" + }, + "code": 403 + } + */ + return createStatus( + metav1.StatusFailure, + "serviceaccounts \"PLACEHOLDER\" is forbidden: User \"system:anonymous\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \"default\"", + metav1.StatusReasonForbidden, + &metav1.StatusDetails{ + Kind: "serviceaccounts", + Name: "PLACEHOLDER", + }, + 403, + ) +} + +func createReadServiceAccountNotFound(namespace, name string) *metav1.Status { + /* + STATUS: 404 + { + "kind": "Status", + "apiVersion": "v1", + "metadata": {}, + "status": "Failure", + "message": "serviceaccounts \"demo\" not found", + "reason": "NotFound", + "details": { + "name": "demo", + "kind": "serviceaccounts" + }, + "code": 404 + } + */ + return createStatus( + metav1.StatusFailure, + "serviceaccounts \""+name+"\" not found", + metav1.StatusReasonNotFound, + &metav1.StatusDetails{ + Kind: "serviceaccounts", + Name: name, + }, + 404, + ) +} + +func createReadServiceAccountFound(namespace, name, uid, overrideAnnotation, jwt string) *corev1.ServiceAccount { + /* + STATUS: 200 + { + "kind": "ServiceAccount", + "apiVersion": "v1", + "metadata": { + "name": "demo", + "namespace": "default", + "selfLink": "/api/v1/namespaces/default/serviceaccounts/demo", + "uid": "9ff51ff4-557e-11e9-9687-48e6c8b8ecb5", + "resourceVersion": "2101", + "creationTimestamp": "2019-04-02T19:36:34Z", + "annotations": { + "consul.hashicorp.com/service-name": "actual", + "kubectl.kubernetes.io/last-applied-configuration": "{\"apiVersion\":\"v1\",\"kind\":\"ServiceAccount\",\"metadata\":{\"annotations\":{\"consul.hashicorp.com/service-name\":\"actual\"},\"name\":\"demo\",\"namespace\":\"default\"}}\n" + } + }, + "secrets": [ + { + "name": "demo-token-m9cvn" + } + ] + } + */ + sa := &corev1.ServiceAccount{ + TypeMeta: metav1.TypeMeta{ + Kind: "ServiceAccount", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + SelfLink: "/api/v1/namespaces/" + namespace + "/serviceaccounts/" + name, + UID: types.UID(uid), + ResourceVersion: "123", + CreationTimestamp: metav1.Time{Time: time.Now()}, + }, + Secrets: []corev1.ObjectReference{ + corev1.ObjectReference{ + Name: name + "-token-m9cvn", + }, + }, + } + if overrideAnnotation != "" { + sa.ObjectMeta.Annotations = map[string]string{ + "consul.hashicorp.com/service-name": overrideAnnotation, + } + } + + return sa +} + +func createStatus(status, message string, reason metav1.StatusReason, details *metav1.StatusDetails, code int32) *metav1.Status { + return &metav1.Status{ + TypeMeta: metav1.TypeMeta{ + Kind: "Status", + APIVersion: "v1", + }, + ListMeta: metav1.ListMeta{}, + Status: status, + Message: message, + Reason: reason, + Details: details, + Code: code, + } +} diff --git a/agent/consul/authmethod/testauth/testing.go b/agent/consul/authmethod/testauth/testing.go new file mode 100644 index 0000000000..638450d94d --- /dev/null +++ b/agent/consul/authmethod/testauth/testing.go @@ -0,0 +1,166 @@ +package testauth + +import ( + "fmt" + "sync" + + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/authmethod" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-uuid" +) + +func init() { + authmethod.Register("testing", newValidator) +} + +var ( + tokenDatabaseMu sync.Mutex + tokenDatabase map[string]map[string]map[string]string // session => token => fieldmap +) + +func StartSession() string { + sessionID, err := uuid.GenerateUUID() + if err != nil { + panic(err) + } + return sessionID +} + +func ResetSession(sessionID string) { + tokenDatabaseMu.Lock() + defer tokenDatabaseMu.Unlock() + if tokenDatabase != nil { + delete(tokenDatabase, sessionID) + } +} + +func InstallSessionToken(sessionID string, token string, namespace, name, uid string) { + fields := map[string]string{ + serviceAccountNamespaceField: namespace, + serviceAccountNameField: name, + serviceAccountUIDField: uid, + } + + tokenDatabaseMu.Lock() + defer tokenDatabaseMu.Unlock() + if tokenDatabase == nil { + tokenDatabase = make(map[string]map[string]map[string]string) + } + sdb, ok := tokenDatabase[sessionID] + if !ok { + sdb = make(map[string]map[string]string) + tokenDatabase[sessionID] = sdb + } + sdb[token] = fields +} + +func GetSessionToken(sessionID string, token string) (map[string]string, bool) { + tokenDatabaseMu.Lock() + defer tokenDatabaseMu.Unlock() + if tokenDatabase == nil { + return nil, false + } + sdb, ok := tokenDatabase[sessionID] + if !ok { + return nil, false + } + fields, ok := sdb[token] + if !ok { + return nil, false + } + + fmCopy := make(map[string]string) + for k, v := range fields { + fmCopy[k] = v + } + + return fmCopy, true +} + +type Config struct { + SessionID string // unique identifier for this set of tokens in the database +} + +func newValidator(method *structs.ACLAuthMethod) (authmethod.Validator, error) { + if method.Type != "testing" { + return nil, fmt.Errorf("%q is not a testing auth method", method.Name) + } + + var config Config + if err := authmethod.ParseConfig(method.Config, &config); err != nil { + return nil, err + } + + if config.SessionID == "" { + // If you don't explicitly create one, we create a random one but you + // won't have access to it. Useful if you are testing everything EXCEPT + // ValidateToken(). + config.SessionID = StartSession() + } + + return &Validator{ + name: method.Name, + config: &config, + }, nil +} + +type Validator struct { + name string + config *Config +} + +func (v *Validator) Name() string { return v.name } + +// ValidateLogin takes raw user-provided auth method metadata and ensures it is +// sane, provably correct, and currently valid. Relevant identifying data is +// extracted and returned for immediate use by the role binding process. +// +// Depending upon the method, it may make sense to use these calls to continue +// to extend the life of the underlying token. +// +// Returns auth method specific metadata suitable for the Role Binding process. +func (v *Validator) ValidateLogin(loginToken string) (map[string]string, error) { + fields, valid := GetSessionToken(v.config.SessionID, loginToken) + if !valid { + return nil, acl.ErrNotFound + } + + return fields, nil +} + +func (v *Validator) AvailableFields() []string { return availableFields } + +const ( + serviceAccountNamespaceField = "serviceaccount.namespace" + serviceAccountNameField = "serviceaccount.name" + serviceAccountUIDField = "serviceaccount.uid" +) + +var availableFields = []string{ + serviceAccountNamespaceField, + serviceAccountNameField, + serviceAccountUIDField, +} + +// MakeFieldMapSelectable converts a field map as returned by ValidateLogin +// into a structure suitable for selection with a binding rule. +func (v *Validator) MakeFieldMapSelectable(fieldMap map[string]string) interface{} { + return &selectableVars{ + ServiceAccount: selectableServiceAccount{ + Namespace: fieldMap[serviceAccountNamespaceField], + Name: fieldMap[serviceAccountNameField], + UID: fieldMap[serviceAccountUIDField], + }, + } +} + +type selectableVars struct { + ServiceAccount selectableServiceAccount `bexpr:"serviceaccount"` +} + +type selectableServiceAccount struct { + Namespace string `bexpr:"namespace"` + Name string `bexpr:"name"` + UID string `bexpr:"uid"` +} diff --git a/agent/consul/config.go b/agent/consul/config.go index baf9e63a29..d7e92943bb 100644 --- a/agent/consul/config.go +++ b/agent/consul/config.go @@ -243,6 +243,11 @@ type Config struct { // a substantial cost. ACLPolicyTTL time.Duration + // ACLRoleTTL controls the time-to-live of cached ACL roles. + // It can be set to zero to disable caching, but this adds + // a substantial cost. + ACLRoleTTL time.Duration + // ACLDisabledTTL is the time between checking if ACLs should be // enabled. This ACLDisabledTTL time.Duration @@ -313,6 +318,16 @@ type Config struct { // Minimum Session TTL SessionTTLMin time.Duration + // maxTokenExpirationDuration is the maximum difference allowed between + // ACLToken CreateTime and ExpirationTime values if ExpirationTime is set + // on a token. + ACLTokenMaxExpirationTTL time.Duration + + // ACLTokenMinExpirationTTL is the minimum difference allowed between + // ACLToken CreateTime and ExpirationTime values if ExpirationTime is set + // on a token. + ACLTokenMinExpirationTTL time.Duration + // ServerUp callback can be used to trigger a notification that // a Consul server is now up and known about. ServerUp func() @@ -460,6 +475,7 @@ func DefaultConfig() *Config { SerfFloodInterval: 60 * time.Second, ReconcileInterval: 60 * time.Second, ProtocolVersion: ProtocolVersion2Compatible, + ACLRoleTTL: 30 * time.Second, ACLPolicyTTL: 30 * time.Second, ACLTokenTTL: 30 * time.Second, ACLDefaultPolicy: "allow", @@ -473,6 +489,8 @@ func DefaultConfig() *Config { TombstoneTTL: 15 * time.Minute, TombstoneTTLGranularity: 30 * time.Second, SessionTTLMin: 10 * time.Second, + ACLTokenMinExpirationTTL: 1 * time.Minute, + ACLTokenMaxExpirationTTL: 24 * time.Hour, // These are tuned to provide a total throughput of 128 updates // per second. If you update these, you should update the client- diff --git a/agent/consul/fsm/commands_oss.go b/agent/consul/fsm/commands_oss.go index 1e3651cba4..f093aa0abd 100644 --- a/agent/consul/fsm/commands_oss.go +++ b/agent/consul/fsm/commands_oss.go @@ -4,7 +4,7 @@ import ( "fmt" "time" - "github.com/armon/go-metrics" + metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" ) @@ -30,6 +30,12 @@ func init() { registerCommand(structs.ACLPolicyDeleteRequestType, (*FSM).applyACLPolicyDeleteOperation) registerCommand(structs.ConnectCALeafRequestType, (*FSM).applyConnectCALeafOperation) registerCommand(structs.ConfigEntryRequestType, (*FSM).applyConfigEntryOperation) + registerCommand(structs.ACLRoleSetRequestType, (*FSM).applyACLRoleSetOperation) + registerCommand(structs.ACLRoleDeleteRequestType, (*FSM).applyACLRoleDeleteOperation) + registerCommand(structs.ACLBindingRuleSetRequestType, (*FSM).applyACLBindingRuleSetOperation) + registerCommand(structs.ACLBindingRuleDeleteRequestType, (*FSM).applyACLBindingRuleDeleteOperation) + registerCommand(structs.ACLAuthMethodSetRequestType, (*FSM).applyACLAuthMethodSetOperation) + registerCommand(structs.ACLAuthMethodDeleteRequestType, (*FSM).applyACLAuthMethodDeleteOperation) } func (c *FSM) applyRegister(buf []byte, index uint64) interface{} { @@ -165,6 +171,7 @@ func (c *FSM) applyACLOperation(buf []byte, index uint64) interface{} { return err } + // No need to check expiration times as those did not exist in legacy tokens. if _, token, err := c.state.ACLTokenGetBySecret(nil, req.ACL.ID); err != nil { return err } else { @@ -451,3 +458,69 @@ func (c *FSM) applyConfigEntryOperation(buf []byte, index uint64) interface{} { return fmt.Errorf("invalid config entry operation type: %v", req.Op) } } + +func (c *FSM) applyACLRoleSetOperation(buf []byte, index uint64) interface{} { + var req structs.ACLRoleBatchSetRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "role"}, time.Now(), + []metrics.Label{{Name: "op", Value: "upsert"}}) + + return c.state.ACLRoleBatchSet(index, req.Roles) +} + +func (c *FSM) applyACLRoleDeleteOperation(buf []byte, index uint64) interface{} { + var req structs.ACLRoleBatchDeleteRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "role"}, time.Now(), + []metrics.Label{{Name: "op", Value: "delete"}}) + + return c.state.ACLRoleBatchDelete(index, req.RoleIDs) +} + +func (c *FSM) applyACLBindingRuleSetOperation(buf []byte, index uint64) interface{} { + var req structs.ACLBindingRuleBatchSetRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "bindingrule"}, time.Now(), + []metrics.Label{{Name: "op", Value: "upsert"}}) + + return c.state.ACLBindingRuleBatchSet(index, req.BindingRules) +} + +func (c *FSM) applyACLBindingRuleDeleteOperation(buf []byte, index uint64) interface{} { + var req structs.ACLBindingRuleBatchDeleteRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "bindingrule"}, time.Now(), + []metrics.Label{{Name: "op", Value: "delete"}}) + + return c.state.ACLBindingRuleBatchDelete(index, req.BindingRuleIDs) +} + +func (c *FSM) applyACLAuthMethodSetOperation(buf []byte, index uint64) interface{} { + var req structs.ACLAuthMethodBatchSetRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "authmethod"}, time.Now(), + []metrics.Label{{Name: "op", Value: "upsert"}}) + + return c.state.ACLAuthMethodBatchSet(index, req.AuthMethods) +} + +func (c *FSM) applyACLAuthMethodDeleteOperation(buf []byte, index uint64) interface{} { + var req structs.ACLAuthMethodBatchDeleteRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "authmethod"}, time.Now(), + []metrics.Label{{Name: "op", Value: "delete"}}) + + return c.state.ACLAuthMethodBatchDelete(index, req.AuthMethodNames) +} diff --git a/agent/consul/fsm/snapshot_oss.go b/agent/consul/fsm/snapshot_oss.go index 0c7713753e..195e6cf136 100644 --- a/agent/consul/fsm/snapshot_oss.go +++ b/agent/consul/fsm/snapshot_oss.go @@ -28,6 +28,9 @@ func init() { registerRestorer(structs.ACLTokenSetRequestType, restoreToken) registerRestorer(structs.ACLPolicySetRequestType, restorePolicy) registerRestorer(structs.ConfigEntryRequestType, restoreConfigEntry) + registerRestorer(structs.ACLRoleSetRequestType, restoreRole) + registerRestorer(structs.ACLBindingRuleSetRequestType, restoreBindingRule) + registerRestorer(structs.ACLAuthMethodSetRequestType, restoreAuthMethod) } func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) error { @@ -178,6 +181,8 @@ func (s *snapshot) persistACLs(sink raft.SnapshotSink, return err } + // Don't check expiration times. Wait for explicit deletions. + for token := tokens.Next(); token != nil; token = tokens.Next() { if _, err := sink.Write([]byte{byte(structs.ACLTokenSetRequestType)}); err != nil { return err @@ -201,6 +206,48 @@ func (s *snapshot) persistACLs(sink raft.SnapshotSink, } } + roles, err := s.state.ACLRoles() + if err != nil { + return err + } + + for role := roles.Next(); role != nil; role = roles.Next() { + if _, err := sink.Write([]byte{byte(structs.ACLRoleSetRequestType)}); err != nil { + return err + } + if err := encoder.Encode(role.(*structs.ACLRole)); err != nil { + return err + } + } + + rules, err := s.state.ACLBindingRules() + if err != nil { + return err + } + + for rule := rules.Next(); rule != nil; rule = rules.Next() { + if _, err := sink.Write([]byte{byte(structs.ACLBindingRuleSetRequestType)}); err != nil { + return err + } + if err := encoder.Encode(rule.(*structs.ACLBindingRule)); err != nil { + return err + } + } + + methods, err := s.state.ACLAuthMethods() + if err != nil { + return err + } + + for method := methods.Next(); method != nil; method = rules.Next() { + if _, err := sink.Write([]byte{byte(structs.ACLAuthMethodSetRequestType)}); err != nil { + return err + } + if err := encoder.Encode(method.(*structs.ACLAuthMethod)); err != nil { + return err + } + } + return nil } @@ -601,3 +648,27 @@ func restoreConfigEntry(header *snapshotHeader, restore *state.Restore, decoder } return restore.ConfigEntry(req.Entry) } + +func restoreRole(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.ACLRole + if err := decoder.Decode(&req); err != nil { + return err + } + return restore.ACLRole(&req) +} + +func restoreBindingRule(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.ACLBindingRule + if err := decoder.Decode(&req); err != nil { + return err + } + return restore.ACLBindingRule(&req) +} + +func restoreAuthMethod(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.ACLAuthMethod + if err := decoder.Decode(&req); err != nil { + return err + } + return restore.ACLAuthMethod(&req) +} diff --git a/agent/consul/fsm/snapshot_oss_test.go b/agent/consul/fsm/snapshot_oss_test.go index b8b5bb1327..9571b5f1a1 100644 --- a/agent/consul/fsm/snapshot_oss_test.go +++ b/agent/consul/fsm/snapshot_oss_test.go @@ -86,7 +86,7 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { }) session := &structs.Session{ID: generateUUID(), Node: "foo"} fsm.state.SessionCreate(9, session) - policy := structs.ACLPolicy{ + policy := &structs.ACLPolicy{ ID: structs.ACLPolicyGlobalManagementID, Name: "global-management", Description: "Builtin Policy that grants unlimited access", @@ -94,7 +94,20 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { Syntax: acl.SyntaxCurrent, } policy.SetHash(true) - require.NoError(fsm.state.ACLPolicySet(1, &policy)) + require.NoError(fsm.state.ACLPolicySet(1, policy)) + + role := &structs.ACLRole{ + ID: "86dedd19-8fae-4594-8294-4e6948a81f9a", + Name: "some-role", + Description: "test snapshot role", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "example", + }, + }, + } + role.SetHash(true) + require.NoError(fsm.state.ACLRoleSet(1, role)) token := &structs.ACLToken{ AccessorID: "30fca056-9fbb-4455-b94a-bf0e2bc575d6", @@ -112,6 +125,26 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { } require.NoError(fsm.state.ACLBootstrap(10, 0, token, false)) + method := &structs.ACLAuthMethod{ + Name: "some-method", + Type: "testing", + Description: "test snapshot auth method", + Config: map[string]interface{}{ + "SessionID": "952ebfa8-2a42-46f0-bcd3-fd98a842000e", + }, + } + require.NoError(fsm.state.ACLAuthMethodSet(1, method)) + + bindingRule := &structs.ACLBindingRule{ + ID: "85184c52-5997-4a84-9817-5945f2632a17", + Description: "test snapshot binding rule", + AuthMethod: "some-method", + Selector: "serviceaccount.namespace==default", + BindType: structs.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + } + require.NoError(fsm.state.ACLBindingRuleSet(1, bindingRule)) + fsm.state.KVSSet(11, &structs.DirEntry{ Key: "/remove", Value: []byte("foo"), @@ -314,21 +347,40 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { t.Fatalf("bad index: %d", idx) } - // Verify ACL Token is restored - _, a, err := fsm2.state.ACLTokenGetByAccessor(nil, token.AccessorID) + // Verify ACL Binding Rule is restored + _, bindingRule2, err := fsm2.state.ACLBindingRuleGetByID(nil, bindingRule.ID) require.NoError(err) - require.Equal(token.AccessorID, a.AccessorID) - require.Equal(token.ModifyIndex, a.ModifyIndex) + require.Equal(bindingRule, bindingRule2) + + // Verify ACL Auth Method is restored + _, method2, err := fsm2.state.ACLAuthMethodGetByName(nil, method.Name) + require.NoError(err) + require.Equal(method, method2) + + // Verify ACL Token is restored + _, token2, err := fsm2.state.ACLTokenGetByAccessor(nil, token.AccessorID) + require.NoError(err) + { + // time.Time is tricky to compare generically when it takes a ser/deserialization round trip. + require.True(token.CreateTime.Equal(token2.CreateTime)) + token2.CreateTime = token.CreateTime + } + require.Equal(token, token2) // Verify the acl-token-bootstrap index was restored canBootstrap, index, err := fsm2.state.CanBootstrapACLToken() require.False(canBootstrap) require.True(index > 0) + // Verify ACL Role is restored + _, role2, err := fsm2.state.ACLRoleGetByID(nil, role.ID) + require.NoError(err) + require.Equal(role, role2) + // Verify ACL Policy is restored _, policy2, err := fsm2.state.ACLPolicyGetByID(nil, structs.ACLPolicyGlobalManagementID) require.NoError(err) - require.Equal(policy.Name, policy2.Name) + require.Equal(policy, policy2) // Verify tombstones are restored func() { diff --git a/agent/consul/helper_test.go b/agent/consul/helper_test.go index bd84fce5c5..678ef033fd 100644 --- a/agent/consul/helper_test.go +++ b/agent/consul/helper_test.go @@ -11,7 +11,7 @@ import ( "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/types" - "github.com/hashicorp/net-rpc-msgpackrpc" + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/raft" "github.com/hashicorp/serf/serf" "github.com/stretchr/testify/require" @@ -166,6 +166,21 @@ func waitForNewACLs(t *testing.T, server *Server) { require.False(t, server.UseLegacyACLs(), "Server cannot use new ACLs") } +func waitForNewACLReplication(t *testing.T, server *Server, expectedReplicationType structs.ACLReplicationType) { + var ( + replTyp structs.ACLReplicationType + running bool + ) + retry.Run(t, func(r *retry.R) { + replTyp, running = server.getACLReplicationStatusRunningType() + require.Equal(r, expectedReplicationType, replTyp, "Server not running new replicator yet") + require.True(r, running, "Server not running new replicator yet") + }) + + require.Equal(t, expectedReplicationType, replTyp, "Server not running new replicator yet") + require.True(t, running, "Server not running new replicator yet") +} + func seeEachOther(a, b []serf.Member, addra, addrb string) bool { return serfMembersContains(a, addrb) && serfMembersContains(b, addra) } diff --git a/agent/consul/leader.go b/agent/consul/leader.go index a0fd46613c..5f4f604699 100644 --- a/agent/consul/leader.go +++ b/agent/consul/leader.go @@ -10,7 +10,7 @@ import ( "sync/atomic" "time" - "github.com/armon/go-metrics" + metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/connect" ca "github.com/hashicorp/consul/agent/connect/ca" @@ -22,7 +22,7 @@ import ( "github.com/hashicorp/consul/types" memdb "github.com/hashicorp/go-memdb" uuid "github.com/hashicorp/go-uuid" - "github.com/hashicorp/go-version" + version "github.com/hashicorp/go-version" "github.com/hashicorp/raft" "github.com/hashicorp/serf/serf" "golang.org/x/time/rate" @@ -308,6 +308,8 @@ func (s *Server) revokeLeadership() error { s.setCAProvider(nil, nil) + s.stopACLTokenReaping() + s.stopACLUpgrade() s.resetConsistentReadReady() @@ -329,6 +331,7 @@ func (s *Server) initializeLegacyACL() error { if err != nil { return fmt.Errorf("failed to get anonymous token: %v", err) } + // Ignoring expiration times to avoid an insertion collision. if token == nil { req := structs.ACLRequest{ Datacenter: authDC, @@ -352,6 +355,7 @@ func (s *Server) initializeLegacyACL() error { if err != nil { return fmt.Errorf("failed to get master token: %v", err) } + // Ignoring expiration times to avoid an insertion collision. if token == nil { req := structs.ACLRequest{ Datacenter: authDC, @@ -423,6 +427,10 @@ func (s *Server) initializeACLs(upgrade bool) error { // leader. s.acls.cache.Purge() + // Purge the auth method validators since they could've changed while we + // were not leader. + s.purgeAuthMethodValidators() + // Remove any token affected by CVE-2019-8336 if !s.InACLDatacenter() { _, token, err := s.fsm.State().ACLTokenGetBySecret(nil, redactedToken) @@ -482,6 +490,7 @@ func (s *Server) initializeACLs(upgrade bool) error { if err != nil { return fmt.Errorf("failed to get master token: %v", err) } + // Ignoring expiration times to avoid an insertion collision. if token == nil { accessor, err := lib.GenerateUUID(s.checkTokenUUID) if err != nil { @@ -543,6 +552,7 @@ func (s *Server) initializeACLs(upgrade bool) error { if err != nil { return fmt.Errorf("failed to get anonymous token: %v", err) } + // Ignoring expiration times to avoid an insertion collision. if token == nil { // DEPRECATED (ACL-Legacy-Compat) - Don't need to query for previous "anonymous" token // check for legacy token that needs an upgrade @@ -550,6 +560,7 @@ func (s *Server) initializeACLs(upgrade bool) error { if err != nil { return fmt.Errorf("failed to get anonymous token: %v", err) } + // Ignoring expiration times to avoid an insertion collision. // the token upgrade routine will take care of upgrading the token if a legacy version exists if legacyToken == nil { @@ -572,6 +583,7 @@ func (s *Server) initializeACLs(upgrade bool) error { s.logger.Printf("[INFO] consul: Created ACL anonymous token from configuration") } } + // launch the upgrade go routine to generate accessors for everything s.startACLUpgrade() } else { if s.UseLegacyACLs() && !upgrade { @@ -588,7 +600,7 @@ func (s *Server) initializeACLs(upgrade bool) error { s.startACLReplication() } - // launch the upgrade go routine to generate accessors for everything + s.startACLTokenReaping() return nil } @@ -617,6 +629,7 @@ func (s *Server) startACLUpgrade() { if err != nil { s.logger.Printf("[WARN] acl: encountered an error while searching for tokens without accessor ids: %v", err) } + // No need to check expiration time here, as that only exists for v2 tokens. if len(tokens) == 0 { ws := memdb.NewWatchSet() @@ -649,7 +662,10 @@ func (s *Server) startACLUpgrade() { } // Assign the global-management policy to legacy management tokens - if len(newToken.Policies) == 0 && newToken.Type == structs.ACLTokenTypeManagement { + if len(newToken.Policies) == 0 && + len(newToken.ServiceIdentities) == 0 && + len(newToken.Roles) == 0 && + newToken.Type == structs.ACLTokenTypeManagement { newToken.Policies = append(newToken.Policies, structs.ACLTokenPolicyLink{ID: structs.ACLPolicyGlobalManagementID}) } @@ -727,7 +743,7 @@ func (s *Server) startLegacyACLReplication() { s.logger.Printf("[WARN] consul: Legacy ACL replication error (will retry if still leader): %v", err) } else { lastRemoteIndex = index - s.updateACLReplicationStatusIndex(index) + s.updateACLReplicationStatusIndex(structs.ACLReplicateLegacy, index) s.logger.Printf("[DEBUG] consul: Legacy ACL replication completed through remote index %d", index) } } @@ -749,8 +765,22 @@ func (s *Server) startACLReplication() { ctx, cancel := context.WithCancel(context.Background()) s.aclReplicationCancel = cancel - replicationType := structs.ACLReplicatePolicies + s.startACLReplicator(ctx, structs.ACLReplicatePolicies, s.replicateACLPolicies) + s.startACLReplicator(ctx, structs.ACLReplicateRoles, s.replicateACLRoles) + if s.config.ACLTokenReplication { + s.startACLReplicator(ctx, structs.ACLReplicateTokens, s.replicateACLTokens) + s.updateACLReplicationStatusRunning(structs.ACLReplicateTokens) + } else { + s.updateACLReplicationStatusRunning(structs.ACLReplicatePolicies) + } + + s.aclReplicationEnabled = true +} + +type replicateFunc func(ctx context.Context, lastRemoteIndex uint64) (uint64, bool, error) + +func (s *Server) startACLReplicator(ctx context.Context, replicationType structs.ACLReplicationType, replicateFunc replicateFunc) { go func() { var failedAttempts uint limiter := rate.NewLimiter(rate.Limit(s.config.ACLReplicationRate), s.config.ACLReplicationBurst) @@ -765,7 +795,7 @@ func (s *Server) startACLReplication() { continue } - index, exit, err := s.replicateACLPolicies(lastRemoteIndex, ctx) + index, exit, err := replicateFunc(ctx, lastRemoteIndex) if exit { return } @@ -773,7 +803,7 @@ func (s *Server) startACLReplication() { if err != nil { lastRemoteIndex = 0 s.updateACLReplicationStatusError() - s.logger.Printf("[WARN] consul: ACL policy replication error (will retry if still leader): %v", err) + s.logger.Printf("[WARN] consul: ACL %s replication error (will retry if still leader): %v", replicationType.SingularNoun(), err) if (1 << failedAttempts) < aclReplicationMaxRetryBackoff { failedAttempts++ } @@ -786,65 +816,14 @@ func (s *Server) startACLReplication() { } } else { lastRemoteIndex = index - s.updateACLReplicationStatusIndex(index) - s.logger.Printf("[DEBUG] consul: ACL policy replication completed through remote index %d", index) + s.updateACLReplicationStatusIndex(replicationType, index) + s.logger.Printf("[DEBUG] consul: ACL %s replication completed through remote index %d", replicationType.SingularNoun(), index) failedAttempts = 0 } } }() - s.logger.Printf("[INFO] acl: started ACL Policy replication") - - if s.config.ACLTokenReplication { - replicationType = structs.ACLReplicateTokens - - go func() { - var failedAttempts uint - limiter := rate.NewLimiter(rate.Limit(s.config.ACLReplicationRate), s.config.ACLReplicationBurst) - var lastRemoteIndex uint64 - for { - if err := limiter.Wait(ctx); err != nil { - return - } - - if s.tokens.ReplicationToken() == "" { - continue - } - - index, exit, err := s.replicateACLTokens(lastRemoteIndex, ctx) - if exit { - return - } - - if err != nil { - lastRemoteIndex = 0 - s.updateACLReplicationStatusError() - s.logger.Printf("[WARN] consul: ACL token replication error (will retry if still leader): %v", err) - if (1 << failedAttempts) < aclReplicationMaxRetryBackoff { - failedAttempts++ - } - - select { - case <-ctx.Done(): - return - case <-time.After((1 << failedAttempts) * time.Second): - // do nothing - } - } else { - lastRemoteIndex = index - s.updateACLReplicationStatusTokenIndex(index) - s.logger.Printf("[DEBUG] consul: ACL token replication completed through remote index %d", index) - failedAttempts = 0 - } - } - }() - - s.logger.Printf("[INFO] acl: started ACL Token replication") - } - - s.updateACLReplicationStatusRunning(replicationType) - - s.aclReplicationEnabled = true + s.logger.Printf("[INFO] acl: started ACL %s replication", replicationType.SingularNoun()) } func (s *Server) stopACLReplication() { diff --git a/agent/consul/server.go b/agent/consul/server.go index f19a907455..d76a46a9b8 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -109,6 +109,15 @@ type Server struct { aclReplicationLock sync.RWMutex aclReplicationEnabled bool + // aclTokenReapCancel is used to shut down the ACL Token expiration reap + // goroutine when we lose leadership. + aclTokenReapCancel context.CancelFunc + aclTokenReapLock sync.RWMutex + aclTokenReapEnabled bool + + aclAuthMethodValidators map[string]*authMethodValidatorEntry + aclAuthMethodValidatorLock sync.RWMutex + // DEPRECATED (ACL-Legacy-Compat) - only needed while we support both // useNewACLs is used to determine whether we can use new ACLs or not useNewACLs int32 diff --git a/agent/consul/state/acl.go b/agent/consul/state/acl.go index 5c2e83d57f..1249844c86 100644 --- a/agent/consul/state/acl.go +++ b/agent/consul/state/acl.go @@ -1,10 +1,12 @@ package state import ( + "encoding/binary" "fmt" + "time" "github.com/hashicorp/consul/agent/structs" - "github.com/hashicorp/go-memdb" + memdb "github.com/hashicorp/go-memdb" ) type TokenPoliciesIndex struct { @@ -58,6 +60,156 @@ func (s *TokenPoliciesIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) return val, nil } +type TokenRolesIndex struct { +} + +func (s *TokenRolesIndex) FromObject(obj interface{}) (bool, [][]byte, error) { + token, ok := obj.(*structs.ACLToken) + if !ok { + return false, nil, fmt.Errorf("object is not an ACLToken") + } + + links := token.Roles + + numLinks := len(links) + if numLinks == 0 { + return false, nil, nil + } + + vals := make([][]byte, 0, numLinks) + for _, link := range links { + vals = append(vals, []byte(link.ID+"\x00")) + } + + return true, vals, nil +} + +func (s *TokenRolesIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + arg, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", args[0]) + } + // Add the null character as a terminator + arg += "\x00" + return []byte(arg), nil +} + +func (s *TokenRolesIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { + val, err := s.FromArgs(args...) + if err != nil { + return nil, err + } + + // Strip the null terminator, the rest is a prefix + n := len(val) + if n > 0 { + return val[:n-1], nil + } + return val, nil +} + +type RolePoliciesIndex struct { +} + +func (s *RolePoliciesIndex) FromObject(obj interface{}) (bool, [][]byte, error) { + role, ok := obj.(*structs.ACLRole) + if !ok { + return false, nil, fmt.Errorf("object is not an ACLRole") + } + + links := role.Policies + + numLinks := len(links) + if numLinks == 0 { + return false, nil, nil + } + + vals := make([][]byte, 0, numLinks) + for _, link := range links { + vals = append(vals, []byte(link.ID+"\x00")) + } + + return true, vals, nil +} + +func (s *RolePoliciesIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + arg, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", args[0]) + } + // Add the null character as a terminator + arg += "\x00" + return []byte(arg), nil +} + +func (s *RolePoliciesIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { + val, err := s.FromArgs(args...) + if err != nil { + return nil, err + } + + // Strip the null terminator, the rest is a prefix + n := len(val) + if n > 0 { + return val[:n-1], nil + } + return val, nil +} + +type TokenExpirationIndex struct { + LocalFilter bool +} + +func (s *TokenExpirationIndex) encodeTime(t time.Time) []byte { + val := t.Unix() + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(val)) + return buf +} + +func (s *TokenExpirationIndex) FromObject(obj interface{}) (bool, []byte, error) { + token, ok := obj.(*structs.ACLToken) + if !ok { + return false, nil, fmt.Errorf("object is not an ACLToken") + } + if s.LocalFilter != token.Local { + return false, nil, nil + } + if !token.HasExpirationTime() { + return false, nil, nil + } + if token.ExpirationTime.Unix() < 0 { + return false, nil, fmt.Errorf("token expiration time cannot be before the unix epoch: %s", token.ExpirationTime) + } + + buf := s.encodeTime(*token.ExpirationTime) + + return true, buf, nil +} + +func (s *TokenExpirationIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + arg, ok := args[0].(time.Time) + if !ok { + return nil, fmt.Errorf("argument must be a time.Time: %#v", args[0]) + } + if arg.Unix() < 0 { + return nil, fmt.Errorf("argument must be a time.Time after the unix epoch: %s", args[0]) + } + + buf := s.encodeTime(arg) + + return buf, nil +} + func tokensTableSchema() *memdb.TableSchema { return &memdb.TableSchema{ Name: "acl-tokens", @@ -87,6 +239,21 @@ func tokensTableSchema() *memdb.TableSchema { Unique: false, Indexer: &TokenPoliciesIndex{}, }, + "roles": &memdb.IndexSchema{ + Name: "roles", + AllowMissing: true, + Unique: false, + Indexer: &TokenRolesIndex{}, + }, + "authmethod": &memdb.IndexSchema{ + Name: "authmethod", + AllowMissing: true, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "AuthMethod", + Lowercase: false, + }, + }, "local": &memdb.IndexSchema{ Name: "local", AllowMissing: false, @@ -100,6 +267,18 @@ func tokensTableSchema() *memdb.TableSchema { }, }, }, + "expires-global": { + Name: "expires-global", + AllowMissing: true, + Unique: false, + Indexer: &TokenExpirationIndex{LocalFilter: false}, + }, + "expires-local": { + Name: "expires-local", + AllowMissing: true, + Unique: false, + Indexer: &TokenExpirationIndex{LocalFilter: true}, + }, //DEPRECATED (ACL-Legacy-Compat) - This index is only needed while we support upgrading v1 to v2 acls // This table indexes all the ACL tokens that do not have an AccessorID @@ -146,9 +325,86 @@ func policiesTableSchema() *memdb.TableSchema { } } +func rolesTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "acl-roles", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.UUIDFieldIndex{ + Field: "ID", + }, + }, + "name": &memdb.IndexSchema{ + Name: "name", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Name", + Lowercase: true, + }, + }, + "policies": &memdb.IndexSchema{ + Name: "policies", + // Need to allow missing for the anonymous token + AllowMissing: true, + Unique: false, + Indexer: &RolePoliciesIndex{}, + }, + }, + } +} + +func bindingRulesTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "acl-binding-rules", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.UUIDFieldIndex{ + Field: "ID", + }, + }, + "authmethod": &memdb.IndexSchema{ + Name: "authmethod", + AllowMissing: false, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "AuthMethod", + Lowercase: true, + }, + }, + }, + } +} + +func authMethodsTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "acl-auth-methods", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Name", + Lowercase: true, + }, + }, + }, + } +} + func init() { registerSchema(tokensTableSchema) registerSchema(policiesTableSchema) + registerSchema(rolesTableSchema) + registerSchema(bindingRulesTableSchema) + registerSchema(authMethodsTableSchema) } // ACLTokens is used when saving a snapshot @@ -193,6 +449,66 @@ func (s *Restore) ACLPolicy(policy *structs.ACLPolicy) error { return nil } +// ACLRoles is used when saving a snapshot +func (s *Snapshot) ACLRoles() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("acl-roles", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +func (s *Restore) ACLRole(role *structs.ACLRole) error { + if err := s.tx.Insert("acl-roles", role); err != nil { + return fmt.Errorf("failed restoring acl role: %s", err) + } + + if err := indexUpdateMaxTxn(s.tx, role.ModifyIndex, "acl-roles"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + return nil +} + +// ACLBindingRules is used when saving a snapshot +func (s *Snapshot) ACLBindingRules() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("acl-binding-rules", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +func (s *Restore) ACLBindingRule(rule *structs.ACLBindingRule) error { + if err := s.tx.Insert("acl-binding-rules", rule); err != nil { + return fmt.Errorf("failed restoring acl binding rule: %s", err) + } + + if err := indexUpdateMaxTxn(s.tx, rule.ModifyIndex, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + return nil +} + +// ACLAuthMethods is used when saving a snapshot +func (s *Snapshot) ACLAuthMethods() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("acl-auth-methods", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +func (s *Restore) ACLAuthMethod(method *structs.ACLAuthMethod) error { + if err := s.tx.Insert("acl-auth-methods", method); err != nil { + return fmt.Errorf("failed restoring acl auth method: %s", err) + } + + if err := indexUpdateMaxTxn(s.tx, method.ModifyIndex, "acl-auth-methods"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + return nil +} + // ACLBootstrap is used to perform a one-time ACL bootstrap operation on a // cluster to get the first management token. func (s *Store) ACLBootstrap(idx, resetIndex uint64, token *structs.ACLToken, legacy bool) error { @@ -307,6 +623,7 @@ func (s *Store) fixupTokenPolicyLinks(tx *memdb.Txn, original *structs.ACLToken) // append the corrected policy token.Policies = append(token.Policies, structs.ACLTokenPolicyLink{ID: link.ID, Name: policy.Name}) + } else if owned { token.Policies = append(token.Policies, link) } @@ -315,6 +632,150 @@ func (s *Store) fixupTokenPolicyLinks(tx *memdb.Txn, original *structs.ACLToken) return token, nil } +func (s *Store) resolveTokenRoleLinks(tx *memdb.Txn, token *structs.ACLToken, allowMissing bool) error { + for linkIndex, link := range token.Roles { + if link.ID != "" { + role, err := s.getRoleWithTxn(tx, nil, link.ID, "id") + + if err != nil { + return err + } + + if role != nil { + // the name doesn't matter here + token.Roles[linkIndex].Name = role.Name + } else if !allowMissing { + return fmt.Errorf("No such role with ID: %s", link.ID) + } + } else { + return fmt.Errorf("Encountered a Token with roles linked by Name in the state store") + } + } + return nil +} + +// fixupTokenRoleLinks is to be used when retrieving tokens from memdb. The role links could have gotten +// stale when a linked role was deleted or renamed. This will correct them and generate a newly allocated +// token only when fixes are needed. If the role links are still accurate then we just return the original +// token. +func (s *Store) fixupTokenRoleLinks(tx *memdb.Txn, original *structs.ACLToken) (*structs.ACLToken, error) { + owned := false + token := original + + cloneToken := func(t *structs.ACLToken, copyNumLinks int) *structs.ACLToken { + clone := *t + clone.Roles = make([]structs.ACLTokenRoleLink, copyNumLinks) + copy(clone.Roles, t.Roles[:copyNumLinks]) + return &clone + } + + for linkIndex, link := range original.Roles { + if link.ID == "" { + return nil, fmt.Errorf("Detected corrupted token within the state store - missing role link ID") + } + + role, err := s.getRoleWithTxn(tx, nil, link.ID, "id") + + if err != nil { + return nil, err + } + + if role == nil { + if !owned { + // clone the token as we cannot touch the original + token = cloneToken(original, linkIndex) + owned = true + } + // if already owned then we just don't append it. + } else if role.Name != link.Name { + if !owned { + token = cloneToken(original, linkIndex) + owned = true + } + + // append the corrected policy + token.Roles = append(token.Roles, structs.ACLTokenRoleLink{ID: link.ID, Name: role.Name}) + + } else if owned { + token.Roles = append(token.Roles, link) + } + } + + return token, nil +} + +func (s *Store) resolveRolePolicyLinks(tx *memdb.Txn, role *structs.ACLRole, allowMissing bool) error { + for linkIndex, link := range role.Policies { + if link.ID != "" { + policy, err := s.getPolicyWithTxn(tx, nil, link.ID, "id") + + if err != nil { + return err + } + + if policy != nil { + // the name doesn't matter here + role.Policies[linkIndex].Name = policy.Name + } else if !allowMissing { + return fmt.Errorf("No such policy with ID: %s", link.ID) + } + } else { + return fmt.Errorf("Encountered a Role with policies linked by Name in the state store") + } + } + return nil +} + +// fixupRolePolicyLinks is to be used when retrieving roles from memdb. The policy links could have gotten +// stale when a linked policy was deleted or renamed. This will correct them and generate a newly allocated +// role only when fixes are needed. If the policy links are still accurate then we just return the original +// role. +func (s *Store) fixupRolePolicyLinks(tx *memdb.Txn, original *structs.ACLRole) (*structs.ACLRole, error) { + owned := false + role := original + + cloneRole := func(t *structs.ACLRole, copyNumLinks int) *structs.ACLRole { + clone := *t + clone.Policies = make([]structs.ACLRolePolicyLink, copyNumLinks) + copy(clone.Policies, t.Policies[:copyNumLinks]) + return &clone + } + + for linkIndex, link := range original.Policies { + if link.ID == "" { + return nil, fmt.Errorf("Detected corrupted role within the state store - missing policy link ID") + } + + policy, err := s.getPolicyWithTxn(tx, nil, link.ID, "id") + + if err != nil { + return nil, err + } + + if policy == nil { + if !owned { + // clone the token as we cannot touch the original + role = cloneRole(original, linkIndex) + owned = true + } + // if already owned then we just don't append it. + } else if policy.Name != link.Name { + if !owned { + role = cloneRole(original, linkIndex) + owned = true + } + + // append the corrected policy + role.Policies = append(role.Policies, structs.ACLRolePolicyLink{ID: link.ID, Name: policy.Name}) + + } else if owned { + role.Policies = append(role.Policies, link) + } + } + + return role, nil +} + // ACLTokenSet is used to insert an ACL rule into the state store. func (s *Store) ACLTokenSet(idx uint64, token *structs.ACLToken, legacy bool) error { tx := s.db.Txn(true) @@ -355,7 +816,7 @@ func (s *Store) ACLTokenBatchSet(idx uint64, tokens structs.ACLTokens, cas bool) // aclTokenSetTxn is the inner method used to insert an ACL token with the // proper indexes into the state store. -func (s *Store) aclTokenSetTxn(tx *memdb.Txn, idx uint64, token *structs.ACLToken, cas, allowMissingPolicyIDs, legacy bool) error { +func (s *Store) aclTokenSetTxn(tx *memdb.Txn, idx uint64, token *structs.ACLToken, cas, allowMissingPolicyAndRoleIDs, legacy bool) error { // Check that the ID is set if token.SecretID == "" { return ErrMissingACLTokenSecret @@ -405,17 +866,36 @@ func (s *Store) aclTokenSetTxn(tx *memdb.Txn, idx uint64, token *structs.ACLToke } if legacy && original != nil { - if len(original.Policies) > 0 || original.Type == "" { + if original.UsesNonLegacyFields() { return fmt.Errorf("failed inserting acl token: cannot use legacy endpoint to modify a non-legacy token") } token.AccessorID = original.AccessorID } - if err := s.resolveTokenPolicyLinks(tx, token, allowMissingPolicyIDs); err != nil { + if err := s.resolveTokenPolicyLinks(tx, token, allowMissingPolicyAndRoleIDs); err != nil { return err } + if err := s.resolveTokenRoleLinks(tx, token, allowMissingPolicyAndRoleIDs); err != nil { + return err + } + + if token.AuthMethod != "" { + method, err := s.getAuthMethodWithTxn(tx, nil, token.AuthMethod, "id") + if err != nil { + return err + } else if method == nil { + return fmt.Errorf("No such auth method with Name: %s", token.AuthMethod) + } + } + + for _, svcid := range token.ServiceIdentities { + if svcid.ServiceName == "" { + return fmt.Errorf("Encountered a Token with an empty service identity name in the state store") + } + } + // Set the indexes if original != nil { if original.AccessorID != "" && token.AccessorID != original.AccessorID { @@ -495,7 +975,12 @@ func (s *Store) aclTokenGetTxn(tx *memdb.Txn, ws memdb.WatchSet, value, index st ws.Add(watchCh) if rawToken != nil { - token, err := s.fixupTokenPolicyLinks(tx, rawToken.(*structs.ACLToken)) + token := rawToken.(*structs.ACLToken) + token, err := s.fixupTokenPolicyLinks(tx, token) + if err != nil { + return nil, err + } + token, err = s.fixupTokenRoleLinks(tx, token) if err != nil { return nil, err } @@ -506,7 +991,7 @@ func (s *Store) aclTokenGetTxn(tx *memdb.Txn, ws memdb.WatchSet, value, index st } // ACLTokenList is used to list out all of the ACLs in the state store. -func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy string) (uint64, structs.ACLTokens, error) { +func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy, role, methodName string) (uint64, structs.ACLTokens, error) { tx := s.db.Txn(false) defer tx.Abort() @@ -517,41 +1002,63 @@ func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy strin // to false but for defaulted structs (zero values for both) we want it to list out // all tokens so our checks just ensure that global == local - if policy != "" { - iter, err = tx.Get("acl-tokens", "policies", policy) - if err == nil && global != local { - iter = memdb.NewFilterIterator(iter, func(raw interface{}) bool { - token, ok := raw.(*structs.ACLToken) - if !ok { - return true - } - - if global && !token.Local { - return false - } else if local && token.Local { - return false - } - - return true - }) + needLocalityFilter := false + if policy == "" && role == "" && methodName == "" { + if global == local { + iter, err = tx.Get("acl-tokens", "id") + } else if global { + iter, err = tx.Get("acl-tokens", "local", false) + } else { + iter, err = tx.Get("acl-tokens", "local", true) } - } else if global == local { - iter, err = tx.Get("acl-tokens", "id") - } else if global { - iter, err = tx.Get("acl-tokens", "local", false) + + } else if policy != "" && role == "" && methodName == "" { + iter, err = tx.Get("acl-tokens", "policies", policy) + needLocalityFilter = true + + } else if policy == "" && role != "" && methodName == "" { + iter, err = tx.Get("acl-tokens", "roles", role) + needLocalityFilter = true + + } else if policy == "" && role == "" && methodName != "" { + iter, err = tx.Get("acl-tokens", "authmethod", methodName) + needLocalityFilter = true + } else { - iter, err = tx.Get("acl-tokens", "local", true) + return 0, nil, fmt.Errorf("can only filter by one of policy, role, or methodName at a time") } if err != nil { return 0, nil, fmt.Errorf("failed acl token lookup: %v", err) } + + if needLocalityFilter && global != local { + iter = memdb.NewFilterIterator(iter, func(raw interface{}) bool { + token, ok := raw.(*structs.ACLToken) + if !ok { + return true + } + + if global && !token.Local { + return false + } else if local && token.Local { + return false + } + + return true + }) + } + ws.Add(iter.WatchCh()) var result structs.ACLTokens for raw := iter.Next(); raw != nil; raw = iter.Next() { - token, err := s.fixupTokenPolicyLinks(tx, raw.(*structs.ACLToken)) - + token := raw.(*structs.ACLToken) + token, err := s.fixupTokenPolicyLinks(tx, token) + if err != nil { + return 0, nil, err + } + token, err = s.fixupTokenRoleLinks(tx, token) if err != nil { return 0, nil, err } @@ -586,6 +1093,62 @@ func (s *Store) ACLTokenListUpgradeable(max int) (structs.ACLTokens, <-chan stru return tokens, iter.WatchCh(), nil } +func (s *Store) ACLTokenMinExpirationTime(local bool) (time.Time, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + item, err := tx.First("acl-tokens", s.expiresIndexName(local)) + if err != nil { + return time.Time{}, fmt.Errorf("failed acl token listing: %v", err) + } + + if item == nil { + return time.Time{}, nil + } + + token := item.(*structs.ACLToken) + + return *token.ExpirationTime, nil +} + +// ACLTokenListExpires lists tokens that are expired as of the provided time. +// The returned set will be no larger than the max value provided. +func (s *Store) ACLTokenListExpired(local bool, asOf time.Time, max int) (structs.ACLTokens, <-chan struct{}, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + iter, err := tx.Get("acl-tokens", s.expiresIndexName(local)) + if err != nil { + return nil, nil, fmt.Errorf("failed acl token listing: %v", err) + } + + var ( + tokens structs.ACLTokens + i int + ) + for raw := iter.Next(); raw != nil; raw = iter.Next() { + token := raw.(*structs.ACLToken) + if token.ExpirationTime != nil && !token.ExpirationTime.Before(asOf) { + return tokens, nil, nil + } + + tokens = append(tokens, token) + i += 1 + if i >= max { + return tokens, nil, nil + } + } + + return tokens, iter.WatchCh(), nil +} + +func (s *Store) expiresIndexName(local bool) string { + if local { + return "expires-local" + } + return "expires-global" +} + // ACLTokenDeleteBySecret is used to remove an existing ACL from the state store. If // the ACL does not exist this is a no-op and no error is returned. func (s *Store) ACLTokenDeleteBySecret(idx uint64, secret string) error { @@ -648,6 +1211,35 @@ func (s *Store) aclTokenDeleteTxn(tx *memdb.Txn, idx uint64, value, index string return nil } +func (s *Store) aclTokenDeleteAllForAuthMethodTxn(tx *memdb.Txn, idx uint64, methodName string) error { + // collect them all + iter, err := tx.Get("acl-tokens", "authmethod", methodName) + if err != nil { + return fmt.Errorf("failed acl token lookup: %v", err) + } + + var tokens structs.ACLTokens + for raw := iter.Next(); raw != nil; raw = iter.Next() { + token := raw.(*structs.ACLToken) + tokens = append(tokens, token) + } + + if len(tokens) > 0 { + // delete them all + for _, token := range tokens { + if err := tx.Delete("acl-tokens", token); err != nil { + return fmt.Errorf("failed deleting acl token: %v", err) + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-tokens"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + } + + return nil +} + func (s *Store) ACLPolicyBatchSet(idx uint64, policies structs.ACLPolicies) error { tx := s.db.Txn(true) defer tx.Abort() @@ -876,3 +1468,621 @@ func (s *Store) aclPolicyDeleteTxn(tx *memdb.Txn, idx uint64, value, index strin } return nil } + +func (s *Store) ACLRoleBatchSet(idx uint64, roles structs.ACLRoles) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, role := range roles { + if err := s.aclRoleSetTxn(tx, idx, role, true); err != nil { + return err + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-roles"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) ACLRoleSet(idx uint64, role *structs.ACLRole) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclRoleSetTxn(tx, idx, role, false); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-roles"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclRoleSetTxn(tx *memdb.Txn, idx uint64, role *structs.ACLRole, allowMissing bool) error { + // Check that the ID is set + if role.ID == "" { + return ErrMissingACLRoleID + } + + if role.Name == "" { + return ErrMissingACLRoleName + } + + existing, err := tx.First("acl-roles", "id", role.ID) + if err != nil { + return fmt.Errorf("failed acl role lookup: %v", err) + } + + // ensure the name is unique (cannot conflict with another role with a different ID) + nameMatch, err := tx.First("acl-roles", "name", role.Name) + if err != nil { + return fmt.Errorf("failed acl role lookup: %v", err) + } + if nameMatch != nil && role.ID != nameMatch.(*structs.ACLRole).ID { + return fmt.Errorf("A role with name %q already exists", role.Name) + } + + if err := s.resolveRolePolicyLinks(tx, role, allowMissing); err != nil { + return err + } + + for _, svcid := range role.ServiceIdentities { + if svcid.ServiceName == "" { + return fmt.Errorf("Encountered a Role with an empty service identity name in the state store") + } + } + + // Set the indexes + if existing != nil { + role.CreateIndex = existing.(*structs.ACLRole).CreateIndex + role.ModifyIndex = idx + } else { + role.CreateIndex = idx + role.ModifyIndex = idx + } + + if err := tx.Insert("acl-roles", role); err != nil { + return fmt.Errorf("failed inserting acl role: %v", err) + } + return nil +} + +func (s *Store) ACLRoleGetByID(ws memdb.WatchSet, id string) (uint64, *structs.ACLRole, error) { + return s.aclRoleGet(ws, id, "id") +} + +func (s *Store) ACLRoleGetByName(ws memdb.WatchSet, name string) (uint64, *structs.ACLRole, error) { + return s.aclRoleGet(ws, name, "name") +} + +func (s *Store) ACLRoleBatchGet(ws memdb.WatchSet, ids []string) (uint64, structs.ACLRoles, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + roles := make(structs.ACLRoles, 0, len(ids)) + for _, rid := range ids { + role, err := s.getRoleWithTxn(tx, ws, rid, "id") + if err != nil { + return 0, nil, err + } + + if role != nil { + roles = append(roles, role) + } + } + + idx := maxIndexTxn(tx, "acl-roles") + + return idx, roles, nil +} + +func (s *Store) getRoleWithTxn(tx *memdb.Txn, ws memdb.WatchSet, value, index string) (*structs.ACLRole, error) { + watchCh, rawRole, err := tx.FirstWatch("acl-roles", index, value) + if err != nil { + return nil, fmt.Errorf("failed acl role lookup: %v", err) + } + ws.Add(watchCh) + + if rawRole != nil { + role := rawRole.(*structs.ACLRole) + role, err := s.fixupRolePolicyLinks(tx, role) + if err != nil { + return nil, err + } + return role, nil + } + + return nil, nil +} + +func (s *Store) aclRoleGet(ws memdb.WatchSet, value, index string) (uint64, *structs.ACLRole, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + role, err := s.getRoleWithTxn(tx, ws, value, index) + if err != nil { + return 0, nil, err + } + + idx := maxIndexTxn(tx, "acl-roles") + + return idx, role, nil +} + +func (s *Store) ACLRoleList(ws memdb.WatchSet, policy string) (uint64, structs.ACLRoles, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + var iter memdb.ResultIterator + var err error + + if policy != "" { + iter, err = tx.Get("acl-roles", "policies", policy) + } else { + iter, err = tx.Get("acl-roles", "id") + } + + if err != nil { + return 0, nil, fmt.Errorf("failed acl role lookup: %v", err) + } + ws.Add(iter.WatchCh()) + + var result structs.ACLRoles + for raw := iter.Next(); raw != nil; raw = iter.Next() { + role := raw.(*structs.ACLRole) + role, err := s.fixupRolePolicyLinks(tx, role) + if err != nil { + return 0, nil, err + } + result = append(result, role) + } + + // Get the table index. + idx := maxIndexTxn(tx, "acl-roles") + + return idx, result, nil +} + +func (s *Store) ACLRoleDeleteByID(idx uint64, id string) error { + return s.aclRoleDelete(idx, id, "id") +} + +func (s *Store) ACLRoleDeleteByName(idx uint64, name string) error { + return s.aclRoleDelete(idx, name, "name") +} + +func (s *Store) ACLRoleBatchDelete(idx uint64, roleIDs []string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, roleID := range roleIDs { + if err := s.aclRoleDeleteTxn(tx, idx, roleID, "id"); err != nil { + return err + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-roles"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + tx.Commit() + return nil +} + +func (s *Store) aclRoleDelete(idx uint64, value, index string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclRoleDeleteTxn(tx, idx, value, index); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-roles"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclRoleDeleteTxn(tx *memdb.Txn, idx uint64, value, index string) error { + // Look up the existing role + rawRole, err := tx.First("acl-roles", index, value) + if err != nil { + return fmt.Errorf("failed acl role lookup: %v", err) + } + + if rawRole == nil { + return nil + } + + role := rawRole.(*structs.ACLRole) + + if err := tx.Delete("acl-roles", role); err != nil { + return fmt.Errorf("failed deleting acl role: %v", err) + } + return nil +} + +func (s *Store) ACLBindingRuleBatchSet(idx uint64, rules structs.ACLBindingRules) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, rule := range rules { + if err := s.aclBindingRuleSetTxn(tx, idx, rule); err != nil { + return err + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) ACLBindingRuleSet(idx uint64, rule *structs.ACLBindingRule) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclBindingRuleSetTxn(tx, idx, rule); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclBindingRuleSetTxn(tx *memdb.Txn, idx uint64, rule *structs.ACLBindingRule) error { + // Check that the ID and AuthMethod are set + if rule.ID == "" { + return ErrMissingACLBindingRuleID + } else if rule.AuthMethod == "" { + return ErrMissingACLBindingRuleAuthMethod + } + + existing, err := tx.First("acl-binding-rules", "id", rule.ID) + if err != nil { + return fmt.Errorf("failed acl binding rule lookup: %v", err) + } + + // Set the indexes + if existing != nil { + rule.CreateIndex = existing.(*structs.ACLBindingRule).CreateIndex + rule.ModifyIndex = idx + } else { + rule.CreateIndex = idx + rule.ModifyIndex = idx + } + + if method, err := tx.First("acl-auth-methods", "id", rule.AuthMethod); err != nil { + return fmt.Errorf("failed acl auth method lookup: %v", err) + } else if method == nil { + return fmt.Errorf("failed inserting acl binding rule: auth method not found") + } + + if err := tx.Insert("acl-binding-rules", rule); err != nil { + return fmt.Errorf("failed inserting acl binding rule: %v", err) + } + return nil +} + +func (s *Store) ACLBindingRuleGetByID(ws memdb.WatchSet, id string) (uint64, *structs.ACLBindingRule, error) { + return s.aclBindingRuleGet(ws, id, "id") +} + +func (s *Store) aclBindingRuleGet(ws memdb.WatchSet, value, index string) (uint64, *structs.ACLBindingRule, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + watchCh, rawRule, err := tx.FirstWatch("acl-binding-rules", index, value) + if err != nil { + return 0, nil, fmt.Errorf("failed acl binding rule lookup: %v", err) + } + ws.Add(watchCh) + + var rule *structs.ACLBindingRule + if rawRule != nil { + rule = rawRule.(*structs.ACLBindingRule) + } + + idx := maxIndexTxn(tx, "acl-binding-rules") + + return idx, rule, nil +} + +func (s *Store) ACLBindingRuleList(ws memdb.WatchSet, methodName string) (uint64, structs.ACLBindingRules, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + var ( + iter memdb.ResultIterator + err error + ) + if methodName != "" { + iter, err = tx.Get("acl-binding-rules", "authmethod", methodName) + } else { + iter, err = tx.Get("acl-binding-rules", "id") + } + if err != nil { + return 0, nil, fmt.Errorf("failed acl binding rule lookup: %v", err) + } + ws.Add(iter.WatchCh()) + + var result structs.ACLBindingRules + for raw := iter.Next(); raw != nil; raw = iter.Next() { + rule := raw.(*structs.ACLBindingRule) + result = append(result, rule) + } + + // Get the table index. + idx := maxIndexTxn(tx, "acl-binding-rules") + + return idx, result, nil +} + +func (s *Store) ACLBindingRuleDeleteByID(idx uint64, id string) error { + return s.aclBindingRuleDelete(idx, id, "id") +} + +func (s *Store) ACLBindingRuleBatchDelete(idx uint64, bindingRuleIDs []string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, bindingRuleID := range bindingRuleIDs { + s.aclBindingRuleDeleteTxn(tx, idx, bindingRuleID, "id") + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + tx.Commit() + return nil +} + +func (s *Store) aclBindingRuleDelete(idx uint64, value, index string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclBindingRuleDeleteTxn(tx, idx, value, index); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclBindingRuleDeleteTxn(tx *memdb.Txn, idx uint64, value, index string) error { + // Look up the existing binding rule + rawRule, err := tx.First("acl-binding-rules", index, value) + if err != nil { + return fmt.Errorf("failed acl binding rule lookup: %v", err) + } + + if rawRule == nil { + return nil + } + + rule := rawRule.(*structs.ACLBindingRule) + + if err := tx.Delete("acl-binding-rules", rule); err != nil { + return fmt.Errorf("failed deleting acl binding rule: %v", err) + } + return nil +} + +func (s *Store) aclBindingRuleDeleteAllForAuthMethodTxn(tx *memdb.Txn, idx uint64, methodName string) error { + // collect them all + iter, err := tx.Get("acl-binding-rules", "authmethod", methodName) + if err != nil { + return fmt.Errorf("failed acl binding rule lookup: %v", err) + } + + var rules structs.ACLBindingRules + for raw := iter.Next(); raw != nil; raw = iter.Next() { + rule := raw.(*structs.ACLBindingRule) + rules = append(rules, rule) + } + + if len(rules) > 0 { + // delete them all + for _, rule := range rules { + if err := tx.Delete("acl-binding-rules", rule); err != nil { + return fmt.Errorf("failed deleting acl binding rule: %v", err) + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + } + + return nil +} + +func (s *Store) ACLAuthMethodBatchSet(idx uint64, methods structs.ACLAuthMethods) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, method := range methods { + // this is only used when doing batch insertions for upgrades and replication. Therefore + // we take whatever those said. + if err := s.aclAuthMethodSetTxn(tx, idx, method); err != nil { + return err + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) ACLAuthMethodSet(idx uint64, method *structs.ACLAuthMethod) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclAuthMethodSetTxn(tx, idx, method); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclAuthMethodSetTxn(tx *memdb.Txn, idx uint64, method *structs.ACLAuthMethod) error { + // Check that the Name and Type are set + if method.Name == "" { + return ErrMissingACLAuthMethodName + } else if method.Type == "" { + return ErrMissingACLAuthMethodType + } + + existing, err := tx.First("acl-auth-methods", "id", method.Name) + if err != nil { + return fmt.Errorf("failed acl auth method lookup: %v", err) + } + + // Set the indexes + if existing != nil { + method.CreateIndex = existing.(*structs.ACLAuthMethod).CreateIndex + method.ModifyIndex = idx + } else { + method.CreateIndex = idx + method.ModifyIndex = idx + } + + if err := tx.Insert("acl-auth-methods", method); err != nil { + return fmt.Errorf("failed inserting acl auth method: %v", err) + } + return nil +} + +func (s *Store) ACLAuthMethodGetByName(ws memdb.WatchSet, name string) (uint64, *structs.ACLAuthMethod, error) { + return s.aclAuthMethodGet(ws, name, "id") +} + +func (s *Store) aclAuthMethodGet(ws memdb.WatchSet, value, index string) (uint64, *structs.ACLAuthMethod, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + method, err := s.getAuthMethodWithTxn(tx, ws, value, index) + if err != nil { + return 0, nil, err + } + + idx := maxIndexTxn(tx, "acl-auth-methods") + + return idx, method, nil +} + +func (s *Store) getAuthMethodWithTxn(tx *memdb.Txn, ws memdb.WatchSet, value, index string) (*structs.ACLAuthMethod, error) { + watchCh, rawMethod, err := tx.FirstWatch("acl-auth-methods", index, value) + if err != nil { + return nil, fmt.Errorf("failed acl auth method lookup: %v", err) + } + ws.Add(watchCh) + + if rawMethod != nil { + return rawMethod.(*structs.ACLAuthMethod), nil + } + + return nil, nil +} + +func (s *Store) ACLAuthMethodList(ws memdb.WatchSet) (uint64, structs.ACLAuthMethods, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + iter, err := tx.Get("acl-auth-methods", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed acl auth method lookup: %v", err) + } + ws.Add(iter.WatchCh()) + + var result structs.ACLAuthMethods + for raw := iter.Next(); raw != nil; raw = iter.Next() { + method := raw.(*structs.ACLAuthMethod) + result = append(result, method) + } + + // Get the table index. + idx := maxIndexTxn(tx, "acl-auth-methods") + + return idx, result, nil +} + +func (s *Store) ACLAuthMethodDeleteByName(idx uint64, name string) error { + return s.aclAuthMethodDelete(idx, name, "id") +} + +func (s *Store) ACLAuthMethodBatchDelete(idx uint64, names []string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, name := range names { + s.aclAuthMethodDeleteTxn(tx, idx, name, "id") + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + tx.Commit() + return nil +} + +func (s *Store) aclAuthMethodDelete(idx uint64, value, index string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclAuthMethodDeleteTxn(tx, idx, value, index); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclAuthMethodDeleteTxn(tx *memdb.Txn, idx uint64, value, index string) error { + // Look up the existing method + rawMethod, err := tx.First("acl-auth-methods", index, value) + if err != nil { + return fmt.Errorf("failed acl auth method lookup: %v", err) + } + + if rawMethod == nil { + return nil + } + + method := rawMethod.(*structs.ACLAuthMethod) + + if err := s.aclBindingRuleDeleteAllForAuthMethodTxn(tx, idx, method.Name); err != nil { + return err + } + + if err := s.aclTokenDeleteAllForAuthMethodTxn(tx, idx, method.Name); err != nil { + return err + } + + if err := tx.Delete("acl-auth-methods", method); err != nil { + return fmt.Errorf("failed deleting acl auth method: %v", err) + } + return nil +} diff --git a/agent/consul/state/acl_test.go b/agent/consul/state/acl_test.go index b6546b5bda..7dddbaccb7 100644 --- a/agent/consul/state/acl_test.go +++ b/agent/consul/state/acl_test.go @@ -1,14 +1,30 @@ package state import ( + "fmt" + "math/rand" + "strconv" "testing" "time" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/lib" + memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/go-uuid" "github.com/stretchr/testify/require" ) +const ( + testRoleID_A = "2c74a9b8-271c-4a21-b727-200db397c01c" // from:setupExtraPoliciesAndRoles + testRoleID_B = "aeab6b63-08d1-455a-b85b-3458b462b426" // from:setupExtraPoliciesAndRoles + testPolicyID_A = "a0625e95-9b3e-42de-a8d6-ceef5b6f3286" // from:setupExtraPolicies + testPolicyID_B = "9386ecae-6677-4686-bcd4-5ab9d86cca1d" // from:setupExtraPolicies + testPolicyID_C = "2bf7359d-cfde-4769-a9fa-54ff1bb2ae4c" // from:setupExtraPolicies + testPolicyID_D = "ff807410-2b82-48ae-9a63-6626a90789d0" // from:setupExtraPolicies + testPolicyID_E = "b4635d48-90aa-4a77-8e1b-9004f68bb3df" // from:setupExtraPolicies +) + func setupGlobalManagement(t *testing.T, s *Store) { policy := structs.ACLPolicy{ ID: structs.ACLPolicyGlobalManagementID, @@ -38,15 +54,54 @@ func testACLStateStore(t *testing.T) *Store { return s } +func setupExtraAuthMethods(t *testing.T, s *Store) { + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + }, + } + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) +} + func setupExtraPolicies(t *testing.T, s *Store) { policies := structs.ACLPolicies{ &structs.ACLPolicy{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, Name: "node-read", Description: "Allows reading all node information", Rules: `node_prefix "" { policy = "read" }`, Syntax: acl.SyntaxCurrent, }, + &structs.ACLPolicy{ + ID: testPolicyID_B, + Name: "agent-read", + Description: "Allows reading all node information", + Rules: `agent_prefix "" { policy = "read" }`, + Syntax: acl.SyntaxCurrent, + }, + &structs.ACLPolicy{ + ID: testPolicyID_C, + Name: "acl-read", + Description: "Allows acl read", + Rules: `acl = "read"`, + Syntax: acl.SyntaxCurrent, + }, + &structs.ACLPolicy{ + ID: testPolicyID_D, + Name: "acl-write", + Description: "Allows acl write", + Rules: `acl = "write"`, + Syntax: acl.SyntaxCurrent, + }, + &structs.ACLPolicy{ + ID: testPolicyID_E, + Name: "kv-read", + Description: "Allows kv read", + Rules: `key_prefix "" { policy = "read" }`, + Syntax: acl.SyntaxCurrent, + }, } for _, policy := range policies { @@ -56,7 +111,46 @@ func setupExtraPolicies(t *testing.T, s *Store) { require.NoError(t, s.ACLPolicyBatchSet(2, policies)) } +func setupExtraPoliciesAndRoles(t *testing.T, s *Store) { + setupExtraPolicies(t, s) + + roles := structs.ACLRoles{ + &structs.ACLRole{ + ID: testRoleID_A, + Name: "node-read-role", + Description: "Allows reading all node information", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + }, + &structs.ACLRole{ + ID: testRoleID_B, + Name: "agent-read-role", + Description: "Allows reading all agent information", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_B, + }, + }, + }, + } + + for _, role := range roles { + role.SetHash(true) + } + + require.NoError(t, s.ACLRoleBatchSet(3, roles)) +} + func testACLTokensStateStore(t *testing.T) *Store { + s := testACLStateStore(t) + setupExtraPoliciesAndRoles(t, s) + return s +} + +func testACLRolesStateStore(t *testing.T) *Store { s := testACLStateStore(t) setupExtraPolicies(t, s) return s @@ -94,23 +188,6 @@ func TestStateStore_ACLBootstrap(t *testing.T) { Type: structs.ACLTokenTypeManagement, } - stripIrrelevantFields := func(token *structs.ACLToken) *structs.ACLToken { - tokenCopy := token.Clone() - // When comparing the tokens disregard the policy link names. This - // data is not cleanly updated in a variety of scenarios and should not - // be relied upon. - for i, _ := range tokenCopy.Policies { - tokenCopy.Policies[i].Name = "" - } - // The raft indexes won't match either because the requester will not - // have access to that. - tokenCopy.RaftIndex = structs.RaftIndex{} - return tokenCopy - } - compareTokens := func(expected, actual *structs.ACLToken) { - require.Equal(t, stripIrrelevantFields(expected), stripIrrelevantFields(actual)) - } - s := testStateStore(t) setupGlobalManagement(t, s) @@ -140,10 +217,10 @@ func TestStateStore_ACLBootstrap(t *testing.T) { require.Equal(t, uint64(3), index) // Make sure the ACLs are in an expected state. - _, tokens, err := s.ACLTokenList(nil, true, true, "") + _, tokens, err := s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) require.Len(t, tokens, 1) - compareTokens(token1, tokens[0]) + compareTokens(t, token1, tokens[0]) // bootstrap reset err = s.ACLBootstrap(32, index-1, token2.Clone(), false) @@ -154,7 +231,7 @@ func TestStateStore_ACLBootstrap(t *testing.T) { err = s.ACLBootstrap(32, index, token2.Clone(), false) require.NoError(t, err) - _, tokens, err = s.ACLTokenList(nil, true, true, "") + _, tokens, err = s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) require.Len(t, tokens, 2) } @@ -170,15 +247,15 @@ func TestStateStore_ACLToken_SetGet_Legacy(t *testing.T) { SecretID: "6d48ce91-2558-4098-bdab-8737e4e57d5f", Policies: []structs.ACLTokenPolicyLink{ structs.ACLTokenPolicyLink{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, }, }, } - require.NoError(t, s.ACLTokenSet(2, token, false)) + require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) // legacy flag is set so it should disallow setting this token - err := s.ACLTokenSet(3, token, true) + err := s.ACLTokenSet(3, token.Clone(), true) require.Error(t, err) }) @@ -190,10 +267,10 @@ func TestStateStore_ACLToken_SetGet_Legacy(t *testing.T) { SecretID: "c0056225-5785-43b3-9b77-3954f06d6aee", } - require.NoError(t, s.ACLTokenSet(2, token, false)) + require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) // legacy flag is set so it should disallow setting this token - err := s.ACLTokenSet(3, token, true) + err := s.ACLTokenSet(3, token.Clone(), true) require.Error(t, err) }) @@ -206,7 +283,7 @@ func TestStateStore_ACLToken_SetGet_Legacy(t *testing.T) { Rules: `service "" { policy = "read" }`, } - require.NoError(t, s.ACLTokenSet(2, token, true)) + require.NoError(t, s.ACLTokenSet(2, token.Clone(), true)) idx, rtoken, err := s.ACLTokenGetBySecret(nil, token.SecretID) require.NoError(t, err) @@ -230,7 +307,7 @@ func TestStateStore_ACLToken_SetGet_Legacy(t *testing.T) { Rules: `service "" { policy = "read" }`, } - require.NoError(t, s.ACLTokenSet(2, original, true)) + require.NoError(t, s.ACLTokenSet(2, original.Clone(), true)) updatedRules := `service "" { policy = "read" } service "foo" { policy = "deny"}` update := &structs.ACLToken{ @@ -239,7 +316,7 @@ func TestStateStore_ACLToken_SetGet_Legacy(t *testing.T) { Rules: updatedRules, } - require.NoError(t, s.ACLTokenSet(3, update, true)) + require.NoError(t, s.ACLTokenSet(3, update.Clone(), true)) idx, rtoken, err := s.ACLTokenGetBySecret(nil, original.SecretID) require.NoError(t, err) @@ -265,7 +342,7 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { AccessorID: "39171632-6f34-4411-827f-9416403687f4", } - err := s.ACLTokenSet(2, token, false) + err := s.ACLTokenSet(2, token.Clone(), false) require.Error(t, err) require.Equal(t, ErrMissingACLTokenSecret, err) }) @@ -277,11 +354,43 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { SecretID: "39171632-6f34-4411-827f-9416403687f4", } - err := s.ACLTokenSet(2, token, false) + err := s.ACLTokenSet(2, token.Clone(), false) require.Error(t, err) require.Equal(t, ErrMissingACLTokenAccessor, err) }) + t.Run("Missing Service Identity Fields", func(t *testing.T) { + t.Parallel() + s := testACLTokensStateStore(t) + token := &structs.ACLToken{ + AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", + SecretID: "39171632-6f34-4411-827f-9416403687f4", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{}, + }, + } + + err := s.ACLTokenSet(2, token, false) + require.Error(t, err) + }) + + t.Run("Missing Service Identity Name", func(t *testing.T) { + t.Parallel() + s := testACLTokensStateStore(t) + token := &structs.ACLToken{ + AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", + SecretID: "39171632-6f34-4411-827f-9416403687f4", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + Datacenters: []string{"dc1"}, + }, + }, + } + + err := s.ACLTokenSet(2, token, false) + require.Error(t, err) + }) + t.Run("Missing Policy ID", func(t *testing.T) { t.Parallel() s := testACLTokensStateStore(t) @@ -295,6 +404,23 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { }, } + err := s.ACLTokenSet(2, token.Clone(), false) + require.Error(t, err) + }) + + t.Run("Missing Role ID", func(t *testing.T) { + t.Parallel() + s := testACLTokensStateStore(t) + token := &structs.ACLToken{ + AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", + SecretID: "39171632-6f34-4411-827f-9416403687f4", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + Name: "no-id", + }, + }, + } + err := s.ACLTokenSet(2, token, false) require.Error(t, err) }) @@ -312,6 +438,36 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { }, } + err := s.ACLTokenSet(2, token.Clone(), false) + require.Error(t, err) + }) + + t.Run("Unresolvable Role ID", func(t *testing.T) { + t.Parallel() + s := testACLTokensStateStore(t) + token := &structs.ACLToken{ + AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", + SecretID: "39171632-6f34-4411-827f-9416403687f4", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: "9b2349b6-55d3-4901-b287-347ae725af2f", + }, + }, + } + + err := s.ACLTokenSet(2, token, false) + require.Error(t, err) + }) + + t.Run("Unresolvable AuthMethod", func(t *testing.T) { + t.Parallel() + s := testACLTokensStateStore(t) + token := &structs.ACLToken{ + AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", + SecretID: "39171632-6f34-4411-827f-9416403687f4", + AuthMethod: "test", + } + err := s.ACLTokenSet(2, token, false) require.Error(t, err) }) @@ -324,22 +480,35 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { SecretID: "39171632-6f34-4411-827f-9416403687f4", Policies: []structs.ACLTokenPolicyLink{ structs.ACLTokenPolicyLink{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, + }, + }, + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: testRoleID_A, + }, + }, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "web", }, }, } - require.NoError(t, s.ACLTokenSet(2, token, false)) + require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) idx, rtoken, err := s.ACLTokenGetByAccessor(nil, "daf37c07-d04d-4fd5-9678-a8206a57d61a") require.NoError(t, err) require.Equal(t, uint64(2), idx) - // pointer equality - require.True(t, rtoken == token) + compareTokens(t, token, rtoken) require.Equal(t, uint64(2), rtoken.CreateIndex) require.Equal(t, uint64(2), rtoken.ModifyIndex) require.Len(t, rtoken.Policies, 1) require.Equal(t, "node-read", rtoken.Policies[0].Name) + require.Len(t, rtoken.Roles, 1) + require.Equal(t, "node-read-role", rtoken.Roles[0].Name) + require.Len(t, rtoken.ServiceIdentities, 1) + require.Equal(t, "web", rtoken.ServiceIdentities[0].ServiceName) }) t.Run("Update", func(t *testing.T) { @@ -350,12 +519,17 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { SecretID: "39171632-6f34-4411-827f-9416403687f4", Policies: []structs.ACLTokenPolicyLink{ structs.ACLTokenPolicyLink{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, + }, + }, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "web", }, }, } - require.NoError(t, s.ACLTokenSet(2, token, false)) + require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) updated := &structs.ACLToken{ AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", @@ -365,20 +539,65 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { ID: structs.ACLPolicyGlobalManagementID, }, }, + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: testRoleID_A, + }, + }, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "db", + }, + }, } - require.NoError(t, s.ACLTokenSet(3, updated, false)) + require.NoError(t, s.ACLTokenSet(3, updated.Clone(), false)) idx, rtoken, err := s.ACLTokenGetByAccessor(nil, "daf37c07-d04d-4fd5-9678-a8206a57d61a") require.NoError(t, err) require.Equal(t, uint64(3), idx) - // pointer equality - require.True(t, rtoken == updated) + compareTokens(t, updated, rtoken) require.Equal(t, uint64(2), rtoken.CreateIndex) require.Equal(t, uint64(3), rtoken.ModifyIndex) require.Len(t, rtoken.Policies, 1) require.Equal(t, structs.ACLPolicyGlobalManagementID, rtoken.Policies[0].ID) require.Equal(t, "global-management", rtoken.Policies[0].Name) + require.Len(t, rtoken.Roles, 1) + require.Equal(t, testRoleID_A, rtoken.Roles[0].ID) + require.Equal(t, "node-read-role", rtoken.Roles[0].Name) + require.Len(t, rtoken.ServiceIdentities, 1) + require.Equal(t, "db", rtoken.ServiceIdentities[0].ServiceName) + }) + + t.Run("New with auth method", func(t *testing.T) { + t.Parallel() + s := testACLTokensStateStore(t) + setupExtraAuthMethods(t, s) + + token := &structs.ACLToken{ + AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", + SecretID: "39171632-6f34-4411-827f-9416403687f4", + AuthMethod: "test", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: testRoleID_A, + }, + }, + } + + require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) + + idx, rtoken, err := s.ACLTokenGetByAccessor(nil, "daf37c07-d04d-4fd5-9678-a8206a57d61a") + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + compareTokens(t, token, rtoken) + require.Equal(t, uint64(2), rtoken.CreateIndex) + require.Equal(t, uint64(2), rtoken.ModifyIndex) + require.Equal(t, "test", rtoken.AuthMethod) + require.Len(t, rtoken.Policies, 0) + require.Len(t, rtoken.ServiceIdentities, 0) + require.Len(t, rtoken.Roles, 1) + require.Equal(t, "node-read-role", rtoken.Roles[0].Name) }) } @@ -524,7 +743,7 @@ func TestStateStore_ACLTokens_UpsertBatchRead(t *testing.T) { Description: "first token", Policies: []structs.ACLTokenPolicyLink{ structs.ACLTokenPolicyLink{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, }, }, }, @@ -563,7 +782,7 @@ func TestStateStore_ACLTokens_UpsertBatchRead(t *testing.T) { require.Equal(t, "00ff4564-dd96-4d1b-8ad6-578a08279f79", rtokens[1].SecretID) require.Equal(t, "first token", rtokens[1].Description) require.Len(t, rtokens[1].Policies, 1) - require.Equal(t, "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", rtokens[1].Policies[0].ID) + require.Equal(t, testPolicyID_A, rtokens[1].Policies[0].ID) require.Equal(t, "node-read", rtokens[1].Policies[0].Name) require.Equal(t, uint64(2), rtokens[1].CreateIndex) require.Equal(t, uint64(3), rtokens[1].ModifyIndex) @@ -665,6 +884,7 @@ func TestStateStore_ACLTokens_ListUpgradeable(t *testing.T) { func TestStateStore_ACLToken_List(t *testing.T) { t.Parallel() s := testACLTokensStateStore(t) + setupExtraAuthMethods(t, s) tokens := structs.ACLTokens{ // the local token @@ -694,7 +914,7 @@ func TestStateStore_ACLToken_List(t *testing.T) { SecretID: "548bdb8e-c0d6-477b-bcc4-67fb836e9e61", Policies: []structs.ACLTokenPolicyLink{ structs.ACLTokenPolicyLink{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, }, }, }, @@ -704,93 +924,217 @@ func TestStateStore_ACLToken_List(t *testing.T) { SecretID: "f6998577-fd9b-4e6c-b202-cc3820513d32", Policies: []structs.ACLTokenPolicyLink{ structs.ACLTokenPolicyLink{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, }, }, Local: true, }, + // the role specific token + &structs.ACLToken{ + AccessorID: "a7715fde-8954-4c92-afbc-d84c6ecdc582", + SecretID: "77a2da3a-b479-4025-a83e-bd6b859f0cfe", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: testRoleID_A, + }, + }, + }, + // the role specific token and local + &structs.ACLToken{ + AccessorID: "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", + SecretID: "c432d12b-3c86-4628-b74f-94ddfc7fb3ba", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: testRoleID_A, + }, + }, + Local: true, + }, + // the method specific token + &structs.ACLToken{ + AccessorID: "74277ae1-6a9b-4035-b444-2370fe6a2cb5", + SecretID: "ab8ac834-0d35-4cb7-83c3-168203f986cd", + AuthMethod: "test", + }, + // the method specific token and local + &structs.ACLToken{ + AccessorID: "211f0360-ef53-41d3-9d4d-db84396eb6c0", + SecretID: "087a0eb4-366f-4190-ab4c-a4aa3d2562aa", + AuthMethod: "test", + Local: true, + }, } require.NoError(t, s.ACLTokenBatchSet(2, tokens, false)) type testCase struct { - name string - local bool - global bool - policy string - accessors []string + name string + local bool + global bool + policy string + role string + methodName string + accessors []string } cases := []testCase{ { - name: "Global", - local: false, - global: true, - policy: "", + name: "Global", + local: false, + global: true, + policy: "", + role: "", + methodName: "", accessors: []string{ structs.ACLTokenAnonymousID, - "47eea4da-bda1-48a6-901c-3e36d2d9262f", - "54866514-3cf2-4fec-8a8a-710583831834", + "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global + "54866514-3cf2-4fec-8a8a-710583831834", // mgmt + global + "74277ae1-6a9b-4035-b444-2370fe6a2cb5", // authMethod + global + "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global }, }, { - name: "Local", - local: true, - global: false, - policy: "", + name: "Local", + local: true, + global: false, + policy: "", + role: "", + methodName: "", accessors: []string{ - "4915fc9d-3726-4171-b588-6c271f45eecd", - "f1093997-b6c7-496d-bfb8-6b1b1895641b", + "211f0360-ef53-41d3-9d4d-db84396eb6c0", // authMethod + local + "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local + "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local + "f1093997-b6c7-496d-bfb8-6b1b1895641b", // mgmt + local }, }, { - name: "Policy", - local: true, - global: true, - policy: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + name: "Policy", + local: true, + global: true, + policy: testPolicyID_A, + role: "", + methodName: "", accessors: []string{ - "47eea4da-bda1-48a6-901c-3e36d2d9262f", - "4915fc9d-3726-4171-b588-6c271f45eecd", + "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global + "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local }, }, { - name: "Policy - Local", - local: true, - global: false, - policy: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + name: "Policy - Local", + local: true, + global: false, + policy: testPolicyID_A, + role: "", + methodName: "", accessors: []string{ - "4915fc9d-3726-4171-b588-6c271f45eecd", + "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local }, }, { - name: "Policy - Global", - local: false, - global: true, - policy: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + name: "Policy - Global", + local: false, + global: true, + policy: testPolicyID_A, + role: "", + methodName: "", accessors: []string{ - "47eea4da-bda1-48a6-901c-3e36d2d9262f", + "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global }, }, { - name: "All", - local: true, - global: true, - policy: "", + name: "Role", + local: true, + global: true, + policy: "", + role: testRoleID_A, + methodName: "", + accessors: []string{ + "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global + "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local + }, + }, + { + name: "Role - Local", + local: true, + global: false, + policy: "", + role: testRoleID_A, + methodName: "", + accessors: []string{ + "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local + }, + }, + { + name: "Role - Global", + local: false, + global: true, + policy: "", + role: testRoleID_A, + methodName: "", + accessors: []string{ + "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global + }, + }, + { + name: "AuthMethod - Local", + local: true, + global: false, + policy: "", + role: "", + methodName: "test", + accessors: []string{ + "211f0360-ef53-41d3-9d4d-db84396eb6c0", // authMethod + local + }, + }, + { + name: "AuthMethod - Global", + local: false, + global: true, + policy: "", + role: "", + methodName: "test", + accessors: []string{ + "74277ae1-6a9b-4035-b444-2370fe6a2cb5", // authMethod + global + }, + }, + { + name: "All", + local: true, + global: true, + policy: "", + role: "", + methodName: "", accessors: []string{ structs.ACLTokenAnonymousID, - "47eea4da-bda1-48a6-901c-3e36d2d9262f", - "4915fc9d-3726-4171-b588-6c271f45eecd", - "54866514-3cf2-4fec-8a8a-710583831834", - "f1093997-b6c7-496d-bfb8-6b1b1895641b", + "211f0360-ef53-41d3-9d4d-db84396eb6c0", // authMethod + local + "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global + "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local + "54866514-3cf2-4fec-8a8a-710583831834", // mgmt + global + "74277ae1-6a9b-4035-b444-2370fe6a2cb5", // authMethod + global + "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global + "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local + "f1093997-b6c7-496d-bfb8-6b1b1895641b", // mgmt + local }, }, } + for _, tc := range []struct{ policy, role, methodName string }{ + {testPolicyID_A, testRoleID_A, "test"}, + {"", testRoleID_A, "test"}, + {testPolicyID_A, "", "test"}, + {testPolicyID_A, testRoleID_A, ""}, + } { + t.Run(fmt.Sprintf("can't filter on more than one: %s/%s/%s", tc.policy, tc.role, tc.methodName), func(t *testing.T) { + _, _, err := s.ACLTokenList(nil, false, false, tc.policy, tc.role, tc.methodName) + require.Error(t, err) + }) + } + for _, tc := range cases { tc := tc // capture range variable t.Run(tc.name, func(t *testing.T) { t.Parallel() - _, tokens, err := s.ACLTokenList(nil, tc.local, tc.global, tc.policy) + _, tokens, err := s.ACLTokenList(nil, tc.local, tc.global, tc.policy, tc.role, tc.methodName) require.NoError(t, err) require.Len(t, tokens, len(tc.accessors)) tokens.Sort() @@ -817,7 +1161,7 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { SecretID: "548bdb8e-c0d6-477b-bcc4-67fb836e9e61", Policies: []structs.ACLTokenPolicyLink{ structs.ACLTokenPolicyLink{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, }, }, } @@ -833,7 +1177,7 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { // rename the policy renamed := &structs.ACLPolicy{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, Name: "node-read-renamed", Description: "Allows reading all node information", Rules: `node_prefix "" { policy = "read" }`, @@ -851,7 +1195,7 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { require.Equal(t, "node-read-renamed", retrieved.Policies[0].Name) // list tokens without stale links - _, tokens, err := s.ACLTokenList(nil, true, true, "") + _, tokens, err := s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) found := false @@ -885,7 +1229,7 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { require.True(t, found) // delete the policy - require.NoError(t, s.ACLPolicyDeleteByID(4, "a0625e95-9b3e-42de-a8d6-ceef5b6f3286")) + require.NoError(t, s.ACLPolicyDeleteByID(4, testPolicyID_A)) // retrieve the token again _, retrieved, err = s.ACLTokenGetByAccessor(nil, token.AccessorID) @@ -895,7 +1239,7 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { require.Len(t, retrieved.Policies, 0) // list tokens without stale links - _, tokens, err = s.ACLTokenList(nil, true, true, "") + _, tokens, err = s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) found = false @@ -927,6 +1271,135 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { require.True(t, found) } +func TestStateStore_ACLToken_FixupRoleLinks(t *testing.T) { + // This test wants to ensure a couple of things. + // + // 1. Doing a token list/get should never modify the data + // tracked by memdb + // 2. Token list/get operations should return an accurate set + // of role links + t.Parallel() + s := testACLTokensStateStore(t) + + // the role specific token + token := &structs.ACLToken{ + AccessorID: "47eea4da-bda1-48a6-901c-3e36d2d9262f", + SecretID: "548bdb8e-c0d6-477b-bcc4-67fb836e9e61", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: testRoleID_A, + }, + }, + } + + require.NoError(t, s.ACLTokenSet(2, token, false)) + + _, retrieved, err := s.ACLTokenGetByAccessor(nil, token.AccessorID) + require.NoError(t, err) + // pointer equality check these should be identical + require.True(t, token == retrieved) + require.Len(t, retrieved.Roles, 1) + require.Equal(t, "node-read-role", retrieved.Roles[0].Name) + + // rename the role + renamed := &structs.ACLRole{ + ID: testRoleID_A, + Name: "node-read-role-renamed", + Description: "Allows reading all node information", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + } + renamed.SetHash(true) + require.NoError(t, s.ACLRoleSet(3, renamed)) + + // retrieve the token again + _, retrieved, err = s.ACLTokenGetByAccessor(nil, token.AccessorID) + require.NoError(t, err) + // pointer equality check these should be different if we cloned things appropriately + require.True(t, token != retrieved) + require.Len(t, retrieved.Roles, 1) + require.Equal(t, "node-read-role-renamed", retrieved.Roles[0].Name) + + // list tokens without stale links + _, tokens, err := s.ACLTokenList(nil, true, true, "", "", "") + require.NoError(t, err) + + found := false + for _, tok := range tokens { + if tok.AccessorID == token.AccessorID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, tok != token) + require.Len(t, tok.Roles, 1) + require.Equal(t, "node-read-role-renamed", tok.Roles[0].Name) + found = true + break + } + } + require.True(t, found) + + // batch get without stale links + _, tokens, err = s.ACLTokenBatchGet(nil, []string{token.AccessorID}) + require.NoError(t, err) + + found = false + for _, tok := range tokens { + if tok.AccessorID == token.AccessorID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, tok != token) + require.Len(t, tok.Roles, 1) + require.Equal(t, "node-read-role-renamed", tok.Roles[0].Name) + found = true + break + } + } + require.True(t, found) + + // delete the role + require.NoError(t, s.ACLRoleDeleteByID(4, testRoleID_A)) + + // retrieve the token again + _, retrieved, err = s.ACLTokenGetByAccessor(nil, token.AccessorID) + require.NoError(t, err) + // pointer equality check these should be different if we cloned things appropriately + require.True(t, token != retrieved) + require.Len(t, retrieved.Roles, 0) + + // list tokens without stale links + _, tokens, err = s.ACLTokenList(nil, true, true, "", "", "") + require.NoError(t, err) + + found = false + for _, tok := range tokens { + if tok.AccessorID == token.AccessorID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, tok != token) + require.Len(t, tok.Roles, 0) + found = true + break + } + } + require.True(t, found) + + // batch get without stale links + _, tokens, err = s.ACLTokenBatchGet(nil, []string{token.AccessorID}) + require.NoError(t, err) + + found = false + for _, tok := range tokens { + if tok.AccessorID == token.AccessorID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, tok != token) + require.Len(t, tok.Roles, 0) + found = true + break + } + } + require.True(t, found) +} + func TestStateStore_ACLToken_Delete(t *testing.T) { t.Parallel() @@ -945,7 +1418,7 @@ func TestStateStore_ACLToken_Delete(t *testing.T) { Local: true, } - require.NoError(t, s.ACLTokenSet(2, token, false)) + require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) _, rtoken, err := s.ACLTokenGetByAccessor(nil, "f1093997-b6c7-496d-bfb8-6b1b1895641b") require.NoError(t, err) @@ -973,7 +1446,7 @@ func TestStateStore_ACLToken_Delete(t *testing.T) { Local: true, } - require.NoError(t, s.ACLTokenSet(2, token, false)) + require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) _, rtoken, err := s.ACLTokenGetByAccessor(nil, "f1093997-b6c7-496d-bfb8-6b1b1895641b") require.NoError(t, err) @@ -1073,7 +1546,7 @@ func TestStateStore_ACLPolicy_SetGet(t *testing.T) { s := testACLStateStore(t) policy := structs.ACLPolicy{ - ID: "2c74a9b8-271c-4a21-b727-200db397c01c", + ID: testRoleID_A, Description: "test", Rules: `keyring = "write"`, } @@ -1141,7 +1614,7 @@ func TestStateStore_ACLPolicy_SetGet(t *testing.T) { s := testACLStateStore(t) policy := structs.ACLPolicy{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + ID: testPolicyID_A, Name: "node-read", Description: "Allows reading all node information", Rules: `node_prefix "" { policy = "read" }`, @@ -1151,7 +1624,7 @@ func TestStateStore_ACLPolicy_SetGet(t *testing.T) { require.NoError(t, s.ACLPolicySet(3, &policy)) - idx, rpolicy, err := s.ACLPolicyGetByID(nil, "a0625e95-9b3e-42de-a8d6-ceef5b6f3286") + idx, rpolicy, err := s.ACLPolicyGetByID(nil, testPolicyID_A) require.Equal(t, uint64(3), idx) require.NoError(t, err) require.NotNil(t, rpolicy) @@ -1183,8 +1656,8 @@ func TestStateStore_ACLPolicy_SetGet(t *testing.T) { // this creates the node read policy which we can update s := testACLTokensStateStore(t) - update := structs.ACLPolicy{ - ID: "a0625e95-9b3e-42de-a8d6-ceef5b6f3286", + update := &structs.ACLPolicy{ + ID: testPolicyID_A, Name: "node-read-modified", Description: "Modified", Rules: `node_prefix "" { policy = "read" } node "secret" { policy = "deny" }`, @@ -1192,19 +1665,29 @@ func TestStateStore_ACLPolicy_SetGet(t *testing.T) { Datacenters: []string{"dc1", "dc2"}, } - require.NoError(t, s.ACLPolicySet(3, &update)) + require.NoError(t, s.ACLPolicySet(3, update.Clone())) - idx, rpolicy, err := s.ACLPolicyGetByID(nil, "a0625e95-9b3e-42de-a8d6-ceef5b6f3286") + expect := update.Clone() + expect.CreateIndex = 2 + expect.ModifyIndex = 3 + + // policy found via id + idx, rpolicy, err := s.ACLPolicyGetByID(nil, testPolicyID_A) + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.Equal(t, expect, rpolicy) + + // policy no longer found via old name + idx, rpolicy, err = s.ACLPolicyGetByName(nil, "node-read") require.Equal(t, uint64(3), idx) require.NoError(t, err) - require.NotNil(t, rpolicy) - require.Equal(t, "node-read-modified", rpolicy.Name) - require.Equal(t, "Modified", rpolicy.Description) - require.Equal(t, `node_prefix "" { policy = "read" } node "secret" { policy = "deny" }`, rpolicy.Rules) - require.Equal(t, acl.SyntaxCurrent, rpolicy.Syntax) - require.ElementsMatch(t, []string{"dc1", "dc2"}, rpolicy.Datacenters) - require.Equal(t, uint64(2), rpolicy.CreateIndex) - require.Equal(t, uint64(3), rpolicy.ModifyIndex) + require.Nil(t, rpolicy) + + // policy is found via new name + idx, rpolicy, err = s.ACLPolicyGetByName(nil, "node-read-modified") + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.Equal(t, expect, rpolicy) }) } @@ -1467,6 +1950,1397 @@ func TestStateStore_ACLPolicy_Delete(t *testing.T) { }) } +func TestStateStore_ACLRole_SetGet(t *testing.T) { + t.Parallel() + + t.Run("Missing ID", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := structs.ACLRole{ + Name: "test-role", + Description: "test", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: structs.ACLPolicyGlobalManagementID, + }, + }, + } + + require.Error(t, s.ACLRoleSet(3, &role)) + }) + + t.Run("Missing Name", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := structs.ACLRole{ + ID: testRoleID_A, + Description: "test", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: structs.ACLPolicyGlobalManagementID, + }, + }, + } + + require.Error(t, s.ACLRoleSet(3, &role)) + }) + + t.Run("Missing Service Identity Fields", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := structs.ACLRole{ + ID: testRoleID_A, + Description: "test", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{}, + }, + } + + require.Error(t, s.ACLRoleSet(3, &role)) + }) + + t.Run("Missing Service Identity Name", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := structs.ACLRole{ + ID: testRoleID_A, + Description: "test", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + Datacenters: []string{"dc1"}, + }, + }, + } + + require.Error(t, s.ACLRoleSet(3, &role)) + }) + + t.Run("Missing Policy ID", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := structs.ACLRole{ + ID: testRoleID_A, + Description: "test", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + Name: "no-id", + }, + }, + } + + require.Error(t, s.ACLRoleSet(3, &role)) + }) + + t.Run("Unresolvable Policy ID", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := structs.ACLRole{ + ID: testRoleID_A, + Description: "test", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "4f20e379-b496-4b99-9599-19a197126490", + }, + }, + } + + require.Error(t, s.ACLRoleSet(3, &role)) + }) + + t.Run("New", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := structs.ACLRole{ + ID: testRoleID_A, + Name: "my-new-role", + Description: "test", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + } + + require.NoError(t, s.ACLRoleSet(3, &role)) + + verify := func(idx uint64, rrole *structs.ACLRole, err error) { + require.Equal(t, uint64(3), idx) + require.NoError(t, err) + require.NotNil(t, rrole) + require.Equal(t, "my-new-role", rrole.Name) + require.Equal(t, "test", rrole.Description) + require.Equal(t, uint64(3), rrole.CreateIndex) + require.Equal(t, uint64(3), rrole.ModifyIndex) + require.Len(t, rrole.ServiceIdentities, 0) + // require.ElementsMatch(t, role.Policies, rrole.Policies) + require.Len(t, rrole.Policies, 1) + require.Equal(t, "node-read", rrole.Policies[0].Name) + } + + idx, rpolicy, err := s.ACLRoleGetByID(nil, testRoleID_A) + verify(idx, rpolicy, err) + + idx, rpolicy, err = s.ACLRoleGetByName(nil, "my-new-role") + verify(idx, rpolicy, err) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + // Create the initial role + role := &structs.ACLRole{ + ID: testRoleID_A, + Name: "node-read-role", + Description: "Allows reading all node information", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + } + role.SetHash(true) + + require.NoError(t, s.ACLRoleSet(2, role)) + + // Now make sure we can update it + update := &structs.ACLRole{ + ID: testRoleID_A, + Name: "node-read-role-modified", + Description: "Modified", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: structs.ACLPolicyGlobalManagementID, + }, + }, + } + update.SetHash(true) + + require.NoError(t, s.ACLRoleSet(3, update)) + + verify := func(idx uint64, rrole *structs.ACLRole, err error) { + require.Equal(t, uint64(3), idx) + require.NoError(t, err) + require.NotNil(t, rrole) + require.Equal(t, "node-read-role-modified", rrole.Name) + require.Equal(t, "Modified", rrole.Description) + require.Equal(t, uint64(2), rrole.CreateIndex) + require.Equal(t, uint64(3), rrole.ModifyIndex) + require.Len(t, rrole.ServiceIdentities, 0) + require.Len(t, rrole.Policies, 1) + require.Equal(t, structs.ACLPolicyGlobalManagementID, rrole.Policies[0].ID) + require.Equal(t, "global-management", rrole.Policies[0].Name) + } + + // role found via id + idx, rrole, err := s.ACLRoleGetByID(nil, testRoleID_A) + verify(idx, rrole, err) + + // role no longer found via old name + idx, rrole, err = s.ACLRoleGetByName(nil, "node-read-role") + require.Equal(t, uint64(3), idx) + require.NoError(t, err) + require.Nil(t, rrole) + + // role is found via new name + idx, rrole, err = s.ACLRoleGetByName(nil, "node-read-role-modified") + verify(idx, rrole, err) + }) +} + +func TestStateStore_ACLRoles_UpsertBatchRead(t *testing.T) { + t.Parallel() + + t.Run("Normal", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + roles := structs.ACLRoles{ + &structs.ACLRole{ + ID: testRoleID_A, + Name: "role1", + Description: "test-role1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + }, + &structs.ACLRole{ + ID: testRoleID_B, + Name: "role2", + Description: "test-role2", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_B, + }, + }, + }, + } + + require.NoError(t, s.ACLRoleBatchSet(2, roles)) + + idx, rroles, err := s.ACLRoleBatchGet(nil, []string{testRoleID_A, testRoleID_B}) + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.Len(t, rroles, 2) + rroles.Sort() + require.ElementsMatch(t, roles, rroles) + require.Equal(t, uint64(2), rroles[0].CreateIndex) + require.Equal(t, uint64(2), rroles[0].ModifyIndex) + require.Equal(t, uint64(2), rroles[1].CreateIndex) + require.Equal(t, uint64(2), rroles[1].ModifyIndex) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + // Seed initial data. + roles := structs.ACLRoles{ + &structs.ACLRole{ + ID: testRoleID_A, + Name: "role1", + Description: "test-role1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + }, + &structs.ACLRole{ + ID: testRoleID_B, + Name: "role2", + Description: "test-role2", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_B, + }, + }, + }, + } + + require.NoError(t, s.ACLRoleBatchSet(2, roles)) + + // Update two roles at the same time. + updates := structs.ACLRoles{ + &structs.ACLRole{ + ID: testRoleID_A, + Name: "role1-modified", + Description: "test-role1-modified", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_C, + }, + }, + }, + &structs.ACLRole{ + ID: testRoleID_B, + Name: "role2-modified", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_D, + }, + structs.ACLRolePolicyLink{ + ID: testPolicyID_E, + }, + }, + }, + } + + require.NoError(t, s.ACLRoleBatchSet(3, updates)) + + idx, rroles, err := s.ACLRoleBatchGet(nil, []string{testRoleID_A, testRoleID_B}) + + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.Len(t, rroles, 2) + rroles.Sort() + require.Equal(t, testRoleID_A, rroles[0].ID) + require.Equal(t, "role1-modified", rroles[0].Name) + require.Equal(t, "test-role1-modified", rroles[0].Description) + require.ElementsMatch(t, updates[0].Policies, rroles[0].Policies) + require.Equal(t, uint64(2), rroles[0].CreateIndex) + require.Equal(t, uint64(3), rroles[0].ModifyIndex) + + require.Equal(t, testRoleID_B, rroles[1].ID) + require.Equal(t, "role2-modified", rroles[1].Name) + require.Equal(t, "", rroles[1].Description) + require.ElementsMatch(t, updates[1].Policies, rroles[1].Policies) + require.Equal(t, uint64(2), rroles[1].CreateIndex) + require.Equal(t, uint64(3), rroles[1].ModifyIndex) + }) +} + +func TestStateStore_ACLRole_List(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + roles := structs.ACLRoles{ + &structs.ACLRole{ + ID: testRoleID_A, + Name: "role1", + Description: "test-role1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + }, + &structs.ACLRole{ + ID: testRoleID_B, + Name: "role2", + Description: "test-role2", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_B, + }, + }, + }, + } + + require.NoError(t, s.ACLRoleBatchSet(2, roles)) + + type testCase struct { + name string + policy string + ids []string + } + + cases := []testCase{ + { + name: "Any", + policy: "", + ids: []string{ + testRoleID_A, + testRoleID_B, + }, + }, + { + name: "Policy A", + policy: testPolicyID_A, + ids: []string{ + testRoleID_A, + }, + }, + { + name: "Policy B", + policy: testPolicyID_B, + ids: []string{ + testRoleID_B, + }, + }, + } + + for _, tc := range cases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + // t.Parallel() + _, rroles, err := s.ACLRoleList(nil, tc.policy) + require.NoError(t, err) + + require.Len(t, rroles, len(tc.ids)) + rroles.Sort() + for i, rrole := range rroles { + expectID := tc.ids[i] + require.Equal(t, expectID, rrole.ID) + switch expectID { + case testRoleID_A: + require.Equal(t, testRoleID_A, rrole.ID) + require.Equal(t, "role1", rrole.Name) + require.Equal(t, "test-role1", rrole.Description) + require.ElementsMatch(t, roles[0].Policies, rrole.Policies) + require.Nil(t, rrole.Hash) + require.Equal(t, uint64(2), rrole.CreateIndex) + require.Equal(t, uint64(2), rrole.ModifyIndex) + case testRoleID_B: + require.Equal(t, testRoleID_B, rrole.ID) + require.Equal(t, "role2", rrole.Name) + require.Equal(t, "test-role2", rrole.Description) + require.ElementsMatch(t, roles[1].Policies, rrole.Policies) + require.Nil(t, rrole.Hash) + require.Equal(t, uint64(2), rrole.CreateIndex) + require.Equal(t, uint64(2), rrole.ModifyIndex) + } + } + }) + } +} + +func TestStateStore_ACLRole_FixupPolicyLinks(t *testing.T) { + // This test wants to ensure a couple of things. + // + // 1. Doing a role list/get should never modify the data + // tracked by memdb + // 2. Role list/get operations should return an accurate set + // of policy links + t.Parallel() + s := testACLRolesStateStore(t) + + // the policy specific role + role := &structs.ACLRole{ + ID: "672537b1-35cb-48fc-a2cd-a1863c301b70", + Name: "test-role", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: testPolicyID_A, + }, + }, + } + + require.NoError(t, s.ACLRoleSet(2, role)) + + _, retrieved, err := s.ACLRoleGetByID(nil, role.ID) + require.NoError(t, err) + // pointer equality check these should be identical + require.True(t, role == retrieved) + require.Len(t, retrieved.Policies, 1) + require.Equal(t, "node-read", retrieved.Policies[0].Name) + + // rename the policy + renamed := &structs.ACLPolicy{ + ID: testPolicyID_A, + Name: "node-read-renamed", + Description: "Allows reading all node information", + Rules: `node_prefix "" { policy = "read" }`, + Syntax: acl.SyntaxCurrent, + } + renamed.SetHash(true) + require.NoError(t, s.ACLPolicySet(3, renamed)) + + // retrieve the role again + _, retrieved, err = s.ACLRoleGetByID(nil, role.ID) + require.NoError(t, err) + // pointer equality check these should be different if we cloned things appropriately + require.True(t, role != retrieved) + require.Len(t, retrieved.Policies, 1) + require.Equal(t, "node-read-renamed", retrieved.Policies[0].Name) + + // list roles without stale links + _, roles, err := s.ACLRoleList(nil, "") + require.NoError(t, err) + + found := false + for _, r := range roles { + if r.ID == role.ID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, r != role) + require.Len(t, r.Policies, 1) + require.Equal(t, "node-read-renamed", r.Policies[0].Name) + found = true + break + } + } + require.True(t, found) + + // batch get without stale links + _, roles, err = s.ACLRoleBatchGet(nil, []string{role.ID}) + require.NoError(t, err) + + found = false + for _, r := range roles { + if r.ID == role.ID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, r != role) + require.Len(t, r.Policies, 1) + require.Equal(t, "node-read-renamed", r.Policies[0].Name) + found = true + break + } + } + require.True(t, found) + + // delete the policy + require.NoError(t, s.ACLPolicyDeleteByID(4, testPolicyID_A)) + + // retrieve the role again + _, retrieved, err = s.ACLRoleGetByID(nil, role.ID) + require.NoError(t, err) + // pointer equality check these should be different if we cloned things appropriately + require.True(t, role != retrieved) + require.Len(t, retrieved.Policies, 0) + + // list roles without stale links + _, roles, err = s.ACLRoleList(nil, "") + require.NoError(t, err) + + found = false + for _, r := range roles { + if r.ID == role.ID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, r != role) + require.Len(t, r.Policies, 0) + found = true + break + } + } + require.True(t, found) + + // batch get without stale links + _, roles, err = s.ACLRoleBatchGet(nil, []string{role.ID}) + require.NoError(t, err) + + found = false + for _, r := range roles { + if r.ID == role.ID { + // these pointers shouldn't be equal because the link should have been fixed + require.True(t, r != role) + require.Len(t, r.Policies, 0) + found = true + break + } + } + require.True(t, found) +} + +func TestStateStore_ACLRole_Delete(t *testing.T) { + t.Parallel() + + t.Run("ID", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := &structs.ACLRole{ + ID: testRoleID_A, + Name: "role1", + Description: "test-role1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: structs.ACLPolicyGlobalManagementID, + }, + }, + } + + require.NoError(t, s.ACLRoleSet(2, role)) + + _, rrole, err := s.ACLRoleGetByID(nil, testRoleID_A) + require.NoError(t, err) + require.NotNil(t, rrole) + + require.NoError(t, s.ACLRoleDeleteByID(3, testRoleID_A)) + require.NoError(t, err) + + _, rrole, err = s.ACLRoleGetByID(nil, testRoleID_A) + require.NoError(t, err) + require.Nil(t, rrole) + }) + + t.Run("Name", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + role := &structs.ACLRole{ + ID: testRoleID_A, + Name: "role1", + Description: "test-role1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: structs.ACLPolicyGlobalManagementID, + }, + }, + } + + require.NoError(t, s.ACLRoleSet(2, role)) + + _, rrole, err := s.ACLRoleGetByName(nil, "role1") + require.NoError(t, err) + require.NotNil(t, rrole) + + require.NoError(t, s.ACLRoleDeleteByName(3, "role1")) + require.NoError(t, err) + + _, rrole, err = s.ACLRoleGetByName(nil, "role1") + require.NoError(t, err) + require.Nil(t, rrole) + }) + + t.Run("Multiple", func(t *testing.T) { + t.Parallel() + s := testACLRolesStateStore(t) + + roles := structs.ACLRoles{ + &structs.ACLRole{ + ID: testRoleID_A, + Name: "role1", + Description: "test-role1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: structs.ACLPolicyGlobalManagementID, + }, + }, + }, + &structs.ACLRole{ + ID: testRoleID_B, + Name: "role2", + Description: "test-role2", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: structs.ACLPolicyGlobalManagementID, + }, + }, + }, + } + + require.NoError(t, s.ACLRoleBatchSet(2, roles)) + + _, rrole, err := s.ACLRoleGetByID(nil, testRoleID_A) + require.NoError(t, err) + require.NotNil(t, rrole) + _, rrole, err = s.ACLRoleGetByID(nil, testRoleID_B) + require.NoError(t, err) + require.NotNil(t, rrole) + + require.NoError(t, s.ACLRoleBatchDelete(3, []string{testRoleID_A, testRoleID_B})) + + _, rrole, err = s.ACLRoleGetByID(nil, testRoleID_A) + require.NoError(t, err) + require.Nil(t, rrole) + _, rrole, err = s.ACLRoleGetByID(nil, testRoleID_B) + require.NoError(t, err) + require.Nil(t, rrole) + }) + + t.Run("Not Found", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + // deletion of non-existant roles is not an error + require.NoError(t, s.ACLRoleDeleteByName(3, "not-found")) + require.NoError(t, s.ACLRoleDeleteByID(3, testRoleID_A)) + }) +} + +func TestStateStore_ACLAuthMethod_SetGet(t *testing.T) { + t.Parallel() + + // The state store only validates key pieces of data, so we only have to + // care about filling in Name+Type. + + t.Run("Missing Name", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + method := structs.ACLAuthMethod{ + Name: "", + Type: "testing", + Description: "test", + } + + require.Error(t, s.ACLAuthMethodSet(3, &method)) + }) + + t.Run("Missing Type", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + method := structs.ACLAuthMethod{ + Name: "test", + Type: "", + Description: "test", + } + + require.Error(t, s.ACLAuthMethodSet(3, &method)) + }) + + t.Run("New", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + method := structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + } + + require.NoError(t, s.ACLAuthMethodSet(3, &method)) + + idx, rmethod, err := s.ACLAuthMethodGetByName(nil, "test") + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.NotNil(t, rmethod) + require.Equal(t, "test", rmethod.Name) + require.Equal(t, "testing", rmethod.Type) + require.Equal(t, "test", rmethod.Description) + require.Equal(t, uint64(3), rmethod.CreateIndex) + require.Equal(t, uint64(3), rmethod.ModifyIndex) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + // Create the initial method + method := structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + } + + require.NoError(t, s.ACLAuthMethodSet(2, &method)) + + // Now make sure we can update it + update := structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "modified", + Config: map[string]interface{}{ + "Host": "https://localhost:8443", + }, + } + + require.NoError(t, s.ACLAuthMethodSet(3, &update)) + + idx, rmethod, err := s.ACLAuthMethodGetByName(nil, "test") + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.NotNil(t, rmethod) + require.Equal(t, "test", rmethod.Name) + require.Equal(t, "testing", rmethod.Type) + require.Equal(t, "modified", rmethod.Description) + require.Equal(t, update.Config, rmethod.Config) + require.Equal(t, uint64(2), rmethod.CreateIndex) + require.Equal(t, uint64(3), rmethod.ModifyIndex) + }) +} + +func TestStateStore_ACLAuthMethods_UpsertBatchRead(t *testing.T) { + t.Parallel() + + t.Run("Normal", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-1", + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + idx, rmethods, err := s.ACLAuthMethodList(nil) + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.Len(t, rmethods, 2) + rmethods.Sort() + require.ElementsMatch(t, methods, rmethods) + require.Equal(t, uint64(2), rmethods[0].CreateIndex) + require.Equal(t, uint64(2), rmethods[0].ModifyIndex) + require.Equal(t, uint64(2), rmethods[1].CreateIndex) + require.Equal(t, uint64(2), rmethods[1].ModifyIndex) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + // Seed initial data. + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + // Update two methods at the same time. + updates := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1 modified", + Config: map[string]interface{}{ + "Host": "https://localhost:8443", + }, + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2 modified", + Config: map[string]interface{}{ + "Host": "https://localhost:8444", + }, + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(3, updates)) + + idx, rmethods, err := s.ACLAuthMethodList(nil) + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.Len(t, rmethods, 2) + rmethods.Sort() + require.ElementsMatch(t, updates, rmethods) + require.Equal(t, uint64(2), rmethods[0].CreateIndex) + require.Equal(t, uint64(3), rmethods[0].ModifyIndex) + require.Equal(t, uint64(2), rmethods[1].CreateIndex) + require.Equal(t, uint64(3), rmethods[1].ModifyIndex) + }) +} + +func TestStateStore_ACLAuthMethod_List(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + _, rmethods, err := s.ACLAuthMethodList(nil) + require.NoError(t, err) + + require.Len(t, rmethods, 2) + rmethods.Sort() + + require.Equal(t, "test-1", rmethods[0].Name) + require.Equal(t, "testing", rmethods[0].Type) + require.Equal(t, "test-1", rmethods[0].Description) + require.Equal(t, uint64(2), rmethods[0].CreateIndex) + require.Equal(t, uint64(2), rmethods[0].ModifyIndex) + + require.Equal(t, "test-2", rmethods[1].Name) + require.Equal(t, "testing", rmethods[1].Type) + require.Equal(t, "test-2", rmethods[1].Description) + require.Equal(t, uint64(2), rmethods[1].CreateIndex) + require.Equal(t, uint64(2), rmethods[1].ModifyIndex) +} + +func TestStateStore_ACLAuthMethod_Delete(t *testing.T) { + t.Parallel() + + t.Run("Name", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + method := structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + } + + require.NoError(t, s.ACLAuthMethodSet(2, &method)) + + _, rmethod, err := s.ACLAuthMethodGetByName(nil, "test") + require.NoError(t, err) + require.NotNil(t, rmethod) + + require.NoError(t, s.ACLAuthMethodDeleteByName(3, "test")) + require.NoError(t, err) + + _, rmethod, err = s.ACLAuthMethodGetByName(nil, "test") + require.NoError(t, err) + require.Nil(t, rmethod) + }) + + t.Run("Multiple", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + _, rmethod, err := s.ACLAuthMethodGetByName(nil, "test-1") + require.NoError(t, err) + require.NotNil(t, rmethod) + _, rmethod, err = s.ACLAuthMethodGetByName(nil, "test-2") + require.NoError(t, err) + require.NotNil(t, rmethod) + + require.NoError(t, s.ACLAuthMethodBatchDelete(3, []string{"test-1", "test-2"})) + + _, rmethod, err = s.ACLAuthMethodGetByName(nil, "test-1") + require.NoError(t, err) + require.Nil(t, rmethod) + _, rmethod, err = s.ACLAuthMethodGetByName(nil, "test-2") + require.NoError(t, err) + require.Nil(t, rmethod) + }) + + t.Run("Not Found", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + // deletion of non-existant methods is not an error + require.NoError(t, s.ACLAuthMethodDeleteByName(3, "not-found")) + }) +} + +// Deleting an auth method atomically deletes all rules and tokens as well. +func TestStateStore_ACLAuthMethod_Delete_RuleAndTokenCascade(t *testing.T) { + t.Parallel() + + s := testACLStateStore(t) + + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2", + }, + } + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + const ( + method1_rule1 = "dff6f8a3-0115-4b22-8661-04a497ebb23c" + method1_rule2 = "69e2d304-703d-4889-bd94-4a720c061fc3" + method2_rule1 = "997ee45c-d6ba-4da1-a98e-aaa012e7d1e2" + method2_rule2 = "9ebae132-f1f1-4b72-b1d9-a4313ac22075" + ) + + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: method1_rule1, + AuthMethod: "test-1", + Description: "test-m1-r1", + }, + &structs.ACLBindingRule{ + ID: method1_rule2, + AuthMethod: "test-1", + Description: "test-m1-r2", + }, + &structs.ACLBindingRule{ + ID: method2_rule1, + AuthMethod: "test-2", + Description: "test-m2-r1", + }, + &structs.ACLBindingRule{ + ID: method2_rule2, + AuthMethod: "test-2", + Description: "test-m2-r2", + }, + } + require.NoError(t, s.ACLBindingRuleBatchSet(3, rules)) + + const ( // accessors + method1_tok1 = "6d020c5d-c4fd-4348-ba79-beac37ed0b9c" + method1_tok2 = "169160dc-34ab-45c6-aba7-ff65e9ace9cb" + method2_tok1 = "8e14628e-7dde-4573-aca1-6386c0f2095d" + method2_tok2 = "291e5af9-c68e-4dd3-8824-b2bdfdcc89e6" + ) + + tokens := structs.ACLTokens{ + &structs.ACLToken{ + AccessorID: method1_tok1, + SecretID: "7a1950c6-79dc-441c-acd2-e22cd3db0240", + Description: "test-m1-t1", + AuthMethod: "test-1", + }, + &structs.ACLToken{ + AccessorID: method1_tok2, + SecretID: "442cee4c-353f-4957-adbb-33db2f9e267f", + Description: "test-m1-t2", + AuthMethod: "test-1", + }, + &structs.ACLToken{ + AccessorID: method2_tok1, + SecretID: "d9399b7d-6c34-46bd-a675-c1352fadb6fd", + Description: "test-m2-t1", + AuthMethod: "test-2", + }, + &structs.ACLToken{ + AccessorID: method2_tok2, + SecretID: "3b72fc27-9230-42ab-a1e8-02cb489ab177", + Description: "test-m2-t2", + AuthMethod: "test-2", + }, + } + require.NoError(t, s.ACLTokenBatchSet(4, tokens, false)) + + // Delete one method. + require.NoError(t, s.ACLAuthMethodDeleteByName(4, "test-1")) + + // Make sure the method is gone. + _, rmethod, err := s.ACLAuthMethodGetByName(nil, "test-1") + require.NoError(t, err) + require.Nil(t, rmethod) + + // Make sure the rules and tokens are gone. + for _, ruleID := range []string{method1_rule1, method1_rule2} { + _, rrule, err := s.ACLBindingRuleGetByID(nil, ruleID) + require.NoError(t, err) + require.Nil(t, rrule) + } + for _, tokID := range []string{method1_tok1, method1_tok2} { + _, tok, err := s.ACLTokenGetByAccessor(nil, tokID) + require.NoError(t, err) + require.Nil(t, tok) + } + + // Make sure the rules and tokens for the untouched method are still there. + for _, ruleID := range []string{method2_rule1, method2_rule2} { + _, rrule, err := s.ACLBindingRuleGetByID(nil, ruleID) + require.NoError(t, err) + require.NotNil(t, rrule) + } + for _, tokID := range []string{method2_tok1, method2_tok2} { + _, tok, err := s.ACLTokenGetByAccessor(nil, tokID) + require.NoError(t, err) + require.NotNil(t, tok) + } +} + +func TestStateStore_ACLBindingRule_SetGet(t *testing.T) { + t.Parallel() + + // The state store only validates key pieces of data, so we only have to + // care about filling in ID+AuthMethod. + + t.Run("Missing ID", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rule := structs.ACLBindingRule{ + ID: "", + AuthMethod: "test", + Description: "test", + } + + require.Error(t, s.ACLBindingRuleSet(3, &rule)) + }) + + t.Run("Missing AuthMethod", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rule := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "", + Description: "test", + } + + require.Error(t, s.ACLBindingRuleSet(3, &rule)) + }) + + t.Run("Unknown AuthMethod", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rule := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "unknown", + Description: "test", + } + + require.Error(t, s.ACLBindingRuleSet(3, &rule)) + }) + + t.Run("New", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rule := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test", + } + + require.NoError(t, s.ACLBindingRuleSet(3, &rule)) + + idx, rrule, err := s.ACLBindingRuleGetByID(nil, rule.ID) + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.NotNil(t, rrule) + require.Equal(t, rule.ID, rrule.ID) + require.Equal(t, "test", rrule.AuthMethod) + require.Equal(t, "test", rrule.Description) + require.Equal(t, uint64(3), rrule.CreateIndex) + require.Equal(t, uint64(3), rrule.ModifyIndex) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + // Create the initial rule + rule := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test", + } + + require.NoError(t, s.ACLBindingRuleSet(2, &rule)) + + // Now make sure we can update it + update := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "modified", + BindType: structs.BindingRuleBindTypeService, + BindName: "web", + } + + require.NoError(t, s.ACLBindingRuleSet(3, &update)) + + idx, rrule, err := s.ACLBindingRuleGetByID(nil, rule.ID) + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.NotNil(t, rrule) + require.Equal(t, rule.ID, rrule.ID) + require.Equal(t, "test", rrule.AuthMethod) + require.Equal(t, "modified", rrule.Description) + require.Equal(t, structs.BindingRuleBindTypeService, rrule.BindType) + require.Equal(t, "web", rrule.BindName) + require.Equal(t, uint64(2), rrule.CreateIndex) + require.Equal(t, uint64(3), rrule.ModifyIndex) + }) +} + +func TestStateStore_ACLBindingRules_UpsertBatchRead(t *testing.T) { + t.Parallel() + + t.Run("Normal", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-1", + }, + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(2, rules)) + + idx, rrules, err := s.ACLBindingRuleList(nil, "test") + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.Len(t, rrules, 2) + rrules.Sort() + require.ElementsMatch(t, rules, rrules) + require.Equal(t, uint64(2), rrules[0].CreateIndex) + require.Equal(t, uint64(2), rrules[0].ModifyIndex) + require.Equal(t, uint64(2), rrules[1].CreateIndex) + require.Equal(t, uint64(2), rrules[1].ModifyIndex) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + // Seed initial data. + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-1", + }, + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(2, rules)) + + // Update two rules at the same time. + updates := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-1 modified", + BindType: structs.BindingRuleBindTypeService, + BindName: "web-1", + }, + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-2 modified", + BindType: structs.BindingRuleBindTypeService, + BindName: "web-2", + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(3, updates)) + + idx, rrules, err := s.ACLBindingRuleList(nil, "test") + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.Len(t, rrules, 2) + rrules.Sort() + require.ElementsMatch(t, updates, rrules) + require.Equal(t, uint64(2), rrules[0].CreateIndex) + require.Equal(t, uint64(3), rrules[0].ModifyIndex) + require.Equal(t, uint64(2), rrules[1].CreateIndex) + require.Equal(t, uint64(3), rrules[1].ModifyIndex) + }) +} + +func TestStateStore_ACLBindingRule_List(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-1", + }, + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(2, rules)) + + _, rrules, err := s.ACLBindingRuleList(nil, "") + require.NoError(t, err) + + require.Len(t, rrules, 2) + rrules.Sort() + + require.Equal(t, "3ebcc27b-f8ba-4611-b385-79a065dfb983", rrules[0].ID) + require.Equal(t, "test", rrules[0].AuthMethod) + require.Equal(t, "test-1", rrules[0].Description) + require.Equal(t, uint64(2), rrules[0].CreateIndex) + require.Equal(t, uint64(2), rrules[0].ModifyIndex) + + require.Equal(t, "9669b2d7-455c-4d70-b0ac-457fd7969a2e", rrules[1].ID) + require.Equal(t, "test", rrules[1].AuthMethod) + require.Equal(t, "test-2", rrules[1].Description) + require.Equal(t, uint64(2), rrules[1].CreateIndex) + require.Equal(t, uint64(2), rrules[1].ModifyIndex) +} + +func TestStateStore_ACLBindingRule_Delete(t *testing.T) { + t.Parallel() + + t.Run("Name", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rule := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test", + } + + require.NoError(t, s.ACLBindingRuleSet(2, &rule)) + + _, rrule, err := s.ACLBindingRuleGetByID(nil, rule.ID) + require.NoError(t, err) + require.NotNil(t, rrule) + + require.NoError(t, s.ACLBindingRuleDeleteByID(3, rule.ID)) + require.NoError(t, err) + + _, rrule, err = s.ACLBindingRuleGetByID(nil, rule.ID) + require.NoError(t, err) + require.Nil(t, rrule) + }) + + t.Run("Multiple", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-1", + }, + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(2, rules)) + + _, rrule, err := s.ACLBindingRuleGetByID(nil, rules[0].ID) + require.NoError(t, err) + require.NotNil(t, rrule) + _, rrule, err = s.ACLBindingRuleGetByID(nil, rules[1].ID) + require.NoError(t, err) + require.NotNil(t, rrule) + + require.NoError(t, s.ACLBindingRuleBatchDelete(3, []string{rules[0].ID, rules[1].ID})) + + _, rrule, err = s.ACLBindingRuleGetByID(nil, rules[0].ID) + require.NoError(t, err) + require.Nil(t, rrule) + _, rrule, err = s.ACLBindingRuleGetByID(nil, rules[1].ID) + require.NoError(t, err) + require.Nil(t, rrule) + }) + + t.Run("Not Found", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + // deletion of non-existant rules is not an error + require.NoError(t, s.ACLBindingRuleDeleteByID(3, "ed3ce1b8-3a16-4e2f-b82e-f92e3b92410d")) + }) +} + func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { s := testStateStore(t) @@ -1493,6 +3367,35 @@ func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { require.NoError(t, s.ACLPolicyBatchSet(2, policies)) + roles := structs.ACLRoles{ + &structs.ACLRole{ + ID: "1a3a9af9-9cdc-473a-8016-010067b7e424", + Name: "role1", + Description: "role1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "ca1fc52c-3676-4050-82ed-ca223e38b2c9", + }, + }, + }, + &structs.ACLRole{ + ID: "4dccc2c7-10f3-4eba-b367-9c09be9a9d67", + Name: "role2", + Description: "role2", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "7b70fa0f-58cd-412d-93c3-a0f17bb19a3e", + }, + }, + }, + } + + for _, role := range roles { + role.SetHash(true) + } + + require.NoError(t, s.ACLRoleBatchSet(3, roles)) + tokens := structs.ACLTokens{ &structs.ACLToken{ AccessorID: "68016c3d-835b-450c-a6f9-75db9ba740be", @@ -1508,8 +3411,16 @@ func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { Name: "policy2", }, }, - Hash: []byte{1, 2, 3, 4}, - RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: "1a3a9af9-9cdc-473a-8016-010067b7e424", + Name: "role1", + }, + structs.ACLTokenRoleLink{ + ID: "4dccc2c7-10f3-4eba-b367-9c09be9a9d67", + Name: "role2", + }, + }, }, &structs.ACLToken{ AccessorID: "b2125a1b-2a52-41d4-88f3-c58761998a46", @@ -1525,12 +3436,22 @@ func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { Name: "policy2", }, }, + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: "1a3a9af9-9cdc-473a-8016-010067b7e424", + Name: "role1", + }, + structs.ACLTokenRoleLink{ + ID: "4dccc2c7-10f3-4eba-b367-9c09be9a9d67", + Name: "role2", + }, + }, Hash: []byte{1, 2, 3, 4}, RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, }, } - require.NoError(t, s.ACLTokenBatchSet(2, tokens, false)) + require.NoError(t, s.ACLTokenBatchSet(4, tokens, false)) // Snapshot the ACLs. snap := s.Snapshot() @@ -1540,7 +3461,7 @@ func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { require.NoError(t, s.ACLTokenDeleteByAccessor(3, tokens[0].AccessorID)) // Verify the snapshot. - require.Equal(t, uint64(2), snap.LastIndex()) + require.Equal(t, uint64(4), snap.LastIndex()) iter, err := snap.ACLTokens() require.NoError(t, err) @@ -1563,12 +3484,15 @@ func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { // need to ensure we have the policies or else the links will be removed require.NoError(t, s.ACLPolicyBatchSet(2, policies)) + // need to ensure we have the roles or else the links will be removed + require.NoError(t, s.ACLRoleBatchSet(2, roles)) + // Read the restored ACLs back out and verify that they match. - idx, res, err := s.ACLTokenList(nil, true, true, "") + idx, res, err := s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) - require.Equal(t, uint64(2), idx) + require.Equal(t, uint64(4), idx) require.ElementsMatch(t, tokens, res) - require.Equal(t, uint64(2), s.maxIndex("acl-tokens")) + require.Equal(t, uint64(4), s.maxIndex("acl-tokens")) }() } @@ -1632,3 +3556,392 @@ func TestStateStore_ACLPolicies_Snapshot_Restore(t *testing.T) { require.Equal(t, uint64(2), s.maxIndex("acl-policies")) }() } + +func TestTokenPoliciesIndex(t *testing.T) { + lib.SeedMathRand() + + idIndex := &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "AccessorID", Lowercase: false}, + } + globalIndex := &memdb.IndexSchema{ + Name: "global", + AllowMissing: true, + Unique: false, + Indexer: &TokenExpirationIndex{LocalFilter: false}, + } + localIndex := &memdb.IndexSchema{ + Name: "local", + AllowMissing: true, + Unique: false, + Indexer: &TokenExpirationIndex{LocalFilter: true}, + } + schema := &memdb.DBSchema{ + Tables: map[string]*memdb.TableSchema{ + "test": &memdb.TableSchema{ + Name: "test", + Indexes: map[string]*memdb.IndexSchema{ + "id": idIndex, + "global": globalIndex, + "local": localIndex, + }, + }, + }, + } + + knownUUIDs := make(map[string]struct{}) + newUUID := func() string { + for { + ret, err := uuid.GenerateUUID() + require.NoError(t, err) + if _, ok := knownUUIDs[ret]; !ok { + knownUUIDs[ret] = struct{}{} + return ret + } + } + } + + baseTime := time.Date(2010, 12, 31, 11, 30, 7, 0, time.UTC) + + newToken := func(local bool, desc string, expTime time.Time) *structs.ACLToken { + return &structs.ACLToken{ + AccessorID: newUUID(), + SecretID: newUUID(), + Description: desc, + Local: local, + ExpirationTime: &expTime, + CreateTime: baseTime, + RaftIndex: structs.RaftIndex{ + CreateIndex: 9, + ModifyIndex: 10, + }, + } + } + + db, err := memdb.NewMemDB(schema) + require.NoError(t, err) + + dumpItems := func(index string) ([]string, error) { + tx := db.Txn(false) + defer tx.Abort() + + iter, err := tx.Get("test", index) + if err != nil { + return nil, err + } + + var out []string + for raw := iter.Next(); raw != nil; raw = iter.Next() { + tok := raw.(*structs.ACLToken) + out = append(out, tok.Description) + } + + return out, nil + } + + { // insert things with no expiration time + tx := db.Txn(true) + for i := 0; i < 10; i++ { + tok := newToken(i%2 != 1, "tok["+strconv.Itoa(i)+"]", time.Time{}) + + require.NoError(t, tx.Insert("test", tok)) + } + tx.Commit() + } + + t.Run("no expiration", func(t *testing.T) { + dump, err := dumpItems("local") + require.NoError(t, err) + require.Len(t, dump, 0) + + dump, err = dumpItems("global") + require.NoError(t, err) + require.Len(t, dump, 0) + }) + + { // insert things with laddered expiration time, inserted in random order + var tokens []*structs.ACLToken + for i := 0; i < 10; i++ { + expTime := baseTime.Add(time.Duration(i+1) * time.Minute) + tok := newToken(i%2 == 0, "exp-tok["+strconv.Itoa(i)+"]", expTime) + tokens = append(tokens, tok) + } + rand.Shuffle(len(tokens), func(i, j int) { + tokens[i], tokens[j] = tokens[j], tokens[i] + }) + + tx := db.Txn(true) + for _, tok := range tokens { + require.NoError(t, tx.Insert("test", tok)) + } + tx.Commit() + } + + t.Run("mixed expiration", func(t *testing.T) { + dump, err := dumpItems("local") + require.NoError(t, err) + require.ElementsMatch(t, []string{ + "exp-tok[0]", + "exp-tok[2]", + "exp-tok[4]", + "exp-tok[6]", + "exp-tok[8]", + }, dump) + + dump, err = dumpItems("global") + require.NoError(t, err) + require.ElementsMatch(t, []string{ + "exp-tok[1]", + "exp-tok[3]", + "exp-tok[5]", + "exp-tok[7]", + "exp-tok[9]", + }, dump) + }) +} + +func stripIrrelevantTokenFields(token *structs.ACLToken) *structs.ACLToken { + tokenCopy := token.Clone() + // When comparing the tokens disregard the policy link names. This + // data is not cleanly updated in a variety of scenarios and should not + // be relied upon. + for i, _ := range tokenCopy.Policies { + tokenCopy.Policies[i].Name = "" + } + // Also do the same for Role links. + for i, _ := range tokenCopy.Roles { + tokenCopy.Roles[i].Name = "" + } + // The raft indexes won't match either because the requester will not + // have access to that. + tokenCopy.RaftIndex = structs.RaftIndex{} + return tokenCopy +} + +func compareTokens(t *testing.T, expected, actual *structs.ACLToken) { + require.Equal(t, stripIrrelevantTokenFields(expected), stripIrrelevantTokenFields(actual)) +} + +func TestStateStore_ACLRoles_Snapshot_Restore(t *testing.T) { + s := testStateStore(t) + + policies := structs.ACLPolicies{ + &structs.ACLPolicy{ + ID: "ca1fc52c-3676-4050-82ed-ca223e38b2c9", + Name: "policy1", + Description: "policy1", + Rules: `node_prefix "" { policy = "read" }`, + Syntax: acl.SyntaxCurrent, + }, + &structs.ACLPolicy{ + ID: "7b70fa0f-58cd-412d-93c3-a0f17bb19a3e", + Name: "policy2", + Description: "policy2", + Rules: `acl = "read"`, + Syntax: acl.SyntaxCurrent, + }, + } + + for _, policy := range policies { + policy.SetHash(true) + } + + require.NoError(t, s.ACLPolicyBatchSet(2, policies)) + + roles := structs.ACLRoles{ + &structs.ACLRole{ + ID: "68016c3d-835b-450c-a6f9-75db9ba740be", + Name: "838f72b5-5c15-4a9e-aa6d-31734c3a0286", + Description: "policy1", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "ca1fc52c-3676-4050-82ed-ca223e38b2c9", + Name: "policy1", + }, + structs.ACLRolePolicyLink{ + ID: "7b70fa0f-58cd-412d-93c3-a0f17bb19a3e", + Name: "policy2", + }, + }, + Hash: []byte{1, 2, 3, 4}, + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + &structs.ACLRole{ + ID: "b2125a1b-2a52-41d4-88f3-c58761998a46", + Name: "ba5d9239-a4ab-49b9-ae09-1f19eed92204", + Description: "policy2", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: "ca1fc52c-3676-4050-82ed-ca223e38b2c9", + Name: "policy1", + }, + structs.ACLRolePolicyLink{ + ID: "7b70fa0f-58cd-412d-93c3-a0f17bb19a3e", + Name: "policy2", + }, + }, + Hash: []byte{1, 2, 3, 4}, + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + } + + require.NoError(t, s.ACLRoleBatchSet(2, roles)) + + // Snapshot the ACLs. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + require.NoError(t, s.ACLRoleDeleteByID(3, roles[0].ID)) + + // Verify the snapshot. + require.Equal(t, uint64(2), snap.LastIndex()) + + iter, err := snap.ACLRoles() + require.NoError(t, err) + + var dump structs.ACLRoles + for role := iter.Next(); role != nil; role = iter.Next() { + dump = append(dump, role.(*structs.ACLRole)) + } + require.ElementsMatch(t, dump, roles) + + // Restore the values into a new state store. + func() { + s := testStateStore(t) + restore := s.Restore() + for _, role := range dump { + require.NoError(t, restore.ACLRole(role)) + } + restore.Commit() + + // need to ensure we have the policies or else the links will be removed + require.NoError(t, s.ACLPolicyBatchSet(2, policies)) + + // Read the restored ACLs back out and verify that they match. + idx, res, err := s.ACLRoleList(nil, "") + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.ElementsMatch(t, roles, res) + require.Equal(t, uint64(2), s.maxIndex("acl-roles")) + }() +} + +func TestStateStore_ACLAuthMethods_Snapshot_Restore(t *testing.T) { + s := testACLStateStore(t) + + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2", + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + // Snapshot the ACLs. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + require.NoError(t, s.ACLAuthMethodDeleteByName(3, "test-1")) + + // Verify the snapshot. + require.Equal(t, uint64(2), snap.LastIndex()) + + iter, err := snap.ACLAuthMethods() + require.NoError(t, err) + + var dump structs.ACLAuthMethods + for method := iter.Next(); method != nil; method = iter.Next() { + dump = append(dump, method.(*structs.ACLAuthMethod)) + } + require.ElementsMatch(t, dump, methods) + + // Restore the values into a new state store. + func() { + s := testStateStore(t) + restore := s.Restore() + for _, method := range dump { + require.NoError(t, restore.ACLAuthMethod(method)) + } + restore.Commit() + + // Read the restored methods back out and verify that they match. + idx, res, err := s.ACLAuthMethodList(nil) + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.ElementsMatch(t, methods, res) + require.Equal(t, uint64(2), s.maxIndex("acl-auth-methods")) + }() +} + +func TestStateStore_ACLBindingRules_Snapshot_Restore(t *testing.T) { + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-1", + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-2", + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(2, rules)) + + // Snapshot the ACLs. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + require.NoError(t, s.ACLBindingRuleDeleteByID(3, rules[0].ID)) + + // Verify the snapshot. + require.Equal(t, uint64(2), snap.LastIndex()) + + iter, err := snap.ACLBindingRules() + require.NoError(t, err) + + var dump structs.ACLBindingRules + for rule := iter.Next(); rule != nil; rule = iter.Next() { + dump = append(dump, rule.(*structs.ACLBindingRule)) + } + require.ElementsMatch(t, dump, rules) + + // Restore the values into a new state store. + func() { + s := testStateStore(t) + setupExtraAuthMethods(t, s) + + restore := s.Restore() + for _, rule := range dump { + require.NoError(t, restore.ACLBindingRule(rule)) + } + restore.Commit() + + // Read the restored rules back out and verify that they match. + idx, res, err := s.ACLBindingRuleList(nil, "") + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.ElementsMatch(t, rules, res) + require.Equal(t, uint64(2), s.maxIndex("acl-binding-rules")) + }() +} diff --git a/agent/consul/state/state_store.go b/agent/consul/state/state_store.go index 7b42c7ad4f..f38bc42da8 100644 --- a/agent/consul/state/state_store.go +++ b/agent/consul/state/state_store.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/hashicorp/consul/types" - "github.com/hashicorp/go-memdb" + memdb "github.com/hashicorp/go-memdb" ) var ( @@ -37,6 +37,30 @@ var ( // policy with an empty Name. ErrMissingACLPolicyName = errors.New("Missing ACL Policy Name") + // ErrMissingACLRoleID is returned when a role set is called on + // a role with an empty ID. + ErrMissingACLRoleID = errors.New("Missing ACL Role ID") + + // ErrMissingACLRoleName is returned when a role set is called on + // a role with an empty Name. + ErrMissingACLRoleName = errors.New("Missing ACL Role Name") + + // ErrMissingACLBindingRuleID is returned when a binding rule set + // is called on a binding rule with an empty ID. + ErrMissingACLBindingRuleID = errors.New("Missing ACL Binding Rule ID") + + // ErrMissingACLBindingRuleAuthMethod is returned when a binding rule set + // is called on a binding rule with an empty AuthMethod. + ErrMissingACLBindingRuleAuthMethod = errors.New("Missing ACL Binding Rule Auth Method") + + // ErrMissingACLAuthMethodName is returned when an auth method set is + // called on an auth method with an empty Name. + ErrMissingACLAuthMethodName = errors.New("Missing ACL Auth Method Name") + + // ErrMissingACLAuthMethodType is returned when an auth method set is + // called on an auth method with an empty Type. + ErrMissingACLAuthMethodType = errors.New("Missing ACL Auth Method Type") + // ErrMissingQueryID is returned when a Query set is called on // a Query with an empty ID. ErrMissingQueryID = errors.New("Missing Query ID") diff --git a/agent/consul/util.go b/agent/consul/util.go index cb0134240a..19ff2cc2b2 100644 --- a/agent/consul/util.go +++ b/agent/consul/util.go @@ -6,10 +6,13 @@ import ( "net" "runtime" "strconv" + "strings" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/go-version" + "github.com/hashicorp/hil" + "github.com/hashicorp/hil/ast" "github.com/hashicorp/serf/serf" ) @@ -322,3 +325,42 @@ func ServersGetACLMode(members []serf.Member, leader string, datacenter string) return } + +// InterpolateHIL processes the string as if it were HIL and interpolates only +// the provided string->string map as possible variables. +func InterpolateHIL(s string, vars map[string]string) (string, error) { + if strings.Index(s, "${") == -1 { + // Skip going to the trouble of parsing something that has no HIL. + return s, nil + } + + tree, err := hil.Parse(s) + if err != nil { + return "", err + } + + vm := make(map[string]ast.Variable) + for k, v := range vars { + vm[k] = ast.Variable{ + Type: ast.TypeString, + Value: v, + } + } + + config := &hil.EvalConfig{ + GlobalScope: &ast.BasicScope{ + VarMap: vm, + }, + } + + result, err := hil.Eval(tree, config) + if err != nil { + return "", err + } + + if result.Type != hil.TypeString { + return "", fmt.Errorf("generated unexpected hil type: %s", result.Type) + } + + return result.Value.(string), nil +} diff --git a/agent/consul/util_test.go b/agent/consul/util_test.go index b0c6e04d30..654673b628 100644 --- a/agent/consul/util_test.go +++ b/agent/consul/util_test.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/go-version" "github.com/hashicorp/serf/serf" + "github.com/stretchr/testify/require" ) func TestGetPrivateIP(t *testing.T) { @@ -403,3 +404,133 @@ func TestServersMeetMinimumVersion(t *testing.T) { } } } + +func TestInterpolateHIL(t *testing.T) { + for _, test := range []struct { + name string + in string + vars map[string]string + exp string + ok bool + }{ + // valid HIL + { + "empty", + "", + map[string]string{}, + "", + true, + }, + { + "no vars", + "nothing", + map[string]string{}, + "nothing", + true, + }, + { + "just var", + "${item}", + map[string]string{"item": "value"}, + "value", + true, + }, + { + "var in middle", + "before ${item}after", + map[string]string{"item": "value"}, + "before valueafter", + true, + }, + { + "two vars", + "before ${item}after ${more}", + map[string]string{"item": "value", "more": "xyz"}, + "before valueafter xyz", + true, + }, + { + "missing map val", + "${item}", + map[string]string{"item": ""}, + "", + true, + }, + // "weird" HIL, but not technically invalid + { + "just end", + "}", + map[string]string{}, + "}", + true, + }, + { + "var without start", + " item }", + map[string]string{"item": "value"}, + " item }", + true, + }, + { + "two vars missing second start", + "before ${ item }after more }", + map[string]string{"item": "value", "more": "xyz"}, + "before valueafter more }", + true, + }, + // invalid HIL + { + "just start", + "${", + map[string]string{}, + "", + false, + }, + { + "backwards", + "}${", + map[string]string{}, + "", + false, + }, + { + "no varname", + "${}", + map[string]string{}, + "", + false, + }, + { + "missing map key", + "${item}", + map[string]string{}, + "", + false, + }, + { + "var without end", + "${ item ", + map[string]string{"item": "value"}, + "", + false, + }, + { + "two vars missing first end", + "before ${ item after ${ more }", + map[string]string{"item": "value", "more": "xyz"}, + "", + false, + }, + } { + t.Run(test.name, func(t *testing.T) { + out, err := InterpolateHIL(test.in, test.vars) + if test.ok { + require.NoError(t, err) + require.Equal(t, test.exp, out) + } else { + require.NotNil(t, err) + require.Equal(t, out, "") + } + }) + } +} diff --git a/agent/http_oss.go b/agent/http_oss.go index e524e450a9..a4584a5a49 100644 --- a/agent/http_oss.go +++ b/agent/http_oss.go @@ -10,10 +10,22 @@ func init() { registerEndpoint("/v1/acl/info/", []string{"GET"}, (*HTTPServer).ACLGet) registerEndpoint("/v1/acl/clone/", []string{"PUT"}, (*HTTPServer).ACLClone) registerEndpoint("/v1/acl/list", []string{"GET"}, (*HTTPServer).ACLList) + registerEndpoint("/v1/acl/login", []string{"POST"}, (*HTTPServer).ACLLogin) + registerEndpoint("/v1/acl/logout", []string{"POST"}, (*HTTPServer).ACLLogout) registerEndpoint("/v1/acl/replication", []string{"GET"}, (*HTTPServer).ACLReplicationStatus) registerEndpoint("/v1/acl/policies", []string{"GET"}, (*HTTPServer).ACLPolicyList) registerEndpoint("/v1/acl/policy", []string{"PUT"}, (*HTTPServer).ACLPolicyCreate) registerEndpoint("/v1/acl/policy/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).ACLPolicyCRUD) + registerEndpoint("/v1/acl/roles", []string{"GET"}, (*HTTPServer).ACLRoleList) + registerEndpoint("/v1/acl/role", []string{"PUT"}, (*HTTPServer).ACLRoleCreate) + registerEndpoint("/v1/acl/role/name/", []string{"GET"}, (*HTTPServer).ACLRoleReadByName) + registerEndpoint("/v1/acl/role/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).ACLRoleCRUD) + registerEndpoint("/v1/acl/binding-rules", []string{"GET"}, (*HTTPServer).ACLBindingRuleList) + registerEndpoint("/v1/acl/binding-rule", []string{"PUT"}, (*HTTPServer).ACLBindingRuleCreate) + registerEndpoint("/v1/acl/binding-rule/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).ACLBindingRuleCRUD) + registerEndpoint("/v1/acl/auth-methods", []string{"GET"}, (*HTTPServer).ACLAuthMethodList) + registerEndpoint("/v1/acl/auth-method", []string{"PUT"}, (*HTTPServer).ACLAuthMethodCreate) + registerEndpoint("/v1/acl/auth-method/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).ACLAuthMethodCRUD) registerEndpoint("/v1/acl/rules/translate", []string{"POST"}, (*HTTPServer).ACLRulesTranslate) registerEndpoint("/v1/acl/rules/translate/", []string{"GET"}, (*HTTPServer).ACLRulesTranslateLegacyToken) registerEndpoint("/v1/acl/tokens", []string{"GET"}, (*HTTPServer).ACLTokenList) diff --git a/agent/structs/acl.go b/agent/structs/acl.go index 5e052b414f..0fc63a12cd 100644 --- a/agent/structs/acl.go +++ b/agent/structs/acl.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "errors" "fmt" + "hash" "hash/fnv" "sort" "strings" @@ -84,6 +85,22 @@ session_prefix "" { // This is the policy ID for anonymous access. This is configurable by the // user. ACLTokenAnonymousID = "00000000-0000-0000-0000-000000000002" + + // aclPolicyTemplateServiceIdentity is the template used for synthesizing + // policies for service identities. + aclPolicyTemplateServiceIdentity = ` +service "%s" { + policy = "write" +} +service "%s-sidecar-proxy" { + policy = "write" +} +service_prefix "" { + policy = "read" +} +node_prefix "" { + policy = "read" +}` ) func ACLIDReserved(id string) bool { @@ -112,7 +129,10 @@ type ACLIdentity interface { ID() string SecretToken() string PolicyIDs() []string + RoleIDs() []string EmbeddedPolicy() *ACLPolicy + ServiceIdentityList() []*ACLServiceIdentity + IsExpired(asOf time.Time) bool } type ACLTokenPolicyLink struct { @@ -120,6 +140,65 @@ type ACLTokenPolicyLink struct { Name string `hash:"ignore"` } +type ACLTokenRoleLink struct { + ID string + Name string `hash:"ignore"` +} + +// ACLServiceIdentity represents a high-level grant of all necessary privileges +// to assume the identity of the named Service in the Catalog and within +// Connect. +type ACLServiceIdentity struct { + ServiceName string + + // Datacenters that the synthetic policy will be valid within. + // - No wildcards allowed + // - If empty then the synthetic policy is valid within all datacenters + // + // Only valid for global tokens. It is an error to specify this for local tokens. + Datacenters []string `json:",omitempty"` +} + +func (s *ACLServiceIdentity) Clone() *ACLServiceIdentity { + s2 := *s + s2.Datacenters = cloneStringSlice(s.Datacenters) + return &s2 +} + +func (s *ACLServiceIdentity) AddToHash(h hash.Hash) { + h.Write([]byte(s.ServiceName)) + for _, dc := range s.Datacenters { + h.Write([]byte(dc)) + } +} + +func (s *ACLServiceIdentity) EstimateSize() int { + size := len(s.ServiceName) + for _, dc := range s.Datacenters { + size += len(dc) + } + return size +} + +func (s *ACLServiceIdentity) SyntheticPolicy() *ACLPolicy { + // Given that we validate this string name before persisting, we do not + // have to escape it before doing the following interpolation. + rules := fmt.Sprintf(aclPolicyTemplateServiceIdentity, s.ServiceName, s.ServiceName) + + hasher := fnv.New128a() + hashID := fmt.Sprintf("%x", hasher.Sum([]byte(rules))) + + policy := &ACLPolicy{} + policy.ID = hashID + policy.Name = fmt.Sprintf("synthetic-policy-%s", hashID) + policy.Description = "synthetic policy" + policy.Rules = rules + policy.Syntax = acl.SyntaxCurrent + policy.Datacenters = s.Datacenters + policy.SetHash(true) + return policy +} + type ACLToken struct { // This is the UUID used for tracking and management purposes AccessorID string @@ -130,10 +209,18 @@ type ACLToken struct { // Human readable string to display for the token (Optional) Description string - // List of policy links - nil/empty for legacy tokens + // List of policy links - nil/empty for legacy tokens or if service identities are in use. // Note this is the list of IDs and not the names. Prior to token creation // the list of policy names gets validated and the policy IDs get stored herein - Policies []ACLTokenPolicyLink + Policies []ACLTokenPolicyLink `json:",omitempty"` + + // List of role links. Note this is the list of IDs and not the names. + // Prior to token creation the list of role names gets validated and the + // role IDs get stored herein + Roles []ACLTokenRoleLink `json:",omitempty"` + + // List of services to generate synthetic policies for. + ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` // Type is the V1 Token Type // DEPRECATED (ACL-Legacy-Compat) - remove once we no longer support v1 ACL compat @@ -150,6 +237,26 @@ type ACLToken struct { // to the ACL datacenter and replicated to others. Local bool + // AuthMethod is the name of the auth method used to create this token. + AuthMethod string `json:",omitempty"` + + // ExpirationTime represents the point after which a token should be + // considered revoked and is eligible for destruction. The zero value + // represents NO expiration. + // + // This is a pointer value so that the zero value is omitted properly + // during json serialization. time.Time does not respect json omitempty + // directives unfortunately. + ExpirationTime *time.Time `json:",omitempty"` + + // ExpirationTTL is a convenience field for helping set ExpirationTime to a + // value of CreateTime+ExpirationTTL. This can only be set during + // TokenCreate and is cleared and used to initialize the ExpirationTime + // field before being persisted to the state store or raft log. + // + // This is a string version of a time.Duration like "2m". + ExpirationTTL time.Duration `json:",omitempty"` + // The time when this token was created CreateTime time.Time `json:",omitempty"` @@ -167,11 +274,23 @@ type ACLToken struct { func (t *ACLToken) Clone() *ACLToken { t2 := *t t2.Policies = nil + t2.Roles = nil + t2.ServiceIdentities = nil if len(t.Policies) > 0 { t2.Policies = make([]ACLTokenPolicyLink, len(t.Policies)) copy(t2.Policies, t.Policies) } + if len(t.Roles) > 0 { + t2.Roles = make([]ACLTokenRoleLink, len(t.Roles)) + copy(t2.Roles, t.Roles) + } + if len(t.ServiceIdentities) > 0 { + t2.ServiceIdentities = make([]*ACLServiceIdentity, len(t.ServiceIdentities)) + for i, s := range t.ServiceIdentities { + t2.ServiceIdentities[i] = s.Clone() + } + } return &t2 } @@ -184,13 +303,62 @@ func (t *ACLToken) SecretToken() string { } func (t *ACLToken) PolicyIDs() []string { - var ids []string + if len(t.Policies) == 0 { + return nil + } + + ids := make([]string, 0, len(t.Policies)) for _, link := range t.Policies { ids = append(ids, link.ID) } return ids } +func (t *ACLToken) RoleIDs() []string { + if len(t.Roles) == 0 { + return nil + } + + ids := make([]string, 0, len(t.Roles)) + for _, link := range t.Roles { + ids = append(ids, link.ID) + } + return ids +} + +func (t *ACLToken) ServiceIdentityList() []*ACLServiceIdentity { + if len(t.ServiceIdentities) == 0 { + return nil + } + + out := make([]*ACLServiceIdentity, 0, len(t.ServiceIdentities)) + for _, s := range t.ServiceIdentities { + out = append(out, s.Clone()) + } + return out +} + +func (t *ACLToken) IsExpired(asOf time.Time) bool { + if asOf.IsZero() || !t.HasExpirationTime() { + return false + } + return t.ExpirationTime.Before(asOf) +} + +func (t *ACLToken) HasExpirationTime() bool { + return t.ExpirationTime != nil && !t.ExpirationTime.IsZero() +} + +func (t *ACLToken) UsesNonLegacyFields() bool { + return len(t.Policies) > 0 || + len(t.ServiceIdentities) > 0 || + len(t.Roles) > 0 || + t.Type == "" || + t.HasExpirationTime() || + t.ExpirationTTL != 0 || + t.AuthMethod != "" +} + func (t *ACLToken) EmbeddedPolicy() *ACLPolicy { // DEPRECATED (ACL-Legacy-Compat) // @@ -229,6 +397,14 @@ func (t *ACLToken) SetHash(force bool) []byte { panic(err) } + // Any non-immutable "content" fields should be involved with the + // overall hash. The IDs are immutable which is why they aren't here. + // The raft indices are metadata similar to the hash which is why they + // aren't incorporated. CreateTime is similarly immutable + // + // The Hash is really only used for replication to determine if a token + // has changed and should be updated locally. + // Write all the user set fields hash.Write([]byte(t.Description)) hash.Write([]byte(t.Type)) @@ -244,6 +420,14 @@ func (t *ACLToken) SetHash(force bool) []byte { hash.Write([]byte(link.ID)) } + for _, link := range t.Roles { + hash.Write([]byte(link.ID)) + } + + for _, srvid := range t.ServiceIdentities { + srvid.AddToHash(hash) + } + // Finalize the hash hashVal := hash.Sum(nil) @@ -254,11 +438,17 @@ func (t *ACLToken) SetHash(force bool) []byte { } func (t *ACLToken) EstimateSize() int { - // 33 = 16 (RaftIndex) + 8 (Hash) + 8 (CreateTime) + 1 (Local) - size := 33 + len(t.AccessorID) + len(t.SecretID) + len(t.Description) + len(t.Type) + len(t.Rules) + // 41 = 16 (RaftIndex) + 8 (Hash) + 8 (ExpirationTime) + 8 (CreateTime) + 1 (Local) + size := 41 + len(t.AccessorID) + len(t.SecretID) + len(t.Description) + len(t.Type) + len(t.Rules) + len(t.AuthMethod) for _, link := range t.Policies { size += len(link.ID) + len(link.Name) } + for _, link := range t.Roles { + size += len(link.ID) + len(link.Name) + } + for _, srvid := range t.ServiceIdentities { + size += srvid.EstimateSize() + } return size } @@ -266,30 +456,38 @@ func (t *ACLToken) EstimateSize() int { type ACLTokens []*ACLToken type ACLTokenListStub struct { - AccessorID string - Description string - Policies []ACLTokenPolicyLink - Local bool - CreateTime time.Time `json:",omitempty"` - Hash []byte - CreateIndex uint64 - ModifyIndex uint64 - Legacy bool `json:",omitempty"` + AccessorID string + Description string + Policies []ACLTokenPolicyLink `json:",omitempty"` + Roles []ACLTokenRoleLink `json:",omitempty"` + ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` + Local bool + AuthMethod string `json:",omitempty"` + ExpirationTime *time.Time `json:",omitempty"` + CreateTime time.Time `json:",omitempty"` + Hash []byte + CreateIndex uint64 + ModifyIndex uint64 + Legacy bool `json:",omitempty"` } type ACLTokenListStubs []*ACLTokenListStub func (token *ACLToken) Stub() *ACLTokenListStub { return &ACLTokenListStub{ - AccessorID: token.AccessorID, - Description: token.Description, - Policies: token.Policies, - Local: token.Local, - CreateTime: token.CreateTime, - Hash: token.Hash, - CreateIndex: token.CreateIndex, - ModifyIndex: token.ModifyIndex, - Legacy: token.Rules != "", + AccessorID: token.AccessorID, + Description: token.Description, + Policies: token.Policies, + Roles: token.Roles, + ServiceIdentities: token.ServiceIdentities, + Local: token.Local, + AuthMethod: token.AuthMethod, + ExpirationTime: token.ExpirationTime, + CreateTime: token.CreateTime, + Hash: token.Hash, + CreateIndex: token.CreateIndex, + ModifyIndex: token.ModifyIndex, + Legacy: token.Rules != "", } } @@ -343,11 +541,7 @@ type ACLPolicy struct { func (p *ACLPolicy) Clone() *ACLPolicy { p2 := *p - p2.Datacenters = nil - if len(p.Datacenters) > 0 { - p2.Datacenters = make([]string, len(p.Datacenters)) - copy(p2.Datacenters, p.Datacenters) - } + p2.Datacenters = cloneStringSlice(p.Datacenters) return &p2 } @@ -384,6 +578,14 @@ func (p *ACLPolicy) SetHash(force bool) []byte { panic(err) } + // Any non-immutable "content" fields should be involved with the + // overall hash. The ID is immutable which is why it isn't here. The + // raft indices are metadata similar to the hash which is why they + // aren't incorporated. CreateTime is similarly immutable + // + // The Hash is really only used for replication to determine if a policy + // has changed and should be updated locally. + // Write all the user set fields hash.Write([]byte(p.Name)) hash.Write([]byte(p.Description)) @@ -414,7 +616,7 @@ func (p *ACLPolicy) EstimateSize() int { return size } -// ACLPolicyListHash returns a consistent hash for a set of policies. +// HashKey returns a consistent hash for a set of policies. func (policies ACLPolicies) HashKey() string { cacheKeyHash, err := blake2b.New256(nil) if err != nil { @@ -500,14 +702,288 @@ func (policies ACLPolicies) Merge(cache *ACLCaches, sentinel sentinel.Evaluator) return acl.MergePolicies(parsed), nil } +type ACLRoles []*ACLRole + +// HashKey returns a consistent hash for a set of roles. +func (roles ACLRoles) HashKey() string { + cacheKeyHash, err := blake2b.New256(nil) + if err != nil { + panic(err) + } + for _, role := range roles { + cacheKeyHash.Write([]byte(role.ID)) + // including the modify index prevents a role set from being + // cached if one of the roles has changed + binary.Write(cacheKeyHash, binary.BigEndian, role.ModifyIndex) + } + return fmt.Sprintf("%x", cacheKeyHash.Sum(nil)) +} + +func (roles ACLRoles) Sort() { + sort.Slice(roles, func(i, j int) bool { + return roles[i].ID < roles[j].ID + }) +} + +type ACLRolePolicyLink struct { + ID string + Name string `hash:"ignore"` +} + +type ACLRole struct { + // ID is the internal UUID associated with the role + ID string + + // Name is the unique name to reference the role by. + Name string + + // Description is a human readable description (Optional) + Description string + + // List of policy links. + // Note this is the list of IDs and not the names. Prior to role creation + // the list of policy names gets validated and the policy IDs get stored herein + Policies []ACLRolePolicyLink `json:",omitempty"` + + // List of services to generate synthetic policies for. + ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` + + // Hash of the contents of the role + // This does not take into account the ID (which is immutable) + // nor the raft metadata. + // + // This is needed mainly for replication purposes. When replicating from + // one DC to another keeping the content Hash will allow us to avoid + // unnecessary calls to the authoritative DC + Hash []byte + + // Embedded Raft Metadata + RaftIndex `hash:"ignore"` +} + +func (r *ACLRole) Clone() *ACLRole { + r2 := *r + r2.Policies = nil + r2.ServiceIdentities = nil + + if len(r.Policies) > 0 { + r2.Policies = make([]ACLRolePolicyLink, len(r.Policies)) + copy(r2.Policies, r.Policies) + } + if len(r.ServiceIdentities) > 0 { + r2.ServiceIdentities = make([]*ACLServiceIdentity, len(r.ServiceIdentities)) + for i, s := range r.ServiceIdentities { + r2.ServiceIdentities[i] = s.Clone() + } + } + return &r2 +} + +func (r *ACLRole) SetHash(force bool) []byte { + if force || r.Hash == nil { + // Initialize a 256bit Blake2 hash (32 bytes) + hash, err := blake2b.New256(nil) + if err != nil { + panic(err) + } + + // Any non-immutable "content" fields should be involved with the + // overall hash. The ID is immutable which is why it isn't here. The + // raft indices are metadata similar to the hash which is why they + // aren't incorporated. CreateTime is similarly immutable + // + // The Hash is really only used for replication to determine if a role + // has changed and should be updated locally. + + // Write all the user set fields + hash.Write([]byte(r.Name)) + hash.Write([]byte(r.Description)) + for _, link := range r.Policies { + hash.Write([]byte(link.ID)) + } + for _, srvid := range r.ServiceIdentities { + srvid.AddToHash(hash) + } + + // Finalize the hash + hashVal := hash.Sum(nil) + + // Set and return the hash + r.Hash = hashVal + } + return r.Hash +} + +func (r *ACLRole) EstimateSize() int { + // This is just an estimate. There is other data structure overhead + // pointers etc that this does not account for. + + // 60 = 36 (uuid) + 16 (RaftIndex) + 8 (Hash) + size := 60 + len(r.Name) + len(r.Description) + for _, link := range r.Policies { + size += len(link.ID) + len(link.Name) + } + for _, srvid := range r.ServiceIdentities { + size += srvid.EstimateSize() + } + + return size +} + +const ( + // BindingRuleBindTypeService is the binding rule bind type that + // assigns a Service Identity to the token that is created using the value + // of the computed BindName as the ServiceName like: + // + // &ACLToken{ + // ...other fields... + // ServiceIdentities: []*ACLServiceIdentity{ + // &ACLServiceIdentity{ + // ServiceName: "", + // }, + // }, + // } + BindingRuleBindTypeService = "service" + + // BindingRuleBindTypeRole is the binding rule bind type that only allows + // the binding rule to function if a role with the given name (BindName) + // exists at login-time. If it does the token that is created is directly + // linked to that role like: + // + // &ACLToken{ + // ...other fields... + // Roles: []ACLTokenRoleLink{ + // { Name: "" } + // } + // } + // + // If it does not exist at login-time the rule is ignored. + BindingRuleBindTypeRole = "role" +) + +type ACLBindingRule struct { + // ID is the internal UUID associated with the binding rule + ID string + + // Description is a human readable description (Optional) + Description string + + // AuthMethod is the name of the auth method for which this rule applies. + AuthMethod string + + // Selector is an expression that matches against verified identity + // attributes returned from the auth method during login. + Selector string + + // BindType adjusts how this binding rule is applied at login time. The + // valid values are: + // + // - BindingRuleBindTypeService = "service" + // - BindingRuleBindTypeRole = "role" + BindType string + + // BindName is the target of the binding. Can be lightly templated using + // HIL ${foo} syntax from available field names. How it is used depends + // upon the BindType. + BindName string + + // Embedded Raft Metadata + RaftIndex `hash:"ignore"` +} + +func (r *ACLBindingRule) Clone() *ACLBindingRule { + r2 := *r + return &r2 +} + +type ACLBindingRules []*ACLBindingRule + +func (rules ACLBindingRules) Sort() { + sort.Slice(rules, func(i, j int) bool { + return rules[i].ID < rules[j].ID + }) +} + +type ACLAuthMethodListStub struct { + Name string + Description string + Type string + CreateIndex uint64 + ModifyIndex uint64 +} + +func (p *ACLAuthMethod) Stub() *ACLAuthMethodListStub { + return &ACLAuthMethodListStub{ + Name: p.Name, + Description: p.Description, + Type: p.Type, + CreateIndex: p.CreateIndex, + ModifyIndex: p.ModifyIndex, + } +} + +type ACLAuthMethods []*ACLAuthMethod +type ACLAuthMethodListStubs []*ACLAuthMethodListStub + +func (methods ACLAuthMethods) Sort() { + sort.Slice(methods, func(i, j int) bool { + return methods[i].Name < methods[j].Name + }) +} + +func (methods ACLAuthMethodListStubs) Sort() { + sort.Slice(methods, func(i, j int) bool { + return methods[i].Name < methods[j].Name + }) +} + +type ACLAuthMethod struct { + // Name is a unique identifier for this specific auth method. + // + // Immutable once set and only settable during create. + Name string + + // Type is the type of the auth method this is. + // + // Immutable once set and only settable during create. + Type string + + // Description is just an optional bunch of explanatory text. + Description string + + // Configuration is arbitrary configuration for the auth method. This + // should only contain primitive values and containers (such as lists and + // maps). + Config map[string]interface{} + + // Embedded Raft Metadata + RaftIndex `hash:"ignore"` +} + type ACLReplicationType string const ( ACLReplicateLegacy ACLReplicationType = "legacy" ACLReplicatePolicies ACLReplicationType = "policies" + ACLReplicateRoles ACLReplicationType = "roles" ACLReplicateTokens ACLReplicationType = "tokens" ) +func (t ACLReplicationType) SingularNoun() string { + switch t { + case ACLReplicateLegacy: + return "legacy" + case ACLReplicatePolicies: + return "policy" + case ACLReplicateRoles: + return "role" + case ACLReplicateTokens: + return "token" + default: + return "" + } +} + // ACLReplicationStatus provides information about the health of the ACL // replication system. type ACLReplicationStatus struct { @@ -516,6 +992,7 @@ type ACLReplicationStatus struct { SourceDatacenter string ReplicationType ACLReplicationType ReplicatedIndex uint64 + ReplicatedRoleIndex uint64 ReplicatedTokenIndex uint64 LastSuccess time.Time LastError time.Time @@ -561,6 +1038,8 @@ type ACLTokenListRequest struct { IncludeLocal bool // Whether local tokens should be included IncludeGlobal bool // Whether global tokens should be included Policy string // Policy filter + Role string // Role filter + AuthMethod string // Auth Method filter Datacenter string // The datacenter to perform the request within QueryOptions } @@ -719,3 +1198,259 @@ type ACLPolicyBatchSetRequest struct { type ACLPolicyBatchDeleteRequest struct { PolicyIDs []string } + +func cloneStringSlice(s []string) []string { + if len(s) == 0 { + return nil + } + out := make([]string, len(s)) + copy(out, s) + return out +} + +// ACLRoleSetRequest is used at the RPC layer for creation and update requests +type ACLRoleSetRequest struct { + Role ACLRole // The role to upsert + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLRoleSetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLRoleDeleteRequest is used at the RPC layer deletion requests +type ACLRoleDeleteRequest struct { + RoleID string // id of the role to delete + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLRoleDeleteRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLRoleGetRequest is used at the RPC layer to perform role read operations +type ACLRoleGetRequest struct { + RoleID string // id used for the role lookup (one of RoleID or RoleName is allowed) + RoleName string // name used for the role lookup (one of RoleID or RoleName is allowed) + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLRoleGetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLRoleListRequest is used at the RPC layer to request a listing of roles +type ACLRoleListRequest struct { + Policy string // Policy filter + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLRoleListRequest) RequestDatacenter() string { + return r.Datacenter +} + +type ACLRoleListResponse struct { + Roles ACLRoles + QueryMeta +} + +// ACLRoleBatchGetRequest is used at the RPC layer to request a subset of +// the roles associated with the token used for retrieval +type ACLRoleBatchGetRequest struct { + RoleIDs []string // List of role ids to fetch + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLRoleBatchGetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLRoleResponse returns a single role + metadata +type ACLRoleResponse struct { + Role *ACLRole + QueryMeta +} + +type ACLRoleBatchResponse struct { + Roles []*ACLRole + QueryMeta +} + +// ACLRoleBatchSetRequest is used at the Raft layer for batching +// multiple role creations and updates +// +// This is particularly useful during replication +type ACLRoleBatchSetRequest struct { + Roles ACLRoles +} + +// ACLRoleBatchDeleteRequest is used at the Raft layer for batching +// multiple role deletions +// +// This is particularly useful during replication +type ACLRoleBatchDeleteRequest struct { + RoleIDs []string +} + +// ACLBindingRuleSetRequest is used at the RPC layer for creation and update requests +type ACLBindingRuleSetRequest struct { + BindingRule ACLBindingRule // The rule to upsert + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLBindingRuleSetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLBindingRuleDeleteRequest is used at the RPC layer deletion requests +type ACLBindingRuleDeleteRequest struct { + BindingRuleID string // id of the rule to delete + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLBindingRuleDeleteRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLBindingRuleGetRequest is used at the RPC layer to perform rule read operations +type ACLBindingRuleGetRequest struct { + BindingRuleID string // id used for the rule lookup + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLBindingRuleGetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLBindingRuleListRequest is used at the RPC layer to request a listing of rules +type ACLBindingRuleListRequest struct { + AuthMethod string // optional filter + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLBindingRuleListRequest) RequestDatacenter() string { + return r.Datacenter +} + +type ACLBindingRuleListResponse struct { + BindingRules ACLBindingRules + QueryMeta +} + +// ACLBindingRuleResponse returns a single binding + metadata +type ACLBindingRuleResponse struct { + BindingRule *ACLBindingRule + QueryMeta +} + +// ACLBindingRuleBatchSetRequest is used at the Raft layer for batching +// multiple rule creations and updates +type ACLBindingRuleBatchSetRequest struct { + BindingRules ACLBindingRules +} + +// ACLBindingRuleBatchDeleteRequest is used at the Raft layer for batching +// multiple rule deletions +type ACLBindingRuleBatchDeleteRequest struct { + BindingRuleIDs []string +} + +// ACLAuthMethodSetRequest is used at the RPC layer for creation and update requests +type ACLAuthMethodSetRequest struct { + AuthMethod ACLAuthMethod // The auth method to upsert + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLAuthMethodSetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLAuthMethodDeleteRequest is used at the RPC layer deletion requests +type ACLAuthMethodDeleteRequest struct { + AuthMethodName string // name of the auth method to delete + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLAuthMethodDeleteRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLAuthMethodGetRequest is used at the RPC layer to perform rule read operations +type ACLAuthMethodGetRequest struct { + AuthMethodName string // name used for the auth method lookup + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLAuthMethodGetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLAuthMethodListRequest is used at the RPC layer to request a listing of auth methods +type ACLAuthMethodListRequest struct { + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLAuthMethodListRequest) RequestDatacenter() string { + return r.Datacenter +} + +type ACLAuthMethodListResponse struct { + AuthMethods ACLAuthMethodListStubs + QueryMeta +} + +// ACLAuthMethodResponse returns a single auth method + metadata +type ACLAuthMethodResponse struct { + AuthMethod *ACLAuthMethod + QueryMeta +} + +// ACLAuthMethodBatchSetRequest is used at the Raft layer for batching +// multiple auth method creations and updates +type ACLAuthMethodBatchSetRequest struct { + AuthMethods ACLAuthMethods +} + +// ACLAuthMethodBatchDeleteRequest is used at the Raft layer for batching +// multiple auth method deletions +type ACLAuthMethodBatchDeleteRequest struct { + AuthMethodNames []string +} + +type ACLLoginParams struct { + AuthMethod string + BearerToken string + Meta map[string]string `json:",omitempty"` +} + +type ACLLoginRequest struct { + Auth *ACLLoginParams + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLLoginRequest) RequestDatacenter() string { + return r.Datacenter +} + +type ACLLogoutRequest struct { + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLLogoutRequest) RequestDatacenter() string { + return r.Datacenter +} diff --git a/agent/structs/acl_cache.go b/agent/structs/acl_cache.go index 9e7df64053..1494727070 100644 --- a/agent/structs/acl_cache.go +++ b/agent/structs/acl_cache.go @@ -4,7 +4,7 @@ import ( "time" "github.com/hashicorp/consul/acl" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" ) type ACLCachesConfig struct { @@ -12,6 +12,7 @@ type ACLCachesConfig struct { Policies int ParsedPolicies int Authorizers int + Roles int } type ACLCaches struct { @@ -19,6 +20,7 @@ type ACLCaches struct { parsedPolicies *lru.TwoQueueCache // policy content hash -> acl.Policy policies *lru.TwoQueueCache // policy ID -> ACLPolicy authorizers *lru.TwoQueueCache // token secret -> acl.Authorizer + roles *lru.TwoQueueCache // role ID -> ACLRole } type IdentityCacheEntry struct { @@ -58,6 +60,15 @@ func (e *AuthorizerCacheEntry) Age() time.Duration { return time.Since(e.CacheTime) } +type RoleCacheEntry struct { + Role *ACLRole + CacheTime time.Time +} + +func (e *RoleCacheEntry) Age() time.Duration { + return time.Since(e.CacheTime) +} + func NewACLCaches(config *ACLCachesConfig) (*ACLCaches, error) { cache := &ACLCaches{} @@ -97,6 +108,15 @@ func NewACLCaches(config *ACLCachesConfig) (*ACLCaches, error) { cache.authorizers = authCache } + if config != nil && config.Roles > 0 { + roleCache, err := lru.New2Q(config.Roles) + if err != nil { + return nil, err + } + + cache.roles = roleCache + } + return cache, nil } @@ -152,6 +172,19 @@ func (c *ACLCaches) GetAuthorizer(id string) *AuthorizerCacheEntry { return nil } +// GetRole fetches a role from the cache by id and returns it +func (c *ACLCaches) GetRole(roleID string) *RoleCacheEntry { + if c == nil || c.roles == nil { + return nil + } + + if raw, ok := c.roles.Get(roleID); ok { + return raw.(*RoleCacheEntry) + } + + return nil +} + // PutIdentity adds a new identity to the cache func (c *ACLCaches) PutIdentity(id string, ident ACLIdentity) { if c == nil || c.identities == nil { @@ -193,6 +226,14 @@ func (c *ACLCaches) PutAuthorizerWithTTL(id string, authorizer acl.Authorizer, t c.authorizers.Add(id, &AuthorizerCacheEntry{Authorizer: authorizer, CacheTime: time.Now(), TTL: ttl}) } +func (c *ACLCaches) PutRole(roleID string, role *ACLRole) { + if c == nil || c.roles == nil { + return + } + + c.roles.Add(roleID, &RoleCacheEntry{Role: role, CacheTime: time.Now()}) +} + func (c *ACLCaches) RemoveIdentity(id string) { if c != nil && c.identities != nil { c.identities.Remove(id) @@ -205,6 +246,12 @@ func (c *ACLCaches) RemovePolicy(policyID string) { } } +func (c *ACLCaches) RemoveRole(roleID string) { + if c != nil && c.roles != nil { + c.roles.Remove(roleID) + } +} + func (c *ACLCaches) Purge() { if c != nil { if c.identities != nil { @@ -219,5 +266,8 @@ func (c *ACLCaches) Purge() { if c.authorizers != nil { c.authorizers.Purge() } + if c.roles != nil { + c.roles.Purge() + } } } diff --git a/agent/structs/acl_cache_test.go b/agent/structs/acl_cache_test.go index 471dc408a2..337d1860f3 100644 --- a/agent/structs/acl_cache_test.go +++ b/agent/structs/acl_cache_test.go @@ -16,7 +16,7 @@ func TestStructs_ACLCaches(t *testing.T) { t.Run("Valid Sizes", func(t *testing.T) { t.Parallel() // 1 isn't valid due to a bug in golang-lru library - config := ACLCachesConfig{2, 2, 2, 2} + config := ACLCachesConfig{2, 2, 2, 2, 2} cache, err := NewACLCaches(&config) require.NoError(t, err) @@ -30,7 +30,7 @@ func TestStructs_ACLCaches(t *testing.T) { t.Run("Zero Sizes", func(t *testing.T) { t.Parallel() // 1 isn't valid due to a bug in golang-lru library - config := ACLCachesConfig{0, 0, 0, 0} + config := ACLCachesConfig{0, 0, 0, 0, 0} cache, err := NewACLCaches(&config) require.NoError(t, err) @@ -102,4 +102,20 @@ func TestStructs_ACLCaches(t *testing.T) { require.NotNil(t, entry.Authorizer) require.True(t, entry.Authorizer == acl.DenyAll()) }) + + t.Run("Roles", func(t *testing.T) { + t.Parallel() + // 1 isn't valid due to a bug in golang-lru library + config := ACLCachesConfig{Roles: 4} + + cache, err := NewACLCaches(&config) + require.NoError(t, err) + require.NotNil(t, cache) + + cache.PutRole("foo", &ACLRole{}) + + entry := cache.GetRole("foo") + require.NotNil(t, entry) + require.NotNil(t, entry.Role) + }) } diff --git a/agent/structs/acl_legacy.go b/agent/structs/acl_legacy.go index ebd8ece82b..fa182d8884 100644 --- a/agent/structs/acl_legacy.go +++ b/agent/structs/acl_legacy.go @@ -73,14 +73,15 @@ func (a *ACL) Convert() *ACLToken { } return &ACLToken{ - AccessorID: "", - SecretID: a.ID, - Description: a.Name, - Policies: nil, - Type: a.Type, - Rules: a.Rules, - Local: false, - RaftIndex: a.RaftIndex, + AccessorID: "", + SecretID: a.ID, + Description: a.Name, + Policies: nil, + ServiceIdentities: nil, + Type: a.Type, + Rules: a.Rules, + Local: false, + RaftIndex: a.RaftIndex, } } diff --git a/agent/structs/acl_test.go b/agent/structs/acl_test.go index fba38545bf..a7860a49d6 100644 --- a/agent/structs/acl_test.go +++ b/agent/structs/acl_test.go @@ -140,6 +140,70 @@ func TestStructs_ACLToken_EmbeddedPolicy(t *testing.T) { }) } +func TestStructs_ACLServiceIdentity_SyntheticPolicy(t *testing.T) { + t.Parallel() + + for _, test := range []struct { + serviceName string + datacenters []string + expectRules string + }{ + {"web", nil, ` +service "web" { + policy = "write" +} +service "web-sidecar-proxy" { + policy = "write" +} +service_prefix "" { + policy = "read" +} +node_prefix "" { + policy = "read" +}`}, + {"companion-cube-99", []string{"dc1", "dc2"}, ` +service "companion-cube-99" { + policy = "write" +} +service "companion-cube-99-sidecar-proxy" { + policy = "write" +} +service_prefix "" { + policy = "read" +} +node_prefix "" { + policy = "read" +}`}, + } { + name := test.serviceName + if len(test.datacenters) > 0 { + name += " [" + strings.Join(test.datacenters, ", ") + "]" + } + t.Run(name, func(t *testing.T) { + svcid := &ACLServiceIdentity{ + ServiceName: test.serviceName, + Datacenters: test.datacenters, + } + + expect := &ACLPolicy{ + Syntax: acl.SyntaxCurrent, + Datacenters: test.datacenters, + Description: "synthetic policy", + Rules: test.expectRules, + } + + got := svcid.SyntheticPolicy() + require.NotEmpty(t, got.ID) + require.True(t, strings.HasPrefix(got.Name, "synthetic-policy-")) + // strip irrelevant fields before equality + got.ID = "" + got.Name = "" + got.Hash = nil + require.Equal(t, expect, got) + }) + } +} + func TestStructs_ACLToken_SetHash(t *testing.T) { t.Parallel() @@ -208,7 +272,7 @@ func TestStructs_ACLToken_EstimateSize(t *testing.T) { // this test is very contrived. Basically just tests that the // math is okay and returns the value. - require.Equal(t, 120, token.EstimateSize()) + require.Equal(t, 128, token.EstimateSize()) } func TestStructs_ACLToken_Stub(t *testing.T) { @@ -451,6 +515,7 @@ func TestStructs_ACLPolicies_resolveWithCache(t *testing.T) { Policies: 0, ParsedPolicies: 4, Authorizers: 0, + Roles: 0, } cache, err := NewACLCaches(&config) require.NoError(t, err) @@ -543,6 +608,7 @@ func TestStructs_ACLPolicies_Compile(t *testing.T) { Policies: 0, ParsedPolicies: 4, Authorizers: 2, + Roles: 0, } cache, err := NewACLCaches(&config) require.NoError(t, err) diff --git a/agent/structs/structs.go b/agent/structs/structs.go index c1acf7a5b4..56a19fc4e7 100644 --- a/agent/structs/structs.go +++ b/agent/structs/structs.go @@ -33,29 +33,35 @@ type RaftIndex struct { // These are serialized between Consul servers and stored in Consul snapshots, // so entries must only ever be added. const ( - RegisterRequestType MessageType = 0 - DeregisterRequestType = 1 - KVSRequestType = 2 - SessionRequestType = 3 - ACLRequestType = 4 // DEPRECATED (ACL-Legacy-Compat) - TombstoneRequestType = 5 - CoordinateBatchUpdateType = 6 - PreparedQueryRequestType = 7 - TxnRequestType = 8 - AutopilotRequestType = 9 - AreaRequestType = 10 - ACLBootstrapRequestType = 11 - IntentionRequestType = 12 - ConnectCARequestType = 13 - ConnectCAProviderStateType = 14 - ConnectCAConfigType = 15 // FSM snapshots only. - IndexRequestType = 16 // FSM snapshots only. - ACLTokenSetRequestType = 17 - ACLTokenDeleteRequestType = 18 - ACLPolicySetRequestType = 19 - ACLPolicyDeleteRequestType = 20 - ConnectCALeafRequestType = 21 - ConfigEntryRequestType = 22 + RegisterRequestType MessageType = 0 + DeregisterRequestType = 1 + KVSRequestType = 2 + SessionRequestType = 3 + ACLRequestType = 4 // DEPRECATED (ACL-Legacy-Compat) + TombstoneRequestType = 5 + CoordinateBatchUpdateType = 6 + PreparedQueryRequestType = 7 + TxnRequestType = 8 + AutopilotRequestType = 9 + AreaRequestType = 10 + ACLBootstrapRequestType = 11 + IntentionRequestType = 12 + ConnectCARequestType = 13 + ConnectCAProviderStateType = 14 + ConnectCAConfigType = 15 // FSM snapshots only. + IndexRequestType = 16 // FSM snapshots only. + ACLTokenSetRequestType = 17 + ACLTokenDeleteRequestType = 18 + ACLPolicySetRequestType = 19 + ACLPolicyDeleteRequestType = 20 + ConnectCALeafRequestType = 21 + ConfigEntryRequestType = 22 + ACLRoleSetRequestType = 23 + ACLRoleDeleteRequestType = 24 + ACLBindingRuleSetRequestType = 25 + ACLBindingRuleDeleteRequestType = 26 + ACLAuthMethodSetRequestType = 27 + ACLAuthMethodDeleteRequestType = 28 ) const ( diff --git a/api/acl.go b/api/acl.go index 53a052363e..3327f667c3 100644 --- a/api/acl.go +++ b/api/acl.go @@ -4,7 +4,10 @@ import ( "fmt" "io" "io/ioutil" + "net/url" "time" + + "github.com/mitchellh/mapstructure" ) const ( @@ -19,18 +22,26 @@ type ACLTokenPolicyLink struct { ID string Name string } +type ACLTokenRoleLink struct { + ID string + Name string +} // ACLToken represents an ACL Token type ACLToken struct { - CreateIndex uint64 - ModifyIndex uint64 - AccessorID string - SecretID string - Description string - Policies []*ACLTokenPolicyLink - Local bool - CreateTime time.Time `json:",omitempty"` - Hash []byte `json:",omitempty"` + CreateIndex uint64 + ModifyIndex uint64 + AccessorID string + SecretID string + Description string + Policies []*ACLTokenPolicyLink `json:",omitempty"` + Roles []*ACLTokenRoleLink `json:",omitempty"` + ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` + Local bool + ExpirationTTL time.Duration `json:",omitempty"` + ExpirationTime *time.Time `json:",omitempty"` + CreateTime time.Time `json:",omitempty"` + Hash []byte `json:",omitempty"` // DEPRECATED (ACL-Legacy-Compat) // Rules will only be present for legacy tokens returned via the new APIs @@ -38,15 +49,18 @@ type ACLToken struct { } type ACLTokenListEntry struct { - CreateIndex uint64 - ModifyIndex uint64 - AccessorID string - Description string - Policies []*ACLTokenPolicyLink - Local bool - CreateTime time.Time - Hash []byte - Legacy bool + CreateIndex uint64 + ModifyIndex uint64 + AccessorID string + Description string + Policies []*ACLTokenPolicyLink `json:",omitempty"` + Roles []*ACLTokenRoleLink `json:",omitempty"` + ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` + Local bool + ExpirationTime *time.Time `json:",omitempty"` + CreateTime time.Time + Hash []byte + Legacy bool } // ACLEntry is used to represent a legacy ACL token @@ -67,11 +81,20 @@ type ACLReplicationStatus struct { SourceDatacenter string ReplicationType string ReplicatedIndex uint64 + ReplicatedRoleIndex uint64 ReplicatedTokenIndex uint64 LastSuccess time.Time LastError time.Time } +// ACLServiceIdentity represents a high-level grant of all necessary privileges +// to assume the identity of the named Service in the Catalog and within +// Connect. +type ACLServiceIdentity struct { + ServiceName string + Datacenters []string `json:",omitempty"` +} + // ACLPolicy represents an ACL Policy. type ACLPolicy struct { ID string @@ -94,6 +117,113 @@ type ACLPolicyListEntry struct { ModifyIndex uint64 } +type ACLRolePolicyLink struct { + ID string + Name string +} + +// ACLRole represents an ACL Role. +type ACLRole struct { + ID string + Name string + Description string + Policies []*ACLRolePolicyLink `json:",omitempty"` + ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` + Hash []byte + CreateIndex uint64 + ModifyIndex uint64 +} + +// BindingRuleBindType is the type of binding rule mechanism used. +type BindingRuleBindType string + +const ( + // BindingRuleBindTypeService binds to a service identity with the given name. + BindingRuleBindTypeService BindingRuleBindType = "service" + + // BindingRuleBindTypeRole binds to pre-existing roles with the given name. + BindingRuleBindTypeRole BindingRuleBindType = "role" +) + +type ACLBindingRule struct { + ID string + Description string + AuthMethod string + Selector string + BindType BindingRuleBindType + BindName string + + CreateIndex uint64 + ModifyIndex uint64 +} + +type ACLAuthMethod struct { + Name string + Type string + Description string + + // Configuration is arbitrary configuration for the auth method. This + // should only contain primitive values and containers (such as lists and + // maps). + Config map[string]interface{} + + CreateIndex uint64 + ModifyIndex uint64 +} + +type ACLAuthMethodListEntry struct { + Name string + Type string + Description string + CreateIndex uint64 + ModifyIndex uint64 +} + +// ParseKubernetesAuthMethodConfig takes a raw config map and returns a parsed +// KubernetesAuthMethodConfig. +func ParseKubernetesAuthMethodConfig(raw map[string]interface{}) (*KubernetesAuthMethodConfig, error) { + var config KubernetesAuthMethodConfig + decodeConf := &mapstructure.DecoderConfig{ + Result: &config, + WeaklyTypedInput: true, + } + + decoder, err := mapstructure.NewDecoder(decodeConf) + if err != nil { + return nil, err + } + + if err := decoder.Decode(raw); err != nil { + return nil, fmt.Errorf("error decoding config: %s", err) + } + + return &config, nil +} + +// KubernetesAuthMethodConfig is the config for the built-in Consul auth method +// for Kubernetes. +type KubernetesAuthMethodConfig struct { + Host string `json:",omitempty"` + CACert string `json:",omitempty"` + ServiceAccountJWT string `json:",omitempty"` +} + +// RenderToConfig converts this into a map[string]interface{} suitable for use +// in the ACLAuthMethod.Config field. +func (c *KubernetesAuthMethodConfig) RenderToConfig() map[string]interface{} { + return map[string]interface{}{ + "Host": c.Host, + "CACert": c.CACert, + "ServiceAccountJWT": c.ServiceAccountJWT, + } +} + +type ACLLoginParams struct { + AuthMethod string + BearerToken string + Meta map[string]string `json:",omitempty"` +} + // ACL can be used to query the ACL endpoints type ACL struct { c *Client @@ -460,7 +590,7 @@ func (a *ACL) PolicyCreate(policy *ACLPolicy, q *WriteOptions) (*ACLPolicy, *Wri // existing policy ID func (a *ACL) PolicyUpdate(policy *ACLPolicy, q *WriteOptions) (*ACLPolicy, *WriteMeta, error) { if policy.ID == "" { - return nil, nil, fmt.Errorf("Must specify an ID in Policy Creation") + return nil, nil, fmt.Errorf("Must specify an ID in Policy Update") } r := a.c.newRequest("PUT", "/v1/acl/policy/"+policy.ID) @@ -586,3 +716,410 @@ func (a *ACL) RulesTranslateToken(tokenID string) (string, error) { return string(ruleBytes), nil } + +// RoleCreate will create a new role. It is not allowed for the role parameters +// ID field to be set as this will be generated by Consul while processing the request. +func (a *ACL) RoleCreate(role *ACLRole, q *WriteOptions) (*ACLRole, *WriteMeta, error) { + if role.ID != "" { + return nil, nil, fmt.Errorf("Cannot specify an ID in Role Creation") + } + + r := a.c.newRequest("PUT", "/v1/acl/role") + r.setWriteOptions(q) + r.obj = role + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLRole + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// RoleUpdate updates a role. The ID field of the role parameter must be set to an +// existing role ID +func (a *ACL) RoleUpdate(role *ACLRole, q *WriteOptions) (*ACLRole, *WriteMeta, error) { + if role.ID == "" { + return nil, nil, fmt.Errorf("Must specify an ID in Role Update") + } + + r := a.c.newRequest("PUT", "/v1/acl/role/"+role.ID) + r.setWriteOptions(q) + r.obj = role + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLRole + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// RoleDelete deletes a role given its ID. +func (a *ACL) RoleDelete(roleID string, q *WriteOptions) (*WriteMeta, error) { + r := a.c.newRequest("DELETE", "/v1/acl/role/"+roleID) + r.setWriteOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, err + } + resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + return wm, nil +} + +// RoleRead retrieves the role details (by ID). Returns nil if not found. +func (a *ACL) RoleRead(roleID string, q *QueryOptions) (*ACLRole, *QueryMeta, error) { + r := a.c.newRequest("GET", "/v1/acl/role/"+roleID) + r.setQueryOptions(q) + found, rtt, resp, err := requireNotFoundOrOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + if !found { + return nil, qm, nil + } + + var out ACLRole + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, qm, nil +} + +// RoleReadByName retrieves the role details (by name). Returns nil if not found. +func (a *ACL) RoleReadByName(roleName string, q *QueryOptions) (*ACLRole, *QueryMeta, error) { + r := a.c.newRequest("GET", "/v1/acl/role/name/"+url.QueryEscape(roleName)) + r.setQueryOptions(q) + found, rtt, resp, err := requireNotFoundOrOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + if !found { + return nil, qm, nil + } + + var out ACLRole + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, qm, nil +} + +// RoleList retrieves a listing of all roles. The listing does not include some +// metadata for the role as those should be retrieved by subsequent calls to +// RoleRead. +func (a *ACL) RoleList(q *QueryOptions) ([]*ACLRole, *QueryMeta, error) { + r := a.c.newRequest("GET", "/v1/acl/roles") + r.setQueryOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + var entries []*ACLRole + if err := decodeBody(resp, &entries); err != nil { + return nil, nil, err + } + return entries, qm, nil +} + +// AuthMethodCreate will create a new auth method. +func (a *ACL) AuthMethodCreate(method *ACLAuthMethod, q *WriteOptions) (*ACLAuthMethod, *WriteMeta, error) { + if method.Name == "" { + return nil, nil, fmt.Errorf("Must specify a Name in Auth Method Creation") + } + + r := a.c.newRequest("PUT", "/v1/acl/auth-method") + r.setWriteOptions(q) + r.obj = method + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLAuthMethod + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// AuthMethodUpdate updates an auth method. +func (a *ACL) AuthMethodUpdate(method *ACLAuthMethod, q *WriteOptions) (*ACLAuthMethod, *WriteMeta, error) { + if method.Name == "" { + return nil, nil, fmt.Errorf("Must specify a Name in Auth Method Update") + } + + r := a.c.newRequest("PUT", "/v1/acl/auth-method/"+url.QueryEscape(method.Name)) + r.setWriteOptions(q) + r.obj = method + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLAuthMethod + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// AuthMethodDelete deletes an auth method given its Name. +func (a *ACL) AuthMethodDelete(methodName string, q *WriteOptions) (*WriteMeta, error) { + if methodName == "" { + return nil, fmt.Errorf("Must specify a Name in Auth Method Delete") + } + + r := a.c.newRequest("DELETE", "/v1/acl/auth-method/"+url.QueryEscape(methodName)) + r.setWriteOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, err + } + resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + return wm, nil +} + +// AuthMethodRead retrieves the auth method. Returns nil if not found. +func (a *ACL) AuthMethodRead(methodName string, q *QueryOptions) (*ACLAuthMethod, *QueryMeta, error) { + if methodName == "" { + return nil, nil, fmt.Errorf("Must specify a Name in Auth Method Read") + } + + r := a.c.newRequest("GET", "/v1/acl/auth-method/"+url.QueryEscape(methodName)) + r.setQueryOptions(q) + found, rtt, resp, err := requireNotFoundOrOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + if !found { + return nil, qm, nil + } + + var out ACLAuthMethod + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, qm, nil +} + +// AuthMethodList retrieves a listing of all auth methods. The listing does not +// include some metadata for the auth method as those should be retrieved by +// subsequent calls to AuthMethodRead. +func (a *ACL) AuthMethodList(q *QueryOptions) ([]*ACLAuthMethodListEntry, *QueryMeta, error) { + r := a.c.newRequest("GET", "/v1/acl/auth-methods") + r.setQueryOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + var entries []*ACLAuthMethodListEntry + if err := decodeBody(resp, &entries); err != nil { + return nil, nil, err + } + return entries, qm, nil +} + +// BindingRuleCreate will create a new binding rule. It is not allowed for the +// binding rule parameter's ID field to be set as this will be generated by +// Consul while processing the request. +func (a *ACL) BindingRuleCreate(rule *ACLBindingRule, q *WriteOptions) (*ACLBindingRule, *WriteMeta, error) { + if rule.ID != "" { + return nil, nil, fmt.Errorf("Cannot specify an ID in Binding Rule Creation") + } + + r := a.c.newRequest("PUT", "/v1/acl/binding-rule") + r.setWriteOptions(q) + r.obj = rule + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLBindingRule + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// BindingRuleUpdate updates a binding rule. The ID field of the role binding +// rule parameter must be set to an existing binding rule ID. +func (a *ACL) BindingRuleUpdate(rule *ACLBindingRule, q *WriteOptions) (*ACLBindingRule, *WriteMeta, error) { + if rule.ID == "" { + return nil, nil, fmt.Errorf("Must specify an ID in Binding Rule Update") + } + + r := a.c.newRequest("PUT", "/v1/acl/binding-rule/"+rule.ID) + r.setWriteOptions(q) + r.obj = rule + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLBindingRule + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// BindingRuleDelete deletes a binding rule given its ID. +func (a *ACL) BindingRuleDelete(bindingRuleID string, q *WriteOptions) (*WriteMeta, error) { + r := a.c.newRequest("DELETE", "/v1/acl/binding-rule/"+bindingRuleID) + r.setWriteOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, err + } + resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + return wm, nil +} + +// BindingRuleRead retrieves the binding rule details. Returns nil if not found. +func (a *ACL) BindingRuleRead(bindingRuleID string, q *QueryOptions) (*ACLBindingRule, *QueryMeta, error) { + r := a.c.newRequest("GET", "/v1/acl/binding-rule/"+bindingRuleID) + r.setQueryOptions(q) + found, rtt, resp, err := requireNotFoundOrOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + if !found { + return nil, qm, nil + } + + var out ACLBindingRule + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, qm, nil +} + +// BindingRuleList retrieves a listing of all binding rules. +func (a *ACL) BindingRuleList(methodName string, q *QueryOptions) ([]*ACLBindingRule, *QueryMeta, error) { + r := a.c.newRequest("GET", "/v1/acl/binding-rules") + if methodName != "" { + r.params.Set("authmethod", methodName) + } + r.setQueryOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + var entries []*ACLBindingRule + if err := decodeBody(resp, &entries); err != nil { + return nil, nil, err + } + return entries, qm, nil +} + +// Login is used to exchange auth method credentials for a newly-minted Consul Token. +func (a *ACL) Login(auth *ACLLoginParams, q *WriteOptions) (*ACLToken, *WriteMeta, error) { + r := a.c.newRequest("POST", "/v1/acl/login") + r.setWriteOptions(q) + r.obj = auth + + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLToken + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + return &out, wm, nil +} + +// Logout is used to destroy a Consul Token created via Login(). +func (a *ACL) Logout(q *WriteOptions) (*WriteMeta, error) { + r := a.c.newRequest("POST", "/v1/acl/logout") + r.setWriteOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, err + } + resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + return wm, nil +} diff --git a/api/api.go b/api/api.go index ffa2ce24df..4b17ff6cda 100644 --- a/api/api.go +++ b/api/api.go @@ -30,6 +30,10 @@ const ( // the HTTP token. HTTPTokenEnvName = "CONSUL_HTTP_TOKEN" + // HTTPTokenFileEnvName defines an environment variable name which sets + // the HTTP token file. + HTTPTokenFileEnvName = "CONSUL_HTTP_TOKEN_FILE" + // HTTPAuthEnvName defines an environment variable name which sets // the HTTP authentication header. HTTPAuthEnvName = "CONSUL_HTTP_AUTH" @@ -280,6 +284,10 @@ type Config struct { // which overrides the agent's default token. Token string + // TokenFile is a file containing the current token to use for this client. + // If provided it is read once at startup and never again. + TokenFile string + TLSConfig TLSConfig } @@ -343,6 +351,10 @@ func defaultConfig(transportFn func() *http.Transport) *Config { config.Address = addr } + if tokenFile := os.Getenv(HTTPTokenFileEnvName); tokenFile != "" { + config.TokenFile = tokenFile + } + if token := os.Getenv(HTTPTokenEnvName); token != "" { config.Token = token } @@ -449,6 +461,7 @@ func (c *Config) GenerateEnv() []string { env = append(env, fmt.Sprintf("%s=%s", HTTPAddrEnvName, c.Address), fmt.Sprintf("%s=%s", HTTPTokenEnvName, c.Token), + fmt.Sprintf("%s=%s", HTTPTokenFileEnvName, c.TokenFile), fmt.Sprintf("%s=%t", HTTPSSLEnvName, c.Scheme == "https"), fmt.Sprintf("%s=%s", HTTPCAFile, c.TLSConfig.CAFile), fmt.Sprintf("%s=%s", HTTPCAPath, c.TLSConfig.CAPath), @@ -541,6 +554,19 @@ func NewClient(config *Config) (*Client, error) { config.Address = parts[1] } + // If the TokenFile is set, always use that, even if a Token is configured. + // This is because when TokenFile is set it is read into the Token field. + // We want any derived clients to have to re-read the token file. + if config.TokenFile != "" { + data, err := ioutil.ReadFile(config.TokenFile) + if err != nil { + return nil, fmt.Errorf("Error loading token file: %s", err) + } + + if token := strings.TrimSpace(string(data)); token != "" { + config.Token = token + } + } if config.Token == "" { config.Token = defConfig.Token } @@ -820,6 +846,8 @@ func (c *Client) write(endpoint string, in, out interface{}, q *WriteOptions) (* } // parseQueryMeta is used to help parse query meta-data +// +// TODO(rb): bug? the error from this function is never handled func parseQueryMeta(resp *http.Response, q *QueryMeta) error { header := resp.Header @@ -897,10 +925,7 @@ func requireOK(d time.Duration, resp *http.Response, e error) (time.Duration, *h return d, nil, e } if resp.StatusCode != 200 { - var buf bytes.Buffer - io.Copy(&buf, resp.Body) - resp.Body.Close() - return d, nil, fmt.Errorf("Unexpected response code: %d (%s)", resp.StatusCode, buf.Bytes()) + return d, nil, generateUnexpectedResponseCodeError(resp) } return d, resp, nil } @@ -912,3 +937,30 @@ func (req *request) filterQuery(filter string) { req.params.Set("filter", filter) } + +// generateUnexpectedResponseCodeError consumes the rest of the body, closes +// the body stream and generates an error indicating the status code was +// unexpected. +func generateUnexpectedResponseCodeError(resp *http.Response) error { + var buf bytes.Buffer + io.Copy(&buf, resp.Body) + resp.Body.Close() + return fmt.Errorf("Unexpected response code: %d (%s)", resp.StatusCode, buf.Bytes()) +} + +func requireNotFoundOrOK(d time.Duration, resp *http.Response, e error) (bool, time.Duration, *http.Response, error) { + if e != nil { + if resp != nil { + resp.Body.Close() + } + return false, d, nil, e + } + switch resp.StatusCode { + case 200: + return true, d, resp, nil + case 404: + return false, d, resp, nil + default: + return false, d, nil, generateUnexpectedResponseCodeError(resp) + } +} diff --git a/api/api_test.go b/api/api_test.go index eca799e022..7934ed87b1 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -875,9 +875,10 @@ func TestAPI_GenerateEnv(t *testing.T) { t.Parallel() c := &Config{ - Address: "127.0.0.1:8500", - Token: "test", - Scheme: "http", + Address: "127.0.0.1:8500", + Token: "test", + TokenFile: "test.file", + Scheme: "http", TLSConfig: TLSConfig{ CAFile: "", CAPath: "", @@ -891,6 +892,7 @@ func TestAPI_GenerateEnv(t *testing.T) { expected := []string{ "CONSUL_HTTP_ADDR=127.0.0.1:8500", "CONSUL_HTTP_TOKEN=test", + "CONSUL_HTTP_TOKEN_FILE=test.file", "CONSUL_HTTP_SSL=false", "CONSUL_CACERT=", "CONSUL_CAPATH=", @@ -908,9 +910,10 @@ func TestAPI_GenerateEnvHTTPS(t *testing.T) { t.Parallel() c := &Config{ - Address: "127.0.0.1:8500", - Token: "test", - Scheme: "https", + Address: "127.0.0.1:8500", + Token: "test", + TokenFile: "test.file", + Scheme: "https", TLSConfig: TLSConfig{ CAFile: "/var/consul/ca.crt", CAPath: "/var/consul/ca.dir", @@ -928,6 +931,7 @@ func TestAPI_GenerateEnvHTTPS(t *testing.T) { expected := []string{ "CONSUL_HTTP_ADDR=127.0.0.1:8500", "CONSUL_HTTP_TOKEN=test", + "CONSUL_HTTP_TOKEN_FILE=test.file", "CONSUL_HTTP_SSL=true", "CONSUL_CACERT=/var/consul/ca.crt", "CONSUL_CAPATH=/var/consul/ca.dir", diff --git a/command/acl/acl_helpers.go b/command/acl/acl_helpers.go index 96d8ec57c9..57c51ce14b 100644 --- a/command/acl/acl_helpers.go +++ b/command/acl/acl_helpers.go @@ -1,6 +1,7 @@ package acl import ( + "encoding/json" "fmt" "strings" @@ -10,19 +11,40 @@ import ( ) func PrintToken(token *api.ACLToken, ui cli.Ui, showMeta bool) { - ui.Info(fmt.Sprintf("AccessorID: %s", token.AccessorID)) - ui.Info(fmt.Sprintf("SecretID: %s", token.SecretID)) - ui.Info(fmt.Sprintf("Description: %s", token.Description)) - ui.Info(fmt.Sprintf("Local: %t", token.Local)) - ui.Info(fmt.Sprintf("Create Time: %v", token.CreateTime)) - if showMeta { - ui.Info(fmt.Sprintf("Hash: %x", token.Hash)) - ui.Info(fmt.Sprintf("Create Index: %d", token.CreateIndex)) - ui.Info(fmt.Sprintf("Modify Index: %d", token.ModifyIndex)) + ui.Info(fmt.Sprintf("AccessorID: %s", token.AccessorID)) + ui.Info(fmt.Sprintf("SecretID: %s", token.SecretID)) + ui.Info(fmt.Sprintf("Description: %s", token.Description)) + ui.Info(fmt.Sprintf("Local: %t", token.Local)) + ui.Info(fmt.Sprintf("Create Time: %v", token.CreateTime)) + if token.ExpirationTime != nil && !token.ExpirationTime.IsZero() { + ui.Info(fmt.Sprintf("Expiration Time: %v", *token.ExpirationTime)) } - ui.Info(fmt.Sprintf("Policies:")) - for _, policy := range token.Policies { - ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + if showMeta { + ui.Info(fmt.Sprintf("Hash: %x", token.Hash)) + ui.Info(fmt.Sprintf("Create Index: %d", token.CreateIndex)) + ui.Info(fmt.Sprintf("Modify Index: %d", token.ModifyIndex)) + } + if len(token.Policies) > 0 { + ui.Info(fmt.Sprintf("Policies:")) + for _, policy := range token.Policies { + ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + } + } + if len(token.Roles) > 0 { + ui.Info(fmt.Sprintf("Roles:")) + for _, role := range token.Roles { + ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name)) + } + } + if len(token.ServiceIdentities) > 0 { + ui.Info(fmt.Sprintf("Service Identities:")) + for _, svcid := range token.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } + } } if token.Rules != "" { ui.Info(fmt.Sprintf("Rules:")) @@ -31,19 +53,40 @@ func PrintToken(token *api.ACLToken, ui cli.Ui, showMeta bool) { } func PrintTokenListEntry(token *api.ACLTokenListEntry, ui cli.Ui, showMeta bool) { - ui.Info(fmt.Sprintf("AccessorID: %s", token.AccessorID)) - ui.Info(fmt.Sprintf("Description: %s", token.Description)) - ui.Info(fmt.Sprintf("Local: %t", token.Local)) - ui.Info(fmt.Sprintf("Create Time: %v", token.CreateTime)) - ui.Info(fmt.Sprintf("Legacy: %t", token.Legacy)) - if showMeta { - ui.Info(fmt.Sprintf("Hash: %x", token.Hash)) - ui.Info(fmt.Sprintf("Create Index: %d", token.CreateIndex)) - ui.Info(fmt.Sprintf("Modify Index: %d", token.ModifyIndex)) + ui.Info(fmt.Sprintf("AccessorID: %s", token.AccessorID)) + ui.Info(fmt.Sprintf("Description: %s", token.Description)) + ui.Info(fmt.Sprintf("Local: %t", token.Local)) + ui.Info(fmt.Sprintf("Create Time: %v", token.CreateTime)) + if token.ExpirationTime != nil && !token.ExpirationTime.IsZero() { + ui.Info(fmt.Sprintf("Expiration Time: %v", *token.ExpirationTime)) } - ui.Info(fmt.Sprintf("Policies:")) - for _, policy := range token.Policies { - ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + ui.Info(fmt.Sprintf("Legacy: %t", token.Legacy)) + if showMeta { + ui.Info(fmt.Sprintf("Hash: %x", token.Hash)) + ui.Info(fmt.Sprintf("Create Index: %d", token.CreateIndex)) + ui.Info(fmt.Sprintf("Modify Index: %d", token.ModifyIndex)) + } + if len(token.Policies) > 0 { + ui.Info(fmt.Sprintf("Policies:")) + for _, policy := range token.Policies { + ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + } + } + if len(token.Roles) > 0 { + ui.Info(fmt.Sprintf("Roles:")) + for _, role := range token.Roles { + ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name)) + } + } + if len(token.ServiceIdentities) > 0 { + ui.Info(fmt.Sprintf("Service Identities:")) + for _, svcid := range token.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } + } } } @@ -73,6 +116,112 @@ func PrintPolicyListEntry(policy *api.ACLPolicyListEntry, ui cli.Ui, showMeta bo } } +func PrintRole(role *api.ACLRole, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("ID: %s", role.ID)) + ui.Info(fmt.Sprintf("Name: %s", role.Name)) + ui.Info(fmt.Sprintf("Description: %s", role.Description)) + if showMeta { + ui.Info(fmt.Sprintf("Hash: %x", role.Hash)) + ui.Info(fmt.Sprintf("Create Index: %d", role.CreateIndex)) + ui.Info(fmt.Sprintf("Modify Index: %d", role.ModifyIndex)) + } + if len(role.Policies) > 0 { + ui.Info(fmt.Sprintf("Policies:")) + for _, policy := range role.Policies { + ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + } + } + if len(role.ServiceIdentities) > 0 { + ui.Info(fmt.Sprintf("Service Identities:")) + for _, svcid := range role.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } + } + } +} + +func PrintRoleListEntry(role *api.ACLRole, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("%s:", role.Name)) + ui.Info(fmt.Sprintf(" ID: %s", role.ID)) + ui.Info(fmt.Sprintf(" Description: %s", role.Description)) + if showMeta { + ui.Info(fmt.Sprintf(" Hash: %x", role.Hash)) + ui.Info(fmt.Sprintf(" Create Index: %d", role.CreateIndex)) + ui.Info(fmt.Sprintf(" Modify Index: %d", role.ModifyIndex)) + } + if len(role.Policies) > 0 { + ui.Info(fmt.Sprintf(" Policies:")) + for _, policy := range role.Policies { + ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + } + } + if len(role.ServiceIdentities) > 0 { + ui.Info(fmt.Sprintf(" Service Identities:")) + for _, svcid := range role.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } + } + } +} + +func PrintAuthMethod(method *api.ACLAuthMethod, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("Name: %s", method.Name)) + ui.Info(fmt.Sprintf("Type: %s", method.Type)) + ui.Info(fmt.Sprintf("Description: %s", method.Description)) + if showMeta { + ui.Info(fmt.Sprintf("Create Index: %d", method.CreateIndex)) + ui.Info(fmt.Sprintf("Modify Index: %d", method.ModifyIndex)) + } + ui.Info(fmt.Sprintf("Config:")) + output, err := json.MarshalIndent(method.Config, "", " ") + if err != nil { + ui.Error(fmt.Sprintf("Error formatting auth method configuration: %s", err)) + } + ui.Output(string(output)) +} + +func PrintAuthMethodListEntry(method *api.ACLAuthMethodListEntry, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("%s:", method.Name)) + ui.Info(fmt.Sprintf(" Type: %s", method.Type)) + ui.Info(fmt.Sprintf(" Description: %s", method.Description)) + if showMeta { + ui.Info(fmt.Sprintf(" Create Index: %d", method.CreateIndex)) + ui.Info(fmt.Sprintf(" Modify Index: %d", method.ModifyIndex)) + } +} + +func PrintBindingRule(rule *api.ACLBindingRule, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("ID: %s", rule.ID)) + ui.Info(fmt.Sprintf("AuthMethod: %s", rule.AuthMethod)) + ui.Info(fmt.Sprintf("Description: %s", rule.Description)) + ui.Info(fmt.Sprintf("BindType: %s", rule.BindType)) + ui.Info(fmt.Sprintf("BindName: %s", rule.BindName)) + ui.Info(fmt.Sprintf("Selector: %s", rule.Selector)) + if showMeta { + ui.Info(fmt.Sprintf("Create Index: %d", rule.CreateIndex)) + ui.Info(fmt.Sprintf("Modify Index: %d", rule.ModifyIndex)) + } +} + +func PrintBindingRuleListEntry(rule *api.ACLBindingRule, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("%s:", rule.ID)) + ui.Info(fmt.Sprintf(" AuthMethod: %s", rule.AuthMethod)) + ui.Info(fmt.Sprintf(" Description: %s", rule.Description)) + ui.Info(fmt.Sprintf(" BindType: %s", rule.BindType)) + ui.Info(fmt.Sprintf(" BindName: %s", rule.BindName)) + ui.Info(fmt.Sprintf(" Selector: %s", rule.Selector)) + if showMeta { + ui.Info(fmt.Sprintf(" Create Index: %d", rule.CreateIndex)) + ui.Info(fmt.Sprintf(" Modify Index: %d", rule.ModifyIndex)) + } +} + func GetTokenIDFromPartial(client *api.Client, partialID string) (string, error) { if partialID == "anonymous" { return structs.ACLTokenAnonymousID, nil @@ -185,3 +334,123 @@ func GetRulesFromLegacyToken(client *api.Client, tokenID string, isSecret bool) return token.Rules, nil } + +func GetRoleIDFromPartial(client *api.Client, partialID string) (string, error) { + // the full UUID string was given + if len(partialID) == 36 { + return partialID, nil + } + + roles, _, err := client.ACL().RoleList(nil) + if err != nil { + return "", err + } + + roleID := "" + for _, role := range roles { + if strings.HasPrefix(role.ID, partialID) { + if roleID != "" { + return "", fmt.Errorf("Partial role ID is not unique") + } + roleID = role.ID + } + } + + if roleID == "" { + return "", fmt.Errorf("No such role ID with prefix: %s", partialID) + } + + return roleID, nil +} + +func GetRoleIDByName(client *api.Client, name string) (string, error) { + if name == "" { + return "", fmt.Errorf("No name specified") + } + + roles, _, err := client.ACL().RoleList(nil) + if err != nil { + return "", err + } + + for _, role := range roles { + if role.Name == name { + return role.ID, nil + } + } + + return "", fmt.Errorf("No such role with name %s", name) +} + +func GetBindingRuleIDFromPartial(client *api.Client, partialID string) (string, error) { + // the full UUID string was given + if len(partialID) == 36 { + return partialID, nil + } + + rules, _, err := client.ACL().BindingRuleList("", nil) + if err != nil { + return "", err + } + + ruleID := "" + for _, rule := range rules { + if strings.HasPrefix(rule.ID, partialID) { + if ruleID != "" { + return "", fmt.Errorf("Partial rule ID is not unique") + } + ruleID = rule.ID + } + } + + if ruleID == "" { + return "", fmt.Errorf("No such rule ID with prefix: %s", partialID) + } + + return ruleID, nil +} + +func ExtractServiceIdentities(serviceIdents []string) ([]*api.ACLServiceIdentity, error) { + var out []*api.ACLServiceIdentity + for _, svcidRaw := range serviceIdents { + parts := strings.Split(svcidRaw, ":") + switch len(parts) { + case 2: + out = append(out, &api.ACLServiceIdentity{ + ServiceName: parts[0], + Datacenters: strings.Split(parts[1], ","), + }) + case 1: + out = append(out, &api.ACLServiceIdentity{ + ServiceName: parts[0], + }) + default: + return nil, fmt.Errorf("Malformed -service-identity argument: %q", svcidRaw) + } + } + return out, nil +} + +// TestKubernetesJWT_A is a valid service account jwt extracted from a minikube setup. +// +// { +// "iss": "kubernetes/serviceaccount", +// "kubernetes.io/serviceaccount/namespace": "default", +// "kubernetes.io/serviceaccount/secret.name": "admin-token-qlz42", +// "kubernetes.io/serviceaccount/service-account.name": "admin", +// "kubernetes.io/serviceaccount/service-account.uid": "738bc251-6532-11e9-b67f-48e6c8b8ecb5", +// "sub": "system:serviceaccount:default:admin" +// } +const TestKubernetesJWT_A = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImFkbWluLXRva2VuLXFsejQyIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQubmFtZSI6ImFkbWluIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQudWlkIjoiNzM4YmMyNTEtNjUzMi0xMWU5LWI2N2YtNDhlNmM4YjhlY2I1Iiwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6YWRtaW4ifQ.ixMlnWrAG7NVuTTKu8cdcYfM7gweS3jlKaEsIBNGOVEjPE7rtXtgMkAwjQTdYR08_0QBjkgzy5fQC5ZNyglSwONJ-bPaXGvhoH1cTnRi1dz9H_63CfqOCvQP1sbdkMeRxNTGVAyWZT76rXoCUIfHP4LY2I8aab0KN9FTIcgZRF0XPTtT70UwGIrSmRpxW38zjiy2ymWL01cc5VWGhJqVysmWmYk3wNp0h5N57H_MOrz4apQR4pKaamzskzjLxO55gpbmZFC76qWuUdexAR7DT2fpbHLOw90atN_NlLMY-VrXyW3-Ei5EhYaVreMB9PSpKwkrA4jULITohV-sxpa1LA" + +// TestKubernetesJWT_B is a valid service account jwt extracted from a minikube setup. +// +// { +// "iss": "kubernetes/serviceaccount", +// "kubernetes.io/serviceaccount/namespace": "default", +// "kubernetes.io/serviceaccount/secret.name": "demo-token-kmb9n", +// "kubernetes.io/serviceaccount/service-account.name": "demo", +// "kubernetes.io/serviceaccount/service-account.uid": "76091af4-4b56-11e9-ac4b-708b11801cbe", +// "sub": "system:serviceaccount:default:demo" +// } +const TestKubernetesJWT_B = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImRlbW8tdG9rZW4ta21iOW4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZGVtbyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6Ijc2MDkxYWY0LTRiNTYtMTFlOS1hYzRiLTcwOGIxMTgwMWNiZSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmRlbW8ifQ.ZiAHjijBAOsKdum0Aix6lgtkLkGo9_Tu87dWQ5Zfwnn3r2FejEWDAnftTft1MqqnMzivZ9Wyyki5ZjQRmTAtnMPJuHC-iivqY4Wh4S6QWCJ1SivBv5tMZR79t5t8mE7R1-OHwst46spru1pps9wt9jsA04d3LpV0eeKYgdPTVaQKklxTm397kIMUugA6yINIBQ3Rh8eQqBgNwEmL4iqyYubzHLVkGkoP9MJikFI05vfRiHtYr-piXz6JFDzXMQj9rW6xtMmrBSn79ChbyvC5nz-Nj2rJPnHsb_0rDUbmXY5PpnMhBpdSH-CbZ4j8jsiib6DtaGJhVZeEQ1GjsFAZwQ" diff --git a/command/acl/authmethod/authmethod.go b/command/acl/authmethod/authmethod.go new file mode 100644 index 0000000000..d8be7857ab --- /dev/null +++ b/command/acl/authmethod/authmethod.go @@ -0,0 +1,64 @@ +package authmethod + +import ( + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New() *cmd { + return &cmd{} +} + +type cmd struct{} + +func (c *cmd) Run(args []string) int { + return cli.RunResultHelp +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(help, nil) +} + +const synopsis = "Manage Consul's ACL Auth Methods" +const help = ` +Usage: consul acl auth-method [options] [args] + + This command has subcommands for managing Consul's ACL Auth Methods. + Here are some simple examples, and more detailed examples are available in + the subcommands or the documentation. + + Create a new auth method: + + $ consul acl auth-method create -type "kubernetes" \ + -name "my-k8s" \ + -description "This is an example kube auth method" \ + -kubernetes-host "https://apiserver.example.com:8443" \ + -kubernetes-ca-file /path/to/kube.ca.crt \ + -kubernetes-service-account-jwt "JWT_CONTENTS" + + List all auth methods: + + $ consul acl auth-method list + + Update all editable fields of the auth method: + + $ consul acl auth-method update -name "my-k8s" \ + -description "new description" \ + -kubernetes-host "https://new-apiserver.example.com:8443" \ + -kubernetes-ca-file /path/to/new-kube.ca.crt \ + -kubernetes-service-account-jwt "NEW_JWT_CONTENTS" + + Read an auth method: + + $ consul acl auth-method read -name my-k8s + + Delete an auth method: + + $ consul acl auth-method delete -name my-k8s + + For more examples, ask for subcommand help or view the documentation. +` diff --git a/command/acl/authmethod/create/authmethod_create.go b/command/acl/authmethod/create/authmethod_create.go new file mode 100644 index 0000000000..46a55882b6 --- /dev/null +++ b/command/acl/authmethod/create/authmethod_create.go @@ -0,0 +1,186 @@ +package authmethodcreate + +import ( + "flag" + "fmt" + "io" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/hashicorp/consul/command/helpers" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + authMethodType string + name string + description string + + k8sHost string + k8sCACert string + k8sServiceAccountJWT string + + showMeta bool + + testStdin io.Reader +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that auth method metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.authMethodType, + "type", + "", + "The new auth method's type. This flag is required.", + ) + c.flags.StringVar( + &c.name, + "name", + "", + "The new auth method's name. This flag is required.", + ) + c.flags.StringVar( + &c.description, + "description", + "", + "A description of the auth method.", + ) + + c.flags.StringVar( + &c.k8sHost, + "kubernetes-host", + "", + "Address of the Kubernetes API server. This flag is required for type=kubernetes.", + ) + c.flags.StringVar( + &c.k8sCACert, + "kubernetes-ca-cert", + "", + "PEM encoded CA cert for use by the TLS client used to talk with the "+ + "Kubernetes API. May be prefixed with '@' to indicate that the "+ + "value is a file path to load the cert from. "+ + "This flag is required for type=kubernetes.", + ) + c.flags.StringVar( + &c.k8sServiceAccountJWT, + "kubernetes-service-account-jwt", + "", + "A kubernetes service account JWT used to access the TokenReview API to "+ + "validate other JWTs during login. "+ + "This flag is required for type=kubernetes.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.authMethodType == "" { + c.UI.Error(fmt.Sprintf("Missing required '-type' flag")) + c.UI.Error(c.Help()) + return 1 + } else if c.name == "" { + c.UI.Error(fmt.Sprintf("Missing required '-name' flag")) + c.UI.Error(c.Help()) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + newAuthMethod := &api.ACLAuthMethod{ + Type: c.authMethodType, + Name: c.name, + Description: c.description, + } + + if c.authMethodType == "kubernetes" { + if c.k8sHost == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-host' flag")) + return 1 + } else if c.k8sCACert == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-ca-cert' flag")) + return 1 + } else if c.k8sServiceAccountJWT == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-service-account-jwt' flag")) + return 1 + } + + c.k8sCACert, err = helpers.LoadDataSource(c.k8sCACert, c.testStdin) + if err != nil { + c.UI.Error(fmt.Sprintf("Invalid '-kubernetes-ca-cert' value: %v", err)) + return 1 + } else if c.k8sCACert == "" { + c.UI.Error(fmt.Sprintf("Kubernetes CA Cert is empty")) + return 1 + } + + newAuthMethod.Config = map[string]interface{}{ + "Host": c.k8sHost, + "CACert": c.k8sCACert, + "ServiceAccountJWT": c.k8sServiceAccountJWT, + } + } + + method, _, err := client.ACL().AuthMethodCreate(newAuthMethod, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to create new auth method: %v", err)) + return 1 + } + + acl.PrintAuthMethod(method, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Create an ACL Auth Method" + +const help = ` +Usage: consul acl auth-method create -name NAME -type TYPE [options] + + Create a new auth method: + + $ consul acl auth-method create -type "kubernetes" \ + -name "my-k8s" \ + -description "This is an example kube method" \ + -kubernetes-host "https://apiserver.example.com:8443" \ + -kubernetes-ca-file /path/to/kube.ca.crt \ + -kubernetes-service-account-jwt "JWT_CONTENTS" +` diff --git a/command/acl/authmethod/create/authmethod_create_test.go b/command/acl/authmethod/create/authmethod_create_test.go new file mode 100644 index 0000000000..a5bb222dd1 --- /dev/null +++ b/command/acl/authmethod/create/authmethod_create_test.go @@ -0,0 +1,226 @@ +package authmethodcreate + +import ( + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestAuthMethodCreateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestAuthMethodCreateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + t.Run("type required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-type' flag") + }) + + t.Run("name required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=testing", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-name' flag") + }) + + t.Run("invalid type", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=invalid", + "-name=my-method", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Invalid Auth Method: Type should be one of") + }) + + t.Run("create testing", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=testing", + "-name=test", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) +} + +func TestAuthMethodCreateCommand_k8s(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + t.Run("k8s host required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=kubernetes", + "-name=k8s", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-host' flag") + }) + + t.Run("k8s ca cert required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=kubernetes", + "-name=k8s", + "-kubernetes-host=https://foo.internal:8443", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-ca-cert' flag") + }) + + ca := connect.TestCA(t, nil) + + t.Run("k8s jwt required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=kubernetes", + "-name=k8s", + "-kubernetes-host=https://foo.internal:8443", + "-kubernetes-ca-cert", ca.RootCert, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-service-account-jwt' flag") + }) + + t.Run("create k8s", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=kubernetes", + "-name=k8s", + "-kubernetes-host", "https://foo.internal:8443", + "-kubernetes-ca-cert", ca.RootCert, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_A, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) + + caFile := filepath.Join(testDir, "ca.crt") + require.NoError(t, ioutil.WriteFile(caFile, []byte(ca.RootCert), 0600)) + + t.Run("create k8s with cert file", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=kubernetes", + "-name=k8s", + "-kubernetes-host", "https://foo.internal:8443", + "-kubernetes-ca-cert", "@" + caFile, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_A, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) +} diff --git a/command/acl/authmethod/delete/authmethod_delete.go b/command/acl/authmethod/delete/authmethod_delete.go new file mode 100644 index 0000000000..d8c341c989 --- /dev/null +++ b/command/acl/authmethod/delete/authmethod_delete.go @@ -0,0 +1,82 @@ +package authmethoddelete + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + name string +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.StringVar( + &c.name, + "name", + "", + "The name of the auth method to delete.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.name == "" { + c.UI.Error(fmt.Sprintf("Must specify the -name parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + if _, err := client.ACL().AuthMethodDelete(c.name, nil); err != nil { + c.UI.Error(fmt.Sprintf("Error deleting auth method %q: %v", c.name, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Auth method %q deleted successfully", c.name)) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Delete an ACL Auth Method" +const help = ` +Usage: consul acl auth-method delete -name NAME [options] + + Delete an auth method: + + $ consul acl auth-method delete -name "my-auth-method" +` diff --git a/command/acl/authmethod/delete/authmethod_delete_test.go b/command/acl/authmethod/delete/authmethod_delete_test.go new file mode 100644 index 0000000000..5d0638727c --- /dev/null +++ b/command/acl/authmethod/delete/authmethod_delete_test.go @@ -0,0 +1,131 @@ +package authmethoddelete + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestAuthMethodDeleteCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestAuthMethodDeleteCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("name required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -name parameter") + }) + + t.Run("delete notfound", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=notfound", + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, "notfound") + }) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "test-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "testing", + Description: "test", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("delete works", func(t *testing.T) { + name := createAuthMethod(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, name) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.Nil(t, method) + }) +} diff --git a/command/acl/authmethod/list/authmethod_list.go b/command/acl/authmethod/list/authmethod_list.go new file mode 100644 index 0000000000..837d5f9ce8 --- /dev/null +++ b/command/acl/authmethod/list/authmethod_list.go @@ -0,0 +1,83 @@ +package authmethodlist + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that auth method metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + methods, _, err := client.ACL().AuthMethodList(nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to retrieve the auth method list: %v", err)) + return 1 + } + + for _, method := range methods { + acl.PrintAuthMethodListEntry(method, c.UI, c.showMeta) + } + + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Lists ACL Auth Methods" +const help = ` +Usage: consul acl auth-method list [options] + + List all auth methods: + + $ consul acl auth-method list +` diff --git a/command/acl/authmethod/list/authmethod_list_test.go b/command/acl/authmethod/list/authmethod_list_test.go new file mode 100644 index 0000000000..a8a650393e --- /dev/null +++ b/command/acl/authmethod/list/authmethod_list_test.go @@ -0,0 +1,109 @@ +package authmethodlist + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestAuthMethodListCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestAuthMethodListCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + t.Run("found none", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + require.Empty(t, ui.OutputWriter.String()) + }) + + client := a.Client() + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "test-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "testing", + Description: "test", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + var methodNames []string + for i := 0; i < 5; i++ { + methodName := createAuthMethod(t) + methodNames = append(methodNames, methodName) + } + + t.Run("found some", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + output := ui.OutputWriter.String() + + for _, methodName := range methodNames { + require.Contains(t, output, methodName) + } + }) +} diff --git a/command/acl/authmethod/read/authmethod_read.go b/command/acl/authmethod/read/authmethod_read.go new file mode 100644 index 0000000000..1a98bbf64d --- /dev/null +++ b/command/acl/authmethod/read/authmethod_read.go @@ -0,0 +1,96 @@ +package authmethodread + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + name string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that auth method metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.name, + "name", + "", + "The name of the auth method to read.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.name == "" { + c.UI.Error(fmt.Sprintf("Must specify the -name parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + method, _, err := client.ACL().AuthMethodRead(c.name, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading auth method %q: %v", c.name, err)) + return 1 + } else if method == nil { + c.UI.Error(fmt.Sprintf("Auth method not found with name %q", c.name)) + return 1 + } + acl.PrintAuthMethod(method, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Read an ACL Auth Method" +const help = ` +Usage: consul acl auth-method read -name NAME [options] + + Read an auth method: + + $ consul acl auth-method read -name my-auth-method +` diff --git a/command/acl/authmethod/read/authmethod_read_test.go b/command/acl/authmethod/read/authmethod_read_test.go new file mode 100644 index 0000000000..72b78e8005 --- /dev/null +++ b/command/acl/authmethod/read/authmethod_read_test.go @@ -0,0 +1,118 @@ +package authmethodread + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestAuthMethodReadCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestAuthMethodReadCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("name required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -name parameter") + }) + + t.Run("not found", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=notfound", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Auth method not found with name") + }) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "test-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "testing", + Description: "test", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("read by name", func(t *testing.T) { + name := createAuthMethod(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, name) + }) +} diff --git a/command/acl/authmethod/update/authmethod_update.go b/command/acl/authmethod/update/authmethod_update.go new file mode 100644 index 0000000000..6f77235f51 --- /dev/null +++ b/command/acl/authmethod/update/authmethod_update.go @@ -0,0 +1,220 @@ +package authmethodupdate + +import ( + "flag" + "fmt" + "io" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/hashicorp/consul/command/helpers" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + name string + + description string + + k8sHost string + k8sCACert string + k8sServiceAccountJWT string + + noMerge bool + showMeta bool + + testStdin io.Reader +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that auth method metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.name, + "name", + "", + "The auth method name.", + ) + + c.flags.StringVar( + &c.description, + "description", + "", + "A description of the auth method.", + ) + + c.flags.StringVar( + &c.k8sHost, + "kubernetes-host", + "", + "Address of the Kubernetes API server. This flag is required for type=kubernetes.", + ) + c.flags.StringVar( + &c.k8sCACert, + "kubernetes-ca-cert", + "", + "PEM encoded CA cert for use by the TLS client used to talk with the "+ + "Kubernetes API. May be prefixed with '@' to indicate that the "+ + "value is a file path to load the cert from. "+ + "This flag is required for type=kubernetes.", + ) + c.flags.StringVar( + &c.k8sServiceAccountJWT, + "kubernetes-service-account-jwt", + "", + "A kubernetes service account JWT used to access the TokenReview API to "+ + "validate other JWTs during login. "+ + "This flag is required for type=kubernetes.", + ) + + c.flags.BoolVar(&c.noMerge, "no-merge", false, "Do not merge the current auth method "+ + "information with what is provided to the command. Instead overwrite all fields "+ + "with the exception of the name which is immutable.") + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.name == "" { + c.UI.Error(fmt.Sprintf("Cannot update an auth method without specifying the -name parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + // Regardless of merge, we need to fetch the prior immutable fields first. + currentAuthMethod, _, err := client.ACL().AuthMethodRead(c.name, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error when retrieving current auth method: %v", err)) + return 1 + } else if currentAuthMethod == nil { + c.UI.Error(fmt.Sprintf("Auth method not found with name %q", c.name)) + return 1 + } + + if c.k8sCACert != "" { + c.k8sCACert, err = helpers.LoadDataSource(c.k8sCACert, c.testStdin) + if err != nil { + c.UI.Error(fmt.Sprintf("Invalid '-kubernetes-ca-cert' value: %v", err)) + return 1 + } else if c.k8sCACert == "" { + c.UI.Error(fmt.Sprintf("Kubernetes CA Cert is empty")) + return 1 + } + } + + var method *api.ACLAuthMethod + if c.noMerge { + method = &api.ACLAuthMethod{ + Name: currentAuthMethod.Name, + Type: currentAuthMethod.Type, + Description: c.description, + } + + if currentAuthMethod.Type == "kubernetes" { + if c.k8sHost == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-host' flag")) + return 1 + } else if c.k8sCACert == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-ca-cert' flag")) + return 1 + } else if c.k8sServiceAccountJWT == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-service-account-jwt' flag")) + return 1 + } + + method.Config = map[string]interface{}{ + "Host": c.k8sHost, + "CACert": c.k8sCACert, + "ServiceAccountJWT": c.k8sServiceAccountJWT, + } + } + } else { + methodCopy := *currentAuthMethod + method = &methodCopy + + if c.description != "" { + method.Description = c.description + } + if method.Config == nil { + method.Config = make(map[string]interface{}) + } + if currentAuthMethod.Type == "kubernetes" { + if c.k8sHost != "" { + method.Config["Host"] = c.k8sHost + } + if c.k8sCACert != "" { + method.Config["CACert"] = c.k8sCACert + } + if c.k8sServiceAccountJWT != "" { + method.Config["ServiceAccountJWT"] = c.k8sServiceAccountJWT + } + } + } + + method, _, err = client.ACL().AuthMethodUpdate(method, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error updating auth method %q: %v", c.name, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Auth method updated successfully")) + acl.PrintAuthMethod(method, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Update an ACL Auth Method" +const help = ` +Usage: consul acl auth-method update -name NAME [options] + + Updates an auth method. By default it will merge the auth method + information with its current state so that you do not have to provide all + parameters. This behavior can be disabled by passing -no-merge. + + Update all editable fields of the auth method: + + $ consul acl auth-method update -name "my-k8s" \ + -description "new description" \ + -kubernetes-host "https://new-apiserver.example.com:8443" \ + -kubernetes-ca-file /path/to/new-kube.ca.crt \ + -kubernetes-service-account-jwt "NEW_JWT_CONTENTS" +` diff --git a/command/acl/authmethod/update/authmethod_update_test.go b/command/acl/authmethod/update/authmethod_update_test.go new file mode 100644 index 0000000000..ba5d92758e --- /dev/null +++ b/command/acl/authmethod/update/authmethod_update_test.go @@ -0,0 +1,647 @@ +package authmethodupdate + +import ( + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestAuthMethodUpdateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestAuthMethodUpdateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("update without name", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Cannot update an auth method without specifying the -name parameter") + }) + + t.Run("update nonexistent method", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=test", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Auth method not found with name") + }) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "test-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "testing", + Description: "test", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("update all fields", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + }) +} + +func TestAuthMethodUpdateCommand_noMerge(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("update without name", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Cannot update an auth method without specifying the -name parameter") + }) + + t.Run("update nonexistent method", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=test", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Auth method not found with name") + }) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "test-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "testing", + Description: "test", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("update all fields", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + }) +} + +func TestAuthMethodUpdateCommand_k8s(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + ca := connect.TestCA(t, nil) + ca2 := connect.TestCA(t, nil) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "k8s-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "kubernetes", + Description: "test", + Config: map[string]interface{}{ + "Host": "https://foo.internal:8443", + "CACert": ca.RootCert, + "ServiceAccountJWT": acl.TestKubernetesJWT_A, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("update all fields", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", ca2.RootCert, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) + + ca2File := filepath.Join(testDir, "ca2.crt") + require.NoError(t, ioutil.WriteFile(ca2File, []byte(ca2.RootCert), 0600)) + + t.Run("update all fields with cert file", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", "@" + ca2File, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) + + t.Run("update all fields but k8s host", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + "-kubernetes-ca-cert", ca2.RootCert, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) + + t.Run("update all fields but k8s ca cert", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) + + t.Run("update all fields but k8s jwt", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", ca2.RootCert, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_A, config.ServiceAccountJWT) + }) +} + +func TestAuthMethodUpdateCommand_k8s_noMerge(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + ca := connect.TestCA(t, nil) + ca2 := connect.TestCA(t, nil) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "k8s-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "kubernetes", + Description: "test", + Config: map[string]interface{}{ + "Host": "https://foo.internal:8443", + "CACert": ca.RootCert, + "ServiceAccountJWT": acl.TestKubernetesJWT_A, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("update missing k8s host", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + "-kubernetes-ca-cert", ca2.RootCert, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-host' flag") + }) + + t.Run("update missing k8s ca cert", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-ca-cert' flag") + }) + + t.Run("update missing k8s jwt", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", ca2.RootCert, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-service-account-jwt' flag") + }) + + t.Run("update all fields", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", ca2.RootCert, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) + + ca2File := filepath.Join(testDir, "ca2.crt") + require.NoError(t, ioutil.WriteFile(ca2File, []byte(ca2.RootCert), 0600)) + + t.Run("update all fields with cert file", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", "@" + ca2File, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) +} diff --git a/command/acl/bindingrule/bindingrule.go b/command/acl/bindingrule/bindingrule.go new file mode 100644 index 0000000000..2b94139463 --- /dev/null +++ b/command/acl/bindingrule/bindingrule.go @@ -0,0 +1,60 @@ +package bindingrule + +import ( + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New() *cmd { + return &cmd{} +} + +type cmd struct{} + +func (c *cmd) Run(args []string) int { + return cli.RunResultHelp +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(help, nil) +} + +const synopsis = "Manage Consul's ACL Binding Rules" +const help = ` +Usage: consul acl binding-rule [options] [args] + + This command has subcommands for managing Consul's ACL Binding Rules. Here + are some simple examples, and more detailed examples are available in the + subcommands or the documentation. + + Create a new binding rule: + + $ consul acl binding-rule create \ + -method=minikube \ + -bind-type=service \ + -bind-name='k8s-${serviceaccount.name}' \ + -selector='serviceaccount.namespace==default and serviceaccount.name==web' + + List all binding rules: + + $ consul acl binding-rule list + + Update a binding rule: + + $ consul acl binding-rule update -id=43cb72df-9c6f-4315-ac8a-01a9d98155ef \ + -bind-name='k8s-${serviceaccount.name}' + + Read a binding rule: + + $ consul acl binding-rule read -id fdabbcb5-9de5-4b1a-961f-77214ae88cba + + Delete a binding rule: + + $ consul acl binding-rule delete -id b6b856da-5193-4e78-845a-7d61ca8371ba + + For more examples, ask for subcommand help or view the documentation. +` diff --git a/command/acl/bindingrule/create/bindingrule_create.go b/command/acl/bindingrule/create/bindingrule_create.go new file mode 100644 index 0000000000..01bcfcbe72 --- /dev/null +++ b/command/acl/bindingrule/create/bindingrule_create.go @@ -0,0 +1,148 @@ +package bindingrulecreate + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + authMethodName string + description string + selector string + bindType string + bindName string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that binding rule metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.authMethodName, + "method", + "", + "The auth method's name for which this binding rule applies. "+ + "This flag is required.", + ) + c.flags.StringVar( + &c.description, + "description", + "", + "A description of the binding rule.", + ) + c.flags.StringVar( + &c.selector, + "selector", + "", + "Selector is an expression that matches against verified identity "+ + "attributes returned from the auth method during login.", + ) + c.flags.StringVar( + &c.bindType, + "bind-type", + string(api.BindingRuleBindTypeService), + "Type of binding to perform (\"service\" or \"role\").", + ) + c.flags.StringVar( + &c.bindName, + "bind-name", + "", + "Name to bind on match. Can use ${var} interpolation. "+ + "This flag is required.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.authMethodName == "" { + c.UI.Error(fmt.Sprintf("Missing required '-method' flag")) + c.UI.Error(c.Help()) + return 1 + } else if c.bindType == "" { + c.UI.Error(fmt.Sprintf("Missing required '-bind-type' flag")) + c.UI.Error(c.Help()) + return 1 + } else if c.bindName == "" { + c.UI.Error(fmt.Sprintf("Missing required '-bind-name' flag")) + c.UI.Error(c.Help()) + return 1 + } + + newRule := &api.ACLBindingRule{ + Description: c.description, + AuthMethod: c.authMethodName, + BindType: api.BindingRuleBindType(c.bindType), + BindName: c.bindName, + Selector: c.selector, + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + rule, _, err := client.ACL().BindingRuleCreate(newRule, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to create new binding rule: %v", err)) + return 1 + } + + acl.PrintBindingRule(rule, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Create an ACL Binding Rule" + +const help = ` +Usage: consul acl binding-rule create [options] + + Create a new binding rule: + + $ consul acl binding-rule create \ + -method=minikube \ + -bind-type=service \ + -bind-name='k8s-${serviceaccount.name}' \ + -selector='serviceaccount.namespace==default and serviceaccount.name==web' +` diff --git a/command/acl/bindingrule/create/bindingrule_create_test.go b/command/acl/bindingrule/create/bindingrule_create_test.go new file mode 100644 index 0000000000..0e8b510963 --- /dev/null +++ b/command/acl/bindingrule/create/bindingrule_create_test.go @@ -0,0 +1,178 @@ +package bindingrulecreate + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestBindingRuleCreateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestBindingRuleCreateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // create an auth method in advance + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + t.Run("method is required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-method' flag") + }) + + t.Run("bind type required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bind-type' flag") + }) + + t.Run("bind name required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=service", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bind-name' flag") + }) + + t.Run("must use roughly valid selector", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=service", + "-bind-name=demo", + "-selector", "foo", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Selector is invalid") + }) + + t.Run("create it with no selector", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=service", + "-bind-name=demo", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) + + t.Run("create it with a match selector", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=service", + "-bind-name=demo", + "-selector", "serviceaccount.namespace==default and serviceaccount.name==vault", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) + + t.Run("create it with type role", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=role", + "-bind-name=demo", + "-selector", "serviceaccount.namespace==default and serviceaccount.name==vault", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) +} diff --git a/command/acl/bindingrule/delete/bindingrule_delete.go b/command/acl/bindingrule/delete/bindingrule_delete.go new file mode 100644 index 0000000000..7956e1e3aa --- /dev/null +++ b/command/acl/bindingrule/delete/bindingrule_delete.go @@ -0,0 +1,97 @@ +package bindingruledelete + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + ruleID string +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.StringVar( + &c.ruleID, + "id", + "", + "The ID of the binding rule to delete. "+ + "It may be specified as a unique ID prefix but will error if the prefix "+ + "matches multiple binding rule IDs", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.ruleID == "" { + c.UI.Error(fmt.Sprintf("Must specify the -id parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + ruleID, err := acl.GetBindingRuleIDFromPartial(client, c.ruleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error determining binding rule ID: %v", err)) + return 1 + } + + if _, err := client.ACL().BindingRuleDelete(ruleID, nil); err != nil { + c.UI.Error(fmt.Sprintf("Error deleting binding rule %q: %v", ruleID, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Binding rule %q deleted successfully", ruleID)) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Delete an ACL Binding Rule" +const help = ` +Usage: consul acl binding-rule delete -id ID [options] + + Deletes an ACL binding rule by providing the ID or a unique ID prefix. + + Delete by prefix: + + $ consul acl binding-rule delete -id b6b85 + + Delete by full ID: + + $ consul acl binding-rule delete -id b6b856da-5193-4e78-845a-7d61ca8371ba +` diff --git a/command/acl/bindingrule/delete/bindingrule_delete_test.go b/command/acl/bindingrule/delete/bindingrule_delete_test.go new file mode 100644 index 0000000000..497f26b21c --- /dev/null +++ b/command/acl/bindingrule/delete/bindingrule_delete_test.go @@ -0,0 +1,187 @@ +package bindingruledelete + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestBindingRuleDeleteCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestBindingRuleDeleteCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // create an auth method in advance + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + createRule := func(t *testing.T) string { + rule, _, err := client.ACL().BindingRuleCreate( + &api.ACLBindingRule{ + AuthMethod: "test", + Description: "test rule", + BindType: api.BindingRuleBindTypeService, + BindName: "test-${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return rule.ID + } + + createDupe := func(t *testing.T) string { + for { + // Check for 1-char duplicates. + rules, _, err := client.ACL().BindingRuleList( + "test", + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + + m := make(map[byte]struct{}) + for _, rule := range rules { + c := rule.ID[0] + + if _, ok := m[c]; ok { + return string(c) + } + m[c] = struct{}{} + } + + _ = createRule(t) + } + } + + t.Run("id required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id parameter") + }) + + t.Run("delete works", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, id) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.Nil(t, rule) + }) + + t.Run("delete works via prefixes", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id[0:5], + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, id) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.Nil(t, rule) + }) + + t.Run("delete fails when prefix matches more than one rule", func(t *testing.T) { + prefix := createDupe(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + prefix, + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID") + }) +} diff --git a/command/acl/bindingrule/list/bindingrule_list.go b/command/acl/bindingrule/list/bindingrule_list.go new file mode 100644 index 0000000000..1150ac42c2 --- /dev/null +++ b/command/acl/bindingrule/list/bindingrule_list.go @@ -0,0 +1,98 @@ +package bindingrulelist + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + authMethodName string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that binding rule metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.authMethodName, + "method", + "", + "Only show rules linked to the auth method with the given name.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + rules, _, err := client.ACL().BindingRuleList(c.authMethodName, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to retrieve the binding rule list: %v", err)) + return 1 + } + + for _, rule := range rules { + acl.PrintBindingRuleListEntry(rule, c.UI, c.showMeta) + } + + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Lists ACL Binding Rules" +const help = ` +Usage: consul acl binding-rule list [options] + + Lists all the ACL binding rules. + + Show all: + + $ consul acl binding-rule list + + Show all for a specific auth method: + + $ consul acl binding-rule list -method="my-method" +` diff --git a/command/acl/bindingrule/list/bindingrule_list_test.go b/command/acl/bindingrule/list/bindingrule_list_test.go new file mode 100644 index 0000000000..2d935857e3 --- /dev/null +++ b/command/acl/bindingrule/list/bindingrule_list_test.go @@ -0,0 +1,167 @@ +package bindingrulelist + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestBindingRuleListCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestBindingRuleListCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + createRule := func(t *testing.T, methodName, description string) string { + rule, _, err := client.ACL().BindingRuleCreate( + &api.ACLBindingRule{ + AuthMethod: methodName, + Description: description, + BindType: api.BindingRuleBindTypeService, + BindName: "test-${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return rule.ID + } + + var ruleIDs []string + for i := 0; i < 10; i++ { + name := fmt.Sprintf("test-rule-%d", i) + + var methodName string + if i%2 == 0 { + methodName = "test-1" + } else { + methodName = "test-2" + } + + id := createRule(t, methodName, name) + + ruleIDs = append(ruleIDs, id) + } + + t.Run("normal", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + output := ui.OutputWriter.String() + + for i, v := range ruleIDs { + require.Contains(t, output, fmt.Sprintf("test-rule-%d", i)) + require.Contains(t, output, v) + } + }) + + t.Run("filter by method 1", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test-1", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + output := ui.OutputWriter.String() + + for i, v := range ruleIDs { + if i%2 == 0 { + require.Contains(t, output, fmt.Sprintf("test-rule-%d", i)) + require.Contains(t, output, v) + } + } + }) + + t.Run("filter by method 2", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test-2", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + output := ui.OutputWriter.String() + + for i, v := range ruleIDs { + if i%2 == 1 { + require.Contains(t, output, fmt.Sprintf("test-rule-%d", i)) + require.Contains(t, output, v) + } + } + }) +} diff --git a/command/acl/bindingrule/read/bindingrule_read.go b/command/acl/bindingrule/read/bindingrule_read.go new file mode 100644 index 0000000000..677a950cf2 --- /dev/null +++ b/command/acl/bindingrule/read/bindingrule_read.go @@ -0,0 +1,108 @@ +package bindingruleread + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + ruleID string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that binding rule metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.ruleID, + "id", + "", + "The ID of the binding rule to read. "+ + "It may be specified as a unique ID prefix but will error if the prefix "+ + "matches multiple binding rule IDs", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.ruleID == "" { + c.UI.Error(fmt.Sprintf("Must specify the -id parameter.")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + ruleID, err := acl.GetBindingRuleIDFromPartial(client, c.ruleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error determining binding rule ID: %v", err)) + return 1 + } + + rule, _, err := client.ACL().BindingRuleRead(ruleID, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading binding rule %q: %v", ruleID, err)) + return 1 + } else if rule == nil { + c.UI.Error(fmt.Sprintf("Binding rule not found with ID %q", ruleID)) + return 1 + } + + acl.PrintBindingRule(rule, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Read an ACL Binding Rule" +const help = ` +Usage: consul acl binding-rule read -id ID [options] + + This command will retrieve and print out the details of a single binding + rule. + + Read a binding rule: + + $ consul acl binding-rule read -id fdabbcb5-9de5-4b1a-961f-77214ae88cba +` diff --git a/command/acl/bindingrule/read/bindingrule_read_test.go b/command/acl/bindingrule/read/bindingrule_read_test.go new file mode 100644 index 0000000000..205e29e2fa --- /dev/null +++ b/command/acl/bindingrule/read/bindingrule_read_test.go @@ -0,0 +1,152 @@ +package bindingruleread + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestBindingRuleReadCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestBindingRuleReadCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // create an auth method in advance + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + createRule := func(t *testing.T) string { + rule, _, err := client.ACL().BindingRuleCreate( + &api.ACLBindingRule{ + AuthMethod: "test", + Description: "test rule", + BindType: api.BindingRuleBindTypeService, + BindName: "test-${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return rule.ID + } + + t.Run("id required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id parameter") + }) + + t.Run("read by id not found", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + fakeID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Binding rule not found with ID") + }) + + t.Run("read by id", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + id, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("test rule")) + require.Contains(t, output, id) + }) + + t.Run("read by id prefix", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + id[0:5], + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("test rule")) + require.Contains(t, output, id) + }) +} diff --git a/command/acl/bindingrule/update/bindingrule_update.go b/command/acl/bindingrule/update/bindingrule_update.go new file mode 100644 index 0000000000..0f6d23f659 --- /dev/null +++ b/command/acl/bindingrule/update/bindingrule_update.go @@ -0,0 +1,212 @@ +package bindingruleupdate + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + ruleID string + + description string + selector string + bindType string + bindName string + + noMerge bool + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that binding rule metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.ruleID, + "id", + "", + "The ID of the binding rule to update. "+ + "It may be specified as a unique ID prefix but will error if the prefix "+ + "matches multiple binding rule IDs", + ) + + c.flags.StringVar( + &c.description, + "description", + "", + "A description of the binding rule.", + ) + c.flags.StringVar( + &c.selector, + "selector", + "", + "Selector is an expression that matches against verified identity "+ + "attributes returned from the auth method during login.", + ) + c.flags.StringVar( + &c.bindType, + "bind-type", + string(api.BindingRuleBindTypeService), + "Type of binding to perform (\"service\" or \"role\").", + ) + c.flags.StringVar( + &c.bindName, + "bind-name", + "", + "Name to bind on match. Can use ${var} interpolation. "+ + "This flag is required.", + ) + + c.flags.BoolVar( + &c.noMerge, + "no-merge", + false, + "Do not merge the current binding rule "+ + "information with what is provided to the command. Instead overwrite all fields "+ + "with the exception of the binding rule ID which is immutable.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.ruleID == "" { + c.UI.Error(fmt.Sprintf("Cannot update a binding rule without specifying the -id parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + ruleID, err := acl.GetBindingRuleIDFromPartial(client, c.ruleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error determining binding rule ID: %v", err)) + return 1 + } + + // Read the current binding rule in both cases so we can fail better if not found. + currentRule, _, err := client.ACL().BindingRuleRead(ruleID, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error when retrieving current binding rule: %v", err)) + return 1 + } else if currentRule == nil { + c.UI.Error(fmt.Sprintf("Binding rule not found with ID %q", ruleID)) + return 1 + } + + var rule *api.ACLBindingRule + if c.noMerge { + if c.bindType == "" { + c.UI.Error(fmt.Sprintf("Missing required '-bind-type' flag")) + c.UI.Error(c.Help()) + return 1 + } else if c.bindName == "" { + c.UI.Error(fmt.Sprintf("Missing required '-bind-name' flag")) + c.UI.Error(c.Help()) + return 1 + } + + rule = &api.ACLBindingRule{ + ID: ruleID, + AuthMethod: currentRule.AuthMethod, // immutable + Description: c.description, + BindType: api.BindingRuleBindType(c.bindType), + BindName: c.bindName, + Selector: c.selector, + } + + } else { + rule = currentRule + + if c.description != "" { + rule.Description = c.description + } + if c.bindType != "" { + rule.BindType = api.BindingRuleBindType(c.bindType) + } + if c.bindName != "" { + rule.BindName = c.bindName + } + if isFlagSet(c.flags, "selector") { + rule.Selector = c.selector // empty is valid + } + } + + rule, _, err = client.ACL().BindingRuleUpdate(rule, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error updating binding rule %q: %v", ruleID, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Binding rule updated successfully")) + acl.PrintBindingRule(rule, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +func isFlagSet(flags *flag.FlagSet, name string) bool { + found := false + flags.Visit(func(f *flag.Flag) { + if f.Name == name { + found = true + } + }) + return found +} + +const synopsis = "Update an ACL Binding Rule" +const help = ` +Usage: consul acl binding-rule update -id ID [options] + + Updates a binding rule. By default it will merge the binding rule + information with its current state so that you do not have to provide all + parameters. This behavior can be disabled by passing -no-merge. + + Update all editable fields of the binding rule: + + $ consul acl binding-rule update \ + -id=43cb72df-9c6f-4315-ac8a-01a9d98155ef \ + -description="new description" \ + -bind-type=role \ + -bind-name='k8s-${serviceaccount.name}' \ + -selector='serviceaccount.namespace==default and serviceaccount.name==web' +` diff --git a/command/acl/bindingrule/update/bindingrule_update_test.go b/command/acl/bindingrule/update/bindingrule_update_test.go new file mode 100644 index 0000000000..82a6e1fbb4 --- /dev/null +++ b/command/acl/bindingrule/update/bindingrule_update_test.go @@ -0,0 +1,768 @@ +package bindingruleupdate + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + uuid "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestBindingRuleUpdateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestBindingRuleUpdateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // create an auth method in advance + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + deleteRules := func(t *testing.T) { + rules, _, err := client.ACL().BindingRuleList( + "test", + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + + for _, rule := range rules { + _, err := client.ACL().BindingRuleDelete( + rule.ID, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + } + + t.Run("rule id required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Cannot update a binding rule without specifying the -id parameter") + }) + + t.Run("rule id partial matches nothing", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID[0:5], + "-token=root", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID") + }) + + t.Run("rule id exact match doesn't exist", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID, + "-token=root", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Binding rule not found with ID") + }) + + createRule := func(t *testing.T) string { + rule, _, err := client.ACL().BindingRuleCreate( + &api.ACLBindingRule{ + AuthMethod: "test", + Description: "test rule", + BindType: api.BindingRuleBindTypeService, + BindName: "test-${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return rule.ID + } + + createDupe := func(t *testing.T) string { + for { + // Check for 1-char duplicates. + rules, _, err := client.ACL().BindingRuleList( + "test", + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + + m := make(map[byte]struct{}) + for _, rule := range rules { + c := rule.ID[0] + + if _, ok := m[c]; ok { + return string(c) + } + m[c] = struct{}{} + } + + _ = createRule(t) + } + } + + t.Run("rule id partial matches multiple", func(t *testing.T) { + prefix := createDupe(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + prefix, + "-token=root", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID") + }) + + t.Run("must use roughly valid selector", func(t *testing.T) { + id := createRule(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-selector", "foo", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Selector is invalid") + }) + + t.Run("update all fields", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-description=test rule edited", + "-bind-type", "role", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields - partial", func(t *testing.T) { + deleteRules(t) // reset since we created a bunch that might be dupes + + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id[0:5], + "-description=test rule edited", + "-bind-type", "role", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields but description", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-bind-type", "role", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule", rule.Description) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields but bind name", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-description=test rule edited", + "-bind-type", "role", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "test-${serviceaccount.name}", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields but must exist", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-description=test rule edited", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields but selector", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-description=test rule edited", + "-bind-type", "role", + "-bind-name=role-updated", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==default", rule.Selector) + }) + + t.Run("update all fields clear selector", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-description=test rule edited", + "-bind-type", "role", + "-bind-name=role-updated", + "-selector=", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Empty(t, rule.Selector) + }) +} + +func TestBindingRuleUpdateCommand_noMerge(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // create an auth method in advance + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + deleteRules := func(t *testing.T) { + rules, _, err := client.ACL().BindingRuleList( + "test", + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + + for _, rule := range rules { + _, err := client.ACL().BindingRuleDelete( + rule.ID, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + } + + t.Run("rule id required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Cannot update a binding rule without specifying the -id parameter") + }) + + t.Run("rule id partial matches nothing", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID[0:5], + "-token=root", + "-no-merge", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID") + }) + + t.Run("rule id exact match doesn't exist", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID, + "-token=root", + "-no-merge", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Binding rule not found with ID") + }) + + createRule := func(t *testing.T) string { + rule, _, err := client.ACL().BindingRuleCreate( + &api.ACLBindingRule{ + AuthMethod: "test", + Description: "test rule", + BindType: api.BindingRuleBindTypeRole, + BindName: "test-${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return rule.ID + } + + createDupe := func(t *testing.T) string { + for { + // Check for 1-char duplicates. + rules, _, err := client.ACL().BindingRuleList( + "test", + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + + m := make(map[byte]struct{}) + for _, rule := range rules { + c := rule.ID[0] + + if _, ok := m[c]; ok { + return string(c) + } + m[c] = struct{}{} + } + + _ = createRule(t) + } + } + + t.Run("rule id partial matches multiple", func(t *testing.T) { + prefix := createDupe(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + prefix, + "-token=root", + "-no-merge", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID") + }) + + t.Run("must use roughly valid selector", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id", id, + "-description=test rule edited", + "-bind-type", "service", + "-bind-name=role-updated", + "-selector", "foo", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Selector is invalid") + }) + + t.Run("update all fields", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id", id, + "-description=test rule edited", + "-bind-type", "service", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields - partial", func(t *testing.T) { + deleteRules(t) // reset since we created a bunch that might be dupes + + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id", id[0:5], + "-description=test rule edited", + "-bind-type", "service", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields but description", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id", id, + "-bind-type", "service", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Empty(t, rule.Description) + require.Equal(t, api.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("missing bind name", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id=" + id, + "-description=test rule edited", + "-bind-type", "role", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bind-name' flag") + }) + + t.Run("update all fields but selector", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id", id, + "-description=test rule edited", + "-bind-type", "service", + "-bind-name=role-updated", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Empty(t, rule.Selector) + }) +} diff --git a/command/acl/role/create/role_create.go b/command/acl/role/create/role_create.go new file mode 100644 index 0000000000..8ab869c00c --- /dev/null +++ b/command/acl/role/create/role_create.go @@ -0,0 +1,134 @@ +package rolecreate + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + aclhelpers "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + name string + description string + policyIDs []string + policyNames []string + serviceIdents []string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.flags.BoolVar(&c.showMeta, "meta", false, "Indicates that role metadata such "+ + "as the content hash and raft indices should be shown for each entry") + c.flags.StringVar(&c.name, "name", "", "The new role's name. This flag is required.") + c.flags.StringVar(&c.description, "description", "", "A description of the role") + c.flags.Var((*flags.AppendSliceValue)(&c.policyIDs), "policy-id", "ID of a "+ + "policy to use for this role. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.policyNames), "policy-name", "Name of a "+ + "policy to use for this role. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.serviceIdents), "service-identity", "Name of a "+ + "service identity to use for this role. May be specified multiple times. Format is "+ + "the SERVICENAME or SERVICENAME:DATACENTER1,DATACENTER2,...") + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.name == "" { + c.UI.Error(fmt.Sprintf("Missing require '-name' flag")) + c.UI.Error(c.Help()) + return 1 + } + + if len(c.policyNames) == 0 && len(c.policyIDs) == 0 && len(c.serviceIdents) == 0 { + c.UI.Error(fmt.Sprintf("Cannot create a role without specifying -policy-name, -policy-id, or -service-identity at least once")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + newRole := &api.ACLRole{ + Name: c.name, + Description: c.description, + } + + for _, policyName := range c.policyNames { + // We could resolve names to IDs here but there isn't any reason why its would be better + // than allowing the agent to do it. + newRole.Policies = append(newRole.Policies, &api.ACLRolePolicyLink{Name: policyName}) + } + + for _, policyID := range c.policyIDs { + policyID, err := acl.GetPolicyIDFromPartial(client, policyID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error resolving policy ID %s: %v", policyID, err)) + return 1 + } + newRole.Policies = append(newRole.Policies, &api.ACLRolePolicyLink{ID: policyID}) + } + + parsedServiceIdents, err := acl.ExtractServiceIdentities(c.serviceIdents) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + newRole.ServiceIdentities = parsedServiceIdents + + role, _, err := client.ACL().RoleCreate(newRole, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to create new role: %v", err)) + return 1 + } + + aclhelpers.PrintRole(role, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Create an ACL Role" + +const help = ` +Usage: consul acl role create -name NAME [options] + + Create a new role: + + $ consul acl role create -name "new-role" \ + -description "This is an example role" \ + -policy-id b52fc3de-5 \ + -policy-name "acl-replication" \ + -service-identity "web" \ + -service-identity "db:east,west" +` diff --git a/command/acl/role/create/role_create_test.go b/command/acl/role/create/role_create_test.go new file mode 100644 index 0000000000..d592aba99c --- /dev/null +++ b/command/acl/role/create/role_create_test.go @@ -0,0 +1,116 @@ +package rolecreate + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" +) + +func TestRoleCreateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestRoleCreateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + ui := cli.NewMockUi() + cmd := New(ui) + + // Create a policy + client := a.Client() + + policy, _, err := client.ACL().PolicyCreate( + &api.ACLPolicy{Name: "test-policy"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + // create with policy by name + { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=role-with-policy-by-name", + "-description=test-role", + "-policy-name=" + policy.Name, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + } + + // create with policy by id + { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=role-with-policy-by-id", + "-description=test-role", + "-policy-id=" + policy.ID, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + } + + // create with service identity + { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=role-with-service-identity", + "-description=test-role", + "-service-identity=web", + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + } + + // create with service identity scoped to 2 DCs + { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=role-with-service-identity-in-2-dcs", + "-description=test-role", + "-service-identity=db:abc,xyz", + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + } +} diff --git a/command/acl/role/delete/role_delete.go b/command/acl/role/delete/role_delete.go new file mode 100644 index 0000000000..5e1b17ad4b --- /dev/null +++ b/command/acl/role/delete/role_delete.go @@ -0,0 +1,98 @@ +package roledelete + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + roleID string + roleName string +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.flags.StringVar(&c.roleID, "id", "", "The ID of the role to delete. "+ + "It may be specified as a unique ID prefix but will error if the prefix "+ + "matches multiple role IDs") + c.flags.StringVar(&c.roleName, "name", "", "The name of the role to delete.") + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.roleID == "" && c.roleName == "" { + c.UI.Error(fmt.Sprintf("Must specify the -id or -name parameters")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + var roleID string + if c.roleID != "" { + roleID, err = acl.GetRoleIDFromPartial(client, c.roleID) + } else { + roleID, err = acl.GetRoleIDByName(client, c.roleName) + } + if err != nil { + c.UI.Error(fmt.Sprintf("Error determining role ID: %v", err)) + return 1 + } + + if _, err := client.ACL().RoleDelete(roleID, nil); err != nil { + c.UI.Error(fmt.Sprintf("Error deleting role %q: %v", roleID, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Role %q deleted successfully", roleID)) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Delete an ACL Role" +const help = ` +Usage: consul acl role delete [options] -id ROLE + + Deletes an ACL role by providing the ID or a unique ID prefix. + + Delete by prefix: + + $ consul acl role delete -id b6b85 + + Delete by full ID: + + $ consul acl role delete -id b6b856da-5193-4e78-845a-7d61ca8371ba + +` diff --git a/command/acl/role/delete/role_delete_test.go b/command/acl/role/delete/role_delete_test.go new file mode 100644 index 0000000000..25f2faf0af --- /dev/null +++ b/command/acl/role/delete/role_delete_test.go @@ -0,0 +1,141 @@ +package roledelete + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" +) + +func TestRoleDeleteCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestRoleDeleteCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("id or name required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id or -name parameters") + }) + + t.Run("delete works", func(t *testing.T) { + // Create a role + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{ + Name: "test-role-for-id-delete", + ServiceIdentities: []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ + ServiceName: "fake", + }, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + role.ID, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, role.ID) + + role, _, err = client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.Nil(t, role) + }) + + t.Run("delete works via prefixes", func(t *testing.T) { + // Create a role + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{ + Name: "test-role-for-id-prefix-delete", + ServiceIdentities: []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ + ServiceName: "fake", + }, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + role.ID[0:5], + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, role.ID) + + role, _, err = client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.Nil(t, role) + }) +} diff --git a/command/acl/role/list/role_list.go b/command/acl/role/list/role_list.go new file mode 100644 index 0000000000..95a3741890 --- /dev/null +++ b/command/acl/role/list/role_list.go @@ -0,0 +1,79 @@ +package rolelist + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.flags.BoolVar(&c.showMeta, "meta", false, "Indicates that policy metadata such "+ + "as the content hash and raft indices should be shown for each entry") + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + roles, _, err := client.ACL().RoleList(nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to retrieve the role list: %v", err)) + return 1 + } + + for _, role := range roles { + acl.PrintRoleListEntry(role, c.UI, c.showMeta) + } + + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Lists ACL Roles" +const help = ` +Usage: consul acl role list [options] + + Lists all the ACL roles. + + Example: + + $ consul acl role list +` diff --git a/command/acl/role/list/role_list_test.go b/command/acl/role/list/role_list_test.go new file mode 100644 index 0000000000..5da280f3d3 --- /dev/null +++ b/command/acl/role/list/role_list_test.go @@ -0,0 +1,83 @@ +package rolelist + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" +) + +func TestRoleListCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestRoleListCommand(t *testing.T) { + t.Parallel() + require := require.New(t) + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + ui := cli.NewMockUi() + cmd := New(ui) + + var roleIDs []string + + // Create a couple roles to list + client := a.Client() + svcids := []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ServiceName: "fake"}, + } + for i := 0; i < 5; i++ { + name := fmt.Sprintf("test-role-%d", i) + + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{Name: name, ServiceIdentities: svcids}, + &api.WriteOptions{Token: "root"}, + ) + roleIDs = append(roleIDs, role.ID) + + require.NoError(err) + } + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(code, 0) + require.Empty(ui.ErrorWriter.String()) + output := ui.OutputWriter.String() + + for i, v := range roleIDs { + require.Contains(output, fmt.Sprintf("test-role-%d", i)) + require.Contains(output, v) + } +} diff --git a/command/acl/role/read/role_read.go b/command/acl/role/read/role_read.go new file mode 100644 index 0000000000..fb51b8d099 --- /dev/null +++ b/command/acl/role/read/role_read.go @@ -0,0 +1,115 @@ +package roleread + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + roleID string + roleName string + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.flags.BoolVar(&c.showMeta, "meta", false, "Indicates that role metadata such "+ + "as the content hash and raft indices should be shown for each entry") + c.flags.StringVar(&c.roleID, "id", "", "The ID of the role to read. "+ + "It may be specified as a unique ID prefix but will error if the prefix "+ + "matches multiple policy IDs") + c.flags.StringVar(&c.roleName, "name", "", "The name of the role to read.") + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.roleID == "" && c.roleName == "" { + c.UI.Error(fmt.Sprintf("Must specify either the -id or -name parameters")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + var role *api.ACLRole + + if c.roleID != "" { + roleID, err := acl.GetRoleIDFromPartial(client, c.roleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error determining role ID: %v", err)) + return 1 + } + role, _, err = client.ACL().RoleRead(roleID, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading role %q: %v", roleID, err)) + return 1 + } else if role == nil { + c.UI.Error(fmt.Sprintf("Role not found with ID %q", roleID)) + return 1 + } + + } else { + role, _, err = client.ACL().RoleReadByName(c.roleName, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading role %q: %v", c.roleName, err)) + return 1 + } else if role == nil { + c.UI.Error(fmt.Sprintf("Role not found with name %q", c.roleName)) + return 1 + } + } + + acl.PrintRole(role, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Read an ACL Role" +const help = ` +Usage: consul acl role read [options] ROLE + + This command will retrieve and print out the details + of a single role. + + Read: + + $ consul acl role read -id fdabbcb5-9de5-4b1a-961f-77214ae88cba + + Read by name: + + $ consul acl role read -name my-policy + +` diff --git a/command/acl/role/read/role_read_test.go b/command/acl/role/read/role_read_test.go new file mode 100644 index 0000000000..f0f7b45dd7 --- /dev/null +++ b/command/acl/role/read/role_read_test.go @@ -0,0 +1,194 @@ +package roleread + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" +) + +func TestRoleReadCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestRoleReadCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("id or name required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify either the -id or -name parameters") + }) + + t.Run("read by id not found", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + fakeID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Role not found with ID") + }) + + t.Run("read by name not found", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=blah", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Role not found with name") + }) + + t.Run("read by id", func(t *testing.T) { + // create a role + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{ + Name: "test-role-by-id", + ServiceIdentities: []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ + ServiceName: "fake", + }, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + role.ID, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("test-role")) + require.Contains(t, output, role.ID) + }) + + t.Run("read by id prefix", func(t *testing.T) { + // create a role + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{ + Name: "test-role-by-id-prefix", + ServiceIdentities: []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ + ServiceName: "fake", + }, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + role.ID[0:5], + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("test-role")) + require.Contains(t, output, role.ID) + }) + + t.Run("read by name", func(t *testing.T) { + // create a role + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{ + Name: "test-role-by-name", + ServiceIdentities: []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ + ServiceName: "fake", + }, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + role.Name, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("test-role")) + require.Contains(t, output, role.ID) + }) +} diff --git a/command/acl/role/role.go b/command/acl/role/role.go new file mode 100644 index 0000000000..87bf01b3d7 --- /dev/null +++ b/command/acl/role/role.go @@ -0,0 +1,56 @@ +package role + +import ( + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New() *cmd { + return &cmd{} +} + +type cmd struct{} + +func (c *cmd) Run(args []string) int { + return cli.RunResultHelp +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(help, nil) +} + +const synopsis = "Manage Consul's ACL Roles" +const help = ` +Usage: consul acl role [options] [args] + + This command has subcommands for managing Consul's ACL Roles. + Here are some simple examples, and more detailed examples are available + in the subcommands or the documentation. + + Create a new ACL Role: + + $ consul acl role create -name "new-role" \ + -description "This is an example role" \ + -policy-id 06acc965 + List all roles: + + $ consul acl role list + + Update a role: + + $ consul acl role update -name "other-role" -datacenter "dc1" + + Read a role: + + $ consul acl role read -id 0479e93e-091c-4475-9b06-79a004765c24 + + Delete a role + + $ consul acl role delete -name "my-role" + + For more examples, ask for subcommand help or view the documentation. +` diff --git a/command/acl/role/update/role_update.go b/command/acl/role/update/role_update.go new file mode 100644 index 0000000000..6327755ccb --- /dev/null +++ b/command/acl/role/update/role_update.go @@ -0,0 +1,225 @@ +package roleupdate + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + roleID string + name string + description string + policyIDs []string + policyNames []string + serviceIdents []string + + noMerge bool + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.flags.BoolVar(&c.showMeta, "meta", false, "Indicates that role metadata such "+ + "as the content hash and raft indices should be shown for each entry") + c.flags.StringVar(&c.roleID, "id", "", "The ID of the role to update. "+ + "It may be specified as a unique ID prefix but will error if the prefix "+ + "matches multiple role IDs") + c.flags.StringVar(&c.name, "name", "", "The role name.") + c.flags.StringVar(&c.description, "description", "", "A description of the role") + c.flags.Var((*flags.AppendSliceValue)(&c.policyIDs), "policy-id", "ID of a "+ + "policy to use for this role. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.policyNames), "policy-name", "Name of a "+ + "policy to use for this role. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.serviceIdents), "service-identity", "Name of a "+ + "service identity to use for this role. May be specified multiple times. Format is "+ + "the SERVICENAME or SERVICENAME:DATACENTER1,DATACENTER2,...") + c.flags.BoolVar(&c.noMerge, "no-merge", false, "Do not merge the current role "+ + "information with what is provided to the command. Instead overwrite all fields "+ + "with the exception of the role ID which is immutable.") + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.roleID == "" { + c.UI.Error(fmt.Sprintf("Cannot update a role without specifying the -id parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + roleID, err := acl.GetRoleIDFromPartial(client, c.roleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error determining role ID: %v", err)) + return 1 + } + + parsedServiceIdents, err := acl.ExtractServiceIdentities(c.serviceIdents) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + + // Read the current role in both cases so we can fail better if not found. + currentRole, _, err := client.ACL().RoleRead(roleID, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error when retrieving current role: %v", err)) + return 1 + } else if currentRole == nil { + c.UI.Error(fmt.Sprintf("Role not found with ID %q", roleID)) + return 1 + } + + var role *api.ACLRole + if c.noMerge { + role = &api.ACLRole{ + ID: c.roleID, + Name: c.name, + Description: c.description, + ServiceIdentities: parsedServiceIdents, + } + + for _, policyName := range c.policyNames { + // We could resolve names to IDs here but there isn't any reason + // why its would be better than allowing the agent to do it. + role.Policies = append(role.Policies, &api.ACLRolePolicyLink{Name: policyName}) + } + + for _, policyID := range c.policyIDs { + policyID, err := acl.GetPolicyIDFromPartial(client, policyID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error resolving policy ID %s: %v", policyID, err)) + return 1 + } + role.Policies = append(role.Policies, &api.ACLRolePolicyLink{ID: policyID}) + } + } else { + role = currentRole + + if c.name != "" { + role.Name = c.name + } + if c.description != "" { + role.Description = c.description + } + + for _, policyName := range c.policyNames { + found := false + for _, link := range role.Policies { + if link.Name == policyName { + found = true + break + } + } + + if !found { + // We could resolve names to IDs here but there isn't any + // reason why its would be better than allowing the agent to do + // it. + role.Policies = append(role.Policies, &api.ACLRolePolicyLink{Name: policyName}) + } + } + + for _, policyID := range c.policyIDs { + policyID, err := acl.GetPolicyIDFromPartial(client, policyID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error resolving policy ID %s: %v", policyID, err)) + return 1 + } + found := false + + for _, link := range role.Policies { + if link.ID == policyID { + found = true + break + } + } + + if !found { + role.Policies = append(role.Policies, &api.ACLRolePolicyLink{ID: policyID}) + } + } + + for _, svcid := range parsedServiceIdents { + found := -1 + for i, link := range role.ServiceIdentities { + if link.ServiceName == svcid.ServiceName { + found = i + break + } + } + + if found != -1 { + role.ServiceIdentities[found] = svcid + } else { + role.ServiceIdentities = append(role.ServiceIdentities, svcid) + } + } + } + + role, _, err = client.ACL().RoleUpdate(role, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error updating role %q: %v", roleID, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Role updated successfully")) + acl.PrintRole(role, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Update an ACL Role" +const help = ` +Usage: consul acl role update [options] + + Updates a role. By default it will merge the role information with its + current state so that you do not have to provide all parameters. This + behavior can be disabled by passing -no-merge. + + Rename the Role: + + $ consul acl role update -id abcd -name "better-name" + + Update all editable fields of the role: + + $ consul acl role update -id abcd \ + -name "better-name" \ + -description "replication" \ + -policy-name "token-replication" \ + -service-identity "web" +` diff --git a/command/acl/role/update/role_update_test.go b/command/acl/role/update/role_update_test.go new file mode 100644 index 0000000000..c9094e7286 --- /dev/null +++ b/command/acl/role/update/role_update_test.go @@ -0,0 +1,398 @@ +package roleupdate + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + uuid "github.com/hashicorp/go-uuid" +) + +func TestRoleUpdateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestRoleUpdateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // Create 2 policies + policy1, _, err := client.ACL().PolicyCreate( + &api.ACLPolicy{Name: "test-policy1"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + policy2, _, err := client.ACL().PolicyCreate( + &api.ACLPolicy{Name: "test-policy2"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + // create a role + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{ + Name: "test-role", + ServiceIdentities: []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ + ServiceName: "fake", + }, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + t.Run("update a role that does not exist", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID, + "-token=root", + "-policy-name=" + policy1.Name, + "-description=test role edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Role not found with ID") + }) + + t.Run("update with policy by name", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-token=root", + "-policy-name=" + policy1.Name, + "-description=test role edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "test role edited", role.Description) + require.Len(t, role.Policies, 1) + require.Len(t, role.ServiceIdentities, 1) + }) + + t.Run("update with policy by id", func(t *testing.T) { + // also update with no description shouldn't delete the current + // description + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-token=root", + "-policy-id=" + policy2.ID, + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "test role edited", role.Description) + require.Len(t, role.Policies, 2) + require.Len(t, role.ServiceIdentities, 1) + }) + + t.Run("update with service identity", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-token=root", + "-service-identity=web", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "test role edited", role.Description) + require.Len(t, role.Policies, 2) + require.Len(t, role.ServiceIdentities, 2) + }) + + t.Run("update with service identity scoped to 2 DCs", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-token=root", + "-service-identity=db:abc,xyz", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "test role edited", role.Description) + require.Len(t, role.Policies, 2) + require.Len(t, role.ServiceIdentities, 3) + }) +} + +func TestRoleUpdateCommand_noMerge(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // Create 3 policies + policy1, _, err := client.ACL().PolicyCreate( + &api.ACLPolicy{Name: "test-policy1"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + policy2, _, err := client.ACL().PolicyCreate( + &api.ACLPolicy{Name: "test-policy2"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + policy3, _, err := client.ACL().PolicyCreate( + &api.ACLPolicy{Name: "test-policy3"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + // create a role + createRole := func(t *testing.T) *api.ACLRole { + roleUnq, err := uuid.GenerateUUID() + require.NoError(t, err) + + role, _, err := client.ACL().RoleCreate( + &api.ACLRole{ + Name: "test-role-" + roleUnq, + Description: "original description", + ServiceIdentities: []*api.ACLServiceIdentity{ + &api.ACLServiceIdentity{ + ServiceName: "fake", + }, + }, + Policies: []*api.ACLRolePolicyLink{ + &api.ACLRolePolicyLink{ + ID: policy3.ID, + }, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return role + } + + t.Run("update a role that does not exist", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID, + "-token=root", + "-policy-name=" + policy1.Name, + "-no-merge", + "-description=test role edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Role not found with ID") + }) + + t.Run("update with policy by name", func(t *testing.T) { + role := createRole(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-name=" + role.Name, + "-token=root", + "-no-merge", + "-policy-name=" + policy1.Name, + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "", role.Description) + require.Len(t, role.Policies, 1) + require.Len(t, role.ServiceIdentities, 0) + }) + + t.Run("update with policy by id", func(t *testing.T) { + role := createRole(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-name=" + role.Name, + "-token=root", + "-no-merge", + "-policy-id=" + policy2.ID, + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "", role.Description) + require.Len(t, role.Policies, 1) + require.Len(t, role.ServiceIdentities, 0) + }) + + t.Run("update with service identity", func(t *testing.T) { + role := createRole(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-name=" + role.Name, + "-token=root", + "-no-merge", + "-service-identity=web", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "", role.Description) + require.Len(t, role.Policies, 0) + require.Len(t, role.ServiceIdentities, 1) + }) + + t.Run("update with service identity scoped to 2 DCs", func(t *testing.T) { + role := createRole(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + role.ID, + "-name=" + role.Name, + "-token=root", + "-no-merge", + "-service-identity=db:abc,xyz", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + role, _, err := client.ACL().RoleRead( + role.ID, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, role) + require.Equal(t, "", role.Description) + require.Len(t, role.Policies, 0) + require.Len(t, role.ServiceIdentities, 1) + }) +} diff --git a/command/acl/token/clone/token_clone_test.go b/command/acl/token/clone/token_clone_test.go index 0baa0c4f2e..55a81a682d 100644 --- a/command/acl/token/clone/token_clone_test.go +++ b/command/acl/token/clone/token_clone_test.go @@ -19,11 +19,11 @@ import ( func parseCloneOutput(t *testing.T, output string) *api.ACLToken { // This will only work for non-legacy tokens re := regexp.MustCompile("Token cloned successfully.\n" + - "AccessorID: ([a-zA-Z0-9\\-]{36})\n" + - "SecretID: ([a-zA-Z0-9\\-]{36})\n" + - "Description: ([^\n]*)\n" + - "Local: (true|false)\n" + - "Create Time: ([^\n]+)\n" + + "AccessorID: ([a-zA-Z0-9\\-]{36})\n" + + "SecretID: ([a-zA-Z0-9\\-]{36})\n" + + "Description: ([^\n]*)\n" + + "Local: (true|false)\n" + + "Create Time: ([^\n]+)\n" + "Policies:\n" + "( [a-zA-Z0-9\\-]{36} - [^\n]+\n)*") diff --git a/command/acl/token/create/token_create.go b/command/acl/token/create/token_create.go index 83a17cd612..5d55563919 100644 --- a/command/acl/token/create/token_create.go +++ b/command/acl/token/create/token_create.go @@ -3,6 +3,7 @@ package tokencreate import ( "flag" "fmt" + "time" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/command/acl" @@ -22,11 +23,15 @@ type cmd struct { http *flags.HTTPFlags help string - policyIDs []string - policyNames []string - description string - local bool - showMeta bool + policyIDs []string + policyNames []string + roleIDs []string + roleNames []string + serviceIdents []string + expirationTTL time.Duration + description string + local bool + showMeta bool } func (c *cmd) init() { @@ -39,6 +44,15 @@ func (c *cmd) init() { "policy to use for this token. May be specified multiple times") c.flags.Var((*flags.AppendSliceValue)(&c.policyNames), "policy-name", "Name of a "+ "policy to use for this token. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.roleIDs), "role-id", "ID of a "+ + "role to use for this token. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.roleNames), "role-name", "Name of a "+ + "role to use for this token. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.serviceIdents), "service-identity", "Name of a "+ + "service identity to use for this token. May be specified multiple times. Format is "+ + "the SERVICENAME or SERVICENAME:DATACENTER1,DATACENTER2,...") + c.flags.DurationVar(&c.expirationTTL, "expires-ttl", 0, "Duration of time this "+ + "token should be valid for") c.http = &flags.HTTPFlags{} flags.Merge(c.flags, c.http.ClientFlags()) flags.Merge(c.flags, c.http.ServerFlags()) @@ -50,8 +64,10 @@ func (c *cmd) Run(args []string) int { return 1 } - if len(c.policyNames) == 0 && len(c.policyIDs) == 0 { - c.UI.Error(fmt.Sprintf("Cannot create a token without specifying -policy-name or -policy-id at least once")) + if len(c.policyNames) == 0 && len(c.policyIDs) == 0 && + len(c.roleNames) == 0 && len(c.roleIDs) == 0 && + len(c.serviceIdents) == 0 { + c.UI.Error(fmt.Sprintf("Cannot create a token without specifying -policy-name, -policy-id, -role-name, -role-id, or -service-identity at least once")) return 1 } @@ -65,6 +81,16 @@ func (c *cmd) Run(args []string) int { Description: c.description, Local: c.local, } + if c.expirationTTL > 0 { + newToken.ExpirationTTL = c.expirationTTL + } + + parsedServiceIdents, err := acl.ExtractServiceIdentities(c.serviceIdents) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + newToken.ServiceIdentities = parsedServiceIdents for _, policyName := range c.policyNames { // We could resolve names to IDs here but there isn't any reason why its would be better @@ -81,6 +107,21 @@ func (c *cmd) Run(args []string) int { newToken.Policies = append(newToken.Policies, &api.ACLTokenPolicyLink{ID: policyID}) } + for _, roleName := range c.roleNames { + // We could resolve names to IDs here but there isn't any reason why its would be better + // than allowing the agent to do it. + newToken.Roles = append(newToken.Roles, &api.ACLTokenRoleLink{Name: roleName}) + } + + for _, roleID := range c.roleIDs { + roleID, err := acl.GetRoleIDFromPartial(client, roleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error resolving role ID %s: %v", roleID, err)) + return 1 + } + newToken.Roles = append(newToken.Roles, &api.ACLTokenRoleLink{ID: roleID}) + } + token, _, err := client.ACL().TokenCreate(newToken, nil) if err != nil { c.UI.Error(fmt.Sprintf("Failed to create new token: %v", err)) @@ -109,7 +150,11 @@ Usage: consul acl token create [options] Create a new token: - $ consul acl token create -description "Replication token" - -policy-id b52fc3de-5 - -policy-name "acl-replication" + $ consul acl token create -description "Replication token" \ + -policy-id b52fc3de-5 \ + -policy-name "acl-replication" \ + -role-id c630d4ef-6 \ + -role-name "db-updater" \ + -service-identity "web" \ + -service-identity "db:east,west" ` diff --git a/command/acl/token/update/token_update.go b/command/acl/token/update/token_update.go index 09df170b5a..72663f8193 100644 --- a/command/acl/token/update/token_update.go +++ b/command/acl/token/update/token_update.go @@ -22,13 +22,18 @@ type cmd struct { http *flags.HTTPFlags help string - tokenID string - policyIDs []string - policyNames []string - description string - mergePolicies bool - showMeta bool - upgradeLegacy bool + tokenID string + policyIDs []string + policyNames []string + roleIDs []string + roleNames []string + serviceIdents []string + description string + mergePolicies bool + mergeRoles bool + mergeServiceIdents bool + showMeta bool + upgradeLegacy bool } func (c *cmd) init() { @@ -37,6 +42,10 @@ func (c *cmd) init() { "as the content hash and raft indices should be shown for each entry") c.flags.BoolVar(&c.mergePolicies, "merge-policies", false, "Merge the new policies "+ "with the existing policies") + c.flags.BoolVar(&c.mergeRoles, "merge-roles", false, "Merge the new roles "+ + "with the existing roles") + c.flags.BoolVar(&c.mergeServiceIdents, "merge-service-identities", false, "Merge the new service identities "+ + "with the existing service identities") c.flags.StringVar(&c.tokenID, "id", "", "The Accessor ID of the token to read. "+ "It may be specified as a unique ID prefix but will error if the prefix "+ "matches multiple token Accessor IDs") @@ -45,6 +54,13 @@ func (c *cmd) init() { "policy to use for this token. May be specified multiple times") c.flags.Var((*flags.AppendSliceValue)(&c.policyNames), "policy-name", "Name of a "+ "policy to use for this token. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.roleIDs), "role-id", "ID of a "+ + "role to use for this token. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.roleNames), "role-name", "Name of a "+ + "role to use for this token. May be specified multiple times") + c.flags.Var((*flags.AppendSliceValue)(&c.serviceIdents), "service-identity", "Name of a "+ + "service identity to use for this token. May be specified multiple times. Format is "+ + "the SERVICENAME or SERVICENAME:DATACENTER1,DATACENTER2,...") c.flags.BoolVar(&c.upgradeLegacy, "upgrade-legacy", false, "Add new polices "+ "to a legacy token replacing all existing rules. This will cause the legacy "+ "token to behave exactly like a new token but keep the same Secret.\n"+ @@ -107,6 +123,12 @@ func (c *cmd) Run(args []string) int { token.Description = c.description } + parsedServiceIdents, err := acl.ExtractServiceIdentities(c.serviceIdents) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + if c.mergePolicies { for _, policyName := range c.policyNames { found := false @@ -162,6 +184,81 @@ func (c *cmd) Run(args []string) int { } } + if c.mergeRoles { + for _, roleName := range c.roleNames { + found := false + for _, link := range token.Roles { + if link.Name == roleName { + found = true + break + } + } + + if !found { + // We could resolve names to IDs here but there isn't any reason why its would be better + // than allowing the agent to do it. + token.Roles = append(token.Roles, &api.ACLTokenRoleLink{Name: roleName}) + } + } + + for _, roleID := range c.roleIDs { + roleID, err := acl.GetRoleIDFromPartial(client, roleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error resolving role ID %s: %v", roleID, err)) + return 1 + } + found := false + + for _, link := range token.Roles { + if link.ID == roleID { + found = true + break + } + } + + if !found { + token.Roles = append(token.Roles, &api.ACLTokenRoleLink{Name: roleID}) + } + } + } else { + token.Roles = nil + + for _, roleName := range c.roleNames { + // We could resolve names to IDs here but there isn't any reason why its would be better + // than allowing the agent to do it. + token.Roles = append(token.Roles, &api.ACLTokenRoleLink{Name: roleName}) + } + + for _, roleID := range c.roleIDs { + roleID, err := acl.GetRoleIDFromPartial(client, roleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error resolving role ID %s: %v", roleID, err)) + return 1 + } + token.Roles = append(token.Roles, &api.ACLTokenRoleLink{ID: roleID}) + } + } + + if c.mergeServiceIdents { + for _, svcid := range parsedServiceIdents { + found := -1 + for i, link := range token.ServiceIdentities { + if link.ServiceName == svcid.ServiceName { + found = i + break + } + } + + if found != -1 { + token.ServiceIdentities[found] = svcid + } else { + token.ServiceIdentities = append(token.ServiceIdentities, svcid) + } + } + } else { + token.ServiceIdentities = parsedServiceIdents + } + token, _, err = client.ACL().TokenUpdate(token, nil) if err != nil { c.UI.Error(fmt.Sprintf("Failed to update token %s: %v", tokenID, err)) @@ -192,7 +289,10 @@ Usage: consul acl token update [options] $ consul acl token update -id abcd -description "replication" -merge-policies - Update all editable fields of the token: + Update all editable fields of the token: - $ consul acl token update -id abcd -description "replication" -policy-name "token-replication" + $ consul acl token update -id abcd \ + -description "replication" \ + -policy-name "token-replication" \ + -role-name "db-updater" ` diff --git a/command/commands_oss.go b/command/commands_oss.go index 91fee14397..89c3a2d408 100644 --- a/command/commands_oss.go +++ b/command/commands_oss.go @@ -3,6 +3,18 @@ package command import ( "github.com/hashicorp/consul/command/acl" aclagent "github.com/hashicorp/consul/command/acl/agenttokens" + aclam "github.com/hashicorp/consul/command/acl/authmethod" + aclamcreate "github.com/hashicorp/consul/command/acl/authmethod/create" + aclamdelete "github.com/hashicorp/consul/command/acl/authmethod/delete" + aclamlist "github.com/hashicorp/consul/command/acl/authmethod/list" + aclamread "github.com/hashicorp/consul/command/acl/authmethod/read" + aclamupdate "github.com/hashicorp/consul/command/acl/authmethod/update" + aclbr "github.com/hashicorp/consul/command/acl/bindingrule" + aclbrcreate "github.com/hashicorp/consul/command/acl/bindingrule/create" + aclbrdelete "github.com/hashicorp/consul/command/acl/bindingrule/delete" + aclbrlist "github.com/hashicorp/consul/command/acl/bindingrule/list" + aclbrread "github.com/hashicorp/consul/command/acl/bindingrule/read" + aclbrupdate "github.com/hashicorp/consul/command/acl/bindingrule/update" aclbootstrap "github.com/hashicorp/consul/command/acl/bootstrap" aclpolicy "github.com/hashicorp/consul/command/acl/policy" aclpcreate "github.com/hashicorp/consul/command/acl/policy/create" @@ -10,6 +22,12 @@ import ( aclplist "github.com/hashicorp/consul/command/acl/policy/list" aclpread "github.com/hashicorp/consul/command/acl/policy/read" aclpupdate "github.com/hashicorp/consul/command/acl/policy/update" + aclrole "github.com/hashicorp/consul/command/acl/role" + aclrcreate "github.com/hashicorp/consul/command/acl/role/create" + aclrdelete "github.com/hashicorp/consul/command/acl/role/delete" + aclrlist "github.com/hashicorp/consul/command/acl/role/list" + aclrread "github.com/hashicorp/consul/command/acl/role/read" + aclrupdate "github.com/hashicorp/consul/command/acl/role/update" aclrules "github.com/hashicorp/consul/command/acl/rules" acltoken "github.com/hashicorp/consul/command/acl/token" acltclone "github.com/hashicorp/consul/command/acl/token/clone" @@ -51,6 +69,8 @@ import ( kvput "github.com/hashicorp/consul/command/kv/put" "github.com/hashicorp/consul/command/leave" "github.com/hashicorp/consul/command/lock" + login "github.com/hashicorp/consul/command/login" + logout "github.com/hashicorp/consul/command/logout" "github.com/hashicorp/consul/command/maint" "github.com/hashicorp/consul/command/members" "github.com/hashicorp/consul/command/monitor" @@ -106,6 +126,24 @@ func init() { Register("acl token read", func(ui cli.Ui) (cli.Command, error) { return acltread.New(ui), nil }) Register("acl token update", func(ui cli.Ui) (cli.Command, error) { return acltupdate.New(ui), nil }) Register("acl token delete", func(ui cli.Ui) (cli.Command, error) { return acltdelete.New(ui), nil }) + Register("acl role", func(cli.Ui) (cli.Command, error) { return aclrole.New(), nil }) + Register("acl role create", func(ui cli.Ui) (cli.Command, error) { return aclrcreate.New(ui), nil }) + Register("acl role list", func(ui cli.Ui) (cli.Command, error) { return aclrlist.New(ui), nil }) + Register("acl role read", func(ui cli.Ui) (cli.Command, error) { return aclrread.New(ui), nil }) + Register("acl role update", func(ui cli.Ui) (cli.Command, error) { return aclrupdate.New(ui), nil }) + Register("acl role delete", func(ui cli.Ui) (cli.Command, error) { return aclrdelete.New(ui), nil }) + Register("acl auth-method", func(cli.Ui) (cli.Command, error) { return aclam.New(), nil }) + Register("acl auth-method create", func(ui cli.Ui) (cli.Command, error) { return aclamcreate.New(ui), nil }) + Register("acl auth-method list", func(ui cli.Ui) (cli.Command, error) { return aclamlist.New(ui), nil }) + Register("acl auth-method read", func(ui cli.Ui) (cli.Command, error) { return aclamread.New(ui), nil }) + Register("acl auth-method update", func(ui cli.Ui) (cli.Command, error) { return aclamupdate.New(ui), nil }) + Register("acl auth-method delete", func(ui cli.Ui) (cli.Command, error) { return aclamdelete.New(ui), nil }) + Register("acl binding-rule", func(cli.Ui) (cli.Command, error) { return aclbr.New(), nil }) + Register("acl binding-rule create", func(ui cli.Ui) (cli.Command, error) { return aclbrcreate.New(ui), nil }) + Register("acl binding-rule list", func(ui cli.Ui) (cli.Command, error) { return aclbrlist.New(ui), nil }) + Register("acl binding-rule read", func(ui cli.Ui) (cli.Command, error) { return aclbrread.New(ui), nil }) + Register("acl binding-rule update", func(ui cli.Ui) (cli.Command, error) { return aclbrupdate.New(ui), nil }) + Register("acl binding-rule delete", func(ui cli.Ui) (cli.Command, error) { return aclbrdelete.New(ui), nil }) Register("agent", func(ui cli.Ui) (cli.Command, error) { return agent.New(ui, rev, ver, verPre, verHuman, make(chan struct{})), nil }) @@ -141,6 +179,8 @@ func init() { Register("kv put", func(ui cli.Ui) (cli.Command, error) { return kvput.New(ui), nil }) Register("leave", func(ui cli.Ui) (cli.Command, error) { return leave.New(ui), nil }) Register("lock", func(ui cli.Ui) (cli.Command, error) { return lock.New(ui), nil }) + Register("login", func(ui cli.Ui) (cli.Command, error) { return login.New(ui), nil }) + Register("logout", func(ui cli.Ui) (cli.Command, error) { return logout.New(ui), nil }) Register("maint", func(ui cli.Ui) (cli.Command, error) { return maint.New(ui), nil }) Register("members", func(ui cli.Ui) (cli.Command, error) { return members.New(ui), nil }) Register("monitor", func(ui cli.Ui) (cli.Command, error) { return monitor.New(ui, MakeShutdownCh()), nil }) diff --git a/command/connect/envoy/envoy.go b/command/connect/envoy/envoy.go index 3500b69476..5aa9bea182 100644 --- a/command/connect/envoy/envoy.go +++ b/command/connect/envoy/envoy.go @@ -104,7 +104,7 @@ func (c *cmd) Run(args []string) int { // enabled. c.grpcAddr = "localhost:8502" } - if c.http.Token() == "" { + if c.http.Token() == "" && c.http.TokenFile() == "" { // Extra check needed since CONSUL_HTTP_TOKEN has not been consulted yet but // calling SetToken with empty will force that to override the if proxyToken := os.Getenv(proxyAgent.EnvProxyToken); proxyToken != "" { diff --git a/command/connect/proxy/proxy.go b/command/connect/proxy/proxy.go index a60f2217e7..4c99ab5730 100644 --- a/command/connect/proxy/proxy.go +++ b/command/connect/proxy/proxy.go @@ -129,7 +129,7 @@ func (c *cmd) Run(args []string) int { if c.sidecarFor == "" { c.sidecarFor = os.Getenv(proxyAgent.EnvSidecarFor) } - if c.http.Token() == "" { + if c.http.Token() == "" && c.http.TokenFile() == "" { c.http.SetToken(os.Getenv(proxyAgent.EnvProxyToken)) } diff --git a/command/flags/http.go b/command/flags/http.go index 7d02f6ab3b..e2688fab8c 100644 --- a/command/flags/http.go +++ b/command/flags/http.go @@ -2,6 +2,8 @@ package flags import ( "flag" + "io/ioutil" + "strings" "github.com/hashicorp/consul/api" ) @@ -10,6 +12,7 @@ type HTTPFlags struct { // client api flags address StringValue token StringValue + tokenFile StringValue caFile StringValue caPath StringValue certFile StringValue @@ -33,6 +36,10 @@ func (f *HTTPFlags) ClientFlags() *flag.FlagSet { "ACL token to use in the request. This can also be specified via the "+ "CONSUL_HTTP_TOKEN environment variable. If unspecified, the query will "+ "default to the token of the Consul agent at the HTTP address.") + fs.Var(&f.tokenFile, "token-file", + "File containing the ACL token to use in the request instead of one specified "+ + "via the -token argument or CONSUL_HTTP_TOKEN environment variable. "+ + "This can also be specified via the CONSUL_HTTP_TOKEN_FILE environment variable.") fs.Var(&f.caFile, "ca-file", "Path to a CA file to use for TLS when communicating with Consul. This "+ "can also be specified via the CONSUL_CACERT environment variable.") @@ -88,6 +95,28 @@ func (f *HTTPFlags) SetToken(v string) error { return f.token.Set(v) } +func (f *HTTPFlags) TokenFile() string { + return f.tokenFile.String() +} + +func (f *HTTPFlags) SetTokenFile(v string) error { + return f.tokenFile.Set(v) +} + +func (f *HTTPFlags) ReadTokenFile() (string, error) { + tokenFile := f.tokenFile.String() + if tokenFile == "" { + return "", nil + } + + data, err := ioutil.ReadFile(tokenFile) + if err != nil { + return "", err + } + + return strings.TrimSpace(string(data)), nil +} + func (f *HTTPFlags) APIClient() (*api.Client, error) { c := api.DefaultConfig() @@ -99,6 +128,7 @@ func (f *HTTPFlags) APIClient() (*api.Client, error) { func (f *HTTPFlags) MergeOntoConfig(c *api.Config) { f.address.Merge(&c.Address) f.token.Merge(&c.Token) + f.tokenFile.Merge(&c.TokenFile) f.caFile.Merge(&c.TLSConfig.CAFile) f.caPath.Merge(&c.TLSConfig.CAPath) f.certFile.Merge(&c.TLSConfig.CertFile) diff --git a/command/login/login.go b/command/login/login.go new file mode 100644 index 0000000000..ada268cac8 --- /dev/null +++ b/command/login/login.go @@ -0,0 +1,148 @@ +package login + +import ( + "flag" + "fmt" + "io/ioutil" + "strings" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/flags" + "github.com/hashicorp/consul/lib/file" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + shutdownCh <-chan struct{} + + bearerToken string + + // flags + authMethodName string + bearerTokenFile string + tokenSinkFile string + meta map[string]string +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.StringVar(&c.authMethodName, "method", "", + "Name of the auth method to login to.") + + c.flags.StringVar(&c.bearerTokenFile, "bearer-token-file", "", + "Path to a file containing a secret bearer token to use with this auth method.") + + c.flags.StringVar(&c.tokenSinkFile, "token-sink-file", "", + "The most recent token's SecretID is kept up to date in this file.") + + c.flags.Var((*flags.FlagMapValue)(&c.meta), "meta", + "Metadata to set on the token, formatted as key=value. This flag "+ + "may be specified multiple times to set multiple meta fields.") + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + if len(c.flags.Args()) > 0 { + c.UI.Error(fmt.Sprintf("Should have no non-flag arguments.")) + return 1 + } + + if c.authMethodName == "" { + c.UI.Error(fmt.Sprintf("Missing required '-method' flag")) + return 1 + } + if c.tokenSinkFile == "" { + c.UI.Error(fmt.Sprintf("Missing required '-token-sink-file' flag")) + return 1 + } + + if c.bearerTokenFile == "" { + c.UI.Error(fmt.Sprintf("Missing required '-bearer-token-file' flag")) + return 1 + } + + data, err := ioutil.ReadFile(c.bearerTokenFile) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + c.bearerToken = strings.TrimSpace(string(data)) + + if c.bearerToken == "" { + c.UI.Error(fmt.Sprintf("No bearer token found in %s", c.bearerTokenFile)) + return 1 + } + + // Ensure that we don't try to use a token when performing a login + // operation. + c.http.SetToken("") + c.http.SetTokenFile("") + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + // Do the login. + req := &api.ACLLoginParams{ + AuthMethod: c.authMethodName, + BearerToken: c.bearerToken, + Meta: c.meta, + } + tok, _, err := client.ACL().Login(req, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error logging in: %s", err)) + return 1 + } + + if err := c.writeToSink(tok); err != nil { + c.UI.Error(fmt.Sprintf("Error writing token to file sink: %s", err)) + return 1 + } + + return 0 +} + +func (c *cmd) writeToSink(tok *api.ACLToken) error { + payload := []byte(tok.SecretID) + return file.WriteAtomicWithPerms(c.tokenSinkFile, payload, 0600) +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Login to Consul using an Auth Method" + +const help = ` +Usage: consul login [options] + + The login command will exchange the provided third party credentials with the + requested auth method for a newly minted Consul ACL Token. The companion + command 'consul logout' should be used to destroy any tokens created this way + to avoid a resource leak. +` diff --git a/command/login/login_test.go b/command/login/login_test.go new file mode 100644 index 0000000000..c2988d8626 --- /dev/null +++ b/command/login/login_test.go @@ -0,0 +1,321 @@ +package login + +import ( + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth" + "github.com/hashicorp/consul/agent/consul/authmethod/testauth" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" +) + +func TestLoginCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestLoginCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("method is required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-method' flag") + }) + + tokenSinkFile := filepath.Join(testDir, "test.token") + + t.Run("token-sink-file is required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-token-sink-file' flag") + }) + + t.Run("bearer-token-file is required", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bearer-token-file' flag") + }) + + bearerTokenFile := filepath.Join(testDir, "bearer.token") + + t.Run("bearer-token-file is empty", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte(""), 0600)) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", bearerTokenFile, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "No bearer token found in") + }) + + require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte("demo-token"), 0600)) + + t.Run("try login with no method configured", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", bearerTokenFile, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)") + }) + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + + testauth.InstallSessionToken( + testSessionID, + "demo-token", + "default", "demo", "76091af4-4b56-11e9-ac4b-708b11801cbe", + ) + + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + t.Run("try login with method configured but no binding rules", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", bearerTokenFile, + } + + code := cmd.Run(args) + require.Equal(t, 1, code, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (Permission denied)") + }) + + { + _, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{ + AuthMethod: "test", + BindType: api.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + t.Run("try login with method configured and binding rules", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", bearerTokenFile, + } + + code := cmd.Run(args) + require.Equal(t, 0, code, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + require.Empty(t, ui.OutputWriter.String()) + + raw, err := ioutil.ReadFile(tokenSinkFile) + require.NoError(t, err) + + token := strings.TrimSpace(string(raw)) + require.Len(t, token, 36, "must be a valid uid: %s", token) + }) +} + +func TestLoginCommand_k8s(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + tokenSinkFile := filepath.Join(testDir, "test.token") + bearerTokenFile := filepath.Join(testDir, "bearer.token") + + // the "B" jwt will be the one being reviewed + require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte(acl.TestKubernetesJWT_B), 0600)) + + // spin up a fake api server + testSrv := kubeauth.StartTestAPIServer(t) + defer testSrv.Stop() + + testSrv.AuthorizeJWT(acl.TestKubernetesJWT_A) + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "", + acl.TestKubernetesJWT_B, + ) + + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "k8s", + Type: "kubernetes", + Config: map[string]interface{}{ + "Host": testSrv.Addr(), + "CACert": testSrv.CACert(), + // the "A" jwt will be the one with token review privs + "ServiceAccountJWT": acl.TestKubernetesJWT_A, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + { + _, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{ + AuthMethod: "k8s", + BindType: api.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + t.Run("try login with method configured and binding rules", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=k8s", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", bearerTokenFile, + } + + code := cmd.Run(args) + require.Equal(t, 0, code, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + require.Empty(t, ui.OutputWriter.String()) + + raw, err := ioutil.ReadFile(tokenSinkFile) + require.NoError(t, err) + + token := strings.TrimSpace(string(raw)) + require.Len(t, token, 36, "must be a valid uid: %s", token) + }) +} diff --git a/command/logout/logout.go b/command/logout/logout.go new file mode 100644 index 0000000000..eca9c416be --- /dev/null +++ b/command/logout/logout.go @@ -0,0 +1,70 @@ +package logout + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + if len(c.flags.Args()) > 0 { + c.UI.Error(fmt.Sprintf("Should have no non-flag arguments.")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + if _, err := client.ACL().Logout(nil); err != nil { + c.UI.Error(fmt.Sprintf("Error destroying token: %v", err)) + return 1 + } + + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Destroy a Consul Token created with Login" + +const help = ` +Usage: consul logout [options] + + The logout command will destroy the provided token if it was created from + 'consul login'. +` diff --git a/command/logout/logout_test.go b/command/logout/logout_test.go new file mode 100644 index 0000000000..5596297b9c --- /dev/null +++ b/command/logout/logout_test.go @@ -0,0 +1,299 @@ +package logout + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth" + "github.com/hashicorp/consul/agent/consul/authmethod/testauth" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" +) + +func TestLogout_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestLogoutCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("no token specified", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)") + }) + + t.Run("logout of deleted token", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + fakeID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)") + }) + + plainToken, _, err := client.ACL().TokenCreate( + &api.ACLToken{Description: "test"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + t.Run("logout of ordinary token", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + plainToken.SecretID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (Permission denied)") + }) + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + + testauth.InstallSessionToken( + testSessionID, + "demo-token", + "default", "demo", "76091af4-4b56-11e9-ac4b-708b11801cbe", + ) + + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + { + _, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{ + AuthMethod: "test", + BindType: api.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + var loginTokenSecret string + { + tok, _, err := client.ACL().Login(&api.ACLLoginParams{ + AuthMethod: "test", + BearerToken: "demo-token", + }, nil) + require.NoError(t, err) + + loginTokenSecret = tok.SecretID + } + + t.Run("logout of login token", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + loginTokenSecret, + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + }) +} + +func TestLogoutCommand_k8s(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("no token specified", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)") + }) + + t.Run("logout of deleted token", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + fakeID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)") + }) + + plainToken, _, err := client.ACL().TokenCreate( + &api.ACLToken{Description: "test"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + t.Run("logout of ordinary token", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + plainToken.SecretID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (Permission denied)") + }) + + // go to the trouble of creating a login token + // require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte(acl.TestKubernetesJWT_B), 0600)) + + // spin up a fake api server + testSrv := kubeauth.StartTestAPIServer(t) + defer testSrv.Stop() + + testSrv.AuthorizeJWT(acl.TestKubernetesJWT_A) + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "", + acl.TestKubernetesJWT_B, + ) + + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "k8s", + Type: "kubernetes", + Config: map[string]interface{}{ + "Host": testSrv.Addr(), + "CACert": testSrv.CACert(), + // the "A" jwt will be the one with token review privs + "ServiceAccountJWT": acl.TestKubernetesJWT_A, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + { + _, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{ + AuthMethod: "k8s", + BindType: api.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + var loginTokenSecret string + { + tok, _, err := client.ACL().Login(&api.ACLLoginParams{ + AuthMethod: "k8s", + BearerToken: acl.TestKubernetesJWT_B, + }, nil) + require.NoError(t, err) + + loginTokenSecret = tok.SecretID + } + + t.Run("logout of login token", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + loginTokenSecret, + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + }) +} diff --git a/command/watch/watch.go b/command/watch/watch.go index fd44e81fbb..92178a1f06 100644 --- a/command/watch/watch.go +++ b/command/watch/watch.go @@ -87,6 +87,14 @@ func (c *cmd) Run(args []string) int { return 1 } + token := c.http.Token() + if tokenFromFile, err := c.http.ReadTokenFile(); err != nil { + c.UI.Error(fmt.Sprintf("Error loading token file: %s", err)) + return 1 + } else if tokenFromFile != "" { + token = tokenFromFile + } + // Compile the watch parameters params := make(map[string]interface{}) if c.watchType != "" { @@ -95,8 +103,8 @@ func (c *cmd) Run(args []string) int { if c.http.Datacenter() != "" { params["datacenter"] = c.http.Datacenter() } - if c.http.Token() != "" { - params["token"] = c.http.Token() + if token != "" { + params["token"] = token } if c.key != "" { params["key"] = c.key diff --git a/go.mod b/go.mod index 4d87ae9f00..642f2d47d7 100644 --- a/go.mod +++ b/go.mod @@ -123,7 +123,9 @@ require ( gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d // indirect gopkg.in/mgo.v2 v2.0.0-20160818020120-3f83fa500528 // indirect gopkg.in/ory-am/dockertest.v3 v3.3.4 // indirect + gopkg.in/square/go-jose.v2 v2.3.1 gotest.tools v2.2.0+incompatible // indirect - k8s.io/api v0.0.0-20190118113203-912cbe2bfef3 // indirect - k8s.io/apimachinery v0.0.0-20180904193909-def12e63c512 // indirect + k8s.io/api v0.0.0-20190325185214-7544f9db76f6 + k8s.io/apimachinery v0.0.0-20190223001710-c182ff3b9841 + k8s.io/client-go v8.0.0+incompatible ) diff --git a/go.sum b/go.sum index 8bba6865b0..649673f1f2 100644 --- a/go.sum +++ b/go.sum @@ -383,6 +383,8 @@ gopkg.in/mgo.v2 v2.0.0-20160818020120-3f83fa500528 h1:/saqWwm73dLmuzbNhe92F0QsZ/ gopkg.in/mgo.v2 v2.0.0-20160818020120-3f83fa500528/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= gopkg.in/ory-am/dockertest.v3 v3.3.4 h1:oen8RiwxVNxtQ1pRoV4e4jqh6UjNsOuIZ1NXns6jdcw= gopkg.in/ory-am/dockertest.v3 v3.3.4/go.mod h1:s9mmoLkaGeAh97qygnNj4xWkiN7e1SKekYC6CovU+ek= +gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4= +gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= @@ -391,10 +393,10 @@ gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= k8s.io/api v0.0.0-20180806132203-61b11ee65332/go.mod h1:iuAfoD4hCxJ8Onx9kaTIt30j7jUFS00AXQi6QMi99vA= -k8s.io/api v0.0.0-20190118113203-912cbe2bfef3 h1:lV0+KGoNkvZOt4zGT4H83hQrzWMt/US/LSz4z4+BQS4= -k8s.io/api v0.0.0-20190118113203-912cbe2bfef3/go.mod h1:iuAfoD4hCxJ8Onx9kaTIt30j7jUFS00AXQi6QMi99vA= +k8s.io/api v0.0.0-20190325185214-7544f9db76f6 h1:9MWtbqhwTyDvF4cS1qAhxDb9Mi8taXiAu+5nEacl7gY= +k8s.io/api v0.0.0-20190325185214-7544f9db76f6/go.mod h1:iuAfoD4hCxJ8Onx9kaTIt30j7jUFS00AXQi6QMi99vA= k8s.io/apimachinery v0.0.0-20180821005732-488889b0007f/go.mod h1:ccL7Eh7zubPUSh9A3USN90/OzHNSVN6zxzde07TDCL0= -k8s.io/apimachinery v0.0.0-20180904193909-def12e63c512 h1:/Z1m/6oEN6hE2SzWP4BHW2yATeUrBRr+1GxNf1Ny58Y= -k8s.io/apimachinery v0.0.0-20180904193909-def12e63c512/go.mod h1:ccL7Eh7zubPUSh9A3USN90/OzHNSVN6zxzde07TDCL0= +k8s.io/apimachinery v0.0.0-20190223001710-c182ff3b9841 h1:Q4RZrHNtlC/mSdC1sTrcZ5RchC/9vxLVj57pWiCBKv4= +k8s.io/apimachinery v0.0.0-20190223001710-c182ff3b9841/go.mod h1:ccL7Eh7zubPUSh9A3USN90/OzHNSVN6zxzde07TDCL0= k8s.io/client-go v8.0.0+incompatible h1:tTI4hRmb1DRMl4fG6Vclfdi6nTM82oIrTT7HfitmxC4= k8s.io/client-go v8.0.0+incompatible/go.mod h1:7vJpHMYJwNQCWgzmNV+VYUl1zCObLyodBc8nIyt8L5s= diff --git a/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go b/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go new file mode 100644 index 0000000000..593f653008 --- /dev/null +++ b/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go @@ -0,0 +1,77 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package pbkdf2 implements the key derivation function PBKDF2 as defined in RFC +2898 / PKCS #5 v2.0. + +A key derivation function is useful when encrypting data based on a password +or any other not-fully-random data. It uses a pseudorandom function to derive +a secure encryption key based on the password. + +While v2.0 of the standard defines only one pseudorandom function to use, +HMAC-SHA1, the drafted v2.1 specification allows use of all five FIPS Approved +Hash Functions SHA-1, SHA-224, SHA-256, SHA-384 and SHA-512 for HMAC. To +choose, you can pass the `New` functions from the different SHA packages to +pbkdf2.Key. +*/ +package pbkdf2 // import "golang.org/x/crypto/pbkdf2" + +import ( + "crypto/hmac" + "hash" +) + +// Key derives a key from the password, salt and iteration count, returning a +// []byte of length keylen that can be used as cryptographic key. The key is +// derived based on the method described as PBKDF2 with the HMAC variant using +// the supplied hash function. +// +// For example, to use a HMAC-SHA-1 based PBKDF2 key derivation function, you +// can get a derived key for e.g. AES-256 (which needs a 32-byte key) by +// doing: +// +// dk := pbkdf2.Key([]byte("some password"), salt, 4096, 32, sha1.New) +// +// Remember to get a good random salt. At least 8 bytes is recommended by the +// RFC. +// +// Using a higher iteration count will increase the cost of an exhaustive +// search but will also make derivation proportionally slower. +func Key(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte { + prf := hmac.New(h, password) + hashLen := prf.Size() + numBlocks := (keyLen + hashLen - 1) / hashLen + + var buf [4]byte + dk := make([]byte, 0, numBlocks*hashLen) + U := make([]byte, hashLen) + for block := 1; block <= numBlocks; block++ { + // N.B.: || means concatenation, ^ means XOR + // for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter + // U_1 = PRF(password, salt || uint(i)) + prf.Reset() + prf.Write(salt) + buf[0] = byte(block >> 24) + buf[1] = byte(block >> 16) + buf[2] = byte(block >> 8) + buf[3] = byte(block) + prf.Write(buf[:4]) + dk = prf.Sum(dk) + T := dk[len(dk)-hashLen:] + copy(U, T) + + // U_n = PRF(password, U_(n-1)) + for n := 2; n <= iter; n++ { + prf.Reset() + prf.Write(U) + U = U[:0] + U = prf.Sum(U) + for x := range U { + T[x] ^= U[x] + } + } + } + return dk[:keyLen] +} diff --git a/vendor/gopkg.in/square/go-jose.v2/.gitcookies.sh.enc b/vendor/gopkg.in/square/go-jose.v2/.gitcookies.sh.enc new file mode 100644 index 0000000000..730e569b06 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/.gitcookies.sh.enc @@ -0,0 +1 @@ +'|Ê&{tÄU|gGê(ìCy=+¨œòcû:u:/pœ#~žü["±4¤!­nÙAªDK<ŠufÿhÅa¿Â:ºü¸¡´B/£Ø¤¹¤ò_hÎÛSãT*wÌx¼¯¹-ç|àÀÓƒÑÄäóÌ㣗A$$â6£ÁâG)8nÏpûÆË¡3ÌšœoïÏvŽB–3¿­]xÝ“Ó2l§G•|qRÞ¯ ö2 5R–Ó×Ç$´ñ½Yè¡ÞÝ™l‘Ë«yAI"ÛŒ˜®íû¹¼kÄ|Kåþ[9ÆâÒå=°úÿŸñ|@S•3 ó#æx?¾V„,¾‚SÆÝõœwPíogÒ6&V6 ©D.dBŠ 7 \ No newline at end of file diff --git a/vendor/gopkg.in/square/go-jose.v2/.gitignore b/vendor/gopkg.in/square/go-jose.v2/.gitignore new file mode 100644 index 0000000000..5b4d73b681 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/.gitignore @@ -0,0 +1,7 @@ +*~ +.*.swp +*.out +*.test +*.pem +*.cov +jose-util/jose-util diff --git a/vendor/gopkg.in/square/go-jose.v2/.travis.yml b/vendor/gopkg.in/square/go-jose.v2/.travis.yml new file mode 100644 index 0000000000..fc501ca9b7 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/.travis.yml @@ -0,0 +1,46 @@ +language: go + +sudo: false + +matrix: + fast_finish: true + allow_failures: + - go: tip + +go: +- '1.7.x' +- '1.8.x' +- '1.9.x' +- '1.10.x' +- '1.11.x' + +go_import_path: gopkg.in/square/go-jose.v2 + +before_script: +- export PATH=$HOME/.local/bin:$PATH + +before_install: +# Install encrypted gitcookies to get around bandwidth-limits +# that is causing Travis-CI builds to fail. For more info, see +# https://github.com/golang/go/issues/12933 +- openssl aes-256-cbc -K $encrypted_1528c3c2cafd_key -iv $encrypted_1528c3c2cafd_iv -in .gitcookies.sh.enc -out .gitcookies.sh -d || true +- bash .gitcookies.sh || true +- go get github.com/wadey/gocovmerge +- go get github.com/mattn/goveralls +- go get github.com/stretchr/testify/assert +- go get golang.org/x/tools/cmd/cover || true +- go get code.google.com/p/go.tools/cmd/cover || true +- pip install cram --user + +script: +- go test . -v -covermode=count -coverprofile=profile.cov +- go test ./cipher -v -covermode=count -coverprofile=cipher/profile.cov +- go test ./jwt -v -covermode=count -coverprofile=jwt/profile.cov +- go test ./json -v # no coverage for forked encoding/json package +- cd jose-util && go build && PATH=$PWD:$PATH cram -v jose-util.t +- cd .. + +after_success: +- gocovmerge *.cov */*.cov > merged.coverprofile +- $HOME/gopath/bin/goveralls -coverprofile merged.coverprofile -service=travis-ci + diff --git a/vendor/gopkg.in/square/go-jose.v2/BUG-BOUNTY.md b/vendor/gopkg.in/square/go-jose.v2/BUG-BOUNTY.md new file mode 100644 index 0000000000..3305db0f65 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/BUG-BOUNTY.md @@ -0,0 +1,10 @@ +Serious about security +====================== + +Square recognizes the important contributions the security research community +can make. We therefore encourage reporting security issues with the code +contained in this repository. + +If you believe you have discovered a security vulnerability, please follow the +guidelines at . + diff --git a/vendor/gopkg.in/square/go-jose.v2/CONTRIBUTING.md b/vendor/gopkg.in/square/go-jose.v2/CONTRIBUTING.md new file mode 100644 index 0000000000..61b183651c --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/CONTRIBUTING.md @@ -0,0 +1,14 @@ +# Contributing + +If you would like to contribute code to go-jose you can do so through GitHub by +forking the repository and sending a pull request. + +When submitting code, please make every effort to follow existing conventions +and style in order to keep the code as readable as possible. Please also make +sure all tests pass by running `go test`, and format your code with `go fmt`. +We also recommend using `golint` and `errcheck`. + +Before your code can be accepted into the project you must also sign the +[Individual Contributor License Agreement][1]. + + [1]: https://spreadsheets.google.com/spreadsheet/viewform?formkey=dDViT2xzUHAwRkI3X3k5Z0lQM091OGc6MQ&ndplr=1 diff --git a/vendor/gopkg.in/square/go-jose.v2/LICENSE b/vendor/gopkg.in/square/go-jose.v2/LICENSE new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/gopkg.in/square/go-jose.v2/README.md b/vendor/gopkg.in/square/go-jose.v2/README.md new file mode 100644 index 0000000000..1791bfa8f6 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/README.md @@ -0,0 +1,118 @@ +# Go JOSE + +[![godoc](http://img.shields.io/badge/godoc-version_1-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v1) +[![godoc](http://img.shields.io/badge/godoc-version_2-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v2) +[![license](http://img.shields.io/badge/license-apache_2.0-blue.svg?style=flat)](https://raw.githubusercontent.com/square/go-jose/master/LICENSE) +[![build](https://travis-ci.org/square/go-jose.svg?branch=v2)](https://travis-ci.org/square/go-jose) +[![coverage](https://coveralls.io/repos/github/square/go-jose/badge.svg?branch=v2)](https://coveralls.io/r/square/go-jose) + +Package jose aims to provide an implementation of the Javascript Object Signing +and Encryption set of standards. This includes support for JSON Web Encryption, +JSON Web Signature, and JSON Web Token standards. + +**Disclaimer**: This library contains encryption software that is subject to +the U.S. Export Administration Regulations. You may not export, re-export, +transfer or download this code or any part of it in violation of any United +States law, directive or regulation. In particular this software may not be +exported or re-exported in any form or on any media to Iran, North Sudan, +Syria, Cuba, or North Korea, or to denied persons or entities mentioned on any +US maintained blocked list. + +## Overview + +The implementation follows the +[JSON Web Encryption](http://dx.doi.org/10.17487/RFC7516) (RFC 7516), +[JSON Web Signature](http://dx.doi.org/10.17487/RFC7515) (RFC 7515), and +[JSON Web Token](http://dx.doi.org/10.17487/RFC7519) (RFC 7519). +Tables of supported algorithms are shown below. The library supports both +the compact and full serialization formats, and has optional support for +multiple recipients. It also comes with a small command-line utility +([`jose-util`](https://github.com/square/go-jose/tree/v2/jose-util)) +for dealing with JOSE messages in a shell. + +**Note**: We use a forked version of the `encoding/json` package from the Go +standard library which uses case-sensitive matching for member names (instead +of [case-insensitive matching](https://www.ietf.org/mail-archive/web/json/current/msg03763.html)). +This is to avoid differences in interpretation of messages between go-jose and +libraries in other languages. + +### Versions + +We use [gopkg.in](https://gopkg.in) for versioning. + +[Version 2](https://gopkg.in/square/go-jose.v2) +([branch](https://github.com/square/go-jose/tree/v2), +[doc](https://godoc.org/gopkg.in/square/go-jose.v2)) is the current version: + + import "gopkg.in/square/go-jose.v2" + +The old `v1` branch ([go-jose.v1](https://gopkg.in/square/go-jose.v1)) will +still receive backported bug fixes and security fixes, but otherwise +development is frozen. All new feature development takes place on the `v2` +branch. Version 2 also contains additional sub-packages such as the +[jwt](https://godoc.org/gopkg.in/square/go-jose.v2/jwt) implementation +contributed by [@shaxbee](https://github.com/shaxbee). + +### Supported algorithms + +See below for a table of supported algorithms. Algorithm identifiers match +the names in the [JSON Web Algorithms](http://dx.doi.org/10.17487/RFC7518) +standard where possible. The Godoc reference has a list of constants. + + Key encryption | Algorithm identifier(s) + :------------------------- | :------------------------------ + RSA-PKCS#1v1.5 | RSA1_5 + RSA-OAEP | RSA-OAEP, RSA-OAEP-256 + AES key wrap | A128KW, A192KW, A256KW + AES-GCM key wrap | A128GCMKW, A192GCMKW, A256GCMKW + ECDH-ES + AES key wrap | ECDH-ES+A128KW, ECDH-ES+A192KW, ECDH-ES+A256KW + ECDH-ES (direct) | ECDH-ES1 + Direct encryption | dir1 + +1. Not supported in multi-recipient mode + + Signing / MAC | Algorithm identifier(s) + :------------------------- | :------------------------------ + RSASSA-PKCS#1v1.5 | RS256, RS384, RS512 + RSASSA-PSS | PS256, PS384, PS512 + HMAC | HS256, HS384, HS512 + ECDSA | ES256, ES384, ES512 + Ed25519 | EdDSA2 + +2. Only available in version 2 of the package + + Content encryption | Algorithm identifier(s) + :------------------------- | :------------------------------ + AES-CBC+HMAC | A128CBC-HS256, A192CBC-HS384, A256CBC-HS512 + AES-GCM | A128GCM, A192GCM, A256GCM + + Compression | Algorithm identifiers(s) + :------------------------- | ------------------------------- + DEFLATE (RFC 1951) | DEF + +### Supported key types + +See below for a table of supported key types. These are understood by the +library, and can be passed to corresponding functions such as `NewEncrypter` or +`NewSigner`. Each of these keys can also be wrapped in a JWK if desired, which +allows attaching a key id. + + Algorithm(s) | Corresponding types + :------------------------- | ------------------------------- + RSA | *[rsa.PublicKey](http://golang.org/pkg/crypto/rsa/#PublicKey), *[rsa.PrivateKey](http://golang.org/pkg/crypto/rsa/#PrivateKey) + ECDH, ECDSA | *[ecdsa.PublicKey](http://golang.org/pkg/crypto/ecdsa/#PublicKey), *[ecdsa.PrivateKey](http://golang.org/pkg/crypto/ecdsa/#PrivateKey) + EdDSA1 | [ed25519.PublicKey](https://godoc.org/golang.org/x/crypto/ed25519#PublicKey), [ed25519.PrivateKey](https://godoc.org/golang.org/x/crypto/ed25519#PrivateKey) + AES, HMAC | []byte + +1. Only available in version 2 of the package + +## Examples + +[![godoc](http://img.shields.io/badge/godoc-version_1-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v1) +[![godoc](http://img.shields.io/badge/godoc-version_2-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v2) + +Examples can be found in the Godoc +reference for this package. The +[`jose-util`](https://github.com/square/go-jose/tree/v2/jose-util) +subdirectory also contains a small command-line utility which might be useful +as an example. diff --git a/vendor/gopkg.in/square/go-jose.v2/asymmetric.go b/vendor/gopkg.in/square/go-jose.v2/asymmetric.go new file mode 100644 index 0000000000..67935561bc --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/asymmetric.go @@ -0,0 +1,592 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package jose + +import ( + "crypto" + "crypto/aes" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "errors" + "fmt" + "math/big" + + "golang.org/x/crypto/ed25519" + "gopkg.in/square/go-jose.v2/cipher" + "gopkg.in/square/go-jose.v2/json" +) + +// A generic RSA-based encrypter/verifier +type rsaEncrypterVerifier struct { + publicKey *rsa.PublicKey +} + +// A generic RSA-based decrypter/signer +type rsaDecrypterSigner struct { + privateKey *rsa.PrivateKey +} + +// A generic EC-based encrypter/verifier +type ecEncrypterVerifier struct { + publicKey *ecdsa.PublicKey +} + +type edEncrypterVerifier struct { + publicKey ed25519.PublicKey +} + +// A key generator for ECDH-ES +type ecKeyGenerator struct { + size int + algID string + publicKey *ecdsa.PublicKey +} + +// A generic EC-based decrypter/signer +type ecDecrypterSigner struct { + privateKey *ecdsa.PrivateKey +} + +type edDecrypterSigner struct { + privateKey ed25519.PrivateKey +} + +// newRSARecipient creates recipientKeyInfo based on the given key. +func newRSARecipient(keyAlg KeyAlgorithm, publicKey *rsa.PublicKey) (recipientKeyInfo, error) { + // Verify that key management algorithm is supported by this encrypter + switch keyAlg { + case RSA1_5, RSA_OAEP, RSA_OAEP_256: + default: + return recipientKeyInfo{}, ErrUnsupportedAlgorithm + } + + if publicKey == nil { + return recipientKeyInfo{}, errors.New("invalid public key") + } + + return recipientKeyInfo{ + keyAlg: keyAlg, + keyEncrypter: &rsaEncrypterVerifier{ + publicKey: publicKey, + }, + }, nil +} + +// newRSASigner creates a recipientSigInfo based on the given key. +func newRSASigner(sigAlg SignatureAlgorithm, privateKey *rsa.PrivateKey) (recipientSigInfo, error) { + // Verify that key management algorithm is supported by this encrypter + switch sigAlg { + case RS256, RS384, RS512, PS256, PS384, PS512: + default: + return recipientSigInfo{}, ErrUnsupportedAlgorithm + } + + if privateKey == nil { + return recipientSigInfo{}, errors.New("invalid private key") + } + + return recipientSigInfo{ + sigAlg: sigAlg, + publicKey: staticPublicKey(&JSONWebKey{ + Key: privateKey.Public(), + }), + signer: &rsaDecrypterSigner{ + privateKey: privateKey, + }, + }, nil +} + +func newEd25519Signer(sigAlg SignatureAlgorithm, privateKey ed25519.PrivateKey) (recipientSigInfo, error) { + if sigAlg != EdDSA { + return recipientSigInfo{}, ErrUnsupportedAlgorithm + } + + if privateKey == nil { + return recipientSigInfo{}, errors.New("invalid private key") + } + return recipientSigInfo{ + sigAlg: sigAlg, + publicKey: staticPublicKey(&JSONWebKey{ + Key: privateKey.Public(), + }), + signer: &edDecrypterSigner{ + privateKey: privateKey, + }, + }, nil +} + +// newECDHRecipient creates recipientKeyInfo based on the given key. +func newECDHRecipient(keyAlg KeyAlgorithm, publicKey *ecdsa.PublicKey) (recipientKeyInfo, error) { + // Verify that key management algorithm is supported by this encrypter + switch keyAlg { + case ECDH_ES, ECDH_ES_A128KW, ECDH_ES_A192KW, ECDH_ES_A256KW: + default: + return recipientKeyInfo{}, ErrUnsupportedAlgorithm + } + + if publicKey == nil || !publicKey.Curve.IsOnCurve(publicKey.X, publicKey.Y) { + return recipientKeyInfo{}, errors.New("invalid public key") + } + + return recipientKeyInfo{ + keyAlg: keyAlg, + keyEncrypter: &ecEncrypterVerifier{ + publicKey: publicKey, + }, + }, nil +} + +// newECDSASigner creates a recipientSigInfo based on the given key. +func newECDSASigner(sigAlg SignatureAlgorithm, privateKey *ecdsa.PrivateKey) (recipientSigInfo, error) { + // Verify that key management algorithm is supported by this encrypter + switch sigAlg { + case ES256, ES384, ES512: + default: + return recipientSigInfo{}, ErrUnsupportedAlgorithm + } + + if privateKey == nil { + return recipientSigInfo{}, errors.New("invalid private key") + } + + return recipientSigInfo{ + sigAlg: sigAlg, + publicKey: staticPublicKey(&JSONWebKey{ + Key: privateKey.Public(), + }), + signer: &ecDecrypterSigner{ + privateKey: privateKey, + }, + }, nil +} + +// Encrypt the given payload and update the object. +func (ctx rsaEncrypterVerifier) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) { + encryptedKey, err := ctx.encrypt(cek, alg) + if err != nil { + return recipientInfo{}, err + } + + return recipientInfo{ + encryptedKey: encryptedKey, + header: &rawHeader{}, + }, nil +} + +// Encrypt the given payload. Based on the key encryption algorithm, +// this will either use RSA-PKCS1v1.5 or RSA-OAEP (with SHA-1 or SHA-256). +func (ctx rsaEncrypterVerifier) encrypt(cek []byte, alg KeyAlgorithm) ([]byte, error) { + switch alg { + case RSA1_5: + return rsa.EncryptPKCS1v15(RandReader, ctx.publicKey, cek) + case RSA_OAEP: + return rsa.EncryptOAEP(sha1.New(), RandReader, ctx.publicKey, cek, []byte{}) + case RSA_OAEP_256: + return rsa.EncryptOAEP(sha256.New(), RandReader, ctx.publicKey, cek, []byte{}) + } + + return nil, ErrUnsupportedAlgorithm +} + +// Decrypt the given payload and return the content encryption key. +func (ctx rsaDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) { + return ctx.decrypt(recipient.encryptedKey, headers.getAlgorithm(), generator) +} + +// Decrypt the given payload. Based on the key encryption algorithm, +// this will either use RSA-PKCS1v1.5 or RSA-OAEP (with SHA-1 or SHA-256). +func (ctx rsaDecrypterSigner) decrypt(jek []byte, alg KeyAlgorithm, generator keyGenerator) ([]byte, error) { + // Note: The random reader on decrypt operations is only used for blinding, + // so stubbing is meanlingless (hence the direct use of rand.Reader). + switch alg { + case RSA1_5: + defer func() { + // DecryptPKCS1v15SessionKey sometimes panics on an invalid payload + // because of an index out of bounds error, which we want to ignore. + // This has been fixed in Go 1.3.1 (released 2014/08/13), the recover() + // only exists for preventing crashes with unpatched versions. + // See: https://groups.google.com/forum/#!topic/golang-dev/7ihX6Y6kx9k + // See: https://code.google.com/p/go/source/detail?r=58ee390ff31602edb66af41ed10901ec95904d33 + _ = recover() + }() + + // Perform some input validation. + keyBytes := ctx.privateKey.PublicKey.N.BitLen() / 8 + if keyBytes != len(jek) { + // Input size is incorrect, the encrypted payload should always match + // the size of the public modulus (e.g. using a 2048 bit key will + // produce 256 bytes of output). Reject this since it's invalid input. + return nil, ErrCryptoFailure + } + + cek, _, err := generator.genKey() + if err != nil { + return nil, ErrCryptoFailure + } + + // When decrypting an RSA-PKCS1v1.5 payload, we must take precautions to + // prevent chosen-ciphertext attacks as described in RFC 3218, "Preventing + // the Million Message Attack on Cryptographic Message Syntax". We are + // therefore deliberately ignoring errors here. + _ = rsa.DecryptPKCS1v15SessionKey(rand.Reader, ctx.privateKey, jek, cek) + + return cek, nil + case RSA_OAEP: + // Use rand.Reader for RSA blinding + return rsa.DecryptOAEP(sha1.New(), rand.Reader, ctx.privateKey, jek, []byte{}) + case RSA_OAEP_256: + // Use rand.Reader for RSA blinding + return rsa.DecryptOAEP(sha256.New(), rand.Reader, ctx.privateKey, jek, []byte{}) + } + + return nil, ErrUnsupportedAlgorithm +} + +// Sign the given payload +func (ctx rsaDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) { + var hash crypto.Hash + + switch alg { + case RS256, PS256: + hash = crypto.SHA256 + case RS384, PS384: + hash = crypto.SHA384 + case RS512, PS512: + hash = crypto.SHA512 + default: + return Signature{}, ErrUnsupportedAlgorithm + } + + hasher := hash.New() + + // According to documentation, Write() on hash never fails + _, _ = hasher.Write(payload) + hashed := hasher.Sum(nil) + + var out []byte + var err error + + switch alg { + case RS256, RS384, RS512: + out, err = rsa.SignPKCS1v15(RandReader, ctx.privateKey, hash, hashed) + case PS256, PS384, PS512: + out, err = rsa.SignPSS(RandReader, ctx.privateKey, hash, hashed, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthAuto, + }) + } + + if err != nil { + return Signature{}, err + } + + return Signature{ + Signature: out, + protected: &rawHeader{}, + }, nil +} + +// Verify the given payload +func (ctx rsaEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error { + var hash crypto.Hash + + switch alg { + case RS256, PS256: + hash = crypto.SHA256 + case RS384, PS384: + hash = crypto.SHA384 + case RS512, PS512: + hash = crypto.SHA512 + default: + return ErrUnsupportedAlgorithm + } + + hasher := hash.New() + + // According to documentation, Write() on hash never fails + _, _ = hasher.Write(payload) + hashed := hasher.Sum(nil) + + switch alg { + case RS256, RS384, RS512: + return rsa.VerifyPKCS1v15(ctx.publicKey, hash, hashed, signature) + case PS256, PS384, PS512: + return rsa.VerifyPSS(ctx.publicKey, hash, hashed, signature, nil) + } + + return ErrUnsupportedAlgorithm +} + +// Encrypt the given payload and update the object. +func (ctx ecEncrypterVerifier) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) { + switch alg { + case ECDH_ES: + // ECDH-ES mode doesn't wrap a key, the shared secret is used directly as the key. + return recipientInfo{ + header: &rawHeader{}, + }, nil + case ECDH_ES_A128KW, ECDH_ES_A192KW, ECDH_ES_A256KW: + default: + return recipientInfo{}, ErrUnsupportedAlgorithm + } + + generator := ecKeyGenerator{ + algID: string(alg), + publicKey: ctx.publicKey, + } + + switch alg { + case ECDH_ES_A128KW: + generator.size = 16 + case ECDH_ES_A192KW: + generator.size = 24 + case ECDH_ES_A256KW: + generator.size = 32 + } + + kek, header, err := generator.genKey() + if err != nil { + return recipientInfo{}, err + } + + block, err := aes.NewCipher(kek) + if err != nil { + return recipientInfo{}, err + } + + jek, err := josecipher.KeyWrap(block, cek) + if err != nil { + return recipientInfo{}, err + } + + return recipientInfo{ + encryptedKey: jek, + header: &header, + }, nil +} + +// Get key size for EC key generator +func (ctx ecKeyGenerator) keySize() int { + return ctx.size +} + +// Get a content encryption key for ECDH-ES +func (ctx ecKeyGenerator) genKey() ([]byte, rawHeader, error) { + priv, err := ecdsa.GenerateKey(ctx.publicKey.Curve, RandReader) + if err != nil { + return nil, rawHeader{}, err + } + + out := josecipher.DeriveECDHES(ctx.algID, []byte{}, []byte{}, priv, ctx.publicKey, ctx.size) + + b, err := json.Marshal(&JSONWebKey{ + Key: &priv.PublicKey, + }) + if err != nil { + return nil, nil, err + } + + headers := rawHeader{ + headerEPK: makeRawMessage(b), + } + + return out, headers, nil +} + +// Decrypt the given payload and return the content encryption key. +func (ctx ecDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) { + epk, err := headers.getEPK() + if err != nil { + return nil, errors.New("square/go-jose: invalid epk header") + } + if epk == nil { + return nil, errors.New("square/go-jose: missing epk header") + } + + publicKey, ok := epk.Key.(*ecdsa.PublicKey) + if publicKey == nil || !ok { + return nil, errors.New("square/go-jose: invalid epk header") + } + + if !ctx.privateKey.Curve.IsOnCurve(publicKey.X, publicKey.Y) { + return nil, errors.New("square/go-jose: invalid public key in epk header") + } + + apuData, err := headers.getAPU() + if err != nil { + return nil, errors.New("square/go-jose: invalid apu header") + } + apvData, err := headers.getAPV() + if err != nil { + return nil, errors.New("square/go-jose: invalid apv header") + } + + deriveKey := func(algID string, size int) []byte { + return josecipher.DeriveECDHES(algID, apuData.bytes(), apvData.bytes(), ctx.privateKey, publicKey, size) + } + + var keySize int + + algorithm := headers.getAlgorithm() + switch algorithm { + case ECDH_ES: + // ECDH-ES uses direct key agreement, no key unwrapping necessary. + return deriveKey(string(headers.getEncryption()), generator.keySize()), nil + case ECDH_ES_A128KW: + keySize = 16 + case ECDH_ES_A192KW: + keySize = 24 + case ECDH_ES_A256KW: + keySize = 32 + default: + return nil, ErrUnsupportedAlgorithm + } + + key := deriveKey(string(algorithm), keySize) + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + return josecipher.KeyUnwrap(block, recipient.encryptedKey) +} + +func (ctx edDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) { + if alg != EdDSA { + return Signature{}, ErrUnsupportedAlgorithm + } + + sig, err := ctx.privateKey.Sign(RandReader, payload, crypto.Hash(0)) + if err != nil { + return Signature{}, err + } + + return Signature{ + Signature: sig, + protected: &rawHeader{}, + }, nil +} + +func (ctx edEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error { + if alg != EdDSA { + return ErrUnsupportedAlgorithm + } + ok := ed25519.Verify(ctx.publicKey, payload, signature) + if !ok { + return errors.New("square/go-jose: ed25519 signature failed to verify") + } + return nil +} + +// Sign the given payload +func (ctx ecDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) { + var expectedBitSize int + var hash crypto.Hash + + switch alg { + case ES256: + expectedBitSize = 256 + hash = crypto.SHA256 + case ES384: + expectedBitSize = 384 + hash = crypto.SHA384 + case ES512: + expectedBitSize = 521 + hash = crypto.SHA512 + } + + curveBits := ctx.privateKey.Curve.Params().BitSize + if expectedBitSize != curveBits { + return Signature{}, fmt.Errorf("square/go-jose: expected %d bit key, got %d bits instead", expectedBitSize, curveBits) + } + + hasher := hash.New() + + // According to documentation, Write() on hash never fails + _, _ = hasher.Write(payload) + hashed := hasher.Sum(nil) + + r, s, err := ecdsa.Sign(RandReader, ctx.privateKey, hashed) + if err != nil { + return Signature{}, err + } + + keyBytes := curveBits / 8 + if curveBits%8 > 0 { + keyBytes++ + } + + // We serialize the outputs (r and s) into big-endian byte arrays and pad + // them with zeros on the left to make sure the sizes work out. Both arrays + // must be keyBytes long, and the output must be 2*keyBytes long. + rBytes := r.Bytes() + rBytesPadded := make([]byte, keyBytes) + copy(rBytesPadded[keyBytes-len(rBytes):], rBytes) + + sBytes := s.Bytes() + sBytesPadded := make([]byte, keyBytes) + copy(sBytesPadded[keyBytes-len(sBytes):], sBytes) + + out := append(rBytesPadded, sBytesPadded...) + + return Signature{ + Signature: out, + protected: &rawHeader{}, + }, nil +} + +// Verify the given payload +func (ctx ecEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error { + var keySize int + var hash crypto.Hash + + switch alg { + case ES256: + keySize = 32 + hash = crypto.SHA256 + case ES384: + keySize = 48 + hash = crypto.SHA384 + case ES512: + keySize = 66 + hash = crypto.SHA512 + default: + return ErrUnsupportedAlgorithm + } + + if len(signature) != 2*keySize { + return fmt.Errorf("square/go-jose: invalid signature size, have %d bytes, wanted %d", len(signature), 2*keySize) + } + + hasher := hash.New() + + // According to documentation, Write() on hash never fails + _, _ = hasher.Write(payload) + hashed := hasher.Sum(nil) + + r := big.NewInt(0).SetBytes(signature[:keySize]) + s := big.NewInt(0).SetBytes(signature[keySize:]) + + match := ecdsa.Verify(ctx.publicKey, hashed, r, s) + if !match { + return errors.New("square/go-jose: ecdsa signature failed to verify") + } + + return nil +} diff --git a/vendor/gopkg.in/square/go-jose.v2/cipher/cbc_hmac.go b/vendor/gopkg.in/square/go-jose.v2/cipher/cbc_hmac.go new file mode 100644 index 0000000000..126b85ce25 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/cipher/cbc_hmac.go @@ -0,0 +1,196 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package josecipher + +import ( + "bytes" + "crypto/cipher" + "crypto/hmac" + "crypto/sha256" + "crypto/sha512" + "crypto/subtle" + "encoding/binary" + "errors" + "hash" +) + +const ( + nonceBytes = 16 +) + +// NewCBCHMAC instantiates a new AEAD based on CBC+HMAC. +func NewCBCHMAC(key []byte, newBlockCipher func([]byte) (cipher.Block, error)) (cipher.AEAD, error) { + keySize := len(key) / 2 + integrityKey := key[:keySize] + encryptionKey := key[keySize:] + + blockCipher, err := newBlockCipher(encryptionKey) + if err != nil { + return nil, err + } + + var hash func() hash.Hash + switch keySize { + case 16: + hash = sha256.New + case 24: + hash = sha512.New384 + case 32: + hash = sha512.New + } + + return &cbcAEAD{ + hash: hash, + blockCipher: blockCipher, + authtagBytes: keySize, + integrityKey: integrityKey, + }, nil +} + +// An AEAD based on CBC+HMAC +type cbcAEAD struct { + hash func() hash.Hash + authtagBytes int + integrityKey []byte + blockCipher cipher.Block +} + +func (ctx *cbcAEAD) NonceSize() int { + return nonceBytes +} + +func (ctx *cbcAEAD) Overhead() int { + // Maximum overhead is block size (for padding) plus auth tag length, where + // the length of the auth tag is equivalent to the key size. + return ctx.blockCipher.BlockSize() + ctx.authtagBytes +} + +// Seal encrypts and authenticates the plaintext. +func (ctx *cbcAEAD) Seal(dst, nonce, plaintext, data []byte) []byte { + // Output buffer -- must take care not to mangle plaintext input. + ciphertext := make([]byte, uint64(len(plaintext))+uint64(ctx.Overhead()))[:len(plaintext)] + copy(ciphertext, plaintext) + ciphertext = padBuffer(ciphertext, ctx.blockCipher.BlockSize()) + + cbc := cipher.NewCBCEncrypter(ctx.blockCipher, nonce) + + cbc.CryptBlocks(ciphertext, ciphertext) + authtag := ctx.computeAuthTag(data, nonce, ciphertext) + + ret, out := resize(dst, uint64(len(dst))+uint64(len(ciphertext))+uint64(len(authtag))) + copy(out, ciphertext) + copy(out[len(ciphertext):], authtag) + + return ret +} + +// Open decrypts and authenticates the ciphertext. +func (ctx *cbcAEAD) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) { + if len(ciphertext) < ctx.authtagBytes { + return nil, errors.New("square/go-jose: invalid ciphertext (too short)") + } + + offset := len(ciphertext) - ctx.authtagBytes + expectedTag := ctx.computeAuthTag(data, nonce, ciphertext[:offset]) + match := subtle.ConstantTimeCompare(expectedTag, ciphertext[offset:]) + if match != 1 { + return nil, errors.New("square/go-jose: invalid ciphertext (auth tag mismatch)") + } + + cbc := cipher.NewCBCDecrypter(ctx.blockCipher, nonce) + + // Make copy of ciphertext buffer, don't want to modify in place + buffer := append([]byte{}, []byte(ciphertext[:offset])...) + + if len(buffer)%ctx.blockCipher.BlockSize() > 0 { + return nil, errors.New("square/go-jose: invalid ciphertext (invalid length)") + } + + cbc.CryptBlocks(buffer, buffer) + + // Remove padding + plaintext, err := unpadBuffer(buffer, ctx.blockCipher.BlockSize()) + if err != nil { + return nil, err + } + + ret, out := resize(dst, uint64(len(dst))+uint64(len(plaintext))) + copy(out, plaintext) + + return ret, nil +} + +// Compute an authentication tag +func (ctx *cbcAEAD) computeAuthTag(aad, nonce, ciphertext []byte) []byte { + buffer := make([]byte, uint64(len(aad))+uint64(len(nonce))+uint64(len(ciphertext))+8) + n := 0 + n += copy(buffer, aad) + n += copy(buffer[n:], nonce) + n += copy(buffer[n:], ciphertext) + binary.BigEndian.PutUint64(buffer[n:], uint64(len(aad))*8) + + // According to documentation, Write() on hash.Hash never fails. + hmac := hmac.New(ctx.hash, ctx.integrityKey) + _, _ = hmac.Write(buffer) + + return hmac.Sum(nil)[:ctx.authtagBytes] +} + +// resize ensures the the given slice has a capacity of at least n bytes. +// If the capacity of the slice is less than n, a new slice is allocated +// and the existing data will be copied. +func resize(in []byte, n uint64) (head, tail []byte) { + if uint64(cap(in)) >= n { + head = in[:n] + } else { + head = make([]byte, n) + copy(head, in) + } + + tail = head[len(in):] + return +} + +// Apply padding +func padBuffer(buffer []byte, blockSize int) []byte { + missing := blockSize - (len(buffer) % blockSize) + ret, out := resize(buffer, uint64(len(buffer))+uint64(missing)) + padding := bytes.Repeat([]byte{byte(missing)}, missing) + copy(out, padding) + return ret +} + +// Remove padding +func unpadBuffer(buffer []byte, blockSize int) ([]byte, error) { + if len(buffer)%blockSize != 0 { + return nil, errors.New("square/go-jose: invalid padding") + } + + last := buffer[len(buffer)-1] + count := int(last) + + if count == 0 || count > blockSize || count > len(buffer) { + return nil, errors.New("square/go-jose: invalid padding") + } + + padding := bytes.Repeat([]byte{last}, count) + if !bytes.HasSuffix(buffer, padding) { + return nil, errors.New("square/go-jose: invalid padding") + } + + return buffer[:len(buffer)-count], nil +} diff --git a/vendor/gopkg.in/square/go-jose.v2/cipher/concat_kdf.go b/vendor/gopkg.in/square/go-jose.v2/cipher/concat_kdf.go new file mode 100644 index 0000000000..f62c3bdba5 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/cipher/concat_kdf.go @@ -0,0 +1,75 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package josecipher + +import ( + "crypto" + "encoding/binary" + "hash" + "io" +) + +type concatKDF struct { + z, info []byte + i uint32 + cache []byte + hasher hash.Hash +} + +// NewConcatKDF builds a KDF reader based on the given inputs. +func NewConcatKDF(hash crypto.Hash, z, algID, ptyUInfo, ptyVInfo, supPubInfo, supPrivInfo []byte) io.Reader { + buffer := make([]byte, uint64(len(algID))+uint64(len(ptyUInfo))+uint64(len(ptyVInfo))+uint64(len(supPubInfo))+uint64(len(supPrivInfo))) + n := 0 + n += copy(buffer, algID) + n += copy(buffer[n:], ptyUInfo) + n += copy(buffer[n:], ptyVInfo) + n += copy(buffer[n:], supPubInfo) + copy(buffer[n:], supPrivInfo) + + hasher := hash.New() + + return &concatKDF{ + z: z, + info: buffer, + hasher: hasher, + cache: []byte{}, + i: 1, + } +} + +func (ctx *concatKDF) Read(out []byte) (int, error) { + copied := copy(out, ctx.cache) + ctx.cache = ctx.cache[copied:] + + for copied < len(out) { + ctx.hasher.Reset() + + // Write on a hash.Hash never fails + _ = binary.Write(ctx.hasher, binary.BigEndian, ctx.i) + _, _ = ctx.hasher.Write(ctx.z) + _, _ = ctx.hasher.Write(ctx.info) + + hash := ctx.hasher.Sum(nil) + chunkCopied := copy(out[copied:], hash) + copied += chunkCopied + ctx.cache = hash[chunkCopied:] + + ctx.i++ + } + + return copied, nil +} diff --git a/vendor/gopkg.in/square/go-jose.v2/cipher/ecdh_es.go b/vendor/gopkg.in/square/go-jose.v2/cipher/ecdh_es.go new file mode 100644 index 0000000000..c128e327f3 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/cipher/ecdh_es.go @@ -0,0 +1,62 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package josecipher + +import ( + "crypto" + "crypto/ecdsa" + "encoding/binary" +) + +// DeriveECDHES derives a shared encryption key using ECDH/ConcatKDF as described in JWE/JWA. +// It is an error to call this function with a private/public key that are not on the same +// curve. Callers must ensure that the keys are valid before calling this function. Output +// size may be at most 1<<16 bytes (64 KiB). +func DeriveECDHES(alg string, apuData, apvData []byte, priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, size int) []byte { + if size > 1<<16 { + panic("ECDH-ES output size too large, must be less than or equal to 1<<16") + } + + // algId, partyUInfo, partyVInfo inputs must be prefixed with the length + algID := lengthPrefixed([]byte(alg)) + ptyUInfo := lengthPrefixed(apuData) + ptyVInfo := lengthPrefixed(apvData) + + // suppPubInfo is the encoded length of the output size in bits + supPubInfo := make([]byte, 4) + binary.BigEndian.PutUint32(supPubInfo, uint32(size)*8) + + if !priv.PublicKey.Curve.IsOnCurve(pub.X, pub.Y) { + panic("public key not on same curve as private key") + } + + z, _ := priv.PublicKey.Curve.ScalarMult(pub.X, pub.Y, priv.D.Bytes()) + reader := NewConcatKDF(crypto.SHA256, z.Bytes(), algID, ptyUInfo, ptyVInfo, supPubInfo, []byte{}) + + key := make([]byte, size) + + // Read on the KDF will never fail + _, _ = reader.Read(key) + return key +} + +func lengthPrefixed(data []byte) []byte { + out := make([]byte, len(data)+4) + binary.BigEndian.PutUint32(out, uint32(len(data))) + copy(out[4:], data) + return out +} diff --git a/vendor/gopkg.in/square/go-jose.v2/cipher/key_wrap.go b/vendor/gopkg.in/square/go-jose.v2/cipher/key_wrap.go new file mode 100644 index 0000000000..1d36d50151 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/cipher/key_wrap.go @@ -0,0 +1,109 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package josecipher + +import ( + "crypto/cipher" + "crypto/subtle" + "encoding/binary" + "errors" +) + +var defaultIV = []byte{0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6} + +// KeyWrap implements NIST key wrapping; it wraps a content encryption key (cek) with the given block cipher. +func KeyWrap(block cipher.Block, cek []byte) ([]byte, error) { + if len(cek)%8 != 0 { + return nil, errors.New("square/go-jose: key wrap input must be 8 byte blocks") + } + + n := len(cek) / 8 + r := make([][]byte, n) + + for i := range r { + r[i] = make([]byte, 8) + copy(r[i], cek[i*8:]) + } + + buffer := make([]byte, 16) + tBytes := make([]byte, 8) + copy(buffer, defaultIV) + + for t := 0; t < 6*n; t++ { + copy(buffer[8:], r[t%n]) + + block.Encrypt(buffer, buffer) + + binary.BigEndian.PutUint64(tBytes, uint64(t+1)) + + for i := 0; i < 8; i++ { + buffer[i] = buffer[i] ^ tBytes[i] + } + copy(r[t%n], buffer[8:]) + } + + out := make([]byte, (n+1)*8) + copy(out, buffer[:8]) + for i := range r { + copy(out[(i+1)*8:], r[i]) + } + + return out, nil +} + +// KeyUnwrap implements NIST key unwrapping; it unwraps a content encryption key (cek) with the given block cipher. +func KeyUnwrap(block cipher.Block, ciphertext []byte) ([]byte, error) { + if len(ciphertext)%8 != 0 { + return nil, errors.New("square/go-jose: key wrap input must be 8 byte blocks") + } + + n := (len(ciphertext) / 8) - 1 + r := make([][]byte, n) + + for i := range r { + r[i] = make([]byte, 8) + copy(r[i], ciphertext[(i+1)*8:]) + } + + buffer := make([]byte, 16) + tBytes := make([]byte, 8) + copy(buffer[:8], ciphertext[:8]) + + for t := 6*n - 1; t >= 0; t-- { + binary.BigEndian.PutUint64(tBytes, uint64(t+1)) + + for i := 0; i < 8; i++ { + buffer[i] = buffer[i] ^ tBytes[i] + } + copy(buffer[8:], r[t%n]) + + block.Decrypt(buffer, buffer) + + copy(r[t%n], buffer[8:]) + } + + if subtle.ConstantTimeCompare(buffer[:8], defaultIV) == 0 { + return nil, errors.New("square/go-jose: failed to unwrap key") + } + + out := make([]byte, n*8) + for i := range r { + copy(out[i*8:], r[i]) + } + + return out, nil +} diff --git a/vendor/gopkg.in/square/go-jose.v2/crypter.go b/vendor/gopkg.in/square/go-jose.v2/crypter.go new file mode 100644 index 0000000000..c45c71206b --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/crypter.go @@ -0,0 +1,535 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package jose + +import ( + "crypto/ecdsa" + "crypto/rsa" + "errors" + "fmt" + "reflect" + + "gopkg.in/square/go-jose.v2/json" +) + +// Encrypter represents an encrypter which produces an encrypted JWE object. +type Encrypter interface { + Encrypt(plaintext []byte) (*JSONWebEncryption, error) + EncryptWithAuthData(plaintext []byte, aad []byte) (*JSONWebEncryption, error) + Options() EncrypterOptions +} + +// A generic content cipher +type contentCipher interface { + keySize() int + encrypt(cek []byte, aad, plaintext []byte) (*aeadParts, error) + decrypt(cek []byte, aad []byte, parts *aeadParts) ([]byte, error) +} + +// A key generator (for generating/getting a CEK) +type keyGenerator interface { + keySize() int + genKey() ([]byte, rawHeader, error) +} + +// A generic key encrypter +type keyEncrypter interface { + encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) // Encrypt a key +} + +// A generic key decrypter +type keyDecrypter interface { + decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) // Decrypt a key +} + +// A generic encrypter based on the given key encrypter and content cipher. +type genericEncrypter struct { + contentAlg ContentEncryption + compressionAlg CompressionAlgorithm + cipher contentCipher + recipients []recipientKeyInfo + keyGenerator keyGenerator + extraHeaders map[HeaderKey]interface{} +} + +type recipientKeyInfo struct { + keyID string + keyAlg KeyAlgorithm + keyEncrypter keyEncrypter +} + +// EncrypterOptions represents options that can be set on new encrypters. +type EncrypterOptions struct { + Compression CompressionAlgorithm + + // Optional map of additional keys to be inserted into the protected header + // of a JWS object. Some specifications which make use of JWS like to insert + // additional values here. All values must be JSON-serializable. + ExtraHeaders map[HeaderKey]interface{} +} + +// WithHeader adds an arbitrary value to the ExtraHeaders map, initializing it +// if necessary. It returns itself and so can be used in a fluent style. +func (eo *EncrypterOptions) WithHeader(k HeaderKey, v interface{}) *EncrypterOptions { + if eo.ExtraHeaders == nil { + eo.ExtraHeaders = map[HeaderKey]interface{}{} + } + eo.ExtraHeaders[k] = v + return eo +} + +// WithContentType adds a content type ("cty") header and returns the updated +// EncrypterOptions. +func (eo *EncrypterOptions) WithContentType(contentType ContentType) *EncrypterOptions { + return eo.WithHeader(HeaderContentType, contentType) +} + +// WithType adds a type ("typ") header and returns the updated EncrypterOptions. +func (eo *EncrypterOptions) WithType(typ ContentType) *EncrypterOptions { + return eo.WithHeader(HeaderType, typ) +} + +// Recipient represents an algorithm/key to encrypt messages to. +// +// PBES2Count and PBES2Salt correspond with the "p2c" and "p2s" headers used +// on the password-based encryption algorithms PBES2-HS256+A128KW, +// PBES2-HS384+A192KW, and PBES2-HS512+A256KW. If they are not provided a safe +// default of 100000 will be used for the count and a 128-bit random salt will +// be generated. +type Recipient struct { + Algorithm KeyAlgorithm + Key interface{} + KeyID string + PBES2Count int + PBES2Salt []byte +} + +// NewEncrypter creates an appropriate encrypter based on the key type +func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions) (Encrypter, error) { + encrypter := &genericEncrypter{ + contentAlg: enc, + recipients: []recipientKeyInfo{}, + cipher: getContentCipher(enc), + } + if opts != nil { + encrypter.compressionAlg = opts.Compression + encrypter.extraHeaders = opts.ExtraHeaders + } + + if encrypter.cipher == nil { + return nil, ErrUnsupportedAlgorithm + } + + var keyID string + var rawKey interface{} + switch encryptionKey := rcpt.Key.(type) { + case JSONWebKey: + keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key + case *JSONWebKey: + keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key + default: + rawKey = encryptionKey + } + + switch rcpt.Algorithm { + case DIRECT: + // Direct encryption mode must be treated differently + if reflect.TypeOf(rawKey) != reflect.TypeOf([]byte{}) { + return nil, ErrUnsupportedKeyType + } + if encrypter.cipher.keySize() != len(rawKey.([]byte)) { + return nil, ErrInvalidKeySize + } + encrypter.keyGenerator = staticKeyGenerator{ + key: rawKey.([]byte), + } + recipientInfo, _ := newSymmetricRecipient(rcpt.Algorithm, rawKey.([]byte)) + recipientInfo.keyID = keyID + if rcpt.KeyID != "" { + recipientInfo.keyID = rcpt.KeyID + } + encrypter.recipients = []recipientKeyInfo{recipientInfo} + return encrypter, nil + case ECDH_ES: + // ECDH-ES (w/o key wrapping) is similar to DIRECT mode + typeOf := reflect.TypeOf(rawKey) + if typeOf != reflect.TypeOf(&ecdsa.PublicKey{}) { + return nil, ErrUnsupportedKeyType + } + encrypter.keyGenerator = ecKeyGenerator{ + size: encrypter.cipher.keySize(), + algID: string(enc), + publicKey: rawKey.(*ecdsa.PublicKey), + } + recipientInfo, _ := newECDHRecipient(rcpt.Algorithm, rawKey.(*ecdsa.PublicKey)) + recipientInfo.keyID = keyID + if rcpt.KeyID != "" { + recipientInfo.keyID = rcpt.KeyID + } + encrypter.recipients = []recipientKeyInfo{recipientInfo} + return encrypter, nil + default: + // Can just add a standard recipient + encrypter.keyGenerator = randomKeyGenerator{ + size: encrypter.cipher.keySize(), + } + err := encrypter.addRecipient(rcpt) + return encrypter, err + } +} + +// NewMultiEncrypter creates a multi-encrypter based on the given parameters +func NewMultiEncrypter(enc ContentEncryption, rcpts []Recipient, opts *EncrypterOptions) (Encrypter, error) { + cipher := getContentCipher(enc) + + if cipher == nil { + return nil, ErrUnsupportedAlgorithm + } + if rcpts == nil || len(rcpts) == 0 { + return nil, fmt.Errorf("square/go-jose: recipients is nil or empty") + } + + encrypter := &genericEncrypter{ + contentAlg: enc, + recipients: []recipientKeyInfo{}, + cipher: cipher, + keyGenerator: randomKeyGenerator{ + size: cipher.keySize(), + }, + } + + if opts != nil { + encrypter.compressionAlg = opts.Compression + } + + for _, recipient := range rcpts { + err := encrypter.addRecipient(recipient) + if err != nil { + return nil, err + } + } + + return encrypter, nil +} + +func (ctx *genericEncrypter) addRecipient(recipient Recipient) (err error) { + var recipientInfo recipientKeyInfo + + switch recipient.Algorithm { + case DIRECT, ECDH_ES: + return fmt.Errorf("square/go-jose: key algorithm '%s' not supported in multi-recipient mode", recipient.Algorithm) + } + + recipientInfo, err = makeJWERecipient(recipient.Algorithm, recipient.Key) + if recipient.KeyID != "" { + recipientInfo.keyID = recipient.KeyID + } + + switch recipient.Algorithm { + case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW: + if sr, ok := recipientInfo.keyEncrypter.(*symmetricKeyCipher); ok { + sr.p2c = recipient.PBES2Count + sr.p2s = recipient.PBES2Salt + } + } + + if err == nil { + ctx.recipients = append(ctx.recipients, recipientInfo) + } + return err +} + +func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKeyInfo, error) { + switch encryptionKey := encryptionKey.(type) { + case *rsa.PublicKey: + return newRSARecipient(alg, encryptionKey) + case *ecdsa.PublicKey: + return newECDHRecipient(alg, encryptionKey) + case []byte: + return newSymmetricRecipient(alg, encryptionKey) + case string: + return newSymmetricRecipient(alg, []byte(encryptionKey)) + case *JSONWebKey: + recipient, err := makeJWERecipient(alg, encryptionKey.Key) + recipient.keyID = encryptionKey.KeyID + return recipient, err + default: + return recipientKeyInfo{}, ErrUnsupportedKeyType + } +} + +// newDecrypter creates an appropriate decrypter based on the key type +func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) { + switch decryptionKey := decryptionKey.(type) { + case *rsa.PrivateKey: + return &rsaDecrypterSigner{ + privateKey: decryptionKey, + }, nil + case *ecdsa.PrivateKey: + return &ecDecrypterSigner{ + privateKey: decryptionKey, + }, nil + case []byte: + return &symmetricKeyCipher{ + key: decryptionKey, + }, nil + case string: + return &symmetricKeyCipher{ + key: []byte(decryptionKey), + }, nil + case JSONWebKey: + return newDecrypter(decryptionKey.Key) + case *JSONWebKey: + return newDecrypter(decryptionKey.Key) + default: + return nil, ErrUnsupportedKeyType + } +} + +// Implementation of encrypt method producing a JWE object. +func (ctx *genericEncrypter) Encrypt(plaintext []byte) (*JSONWebEncryption, error) { + return ctx.EncryptWithAuthData(plaintext, nil) +} + +// Implementation of encrypt method producing a JWE object. +func (ctx *genericEncrypter) EncryptWithAuthData(plaintext, aad []byte) (*JSONWebEncryption, error) { + obj := &JSONWebEncryption{} + obj.aad = aad + + obj.protected = &rawHeader{} + err := obj.protected.set(headerEncryption, ctx.contentAlg) + if err != nil { + return nil, err + } + + obj.recipients = make([]recipientInfo, len(ctx.recipients)) + + if len(ctx.recipients) == 0 { + return nil, fmt.Errorf("square/go-jose: no recipients to encrypt to") + } + + cek, headers, err := ctx.keyGenerator.genKey() + if err != nil { + return nil, err + } + + obj.protected.merge(&headers) + + for i, info := range ctx.recipients { + recipient, err := info.keyEncrypter.encryptKey(cek, info.keyAlg) + if err != nil { + return nil, err + } + + err = recipient.header.set(headerAlgorithm, info.keyAlg) + if err != nil { + return nil, err + } + + if info.keyID != "" { + err = recipient.header.set(headerKeyID, info.keyID) + if err != nil { + return nil, err + } + } + obj.recipients[i] = recipient + } + + if len(ctx.recipients) == 1 { + // Move per-recipient headers into main protected header if there's + // only a single recipient. + obj.protected.merge(obj.recipients[0].header) + obj.recipients[0].header = nil + } + + if ctx.compressionAlg != NONE { + plaintext, err = compress(ctx.compressionAlg, plaintext) + if err != nil { + return nil, err + } + + err = obj.protected.set(headerCompression, ctx.compressionAlg) + if err != nil { + return nil, err + } + } + + for k, v := range ctx.extraHeaders { + b, err := json.Marshal(v) + if err != nil { + return nil, err + } + (*obj.protected)[k] = makeRawMessage(b) + } + + authData := obj.computeAuthData() + parts, err := ctx.cipher.encrypt(cek, authData, plaintext) + if err != nil { + return nil, err + } + + obj.iv = parts.iv + obj.ciphertext = parts.ciphertext + obj.tag = parts.tag + + return obj, nil +} + +func (ctx *genericEncrypter) Options() EncrypterOptions { + return EncrypterOptions{ + Compression: ctx.compressionAlg, + ExtraHeaders: ctx.extraHeaders, + } +} + +// Decrypt and validate the object and return the plaintext. Note that this +// function does not support multi-recipient, if you desire multi-recipient +// decryption use DecryptMulti instead. +func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) { + headers := obj.mergedHeaders(nil) + + if len(obj.recipients) > 1 { + return nil, errors.New("square/go-jose: too many recipients in payload; expecting only one") + } + + critical, err := headers.getCritical() + if err != nil { + return nil, fmt.Errorf("square/go-jose: invalid crit header") + } + + if len(critical) > 0 { + return nil, fmt.Errorf("square/go-jose: unsupported crit header") + } + + decrypter, err := newDecrypter(decryptionKey) + if err != nil { + return nil, err + } + + cipher := getContentCipher(headers.getEncryption()) + if cipher == nil { + return nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(headers.getEncryption())) + } + + generator := randomKeyGenerator{ + size: cipher.keySize(), + } + + parts := &aeadParts{ + iv: obj.iv, + ciphertext: obj.ciphertext, + tag: obj.tag, + } + + authData := obj.computeAuthData() + + var plaintext []byte + recipient := obj.recipients[0] + recipientHeaders := obj.mergedHeaders(&recipient) + + cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator) + if err == nil { + // Found a valid CEK -- let's try to decrypt. + plaintext, err = cipher.decrypt(cek, authData, parts) + } + + if plaintext == nil { + return nil, ErrCryptoFailure + } + + // The "zip" header parameter may only be present in the protected header. + if comp := obj.protected.getCompression(); comp != "" { + plaintext, err = decompress(comp, plaintext) + } + + return plaintext, err +} + +// DecryptMulti decrypts and validates the object and returns the plaintexts, +// with support for multiple recipients. It returns the index of the recipient +// for which the decryption was successful, the merged headers for that recipient, +// and the plaintext. +func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Header, []byte, error) { + globalHeaders := obj.mergedHeaders(nil) + + critical, err := globalHeaders.getCritical() + if err != nil { + return -1, Header{}, nil, fmt.Errorf("square/go-jose: invalid crit header") + } + + if len(critical) > 0 { + return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported crit header") + } + + decrypter, err := newDecrypter(decryptionKey) + if err != nil { + return -1, Header{}, nil, err + } + + encryption := globalHeaders.getEncryption() + cipher := getContentCipher(encryption) + if cipher == nil { + return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(encryption)) + } + + generator := randomKeyGenerator{ + size: cipher.keySize(), + } + + parts := &aeadParts{ + iv: obj.iv, + ciphertext: obj.ciphertext, + tag: obj.tag, + } + + authData := obj.computeAuthData() + + index := -1 + var plaintext []byte + var headers rawHeader + + for i, recipient := range obj.recipients { + recipientHeaders := obj.mergedHeaders(&recipient) + + cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator) + if err == nil { + // Found a valid CEK -- let's try to decrypt. + plaintext, err = cipher.decrypt(cek, authData, parts) + if err == nil { + index = i + headers = recipientHeaders + break + } + } + } + + if plaintext == nil || err != nil { + return -1, Header{}, nil, ErrCryptoFailure + } + + // The "zip" header parameter may only be present in the protected header. + if comp := obj.protected.getCompression(); comp != "" { + plaintext, err = decompress(comp, plaintext) + } + + sanitized, err := headers.sanitized() + if err != nil { + return -1, Header{}, nil, fmt.Errorf("square/go-jose: failed to sanitize header: %v", err) + } + + return index, sanitized, plaintext, err +} diff --git a/vendor/gopkg.in/square/go-jose.v2/doc.go b/vendor/gopkg.in/square/go-jose.v2/doc.go new file mode 100644 index 0000000000..dd1387f3f0 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/doc.go @@ -0,0 +1,27 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + +Package jose aims to provide an implementation of the Javascript Object Signing +and Encryption set of standards. It implements encryption and signing based on +the JSON Web Encryption and JSON Web Signature standards, with optional JSON +Web Token support available in a sub-package. The library supports both the +compact and full serialization formats, and has optional support for multiple +recipients. + +*/ +package jose diff --git a/vendor/gopkg.in/square/go-jose.v2/encoding.go b/vendor/gopkg.in/square/go-jose.v2/encoding.go new file mode 100644 index 0000000000..b9687c647d --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/encoding.go @@ -0,0 +1,179 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package jose + +import ( + "bytes" + "compress/flate" + "encoding/base64" + "encoding/binary" + "io" + "math/big" + "regexp" + + "gopkg.in/square/go-jose.v2/json" +) + +var stripWhitespaceRegex = regexp.MustCompile("\\s") + +// Helper function to serialize known-good objects. +// Precondition: value is not a nil pointer. +func mustSerializeJSON(value interface{}) []byte { + out, err := json.Marshal(value) + if err != nil { + panic(err) + } + // We never want to serialize the top-level value "null," since it's not a + // valid JOSE message. But if a caller passes in a nil pointer to this method, + // MarshalJSON will happily serialize it as the top-level value "null". If + // that value is then embedded in another operation, for instance by being + // base64-encoded and fed as input to a signing algorithm + // (https://github.com/square/go-jose/issues/22), the result will be + // incorrect. Because this method is intended for known-good objects, and a nil + // pointer is not a known-good object, we are free to panic in this case. + // Note: It's not possible to directly check whether the data pointed at by an + // interface is a nil pointer, so we do this hacky workaround. + // https://groups.google.com/forum/#!topic/golang-nuts/wnH302gBa4I + if string(out) == "null" { + panic("Tried to serialize a nil pointer.") + } + return out +} + +// Strip all newlines and whitespace +func stripWhitespace(data string) string { + return stripWhitespaceRegex.ReplaceAllString(data, "") +} + +// Perform compression based on algorithm +func compress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) { + switch algorithm { + case DEFLATE: + return deflate(input) + default: + return nil, ErrUnsupportedAlgorithm + } +} + +// Perform decompression based on algorithm +func decompress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) { + switch algorithm { + case DEFLATE: + return inflate(input) + default: + return nil, ErrUnsupportedAlgorithm + } +} + +// Compress with DEFLATE +func deflate(input []byte) ([]byte, error) { + output := new(bytes.Buffer) + + // Writing to byte buffer, err is always nil + writer, _ := flate.NewWriter(output, 1) + _, _ = io.Copy(writer, bytes.NewBuffer(input)) + + err := writer.Close() + return output.Bytes(), err +} + +// Decompress with DEFLATE +func inflate(input []byte) ([]byte, error) { + output := new(bytes.Buffer) + reader := flate.NewReader(bytes.NewBuffer(input)) + + _, err := io.Copy(output, reader) + if err != nil { + return nil, err + } + + err = reader.Close() + return output.Bytes(), err +} + +// byteBuffer represents a slice of bytes that can be serialized to url-safe base64. +type byteBuffer struct { + data []byte +} + +func newBuffer(data []byte) *byteBuffer { + if data == nil { + return nil + } + return &byteBuffer{ + data: data, + } +} + +func newFixedSizeBuffer(data []byte, length int) *byteBuffer { + if len(data) > length { + panic("square/go-jose: invalid call to newFixedSizeBuffer (len(data) > length)") + } + pad := make([]byte, length-len(data)) + return newBuffer(append(pad, data...)) +} + +func newBufferFromInt(num uint64) *byteBuffer { + data := make([]byte, 8) + binary.BigEndian.PutUint64(data, num) + return newBuffer(bytes.TrimLeft(data, "\x00")) +} + +func (b *byteBuffer) MarshalJSON() ([]byte, error) { + return json.Marshal(b.base64()) +} + +func (b *byteBuffer) UnmarshalJSON(data []byte) error { + var encoded string + err := json.Unmarshal(data, &encoded) + if err != nil { + return err + } + + if encoded == "" { + return nil + } + + decoded, err := base64.RawURLEncoding.DecodeString(encoded) + if err != nil { + return err + } + + *b = *newBuffer(decoded) + + return nil +} + +func (b *byteBuffer) base64() string { + return base64.RawURLEncoding.EncodeToString(b.data) +} + +func (b *byteBuffer) bytes() []byte { + // Handling nil here allows us to transparently handle nil slices when serializing. + if b == nil { + return nil + } + return b.data +} + +func (b byteBuffer) bigInt() *big.Int { + return new(big.Int).SetBytes(b.data) +} + +func (b byteBuffer) toInt() int { + return int(b.bigInt().Int64()) +} diff --git a/vendor/gopkg.in/square/go-jose.v2/json/LICENSE b/vendor/gopkg.in/square/go-jose.v2/json/LICENSE new file mode 100644 index 0000000000..7448756763 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/json/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2012 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/gopkg.in/square/go-jose.v2/json/README.md b/vendor/gopkg.in/square/go-jose.v2/json/README.md new file mode 100644 index 0000000000..86de5e5581 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/json/README.md @@ -0,0 +1,13 @@ +# Safe JSON + +This repository contains a fork of the `encoding/json` package from Go 1.6. + +The following changes were made: + +* Object deserialization uses case-sensitive member name matching instead of + [case-insensitive matching](https://www.ietf.org/mail-archive/web/json/current/msg03763.html). + This is to avoid differences in the interpretation of JOSE messages between + go-jose and libraries written in other languages. +* When deserializing a JSON object, we check for duplicate keys and reject the + input whenever we detect a duplicate. Rather than trying to work with malformed + data, we prefer to reject it right away. diff --git a/vendor/gopkg.in/square/go-jose.v2/json/decode.go b/vendor/gopkg.in/square/go-jose.v2/json/decode.go new file mode 100644 index 0000000000..37457e5a83 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/json/decode.go @@ -0,0 +1,1183 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Represents JSON data structure using native Go types: booleans, floats, +// strings, arrays, and maps. + +package json + +import ( + "bytes" + "encoding" + "encoding/base64" + "errors" + "fmt" + "reflect" + "runtime" + "strconv" + "unicode" + "unicode/utf16" + "unicode/utf8" +) + +// Unmarshal parses the JSON-encoded data and stores the result +// in the value pointed to by v. +// +// Unmarshal uses the inverse of the encodings that +// Marshal uses, allocating maps, slices, and pointers as necessary, +// with the following additional rules: +// +// To unmarshal JSON into a pointer, Unmarshal first handles the case of +// the JSON being the JSON literal null. In that case, Unmarshal sets +// the pointer to nil. Otherwise, Unmarshal unmarshals the JSON into +// the value pointed at by the pointer. If the pointer is nil, Unmarshal +// allocates a new value for it to point to. +// +// To unmarshal JSON into a struct, Unmarshal matches incoming object +// keys to the keys used by Marshal (either the struct field name or its tag), +// preferring an exact match but also accepting a case-insensitive match. +// Unmarshal will only set exported fields of the struct. +// +// To unmarshal JSON into an interface value, +// Unmarshal stores one of these in the interface value: +// +// bool, for JSON booleans +// float64, for JSON numbers +// string, for JSON strings +// []interface{}, for JSON arrays +// map[string]interface{}, for JSON objects +// nil for JSON null +// +// To unmarshal a JSON array into a slice, Unmarshal resets the slice length +// to zero and then appends each element to the slice. +// As a special case, to unmarshal an empty JSON array into a slice, +// Unmarshal replaces the slice with a new empty slice. +// +// To unmarshal a JSON array into a Go array, Unmarshal decodes +// JSON array elements into corresponding Go array elements. +// If the Go array is smaller than the JSON array, +// the additional JSON array elements are discarded. +// If the JSON array is smaller than the Go array, +// the additional Go array elements are set to zero values. +// +// To unmarshal a JSON object into a string-keyed map, Unmarshal first +// establishes a map to use, If the map is nil, Unmarshal allocates a new map. +// Otherwise Unmarshal reuses the existing map, keeping existing entries. +// Unmarshal then stores key-value pairs from the JSON object into the map. +// +// If a JSON value is not appropriate for a given target type, +// or if a JSON number overflows the target type, Unmarshal +// skips that field and completes the unmarshaling as best it can. +// If no more serious errors are encountered, Unmarshal returns +// an UnmarshalTypeError describing the earliest such error. +// +// The JSON null value unmarshals into an interface, map, pointer, or slice +// by setting that Go value to nil. Because null is often used in JSON to mean +// ``not present,'' unmarshaling a JSON null into any other Go type has no effect +// on the value and produces no error. +// +// When unmarshaling quoted strings, invalid UTF-8 or +// invalid UTF-16 surrogate pairs are not treated as an error. +// Instead, they are replaced by the Unicode replacement +// character U+FFFD. +// +func Unmarshal(data []byte, v interface{}) error { + // Check for well-formedness. + // Avoids filling out half a data structure + // before discovering a JSON syntax error. + var d decodeState + err := checkValid(data, &d.scan) + if err != nil { + return err + } + + d.init(data) + return d.unmarshal(v) +} + +// Unmarshaler is the interface implemented by objects +// that can unmarshal a JSON description of themselves. +// The input can be assumed to be a valid encoding of +// a JSON value. UnmarshalJSON must copy the JSON data +// if it wishes to retain the data after returning. +type Unmarshaler interface { + UnmarshalJSON([]byte) error +} + +// An UnmarshalTypeError describes a JSON value that was +// not appropriate for a value of a specific Go type. +type UnmarshalTypeError struct { + Value string // description of JSON value - "bool", "array", "number -5" + Type reflect.Type // type of Go value it could not be assigned to + Offset int64 // error occurred after reading Offset bytes +} + +func (e *UnmarshalTypeError) Error() string { + return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String() +} + +// An UnmarshalFieldError describes a JSON object key that +// led to an unexported (and therefore unwritable) struct field. +// (No longer used; kept for compatibility.) +type UnmarshalFieldError struct { + Key string + Type reflect.Type + Field reflect.StructField +} + +func (e *UnmarshalFieldError) Error() string { + return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String() +} + +// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal. +// (The argument to Unmarshal must be a non-nil pointer.) +type InvalidUnmarshalError struct { + Type reflect.Type +} + +func (e *InvalidUnmarshalError) Error() string { + if e.Type == nil { + return "json: Unmarshal(nil)" + } + + if e.Type.Kind() != reflect.Ptr { + return "json: Unmarshal(non-pointer " + e.Type.String() + ")" + } + return "json: Unmarshal(nil " + e.Type.String() + ")" +} + +func (d *decodeState) unmarshal(v interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + err = r.(error) + } + }() + + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return &InvalidUnmarshalError{reflect.TypeOf(v)} + } + + d.scan.reset() + // We decode rv not rv.Elem because the Unmarshaler interface + // test must be applied at the top level of the value. + d.value(rv) + return d.savedError +} + +// A Number represents a JSON number literal. +type Number string + +// String returns the literal text of the number. +func (n Number) String() string { return string(n) } + +// Float64 returns the number as a float64. +func (n Number) Float64() (float64, error) { + return strconv.ParseFloat(string(n), 64) +} + +// Int64 returns the number as an int64. +func (n Number) Int64() (int64, error) { + return strconv.ParseInt(string(n), 10, 64) +} + +// isValidNumber reports whether s is a valid JSON number literal. +func isValidNumber(s string) bool { + // This function implements the JSON numbers grammar. + // See https://tools.ietf.org/html/rfc7159#section-6 + // and http://json.org/number.gif + + if s == "" { + return false + } + + // Optional - + if s[0] == '-' { + s = s[1:] + if s == "" { + return false + } + } + + // Digits + switch { + default: + return false + + case s[0] == '0': + s = s[1:] + + case '1' <= s[0] && s[0] <= '9': + s = s[1:] + for len(s) > 0 && '0' <= s[0] && s[0] <= '9' { + s = s[1:] + } + } + + // . followed by 1 or more digits. + if len(s) >= 2 && s[0] == '.' && '0' <= s[1] && s[1] <= '9' { + s = s[2:] + for len(s) > 0 && '0' <= s[0] && s[0] <= '9' { + s = s[1:] + } + } + + // e or E followed by an optional - or + and + // 1 or more digits. + if len(s) >= 2 && (s[0] == 'e' || s[0] == 'E') { + s = s[1:] + if s[0] == '+' || s[0] == '-' { + s = s[1:] + if s == "" { + return false + } + } + for len(s) > 0 && '0' <= s[0] && s[0] <= '9' { + s = s[1:] + } + } + + // Make sure we are at the end. + return s == "" +} + +// decodeState represents the state while decoding a JSON value. +type decodeState struct { + data []byte + off int // read offset in data + scan scanner + nextscan scanner // for calls to nextValue + savedError error + useNumber bool +} + +// errPhase is used for errors that should not happen unless +// there is a bug in the JSON decoder or something is editing +// the data slice while the decoder executes. +var errPhase = errors.New("JSON decoder out of sync - data changing underfoot?") + +func (d *decodeState) init(data []byte) *decodeState { + d.data = data + d.off = 0 + d.savedError = nil + return d +} + +// error aborts the decoding by panicking with err. +func (d *decodeState) error(err error) { + panic(err) +} + +// saveError saves the first err it is called with, +// for reporting at the end of the unmarshal. +func (d *decodeState) saveError(err error) { + if d.savedError == nil { + d.savedError = err + } +} + +// next cuts off and returns the next full JSON value in d.data[d.off:]. +// The next value is known to be an object or array, not a literal. +func (d *decodeState) next() []byte { + c := d.data[d.off] + item, rest, err := nextValue(d.data[d.off:], &d.nextscan) + if err != nil { + d.error(err) + } + d.off = len(d.data) - len(rest) + + // Our scanner has seen the opening brace/bracket + // and thinks we're still in the middle of the object. + // invent a closing brace/bracket to get it out. + if c == '{' { + d.scan.step(&d.scan, '}') + } else { + d.scan.step(&d.scan, ']') + } + + return item +} + +// scanWhile processes bytes in d.data[d.off:] until it +// receives a scan code not equal to op. +// It updates d.off and returns the new scan code. +func (d *decodeState) scanWhile(op int) int { + var newOp int + for { + if d.off >= len(d.data) { + newOp = d.scan.eof() + d.off = len(d.data) + 1 // mark processed EOF with len+1 + } else { + c := d.data[d.off] + d.off++ + newOp = d.scan.step(&d.scan, c) + } + if newOp != op { + break + } + } + return newOp +} + +// value decodes a JSON value from d.data[d.off:] into the value. +// it updates d.off to point past the decoded value. +func (d *decodeState) value(v reflect.Value) { + if !v.IsValid() { + _, rest, err := nextValue(d.data[d.off:], &d.nextscan) + if err != nil { + d.error(err) + } + d.off = len(d.data) - len(rest) + + // d.scan thinks we're still at the beginning of the item. + // Feed in an empty string - the shortest, simplest value - + // so that it knows we got to the end of the value. + if d.scan.redo { + // rewind. + d.scan.redo = false + d.scan.step = stateBeginValue + } + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '"') + + n := len(d.scan.parseState) + if n > 0 && d.scan.parseState[n-1] == parseObjectKey { + // d.scan thinks we just read an object key; finish the object + d.scan.step(&d.scan, ':') + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '}') + } + + return + } + + switch op := d.scanWhile(scanSkipSpace); op { + default: + d.error(errPhase) + + case scanBeginArray: + d.array(v) + + case scanBeginObject: + d.object(v) + + case scanBeginLiteral: + d.literal(v) + } +} + +type unquotedValue struct{} + +// valueQuoted is like value but decodes a +// quoted string literal or literal null into an interface value. +// If it finds anything other than a quoted string literal or null, +// valueQuoted returns unquotedValue{}. +func (d *decodeState) valueQuoted() interface{} { + switch op := d.scanWhile(scanSkipSpace); op { + default: + d.error(errPhase) + + case scanBeginArray: + d.array(reflect.Value{}) + + case scanBeginObject: + d.object(reflect.Value{}) + + case scanBeginLiteral: + switch v := d.literalInterface().(type) { + case nil, string: + return v + } + } + return unquotedValue{} +} + +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +// if it encounters an Unmarshaler, indirect stops and returns that. +// if decodingNull is true, indirect stops at the last pointer so it can be set to nil. +func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. + if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + v = v.Addr() + } + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) { + v = e + continue + } + } + + if v.Kind() != reflect.Ptr { + break + } + + if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() { + break + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + if v.Type().NumMethod() > 0 { + if u, ok := v.Interface().(Unmarshaler); ok { + return u, nil, reflect.Value{} + } + if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { + return nil, u, reflect.Value{} + } + } + v = v.Elem() + } + return nil, nil, v +} + +// array consumes an array from d.data[d.off-1:], decoding into the value v. +// the first byte of the array ('[') has been read already. +func (d *decodeState) array(v reflect.Value) { + // Check for unmarshaler. + u, ut, pv := d.indirect(v, false) + if u != nil { + d.off-- + err := u.UnmarshalJSON(d.next()) + if err != nil { + d.error(err) + } + return + } + if ut != nil { + d.saveError(&UnmarshalTypeError{"array", v.Type(), int64(d.off)}) + d.off-- + d.next() + return + } + + v = pv + + // Check type of target. + switch v.Kind() { + case reflect.Interface: + if v.NumMethod() == 0 { + // Decoding into nil interface? Switch to non-reflect code. + v.Set(reflect.ValueOf(d.arrayInterface())) + return + } + // Otherwise it's invalid. + fallthrough + default: + d.saveError(&UnmarshalTypeError{"array", v.Type(), int64(d.off)}) + d.off-- + d.next() + return + case reflect.Array: + case reflect.Slice: + break + } + + i := 0 + for { + // Look ahead for ] - can only happen on first iteration. + op := d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + + // Back up so d.value can have the byte we just read. + d.off-- + d.scan.undo(op) + + // Get element of array, growing if necessary. + if v.Kind() == reflect.Slice { + // Grow slice if necessary + if i >= v.Cap() { + newcap := v.Cap() + v.Cap()/2 + if newcap < 4 { + newcap = 4 + } + newv := reflect.MakeSlice(v.Type(), v.Len(), newcap) + reflect.Copy(newv, v) + v.Set(newv) + } + if i >= v.Len() { + v.SetLen(i + 1) + } + } + + if i < v.Len() { + // Decode into element. + d.value(v.Index(i)) + } else { + // Ran out of fixed array: skip. + d.value(reflect.Value{}) + } + i++ + + // Next token must be , or ]. + op = d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + if op != scanArrayValue { + d.error(errPhase) + } + } + + if i < v.Len() { + if v.Kind() == reflect.Array { + // Array. Zero the rest. + z := reflect.Zero(v.Type().Elem()) + for ; i < v.Len(); i++ { + v.Index(i).Set(z) + } + } else { + v.SetLen(i) + } + } + if i == 0 && v.Kind() == reflect.Slice { + v.Set(reflect.MakeSlice(v.Type(), 0, 0)) + } +} + +var nullLiteral = []byte("null") + +// object consumes an object from d.data[d.off-1:], decoding into the value v. +// the first byte ('{') of the object has been read already. +func (d *decodeState) object(v reflect.Value) { + // Check for unmarshaler. + u, ut, pv := d.indirect(v, false) + if u != nil { + d.off-- + err := u.UnmarshalJSON(d.next()) + if err != nil { + d.error(err) + } + return + } + if ut != nil { + d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)}) + d.off-- + d.next() // skip over { } in input + return + } + v = pv + + // Decoding into nil interface? Switch to non-reflect code. + if v.Kind() == reflect.Interface && v.NumMethod() == 0 { + v.Set(reflect.ValueOf(d.objectInterface())) + return + } + + // Check type of target: struct or map[string]T + switch v.Kind() { + case reflect.Map: + // map must have string kind + t := v.Type() + if t.Key().Kind() != reflect.String { + d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)}) + d.off-- + d.next() // skip over { } in input + return + } + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } + case reflect.Struct: + + default: + d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)}) + d.off-- + d.next() // skip over { } in input + return + } + + var mapElem reflect.Value + keys := map[string]bool{} + + for { + // Read opening " of string key or closing }. + op := d.scanWhile(scanSkipSpace) + if op == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if op != scanBeginLiteral { + d.error(errPhase) + } + + // Read key. + start := d.off - 1 + op = d.scanWhile(scanContinue) + item := d.data[start : d.off-1] + key, ok := unquote(item) + if !ok { + d.error(errPhase) + } + + // Check for duplicate keys. + _, ok = keys[key] + if !ok { + keys[key] = true + } else { + d.error(fmt.Errorf("json: duplicate key '%s' in object", key)) + } + + // Figure out field corresponding to key. + var subv reflect.Value + destring := false // whether the value is wrapped in a string to be decoded first + + if v.Kind() == reflect.Map { + elemType := v.Type().Elem() + if !mapElem.IsValid() { + mapElem = reflect.New(elemType).Elem() + } else { + mapElem.Set(reflect.Zero(elemType)) + } + subv = mapElem + } else { + var f *field + fields := cachedTypeFields(v.Type()) + for i := range fields { + ff := &fields[i] + if bytes.Equal(ff.nameBytes, []byte(key)) { + f = ff + break + } + } + if f != nil { + subv = v + destring = f.quoted + for _, i := range f.index { + if subv.Kind() == reflect.Ptr { + if subv.IsNil() { + subv.Set(reflect.New(subv.Type().Elem())) + } + subv = subv.Elem() + } + subv = subv.Field(i) + } + } + } + + // Read : before value. + if op == scanSkipSpace { + op = d.scanWhile(scanSkipSpace) + } + if op != scanObjectKey { + d.error(errPhase) + } + + // Read value. + if destring { + switch qv := d.valueQuoted().(type) { + case nil: + d.literalStore(nullLiteral, subv, false) + case string: + d.literalStore([]byte(qv), subv, true) + default: + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", subv.Type())) + } + } else { + d.value(subv) + } + + // Write value back to map; + // if using struct, subv points into struct already. + if v.Kind() == reflect.Map { + kv := reflect.ValueOf(key).Convert(v.Type().Key()) + v.SetMapIndex(kv, subv) + } + + // Next token must be , or }. + op = d.scanWhile(scanSkipSpace) + if op == scanEndObject { + break + } + if op != scanObjectValue { + d.error(errPhase) + } + } +} + +// literal consumes a literal from d.data[d.off-1:], decoding into the value v. +// The first byte of the literal has been read already +// (that's how the caller knows it's a literal). +func (d *decodeState) literal(v reflect.Value) { + // All bytes inside literal return scanContinue op code. + start := d.off - 1 + op := d.scanWhile(scanContinue) + + // Scan read one byte too far; back up. + d.off-- + d.scan.undo(op) + + d.literalStore(d.data[start:d.off], v, false) +} + +// convertNumber converts the number literal s to a float64 or a Number +// depending on the setting of d.useNumber. +func (d *decodeState) convertNumber(s string) (interface{}, error) { + if d.useNumber { + return Number(s), nil + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return nil, &UnmarshalTypeError{"number " + s, reflect.TypeOf(0.0), int64(d.off)} + } + return f, nil +} + +var numberType = reflect.TypeOf(Number("")) + +// literalStore decodes a literal stored in item into v. +// +// fromQuoted indicates whether this literal came from unwrapping a +// string from the ",string" struct tag option. this is used only to +// produce more helpful error messages. +func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool) { + // Check for unmarshaler. + if len(item) == 0 { + //Empty string given + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + return + } + wantptr := item[0] == 'n' // null + u, ut, pv := d.indirect(v, wantptr) + if u != nil { + err := u.UnmarshalJSON(item) + if err != nil { + d.error(err) + } + return + } + if ut != nil { + if item[0] != '"' { + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.saveError(&UnmarshalTypeError{"string", v.Type(), int64(d.off)}) + } + return + } + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(errPhase) + } + } + err := ut.UnmarshalText(s) + if err != nil { + d.error(err) + } + return + } + + v = pv + + switch c := item[0]; c { + case 'n': // null + switch v.Kind() { + case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice: + v.Set(reflect.Zero(v.Type())) + // otherwise, ignore null for primitives/string + } + case 't', 'f': // true, false + value := c == 't' + switch v.Kind() { + default: + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.saveError(&UnmarshalTypeError{"bool", v.Type(), int64(d.off)}) + } + case reflect.Bool: + v.SetBool(value) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(value)) + } else { + d.saveError(&UnmarshalTypeError{"bool", v.Type(), int64(d.off)}) + } + } + + case '"': // string + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(errPhase) + } + } + switch v.Kind() { + default: + d.saveError(&UnmarshalTypeError{"string", v.Type(), int64(d.off)}) + case reflect.Slice: + if v.Type().Elem().Kind() != reflect.Uint8 { + d.saveError(&UnmarshalTypeError{"string", v.Type(), int64(d.off)}) + break + } + b := make([]byte, base64.StdEncoding.DecodedLen(len(s))) + n, err := base64.StdEncoding.Decode(b, s) + if err != nil { + d.saveError(err) + break + } + v.SetBytes(b[:n]) + case reflect.String: + v.SetString(string(s)) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(string(s))) + } else { + d.saveError(&UnmarshalTypeError{"string", v.Type(), int64(d.off)}) + } + } + + default: // number + if c != '-' && (c < '0' || c > '9') { + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(errPhase) + } + } + s := string(item) + switch v.Kind() { + default: + if v.Kind() == reflect.String && v.Type() == numberType { + v.SetString(s) + if !isValidNumber(s) { + d.error(fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number", item)) + } + break + } + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(&UnmarshalTypeError{"number", v.Type(), int64(d.off)}) + } + case reflect.Interface: + n, err := d.convertNumber(s) + if err != nil { + d.saveError(err) + break + } + if v.NumMethod() != 0 { + d.saveError(&UnmarshalTypeError{"number", v.Type(), int64(d.off)}) + break + } + v.Set(reflect.ValueOf(n)) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(s, 10, 64) + if err != nil || v.OverflowInt(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type(), int64(d.off)}) + break + } + v.SetInt(n) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + n, err := strconv.ParseUint(s, 10, 64) + if err != nil || v.OverflowUint(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type(), int64(d.off)}) + break + } + v.SetUint(n) + + case reflect.Float32, reflect.Float64: + n, err := strconv.ParseFloat(s, v.Type().Bits()) + if err != nil || v.OverflowFloat(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type(), int64(d.off)}) + break + } + v.SetFloat(n) + } + } +} + +// The xxxInterface routines build up a value to be stored +// in an empty interface. They are not strictly necessary, +// but they avoid the weight of reflection in this common case. + +// valueInterface is like value but returns interface{} +func (d *decodeState) valueInterface() interface{} { + switch d.scanWhile(scanSkipSpace) { + default: + d.error(errPhase) + panic("unreachable") + case scanBeginArray: + return d.arrayInterface() + case scanBeginObject: + return d.objectInterface() + case scanBeginLiteral: + return d.literalInterface() + } +} + +// arrayInterface is like array but returns []interface{}. +func (d *decodeState) arrayInterface() []interface{} { + var v = make([]interface{}, 0) + for { + // Look ahead for ] - can only happen on first iteration. + op := d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + + // Back up so d.value can have the byte we just read. + d.off-- + d.scan.undo(op) + + v = append(v, d.valueInterface()) + + // Next token must be , or ]. + op = d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + if op != scanArrayValue { + d.error(errPhase) + } + } + return v +} + +// objectInterface is like object but returns map[string]interface{}. +func (d *decodeState) objectInterface() map[string]interface{} { + m := make(map[string]interface{}) + keys := map[string]bool{} + + for { + // Read opening " of string key or closing }. + op := d.scanWhile(scanSkipSpace) + if op == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if op != scanBeginLiteral { + d.error(errPhase) + } + + // Read string key. + start := d.off - 1 + op = d.scanWhile(scanContinue) + item := d.data[start : d.off-1] + key, ok := unquote(item) + if !ok { + d.error(errPhase) + } + + // Check for duplicate keys. + _, ok = keys[key] + if !ok { + keys[key] = true + } else { + d.error(fmt.Errorf("json: duplicate key '%s' in object", key)) + } + + // Read : before value. + if op == scanSkipSpace { + op = d.scanWhile(scanSkipSpace) + } + if op != scanObjectKey { + d.error(errPhase) + } + + // Read value. + m[key] = d.valueInterface() + + // Next token must be , or }. + op = d.scanWhile(scanSkipSpace) + if op == scanEndObject { + break + } + if op != scanObjectValue { + d.error(errPhase) + } + } + return m +} + +// literalInterface is like literal but returns an interface value. +func (d *decodeState) literalInterface() interface{} { + // All bytes inside literal return scanContinue op code. + start := d.off - 1 + op := d.scanWhile(scanContinue) + + // Scan read one byte too far; back up. + d.off-- + d.scan.undo(op) + item := d.data[start:d.off] + + switch c := item[0]; c { + case 'n': // null + return nil + + case 't', 'f': // true, false + return c == 't' + + case '"': // string + s, ok := unquote(item) + if !ok { + d.error(errPhase) + } + return s + + default: // number + if c != '-' && (c < '0' || c > '9') { + d.error(errPhase) + } + n, err := d.convertNumber(string(item)) + if err != nil { + d.saveError(err) + } + return n + } +} + +// getu4 decodes \uXXXX from the beginning of s, returning the hex value, +// or it returns -1. +func getu4(s []byte) rune { + if len(s) < 6 || s[0] != '\\' || s[1] != 'u' { + return -1 + } + r, err := strconv.ParseUint(string(s[2:6]), 16, 64) + if err != nil { + return -1 + } + return rune(r) +} + +// unquote converts a quoted JSON string literal s into an actual string t. +// The rules are different than for Go, so cannot use strconv.Unquote. +func unquote(s []byte) (t string, ok bool) { + s, ok = unquoteBytes(s) + t = string(s) + return +} + +func unquoteBytes(s []byte) (t []byte, ok bool) { + if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' { + return + } + s = s[1 : len(s)-1] + + // Check for unusual characters. If there are none, + // then no unquoting is needed, so return a slice of the + // original bytes. + r := 0 + for r < len(s) { + c := s[r] + if c == '\\' || c == '"' || c < ' ' { + break + } + if c < utf8.RuneSelf { + r++ + continue + } + rr, size := utf8.DecodeRune(s[r:]) + if rr == utf8.RuneError && size == 1 { + break + } + r += size + } + if r == len(s) { + return s, true + } + + b := make([]byte, len(s)+2*utf8.UTFMax) + w := copy(b, s[0:r]) + for r < len(s) { + // Out of room? Can only happen if s is full of + // malformed UTF-8 and we're replacing each + // byte with RuneError. + if w >= len(b)-2*utf8.UTFMax { + nb := make([]byte, (len(b)+utf8.UTFMax)*2) + copy(nb, b[0:w]) + b = nb + } + switch c := s[r]; { + case c == '\\': + r++ + if r >= len(s) { + return + } + switch s[r] { + default: + return + case '"', '\\', '/', '\'': + b[w] = s[r] + r++ + w++ + case 'b': + b[w] = '\b' + r++ + w++ + case 'f': + b[w] = '\f' + r++ + w++ + case 'n': + b[w] = '\n' + r++ + w++ + case 'r': + b[w] = '\r' + r++ + w++ + case 't': + b[w] = '\t' + r++ + w++ + case 'u': + r-- + rr := getu4(s[r:]) + if rr < 0 { + return + } + r += 6 + if utf16.IsSurrogate(rr) { + rr1 := getu4(s[r:]) + if dec := utf16.DecodeRune(rr, rr1); dec != unicode.ReplacementChar { + // A valid pair; consume. + r += 6 + w += utf8.EncodeRune(b[w:], dec) + break + } + // Invalid surrogate; fall back to replacement rune. + rr = unicode.ReplacementChar + } + w += utf8.EncodeRune(b[w:], rr) + } + + // Quote, control characters are invalid. + case c == '"', c < ' ': + return + + // ASCII + case c < utf8.RuneSelf: + b[w] = c + r++ + w++ + + // Coerce to well-formed UTF-8. + default: + rr, size := utf8.DecodeRune(s[r:]) + r += size + w += utf8.EncodeRune(b[w:], rr) + } + } + return b[0:w], true +} diff --git a/vendor/gopkg.in/square/go-jose.v2/json/encode.go b/vendor/gopkg.in/square/go-jose.v2/json/encode.go new file mode 100644 index 0000000000..1dae8bb7cd --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/json/encode.go @@ -0,0 +1,1197 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package json implements encoding and decoding of JSON objects as defined in +// RFC 4627. The mapping between JSON objects and Go values is described +// in the documentation for the Marshal and Unmarshal functions. +// +// See "JSON and Go" for an introduction to this package: +// https://golang.org/doc/articles/json_and_go.html +package json + +import ( + "bytes" + "encoding" + "encoding/base64" + "fmt" + "math" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "unicode" + "unicode/utf8" +) + +// Marshal returns the JSON encoding of v. +// +// Marshal traverses the value v recursively. +// If an encountered value implements the Marshaler interface +// and is not a nil pointer, Marshal calls its MarshalJSON method +// to produce JSON. If no MarshalJSON method is present but the +// value implements encoding.TextMarshaler instead, Marshal calls +// its MarshalText method. +// The nil pointer exception is not strictly necessary +// but mimics a similar, necessary exception in the behavior of +// UnmarshalJSON. +// +// Otherwise, Marshal uses the following type-dependent default encodings: +// +// Boolean values encode as JSON booleans. +// +// Floating point, integer, and Number values encode as JSON numbers. +// +// String values encode as JSON strings coerced to valid UTF-8, +// replacing invalid bytes with the Unicode replacement rune. +// The angle brackets "<" and ">" are escaped to "\u003c" and "\u003e" +// to keep some browsers from misinterpreting JSON output as HTML. +// Ampersand "&" is also escaped to "\u0026" for the same reason. +// +// Array and slice values encode as JSON arrays, except that +// []byte encodes as a base64-encoded string, and a nil slice +// encodes as the null JSON object. +// +// Struct values encode as JSON objects. Each exported struct field +// becomes a member of the object unless +// - the field's tag is "-", or +// - the field is empty and its tag specifies the "omitempty" option. +// The empty values are false, 0, any +// nil pointer or interface value, and any array, slice, map, or string of +// length zero. The object's default key string is the struct field name +// but can be specified in the struct field's tag value. The "json" key in +// the struct field's tag value is the key name, followed by an optional comma +// and options. Examples: +// +// // Field is ignored by this package. +// Field int `json:"-"` +// +// // Field appears in JSON as key "myName". +// Field int `json:"myName"` +// +// // Field appears in JSON as key "myName" and +// // the field is omitted from the object if its value is empty, +// // as defined above. +// Field int `json:"myName,omitempty"` +// +// // Field appears in JSON as key "Field" (the default), but +// // the field is skipped if empty. +// // Note the leading comma. +// Field int `json:",omitempty"` +// +// The "string" option signals that a field is stored as JSON inside a +// JSON-encoded string. It applies only to fields of string, floating point, +// integer, or boolean types. This extra level of encoding is sometimes used +// when communicating with JavaScript programs: +// +// Int64String int64 `json:",string"` +// +// The key name will be used if it's a non-empty string consisting of +// only Unicode letters, digits, dollar signs, percent signs, hyphens, +// underscores and slashes. +// +// Anonymous struct fields are usually marshaled as if their inner exported fields +// were fields in the outer struct, subject to the usual Go visibility rules amended +// as described in the next paragraph. +// An anonymous struct field with a name given in its JSON tag is treated as +// having that name, rather than being anonymous. +// An anonymous struct field of interface type is treated the same as having +// that type as its name, rather than being anonymous. +// +// The Go visibility rules for struct fields are amended for JSON when +// deciding which field to marshal or unmarshal. If there are +// multiple fields at the same level, and that level is the least +// nested (and would therefore be the nesting level selected by the +// usual Go rules), the following extra rules apply: +// +// 1) Of those fields, if any are JSON-tagged, only tagged fields are considered, +// even if there are multiple untagged fields that would otherwise conflict. +// 2) If there is exactly one field (tagged or not according to the first rule), that is selected. +// 3) Otherwise there are multiple fields, and all are ignored; no error occurs. +// +// Handling of anonymous struct fields is new in Go 1.1. +// Prior to Go 1.1, anonymous struct fields were ignored. To force ignoring of +// an anonymous struct field in both current and earlier versions, give the field +// a JSON tag of "-". +// +// Map values encode as JSON objects. +// The map's key type must be string; the map keys are used as JSON object +// keys, subject to the UTF-8 coercion described for string values above. +// +// Pointer values encode as the value pointed to. +// A nil pointer encodes as the null JSON object. +// +// Interface values encode as the value contained in the interface. +// A nil interface value encodes as the null JSON object. +// +// Channel, complex, and function values cannot be encoded in JSON. +// Attempting to encode such a value causes Marshal to return +// an UnsupportedTypeError. +// +// JSON cannot represent cyclic data structures and Marshal does not +// handle them. Passing cyclic structures to Marshal will result in +// an infinite recursion. +// +func Marshal(v interface{}) ([]byte, error) { + e := &encodeState{} + err := e.marshal(v) + if err != nil { + return nil, err + } + return e.Bytes(), nil +} + +// MarshalIndent is like Marshal but applies Indent to format the output. +func MarshalIndent(v interface{}, prefix, indent string) ([]byte, error) { + b, err := Marshal(v) + if err != nil { + return nil, err + } + var buf bytes.Buffer + err = Indent(&buf, b, prefix, indent) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// HTMLEscape appends to dst the JSON-encoded src with <, >, &, U+2028 and U+2029 +// characters inside string literals changed to \u003c, \u003e, \u0026, \u2028, \u2029 +// so that the JSON will be safe to embed inside HTML