diff --git a/.changelog/12583.txt b/.changelog/12583.txt new file mode 100644 index 0000000000..4b5dad9c0c --- /dev/null +++ b/.changelog/12583.txt @@ -0,0 +1,3 @@ +```release-note:feature +acl: Added an AWS IAM auth method that allows authenticating to Consul using AWS IAM identities +``` diff --git a/agent/consul/acl_authmethod.go b/agent/consul/acl_authmethod.go index 2e973c6a12..b901ce131d 100644 --- a/agent/consul/acl_authmethod.go +++ b/agent/consul/acl_authmethod.go @@ -8,6 +8,7 @@ import ( "github.com/hashicorp/go-bexpr" // register these as a builtin auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/awsauth" _ "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth" _ "github.com/hashicorp/consul/agent/consul/authmethod/ssoauth" ) diff --git a/agent/consul/authmethod/awsauth/aws.go b/agent/consul/authmethod/awsauth/aws.go new file mode 100644 index 0000000000..32320e3f74 --- /dev/null +++ b/agent/consul/authmethod/awsauth/aws.go @@ -0,0 +1,193 @@ +package awsauth + +import ( + "context" + "fmt" + + "github.com/hashicorp/consul/agent/consul/authmethod" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/internal/iamauth" + "github.com/hashicorp/go-hclog" +) + +const ( + authMethodType string = "aws-iam" + + IAMServerIDHeaderName string = "X-Consul-IAM-ServerID" + GetEntityMethodHeader string = "X-Consul-IAM-GetEntity-Method" + GetEntityURLHeader string = "X-Consul-IAM-GetEntity-URL" + GetEntityHeadersHeader string = "X-Consul-IAM-GetEntity-Headers" + GetEntityBodyHeader string = "X-Consul-IAM-GetEntity-Body" +) + +func init() { + // register this as an available auth method type + authmethod.Register(authMethodType, func(logger hclog.Logger, method *structs.ACLAuthMethod) (authmethod.Validator, error) { + v, err := NewValidator(logger, method) + if err != nil { + return nil, err + } + return v, nil + }) +} + +type Config struct { + // BoundIAMPrincipalARNs are the trusted AWS IAM principal ARNs that are permitted + // to login to the auth method. These can be the exact ARNs or wildcards. Wildcards + // are only supported if EnableIAMEntityDetails is true. + BoundIAMPrincipalARNs []string `json:",omitempty"` + + // EnableIAMEntityDetails will fetch the IAM User or IAM Role details to include + // in binding rules. Required if wildcard principal ARNs are used. + EnableIAMEntityDetails bool `json:",omitempty"` + + // IAMEntityTags are the specific IAM User or IAM Role tags to include as selectable + // fields in the binding rule attributes. Requires EnableIAMEntityDetails = true. + IAMEntityTags []string `json:",omitempty"` + + // ServerIDHeaderValue adds a X-Consul-IAM-ServerID header to each AWS API request. + // This helps protect against replay attacks. + ServerIDHeaderValue string `json:",omitempty"` + + // MaxRetries is the maximum number of retries on AWS API requests for recoverable errors. + MaxRetries int `json:",omitempty"` + // IAMEndpoint is the AWS IAM endpoint where iam:GetRole or iam:GetUser requests will be sent. + // Note that the Host header in a signed request cannot be changed. + IAMEndpoint string `json:",omitempty"` + // STSEndpoint is the AWS STS endpoint where sts:GetCallerIdentity requests will be sent. + // Note that the Host header in a signed request cannot be changed. + STSEndpoint string `json:",omitempty"` + // STSRegion is the region for the AWS STS service. This should only be set if STSEndpoint + // is set, and must match the region of the STSEndpoint. + STSRegion string `json:",omitempty"` + + // AllowedSTSHeaderValues is a list of additional allowed headers on the sts:GetCallerIdentity + // request in the bearer token. A default list of necessary headers is allowed in any case. + AllowedSTSHeaderValues []string `json:",omitempty"` +} + +func (c *Config) convertForLibrary() *iamauth.Config { + return &iamauth.Config{ + BoundIAMPrincipalARNs: c.BoundIAMPrincipalARNs, + EnableIAMEntityDetails: c.EnableIAMEntityDetails, + IAMEntityTags: c.IAMEntityTags, + ServerIDHeaderValue: c.ServerIDHeaderValue, + MaxRetries: c.MaxRetries, + IAMEndpoint: c.IAMEndpoint, + STSEndpoint: c.STSEndpoint, + STSRegion: c.STSRegion, + AllowedSTSHeaderValues: c.AllowedSTSHeaderValues, + + ServerIDHeaderName: IAMServerIDHeaderName, + GetEntityMethodHeader: GetEntityMethodHeader, + GetEntityURLHeader: GetEntityURLHeader, + GetEntityHeadersHeader: GetEntityHeadersHeader, + GetEntityBodyHeader: GetEntityBodyHeader, + } +} + +type Validator struct { + name string + config *iamauth.Config + logger hclog.Logger + + auth *iamauth.Authenticator +} + +func NewValidator(logger hclog.Logger, method *structs.ACLAuthMethod) (*Validator, error) { + if method.Type != authMethodType { + return nil, fmt.Errorf("%q is not an AWS IAM auth method", method.Name) + } + + var config Config + if err := authmethod.ParseConfig(method.Config, &config); err != nil { + return nil, err + } + iamConfig := config.convertForLibrary() + + auth, err := iamauth.NewAuthenticator(iamConfig, logger) + if err != nil { + return nil, err + } + + return &Validator{ + name: method.Name, + config: iamConfig, + logger: logger, + auth: auth, + }, nil +} + +// Name implements authmethod.Validator. +func (v *Validator) Name() string { return v.name } + +// Stop implements authmethod.Validator. +func (v *Validator) Stop() {} + +// ValidateLogin implements authmethod.Validator. +func (v *Validator) ValidateLogin(ctx context.Context, loginToken string) (*authmethod.Identity, error) { + details, err := v.auth.ValidateLogin(ctx, loginToken) + if err != nil { + return nil, err + } + + vars := map[string]string{ + "entity_name": details.EntityName, + "entity_id": details.EntityId, + "account_id": details.AccountId, + } + fields := &awsSelectableFields{ + EntityName: details.EntityName, + EntityId: details.EntityId, + AccountId: details.AccountId, + } + + if v.config.EnableIAMEntityDetails { + vars["entity_path"] = details.EntityPath + fields.EntityPath = details.EntityPath + fields.EntityTags = map[string]string{} + for _, tag := range v.config.IAMEntityTags { + vars["entity_tags."+tag] = details.EntityTags[tag] + fields.EntityTags[tag] = details.EntityTags[tag] + } + } + + result := &authmethod.Identity{ + SelectableFields: fields, + ProjectedVars: vars, + EnterpriseMeta: nil, + } + return result, nil + +} + +func (v *Validator) NewIdentity() *authmethod.Identity { + fields := &awsSelectableFields{ + EntityTags: map[string]string{}, + } + vars := map[string]string{ + "entity_name": "", + "entity_id": "", + "account_id": "", + } + if v.config.EnableIAMEntityDetails { + vars["entity_path"] = "" + for _, tag := range v.config.IAMEntityTags { + vars["entity_tags."+tag] = "" + fields.EntityTags[tag] = "" + } + } + return &authmethod.Identity{ + SelectableFields: fields, + ProjectedVars: vars, + } +} + +type awsSelectableFields struct { + EntityName string `bexpr:"entity_name"` + EntityId string `bexpr:"entity_id"` + AccountId string `bexpr:"account_id"` + + EntityPath string `bexpr:"entity_path"` + EntityTags map[string]string `bexpr:"entity_tags"` +} diff --git a/agent/consul/authmethod/awsauth/aws_test.go b/agent/consul/authmethod/awsauth/aws_test.go new file mode 100644 index 0000000000..8ee5076923 --- /dev/null +++ b/agent/consul/authmethod/awsauth/aws_test.go @@ -0,0 +1,342 @@ +package awsauth + +import ( + "context" + "encoding/json" + "fmt" + "net/http/httptest" + "testing" + + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/hashicorp/consul/agent/consul/authmethod" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/internal/iamauth" + "github.com/hashicorp/consul/internal/iamauth/iamauthtest" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/require" +) + +func TestNewValidator(t *testing.T) { + f := iamauthtest.MakeFixture() + expConfig := &iamauth.Config{ + BoundIAMPrincipalARNs: []string{f.AssumedRoleARN}, + EnableIAMEntityDetails: true, + IAMEntityTags: []string{"tag-1"}, + ServerIDHeaderValue: "x-some-header", + MaxRetries: 3, + IAMEndpoint: "iam-endpoint", + STSEndpoint: "sts-endpoint", + STSRegion: "sts-region", + AllowedSTSHeaderValues: []string{"header-value"}, + ServerIDHeaderName: "X-Consul-IAM-ServerID", + GetEntityMethodHeader: "X-Consul-IAM-GetEntity-Method", + GetEntityURLHeader: "X-Consul-IAM-GetEntity-URL", + GetEntityHeadersHeader: "X-Consul-IAM-GetEntity-Headers", + GetEntityBodyHeader: "X-Consul-IAM-GetEntity-Body", + } + + type AM = *structs.ACLAuthMethod + // Create the auth method, with an optional modification function. + makeMethod := func(modifyFn func(AM)) AM { + config := map[string]interface{}{ + "BoundIAMPrincipalARNs": []string{f.AssumedRoleARN}, + "EnableIAMEntityDetails": true, + "IAMEntityTags": []string{"tag-1"}, + "ServerIDHeaderValue": "x-some-header", + "MaxRetries": 3, + "IAMEndpoint": "iam-endpoint", + "STSEndpoint": "sts-endpoint", + "STSRegion": "sts-region", + "AllowedSTSHeaderValues": []string{"header-value"}, + } + + m := &structs.ACLAuthMethod{ + Name: "test-iam", + Type: "aws-iam", + Description: "aws iam auth", + Config: config, + } + if modifyFn != nil { + modifyFn(m) + } + return m + } + + cases := map[string]struct { + ok bool + modifyFn func(AM) + }{ + "success": {true, nil}, + "wrong type": {false, func(m AM) { m.Type = "not-iam" }}, + "extra config": {false, func(m AM) { m.Config["extraField"] = "123" }}, + "wrong config value type": {false, func(m AM) { m.Config["MaxRetries"] = []string{"1"} }}, + "missing bound principals": {false, func(m AM) { delete(m.Config, "BoundIAMPrincipalARNs") }}, + } + for name, c := range cases { + t.Run(name, func(t *testing.T) { + v, err := NewValidator(nil, makeMethod(c.modifyFn)) + if c.ok { + require.NoError(t, err) + require.NotNil(t, v) + require.Equal(t, "test-iam", v.name) + require.NotNil(t, v.auth) + require.Equal(t, expConfig, v.config) + } else { + require.Error(t, err) + require.Nil(t, v) + } + }) + } +} + +func TestValidateLogin(t *testing.T) { + f := iamauthtest.MakeFixture() + + cases := map[string]struct { + server *iamauthtest.Server + token string + config map[string]interface{} + expVars map[string]string + expFields []string + expError string + }{ + "success - role login": { + server: f.ServerForRole, + config: map[string]interface{}{ + "BoundIAMPrincipalARNs": []string{f.CanonicalRoleARN}, + }, + expVars: map[string]string{ + "entity_id": f.EntityID, + "entity_name": f.RoleName, + "account_id": f.AccountID, + }, + expFields: []string{ + fmt.Sprintf(`entity_id == %q`, f.EntityID), + fmt.Sprintf(`entity_name == %q`, f.RoleName), + fmt.Sprintf(`account_id == %q`, f.AccountID), + }, + }, + "success - user login": { + server: f.ServerForUser, + config: map[string]interface{}{ + "BoundIAMPrincipalARNs": []string{f.UserARN}, + }, + expVars: map[string]string{ + "entity_id": f.EntityID, + "entity_name": f.UserName, + "account_id": f.AccountID, + }, + expFields: []string{ + fmt.Sprintf(`entity_id == %q`, f.EntityID), + fmt.Sprintf(`entity_name == %q`, f.UserName), + fmt.Sprintf(`account_id == %q`, f.AccountID), + }, + }, + "success - role login with entity details": { + server: f.ServerForUser, + config: map[string]interface{}{ + "BoundIAMPrincipalARNs": []string{f.UserARN}, + "EnableIAMEntityDetails": true, + }, + expVars: map[string]string{ + "entity_id": f.EntityID, + "entity_name": f.UserName, + "account_id": f.AccountID, + "entity_path": f.UserPath, + }, + expFields: []string{ + fmt.Sprintf(`entity_id == %q`, f.EntityID), + fmt.Sprintf(`entity_name == %q`, f.UserName), + fmt.Sprintf(`account_id == %q`, f.AccountID), + fmt.Sprintf(`entity_path == %q`, f.UserPath), + }, + }, + "success - user login with entity details": { + server: f.ServerForUser, + config: map[string]interface{}{ + "BoundIAMPrincipalARNs": []string{f.UserARN}, + "EnableIAMEntityDetails": true, + }, + expVars: map[string]string{ + "entity_id": f.EntityID, + "entity_name": f.UserName, + "account_id": f.AccountID, + "entity_path": f.UserPath, + }, + expFields: []string{ + fmt.Sprintf(`entity_id == %q`, f.EntityID), + fmt.Sprintf(`entity_name == %q`, f.UserName), + fmt.Sprintf(`account_id == %q`, f.AccountID), + fmt.Sprintf(`entity_path == %q`, f.UserPath), + }, + }, + "invalid token": { + server: f.ServerForUser, + config: map[string]interface{}{ + "BoundIAMPrincipalARNs": []string{f.UserARN}, + }, + token: `invalid`, + expError: "invalid token", + }, + "empty json token": { + server: f.ServerForUser, + config: map[string]interface{}{ + "BoundIAMPrincipalARNs": []string{f.UserARN}, + }, + token: `{}`, + expError: "invalid token", + }, + "empty json fields in token": { + server: f.ServerForUser, + config: map[string]interface{}{ + "BoundIAMPrincipalARNs": []string{f.UserARN}, + }, + token: `{"iam_http_request_method": "", +"iam_request_body": "", +"iam_request_headers": "", +"iam_request_url": "" +}`, + expError: "invalid token", + }, + } + for name, c := range cases { + t.Run(name, func(t *testing.T) { + v, _, token := setup(t, c.config, c.server) + if c.token != "" { + token = c.token + } + id, err := v.ValidateLogin(context.Background(), token) + if c.expError != "" { + require.Error(t, err) + require.Contains(t, err.Error(), c.expError) + require.Nil(t, id) + } else { + require.NoError(t, err) + authmethod.RequireIdentityMatch(t, id, c.expVars, c.expFields...) + } + }) + } +} + +func setup(t *testing.T, config map[string]interface{}, server *iamauthtest.Server) (*Validator, *httptest.Server, string) { + t.Helper() + + fakeAws := iamauthtest.NewTestServer(t, server) + + config["STSEndpoint"] = fakeAws.URL + "/sts" + config["STSRegion"] = "fake-region" + config["IAMEndpoint"] = fakeAws.URL + "/iam" + + method := &structs.ACLAuthMethod{ + Name: "test-method", + Type: "aws-iam", + Config: config, + } + nullLogger := hclog.NewNullLogger() + v, err := NewValidator(nullLogger, method) + require.NoError(t, err) + + // Generate the login token + tokenData, err := iamauth.GenerateLoginData(&iamauth.LoginInput{ + Creds: credentials.NewStaticCredentials("fake", "fake", ""), + IncludeIAMEntity: v.config.EnableIAMEntityDetails, + STSEndpoint: v.config.STSEndpoint, + STSRegion: v.config.STSRegion, + Logger: nullLogger, + ServerIDHeaderValue: v.config.ServerIDHeaderValue, + ServerIDHeaderName: v.config.ServerIDHeaderName, + GetEntityMethodHeader: v.config.GetEntityMethodHeader, + GetEntityURLHeader: v.config.GetEntityURLHeader, + GetEntityHeadersHeader: v.config.GetEntityHeadersHeader, + GetEntityBodyHeader: v.config.GetEntityBodyHeader, + }) + require.NoError(t, err) + + token, err := json.Marshal(tokenData) + require.NoError(t, err) + return v, fakeAws, string(token) +} + +func TestNewIdentity(t *testing.T) { + principals := []string{"arn:aws:sts::1234567890:assumed-role/my-role/some-session"} + cases := map[string]struct { + config map[string]interface{} + expVars map[string]string + expFilters []string + }{ + "entity details disabled": { + config: map[string]interface{}{ + "BoundIAMPrincipalARNs": principals, + }, + expVars: map[string]string{ + "entity_name": "", + "entity_id": "", + "account_id": "", + }, + expFilters: []string{ + `entity_name == ""`, + `entity_id == ""`, + `account_id == ""`, + }, + }, + "entity details enabled": { + config: map[string]interface{}{ + "BoundIAMPrincipalARNs": principals, + "EnableIAMEntityDetails": true, + }, + expVars: map[string]string{ + "entity_name": "", + "entity_id": "", + "account_id": "", + "entity_path": "", + }, + expFilters: []string{ + `entity_name == ""`, + `entity_id == ""`, + `account_id == ""`, + `entity_path == ""`, + }, + }, + "entity tags": { + config: map[string]interface{}{ + "BoundIAMPrincipalARNs": principals, + "EnableIAMEntityDetails": true, + "IAMEntityTags": []string{ + "test_tag", + "test_tag_2", + }, + }, + expVars: map[string]string{ + "entity_name": "", + "entity_id": "", + "account_id": "", + "entity_path": "", + "entity_tags.test_tag": "", + "entity_tags.test_tag_2": "", + }, + expFilters: []string{ + `entity_name == ""`, + `entity_id == ""`, + `account_id == ""`, + `entity_path == ""`, + `entity_tags.test_tag == ""`, + `entity_tags.test_tag_2 == ""`, + }, + }, + } + for name, c := range cases { + t.Run(name, func(t *testing.T) { + method := &structs.ACLAuthMethod{ + Name: "test-method", + Type: "aws-iam", + Config: c.config, + } + nullLogger := hclog.NewNullLogger() + v, err := NewValidator(nullLogger, method) + require.NoError(t, err) + + id := v.NewIdentity() + authmethod.RequireIdentityMatch(t, id, c.expVars, c.expFilters...) + }) + } +} diff --git a/command/login/aws.go b/command/login/aws.go new file mode 100644 index 0000000000..bae90c9439 --- /dev/null +++ b/command/login/aws.go @@ -0,0 +1,148 @@ +package login + +import ( + "encoding/json" + "flag" + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + + "github.com/hashicorp/consul/agent/consul/authmethod/awsauth" + "github.com/hashicorp/consul/internal/iamauth" + "github.com/hashicorp/go-hclog" +) + +type AWSLogin struct { + autoBearerToken bool + includeEntity bool + stsEndpoint string + region string + serverIDHeaderValue string + accessKeyId string + secretAccessKey string + sessionToken string +} + +func (a *AWSLogin) flags() *flag.FlagSet { + fs := flag.NewFlagSet("", flag.ContinueOnError) + fs.BoolVar(&a.autoBearerToken, "aws-auto-bearer-token", false, + "Construct a bearer token and login to the AWS IAM auth method. This requires AWS credentials. "+ + "AWS credentials are automatically discovered from standard sources supported by the Go SDK for "+ + "AWS. Alternatively, explicit credentials can be passed using the -aws-acesss-key-id and "+ + "-aws-secret-access-key flags. [aws-iam only]") + + fs.BoolVar(&a.includeEntity, "aws-include-entity", false, + "Include a signed request to get the IAM role or IAM user in the bearer token. [aws-iam only]") + + fs.StringVar(&a.stsEndpoint, "aws-sts-endpoint", "", + "URL for AWS STS API calls. [aws-iam only]") + + fs.StringVar(&a.region, "aws-region", "", + "Region for AWS API calls. If set, should match the region of -aws-sts-endpoint. "+ + "If not provided, the region will be discovered from standard sources, such as "+ + "the AWS_REGION environment variable. [aws-iam only]") + + fs.StringVar(&a.serverIDHeaderValue, "aws-server-id-header-value", "", + "If set, an X-Consul-IAM-ServerID header is included in signed AWS API request(s) that form "+ + "the bearer token. This value must match the server-side configured value for the auth method "+ + "in order to login. This is optional and helps protect against replay attacks. [aws-iam only]") + + fs.StringVar(&a.accessKeyId, "aws-access-key-id", "", + "AWS access key id to use. Requires -aws-secret-access-key if specified. [aws-iam only]") + + fs.StringVar(&a.secretAccessKey, "aws-secret-access-key", "", + "AWS secret access key to use. Requires -aws-access-key-id if specified. [aws-iam only]") + + fs.StringVar(&a.sessionToken, "aws-session-token", "", + "AWS session token to use. Requires -aws-access-key-id and -aws-secret-access-key if "+ + "specified. [aws-iam only]") + return fs +} + +// checkFlags validates flags for the aws-iam auth method. +func (a *AWSLogin) checkFlags() error { + if !a.autoBearerToken { + if a.includeEntity || a.stsEndpoint != "" || a.region != "" || a.serverIDHeaderValue != "" || + a.accessKeyId != "" || a.secretAccessKey != "" || a.sessionToken != "" { + return fmt.Errorf("Missing '-aws-auto-bearer-token' flag") + } + } + if a.accessKeyId != "" && a.secretAccessKey == "" { + return fmt.Errorf("Missing '-aws-secret-access-key' flag") + } + if a.secretAccessKey != "" && a.accessKeyId == "" { + return fmt.Errorf("Missing '-aws-access-key-id' flag") + } + if a.sessionToken != "" && (a.accessKeyId == "" || a.secretAccessKey == "") { + return fmt.Errorf("Missing '-aws-access-key-id' and '-aws-secret-access-key' flags") + } + return nil +} + +// createAWSBearerToken generates a bearer token string for the AWS IAM auth method. +// It will discover AWS credentials which are used to sign AWS API requests. +// Alternatively, static credentials can be passed as flags. +// +// The bearer token contains a signed sts:GetCallerIdentity request. +// If aws-include-entity is specified, a signed iam:GetRole or iam:GetUser request is +// also included. The AWS credentials are used to retrieve the current user's role +// or user name for the iam:GetRole or iam:GetUser request. +func (a *AWSLogin) createAWSBearerToken() (string, error) { + cfg := aws.Config{ + Endpoint: aws.String(a.stsEndpoint), + Region: aws.String(a.region), + // More detailed error message to help debug credential discovery. + CredentialsChainVerboseErrors: aws.Bool(true), + } + + if a.accessKeyId != "" { + // Use creds from flags. + cfg.Credentials = credentials.NewStaticCredentials( + a.accessKeyId, a.secretAccessKey, a.sessionToken, + ) + } + + // Session loads creds from standard sources (env vars, file, EC2 metadata, ...) + sess, err := session.NewSessionWithOptions(session.Options{ + Config: cfg, + // Allow loading from config files by default: + // ~/.aws/config or AWS_CONFIG_FILE + // ~/.aws/credentials or AWS_SHARED_CREDENTIALS_FILE + SharedConfigState: session.SharedConfigEnable, + }) + if err != nil { + return "", err + } + if sess.Config.Region == nil || *sess.Config.Region == "" { + return "", fmt.Errorf("AWS region not found") + } + if sess.Config.Credentials == nil { + return "", fmt.Errorf("AWS credentials not found") + } + creds := sess.Config.Credentials + + loginData, err := iamauth.GenerateLoginData(&iamauth.LoginInput{ + Creds: creds, + IncludeIAMEntity: a.includeEntity, + STSEndpoint: a.stsEndpoint, + STSRegion: a.region, + Logger: hclog.New(nil), + ServerIDHeaderValue: a.serverIDHeaderValue, + ServerIDHeaderName: awsauth.IAMServerIDHeaderName, + GetEntityMethodHeader: awsauth.GetEntityMethodHeader, + GetEntityURLHeader: awsauth.GetEntityURLHeader, + GetEntityHeadersHeader: awsauth.GetEntityHeadersHeader, + GetEntityBodyHeader: awsauth.GetEntityBodyHeader, + }) + if err != nil { + return "", err + } + + loginDataJson, err := json.Marshal(loginData) + if err != nil { + return "", err + } + return string(loginDataJson), err +} diff --git a/command/login/login.go b/command/login/login.go index ded0958f9c..a8f58556ac 100644 --- a/command/login/login.go +++ b/command/login/login.go @@ -36,6 +36,8 @@ type cmd struct { tokenSinkFile string meta map[string]string + aws AWSLogin + enterpriseCmd } @@ -57,10 +59,10 @@ func (c *cmd) init() { c.flags.Var((*flags.FlagMapValue)(&c.meta), "meta", "Metadata to set on the token, formatted as key=value. This flag "+ "may be specified multiple times to set multiple meta fields.") - c.initEnterpriseFlags() c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.aws.flags()) flags.Merge(c.flags, c.http.ClientFlags()) flags.Merge(c.flags, c.http.ServerFlags()) flags.Merge(c.flags, c.http.MultiTenancyFlags()) @@ -89,21 +91,38 @@ func (c *cmd) Run(args []string) int { } func (c *cmd) bearerTokenLogin() int { - if c.bearerTokenFile == "" { - c.UI.Error(fmt.Sprintf("Missing required '-bearer-token-file' flag")) - return 1 - } - - data, err := ioutil.ReadFile(c.bearerTokenFile) - if err != nil { + if err := c.aws.checkFlags(); err != nil { c.UI.Error(err.Error()) return 1 } - c.bearerToken = strings.TrimSpace(string(data)) - if c.bearerToken == "" { - c.UI.Error(fmt.Sprintf("No bearer token found in %s", c.bearerTokenFile)) + if c.aws.autoBearerToken { + if c.bearerTokenFile != "" { + c.UI.Error("Cannot use '-bearer-token-file' flag with '-aws-auto-bearer-token'") + return 1 + } + + if token, err := c.aws.createAWSBearerToken(); err != nil { + c.UI.Error(fmt.Sprintf("Error with aws-iam auth method: %s", err)) + return 1 + } else { + c.bearerToken = token + } + } else if c.bearerTokenFile == "" { + c.UI.Error("Missing required '-bearer-token-file' flag") return 1 + } else { + data, err := ioutil.ReadFile(c.bearerTokenFile) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + c.bearerToken = strings.TrimSpace(string(data)) + + if c.bearerToken == "" { + c.UI.Error(fmt.Sprintf("No bearer token found in %s", c.bearerTokenFile)) + return 1 + } } // Ensure that we don't try to use a token when performing a login diff --git a/command/login/login_test.go b/command/login/login_test.go index 3d730548d6..7eba6a4037 100644 --- a/command/login/login_test.go +++ b/command/login/login_test.go @@ -1,6 +1,7 @@ package login import ( + "fmt" "io/ioutil" "os" "path/filepath" @@ -18,6 +19,7 @@ import ( "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/command/acl" "github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest" + "github.com/hashicorp/consul/internal/iamauth/iamauthtest" "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/testrpc" ) @@ -39,18 +41,7 @@ func TestLoginCommand(t *testing.T) { testDir := testutil.TempDir(t, "acl") - a := agent.NewTestAgent(t, ` - primary_datacenter = "dc1" - acl { - enabled = true - tokens { - initial_management = "root" - } - }`) - - defer a.Shutdown() - testrpc.WaitForLeader(t, a.RPC, "dc1") - + a := newTestAgent(t) client := a.Client() t.Run("method is required", func(t *testing.T) { @@ -102,6 +93,81 @@ func TestLoginCommand(t *testing.T) { require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bearer-token-file' flag") }) + t.Run("bearer-token-file disallowed with aws-auto-bearer-token", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", "none.txt", + "-aws-auto-bearer-token", + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "Cannot use '-bearer-token-file' flag with '-aws-auto-bearer-token'") + }) + + t.Run("aws flags require aws-auto-bearer-token", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + baseArgs := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + } + + for _, extraArgs := range [][]string{ + {"-aws-include-entity"}, + {"-aws-sts-endpoint", "some-endpoint"}, + {"-aws-region", "some-region"}, + {"-aws-server-id-header-value", "some-value"}, + {"-aws-access-key-id", "some-key"}, + {"-aws-secret-access-key", "some-secret"}, + {"-aws-session-token", "some-token"}, + } { + ui := cli.NewMockUi() + code := New(ui).Run(append(baseArgs, extraArgs...)) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "Missing '-aws-auto-bearer-token' flag") + } + }) + + t.Run("aws-access-key-id and aws-secret-access-key require each other", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + baseArgs := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + "-aws-auto-bearer-token", + } + + ui := cli.NewMockUi() + code := New(ui).Run(append(baseArgs, "-aws-access-key-id", "some-key")) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "Missing '-aws-secret-access-key' flag") + + ui = cli.NewMockUi() + code = New(ui).Run(append(baseArgs, "-aws-secret-access-key", "some-key")) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "Missing '-aws-access-key-id' flag") + + ui = cli.NewMockUi() + code = New(ui).Run(append(baseArgs, "-aws-session-token", "some-token")) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), + "Missing '-aws-access-key-id' and '-aws-secret-access-key' flags") + + }) + bearerTokenFile := filepath.Join(testDir, "bearer.token") t.Run("bearer-token-file is empty", func(t *testing.T) { @@ -236,18 +302,7 @@ func TestLoginCommand_k8s(t *testing.T) { testDir := testutil.TempDir(t, "acl") - a := agent.NewTestAgent(t, ` - primary_datacenter = "dc1" - acl { - enabled = true - tokens { - initial_management = "root" - } - }`) - - defer a.Shutdown() - testrpc.WaitForLeader(t, a.RPC, "dc1") - + a := newTestAgent(t) client := a.Client() tokenSinkFile := filepath.Join(testDir, "test.token") @@ -334,18 +389,7 @@ func TestLoginCommand_jwt(t *testing.T) { testDir := testutil.TempDir(t, "acl") - a := agent.NewTestAgent(t, ` - primary_datacenter = "dc1" - acl { - enabled = true - tokens { - initial_management = "root" - } - }`) - - defer a.Shutdown() - testrpc.WaitForLeader(t, a.RPC, "dc1") - + a := newTestAgent(t) client := a.Client() tokenSinkFile := filepath.Join(testDir, "test.token") @@ -470,3 +514,178 @@ func TestLoginCommand_jwt(t *testing.T) { }) } } + +func TestLoginCommand_aws_iam(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + t.Parallel() + + // Formats an HIL template for a BindName, and the expected value for entity tags. + // Input: string{"a", "b"}, []string{"1", "2"} + // Return: "${entity_tags.a}-${entity_tags.b}", "1-2" + entityTagsBind := func(keys, values []string) (string, string) { + parts := []string{} + for _, k := range keys { + parts = append(parts, fmt.Sprintf("${entity_tags.%s}", k)) + } + return strings.Join(parts, "-"), strings.Join(values, "-") + } + + f := iamauthtest.MakeFixture() + roleTagsBindName, roleTagsBindValue := entityTagsBind(f.RoleTagKeys(), f.RoleTagValues()) + userTagsBindName, userTagsBindValue := entityTagsBind(f.UserTagKeys(), f.UserTagValues()) + + cases := map[string]struct { + awsServer *iamauthtest.Server + cmdArgs []string + config map[string]interface{} + bindingRule *api.ACLBindingRule + expServiceIdentity *api.ACLServiceIdentity + }{ + "success - login with role": { + awsServer: f.ServerForRole, + cmdArgs: []string{"-aws-auto-bearer-token"}, + config: map[string]interface{}{ + // Test that an assumed-role arn is translated to the canonical role arn. + "BoundIAMPrincipalARNs": []string{f.CanonicalRoleARN}, + }, + bindingRule: &api.ACLBindingRule{ + BindType: api.BindingRuleBindTypeService, + BindName: "${entity_name}-${entity_id}-${account_id}", + Selector: fmt.Sprintf(`entity_name==%q and entity_id==%q and account_id==%q`, + f.RoleName, f.EntityID, f.AccountID), + }, + expServiceIdentity: &api.ACLServiceIdentity{ + ServiceName: fmt.Sprintf("%s-%s-%s", f.RoleName, strings.ToLower(f.EntityID), f.AccountID), + }, + }, + "success - login with role and entity details enabled": { + awsServer: f.ServerForRole, + cmdArgs: []string{"-aws-auto-bearer-token", "-aws-include-entity"}, + config: map[string]interface{}{ + // Test that we can login with full user path. + "BoundIAMPrincipalARNs": []string{f.RoleARN}, + "EnableIAMEntityDetails": true, + }, + bindingRule: &api.ACLBindingRule{ + BindType: api.BindingRuleBindTypeService, + // TODO: Path cannot be used as service name if it contains a '/' + BindName: "${entity_name}", + Selector: fmt.Sprintf(`entity_name==%q and entity_path==%q`, f.RoleName, f.RolePath), + }, + expServiceIdentity: &api.ACLServiceIdentity{ServiceName: f.RoleName}, + }, + "success - login with role and role tags": { + awsServer: f.ServerForRole, + cmdArgs: []string{"-aws-auto-bearer-token", "-aws-include-entity"}, + config: map[string]interface{}{ + // Test that we can login with a wildcard. + "BoundIAMPrincipalARNs": []string{f.RoleARNWildcard}, + "EnableIAMEntityDetails": true, + "IAMEntityTags": f.RoleTagKeys(), + }, + bindingRule: &api.ACLBindingRule{ + BindType: api.BindingRuleBindTypeService, + BindName: roleTagsBindName, + Selector: fmt.Sprintf(`entity_name==%q and entity_path==%q`, f.RoleName, f.RolePath), + }, + expServiceIdentity: &api.ACLServiceIdentity{ServiceName: roleTagsBindValue}, + }, + "success - login with user and user tags": { + awsServer: f.ServerForUser, + cmdArgs: []string{"-aws-auto-bearer-token", "-aws-include-entity"}, + config: map[string]interface{}{ + // Test that we can login with a wildcard. + "BoundIAMPrincipalARNs": []string{f.UserARNWildcard}, + "EnableIAMEntityDetails": true, + "IAMEntityTags": f.UserTagKeys(), + }, + bindingRule: &api.ACLBindingRule{ + BindType: api.BindingRuleBindTypeService, + BindName: "${entity_name}-" + userTagsBindName, + Selector: fmt.Sprintf(`entity_name==%q and entity_path==%q`, f.UserName, f.UserPath), + }, + expServiceIdentity: &api.ACLServiceIdentity{ + ServiceName: fmt.Sprintf("%s-%s", f.UserName, userTagsBindValue), + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + a := newTestAgent(t) + client := a.Client() + + fakeAws := iamauthtest.NewTestServer(t, c.awsServer) + + c.config["STSEndpoint"] = fakeAws.URL + "/sts" + c.config["IAMEndpoint"] = fakeAws.URL + "/iam" + + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "iam-test", + Type: "aws-iam", + Config: c.config, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + c.bindingRule.AuthMethod = "iam-test" + _, _, err = client.ACL().BindingRuleCreate( + c.bindingRule, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + testDir := testutil.TempDir(t, "acl") + tokenSinkFile := filepath.Join(testDir, "test.token") + t.Cleanup(func() { _ = os.Remove(tokenSinkFile) }) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=iam-test", + "-token-sink-file", tokenSinkFile, + "-aws-sts-endpoint", fakeAws.URL + "/sts", + "-aws-region", "fake-region", + "-aws-access-key-id", "fake-key-id", + "-aws-secret-access-key", "fake-secret-key", + } + args = append(args, c.cmdArgs...) + code := cmd.Run(args) + require.Equal(t, 0, code, ui.ErrorWriter.String()) + + raw, err := ioutil.ReadFile(tokenSinkFile) + require.NoError(t, err) + + token := strings.TrimSpace(string(raw)) + require.Len(t, token, 36, "must be a valid uid: %s", token) + + // Validate correct BindName was interpolated. + tokenRead, _, err := client.ACL().TokenReadSelf(&api.QueryOptions{Token: token}) + require.NoError(t, err) + require.Len(t, tokenRead.ServiceIdentities, 1) + require.Equal(t, c.expServiceIdentity, tokenRead.ServiceIdentities[0]) + + }) + } +} + +func newTestAgent(t *testing.T) *agent.TestAgent { + a := agent.NewTestAgent(t, ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + initial_management = "root" + } + }`) + t.Cleanup(func() { _ = a.Shutdown() }) + testrpc.WaitForLeader(t, a.RPC, "dc1") + return a +} diff --git a/go.mod b/go.mod index 456623d40d..e99a098ba2 100644 --- a/go.mod +++ b/go.mod @@ -44,7 +44,7 @@ require ( github.com/hashicorp/go-memdb v1.3.2 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-raftchunking v0.6.2 - github.com/hashicorp/go-retryablehttp v0.6.7 // indirect + github.com/hashicorp/go-retryablehttp v0.6.7 github.com/hashicorp/go-sockaddr v1.0.2 github.com/hashicorp/go-syslog v1.0.0 github.com/hashicorp/go-uuid v1.0.2 diff --git a/internal/iamauth/README.md b/internal/iamauth/README.md new file mode 100644 index 0000000000..a9880a3559 --- /dev/null +++ b/internal/iamauth/README.md @@ -0,0 +1,2 @@ +This is an internal package to house the AWS IAM auth method utilities for potential +future extraction from Consul. diff --git a/internal/iamauth/auth.go b/internal/iamauth/auth.go new file mode 100644 index 0000000000..aaf6bc6579 --- /dev/null +++ b/internal/iamauth/auth.go @@ -0,0 +1,311 @@ +package iamauth + +import ( + "context" + "encoding/xml" + "fmt" + "io/ioutil" + "net/http" + "regexp" + "strings" + "time" + + "github.com/hashicorp/consul/internal/iamauth/responses" + "github.com/hashicorp/consul/lib" + "github.com/hashicorp/consul/lib/stringslice" + "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-retryablehttp" +) + +const ( + // Retry configuration + retryWaitMin = 500 * time.Millisecond + retryWaitMax = 30 * time.Second +) + +type Authenticator struct { + config *Config + logger hclog.Logger +} + +type IdentityDetails struct { + EntityName string + EntityId string + AccountId string + + EntityPath string + EntityTags map[string]string +} + +func NewAuthenticator(config *Config, logger hclog.Logger) (*Authenticator, error) { + if err := config.Validate(); err != nil { + return nil, err + } + return &Authenticator{ + config: config, + logger: logger, + }, nil +} + +// ValidateLogin determines if the identity in the loginToken is permitted to login. +// If so, it returns details about the identity. Otherwise, an error is returned. +func (a *Authenticator) ValidateLogin(ctx context.Context, loginToken string) (*IdentityDetails, error) { + token, err := NewBearerToken(loginToken, a.config) + if err != nil { + return nil, err + } + + req, err := token.GetCallerIdentityRequest() + if err != nil { + return nil, err + } + + if a.config.ServerIDHeaderValue != "" { + err := validateHeaderValue(req.Header, a.config.ServerIDHeaderName, a.config.ServerIDHeaderValue) + if err != nil { + return nil, err + } + } + + callerIdentity, err := a.submitCallerIdentityRequest(ctx, req) + if err != nil { + return nil, err + } + a.logger.Debug("iamauth login attempt", "arn", callerIdentity.Arn) + + entity, err := responses.ParseArn(callerIdentity.Arn) + if err != nil { + return nil, err + } + + identityDetails := &IdentityDetails{ + EntityName: entity.FriendlyName, + // This could either be a "userID:SessionID" (in the case of an assumed role) or just a "userID" + // (in the case of an IAM user). + EntityId: strings.Split(callerIdentity.UserId, ":")[0], + AccountId: callerIdentity.Account, + } + clientArn := entity.CanonicalArn() + + // Fetch the IAM Role or IAM User, if configured. + // This requires the token to contain a signed iam:GetRole or iam:GetUser request. + if a.config.EnableIAMEntityDetails { + iamReq, err := token.GetEntityRequest() + if err != nil { + return nil, err + } + + if a.config.ServerIDHeaderValue != "" { + err := validateHeaderValue(iamReq.Header, a.config.ServerIDHeaderName, a.config.ServerIDHeaderValue) + if err != nil { + return nil, err + } + } + + iamEntityDetails, err := a.submitGetIAMEntityRequest(ctx, iamReq, token.entityRequestType) + if err != nil { + return nil, err + } + + // Only the CallerIdentity response is a guarantee of the client's identity. + // The role/user details must have a unique id match to the CallerIdentity before use. + if iamEntityDetails.EntityId() != identityDetails.EntityId { + return nil, fmt.Errorf("unique id mismatch in login token") + } + + // Use the full ARN with path from the Role/User details + clientArn = iamEntityDetails.EntityArn() + identityDetails.EntityPath = iamEntityDetails.EntityPath() + identityDetails.EntityTags = iamEntityDetails.EntityTags() + } + + if err := a.validateIdentity(clientArn); err != nil { + return nil, err + } + return identityDetails, nil +} + +// https://github.com/hashicorp/vault/blob/ba533d006f2244103648785ebfe8a9a9763d2b6e/builtin/credential/aws/path_login.go#L1321-L1361 +func (a *Authenticator) validateIdentity(clientArn string) error { + if stringslice.Contains(a.config.BoundIAMPrincipalARNs, clientArn) { + // Matches one of BoundIAMPrincipalARNs, so it is trusted + return nil + } + if a.config.EnableIAMEntityDetails { + for _, principalArn := range a.config.BoundIAMPrincipalARNs { + if strings.HasSuffix(principalArn, "*") && lib.GlobbedStringsMatch(principalArn, clientArn) { + // Wildcard match, so it is trusted + return nil + } + } + } + return fmt.Errorf("IAM principal %s is not trusted", clientArn) +} + +func (a *Authenticator) submitCallerIdentityRequest(ctx context.Context, req *http.Request) (*responses.GetCallerIdentityResult, error) { + responseBody, err := a.submitRequest(ctx, req) + if err != nil { + return nil, err + } + callerIdentityResponse, err := parseGetCallerIdentityResponse(responseBody) + if err != nil { + return nil, fmt.Errorf("error parsing STS response") + } + + if n := len(callerIdentityResponse.GetCallerIdentityResult); n != 1 { + return nil, fmt.Errorf("received %d identities in STS response but expected 1", n) + } + return &callerIdentityResponse.GetCallerIdentityResult[0], nil +} + +func (a *Authenticator) submitGetIAMEntityRequest(ctx context.Context, req *http.Request, reqType string) (responses.IAMEntity, error) { + responseBody, err := a.submitRequest(ctx, req) + if err != nil { + return nil, err + } + iamResponse, err := parseGetIAMEntityResponse(responseBody, reqType) + if err != nil { + return nil, fmt.Errorf("error parsing IAM response: %s", err) + } + return iamResponse, nil + +} + +// https://github.com/hashicorp/vault/blob/b17e3256dde937a6248c9a2fa56206aac93d07de/builtin/credential/aws/path_login.go#L1636 +func (a *Authenticator) submitRequest(ctx context.Context, req *http.Request) (string, error) { + retryableReq, err := retryablehttp.FromRequest(req) + if err != nil { + return "", err + } + retryableReq = retryableReq.WithContext(ctx) + client := cleanhttp.DefaultClient() + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + retryingClient := &retryablehttp.Client{ + HTTPClient: client, + RetryWaitMin: retryWaitMin, + RetryWaitMax: retryWaitMax, + RetryMax: a.config.MaxRetries, + CheckRetry: retryablehttp.DefaultRetryPolicy, + Backoff: retryablehttp.DefaultBackoff, + } + + response, err := retryingClient.Do(retryableReq) + if err != nil { + return "", fmt.Errorf("error making request: %w", err) + } + if response != nil { + defer response.Body.Close() + } + // Validate that the response type is XML + if ct := response.Header.Get("Content-Type"); ct != "text/xml" { + return "", fmt.Errorf("response body is invalid") + } + + // we check for status code afterwards to also print out response body + responseBody, err := ioutil.ReadAll(response.Body) + if err != nil { + return "", err + } + if response.StatusCode != 200 { + return "", fmt.Errorf("received error code %d: %s", response.StatusCode, string(responseBody)) + } + return string(responseBody), nil + +} + +// https://github.com/hashicorp/vault/blob/ba533d006f2244103648785ebfe8a9a9763d2b6e/builtin/credential/aws/path_login.go#L1625-L1634 +func parseGetCallerIdentityResponse(response string) (responses.GetCallerIdentityResponse, error) { + result := responses.GetCallerIdentityResponse{} + response = strings.TrimSpace(response) + if !strings.HasPrefix(response, " 2 { + return fmt.Errorf("found multiple SignedHeaders components") + } + signedHeaders := string(matches[1]) + return ensureHeaderIsSigned(signedHeaders, headerName) + } + // NOTE: If we support GET requests, then we need to parse the X-Amz-SignedHeaders + // argument out of the query string and search in there for the header value + return fmt.Errorf("missing Authorization header") +} + +func ensureHeaderIsSigned(signedHeaders, headerToSign string) error { + // Not doing a constant time compare here, the values aren't secret + for _, header := range strings.Split(signedHeaders, ";") { + if header == strings.ToLower(headerToSign) { + return nil + } + } + return fmt.Errorf("header wasn't signed") +} diff --git a/internal/iamauth/auth_test.go b/internal/iamauth/auth_test.go new file mode 100644 index 0000000000..736c3203a0 --- /dev/null +++ b/internal/iamauth/auth_test.go @@ -0,0 +1,123 @@ +package iamauth + +import ( + "context" + "encoding/json" + "testing" + + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/hashicorp/consul/internal/iamauth/iamauthtest" + "github.com/hashicorp/consul/internal/iamauth/responsestest" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/require" +) + +func TestValidateLogin(t *testing.T) { + f := iamauthtest.MakeFixture() + + var ( + serverForRoleMismatchedIds = &iamauthtest.Server{ + GetCallerIdentityResponse: f.ServerForRole.GetCallerIdentityResponse, + GetRoleResponse: responsestest.MakeGetRoleResponse(f.RoleARN, "AAAAsomenonmatchingid"), + } + serverForUserMismatchedIds = &iamauthtest.Server{ + GetCallerIdentityResponse: f.ServerForUser.GetCallerIdentityResponse, + GetUserResponse: responsestest.MakeGetUserResponse(f.UserARN, "AAAAsomenonmatchingid"), + } + ) + + cases := map[string]struct { + config *Config + server *iamauthtest.Server + expIdent *IdentityDetails + expError string + }{ + "no bound principals": { + expError: "not trusted", + server: f.ServerForRole, + config: &Config{}, + }, + "no matching principal": { + expError: "not trusted", + server: f.ServerForUser, + config: &Config{ + BoundIAMPrincipalARNs: []string{ + "arn:aws:iam::1234567890:user/some-other-role", + "arn:aws:iam::1234567890:user/some-other-user", + }, + }, + }, + "mismatched server id header": { + expError: `expected "some-non-matching-value" but got "server.id.example.com"`, + server: f.ServerForRole, + config: &Config{ + BoundIAMPrincipalARNs: []string{f.CanonicalRoleARN}, + ServerIDHeaderValue: "some-non-matching-value", + ServerIDHeaderName: "X-Test-ServerID", + }, + }, + "role unique id mismatch": { + expError: "unique id mismatch in login token", + // The RoleId in the GetRole response must match the UserId in the GetCallerIdentity response + // during login. If not, the RoleId cannot be used. + server: serverForRoleMismatchedIds, + config: &Config{ + BoundIAMPrincipalARNs: []string{f.RoleARN}, + EnableIAMEntityDetails: true, + }, + }, + "user unique id mismatch": { + expError: "unique id mismatch in login token", + server: serverForUserMismatchedIds, + config: &Config{ + BoundIAMPrincipalARNs: []string{f.UserARN}, + EnableIAMEntityDetails: true, + }, + }, + } + logger := hclog.New(nil) + for name, c := range cases { + t.Run(name, func(t *testing.T) { + fakeAws := iamauthtest.NewTestServer(t, c.server) + + c.config.STSEndpoint = fakeAws.URL + "/sts" + c.config.IAMEndpoint = fakeAws.URL + "/iam" + setTestHeaderNames(c.config) + + // This bypasses NewAuthenticator, which bypasses config.Validate(). + auth := &Authenticator{config: c.config, logger: logger} + + loginInput := &LoginInput{ + Creds: credentials.NewStaticCredentials("fake", "fake", ""), + IncludeIAMEntity: c.config.EnableIAMEntityDetails, + STSEndpoint: c.config.STSEndpoint, + STSRegion: "fake-region", + Logger: logger, + ServerIDHeaderValue: "server.id.example.com", + } + setLoginInputHeaderNames(loginInput) + loginData, err := GenerateLoginData(loginInput) + require.NoError(t, err) + loginBytes, err := json.Marshal(loginData) + require.NoError(t, err) + + ident, err := auth.ValidateLogin(context.Background(), string(loginBytes)) + if c.expError != "" { + require.Error(t, err) + require.Contains(t, err.Error(), c.expError) + require.Nil(t, ident) + } else { + require.NoError(t, err) + require.Equal(t, c.expIdent, ident) + } + }) + } +} + +func setLoginInputHeaderNames(in *LoginInput) { + in.ServerIDHeaderName = "X-Test-ServerID" + in.GetEntityMethodHeader = "X-Test-Method" + in.GetEntityURLHeader = "X-Test-URL" + in.GetEntityHeadersHeader = "X-Test-Headers" + in.GetEntityBodyHeader = "X-Test-Body" +} diff --git a/internal/iamauth/config.go b/internal/iamauth/config.go new file mode 100644 index 0000000000..a8a6b61d51 --- /dev/null +++ b/internal/iamauth/config.go @@ -0,0 +1,69 @@ +package iamauth + +import ( + "fmt" + "strings" + + awsArn "github.com/aws/aws-sdk-go/aws/arn" +) + +type Config struct { + BoundIAMPrincipalARNs []string + EnableIAMEntityDetails bool + IAMEntityTags []string + ServerIDHeaderValue string + MaxRetries int + IAMEndpoint string + STSEndpoint string + STSRegion string + AllowedSTSHeaderValues []string + + // Customizable header names + ServerIDHeaderName string + GetEntityMethodHeader string + GetEntityURLHeader string + GetEntityHeadersHeader string + GetEntityBodyHeader string +} + +func (c *Config) Validate() error { + if len(c.BoundIAMPrincipalARNs) == 0 { + return fmt.Errorf("BoundIAMPrincipalARNs is required and must have at least 1 entry") + } + + for _, arn := range c.BoundIAMPrincipalARNs { + if n := strings.Count(arn, "*"); n > 0 { + if !c.EnableIAMEntityDetails { + return fmt.Errorf("Must set EnableIAMEntityDetails=true to use wildcards in BoundIAMPrincipalARNs") + } + if n != 1 || !strings.HasSuffix(arn, "*") { + return fmt.Errorf("Only one wildcard is allowed at the end of the bound IAM principal ARN") + } + } + + if parsed, err := awsArn.Parse(arn); err != nil { + return fmt.Errorf("Invalid principal ARN: %q", arn) + } else if parsed.Service != "iam" && parsed.Service != "sts" { + return fmt.Errorf("Invalid principal ARN: %q", arn) + } + } + + if len(c.IAMEntityTags) > 0 && !c.EnableIAMEntityDetails { + return fmt.Errorf("Must set EnableIAMEntityDetails=true to use IAMUserTags") + } + + // If server id header checking is enabled, we need the header name. + if c.ServerIDHeaderValue != "" && c.ServerIDHeaderName == "" { + return fmt.Errorf("Must set ServerIDHeaderName to use a server ID value") + } + + if c.EnableIAMEntityDetails && (c.GetEntityBodyHeader == "" || + c.GetEntityHeadersHeader == "" || + c.GetEntityMethodHeader == "" || + c.GetEntityURLHeader == "") { + return fmt.Errorf("Must set all of GetEntityMethodHeader, GetEntityURLHeader, " + + "GetEntityHeadersHeader, and GetEntityBodyHeader when EnableIAMEntityDetails=true") + } + + return nil +} diff --git a/internal/iamauth/config_test.go b/internal/iamauth/config_test.go new file mode 100644 index 0000000000..d23dc992ae --- /dev/null +++ b/internal/iamauth/config_test.go @@ -0,0 +1,150 @@ +package iamauth + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConfigValidate(t *testing.T) { + principalArn := "arn:aws:iam::000000000000:role/my-role" + + cases := map[string]struct { + expError string + configs []Config + + includeHeaderNames bool + }{ + "bound iam principals are required": { + expError: "BoundIAMPrincipalARNs is required and must have at least 1 entry", + configs: []Config{ + {BoundIAMPrincipalARNs: nil}, + {BoundIAMPrincipalARNs: []string{}}, + }, + }, + "entity tags require entity details": { + expError: "Must set EnableIAMEntityDetails=true to use IAMUserTags", + configs: []Config{ + { + BoundIAMPrincipalARNs: []string{principalArn}, + EnableIAMEntityDetails: false, + IAMEntityTags: []string{"some-tag"}, + }, + }, + }, + "entity details require all entity header names": { + expError: "Must set all of GetEntityMethodHeader, GetEntityURLHeader, " + + "GetEntityHeadersHeader, and GetEntityBodyHeader when EnableIAMEntityDetails=true", + configs: []Config{ + { + BoundIAMPrincipalARNs: []string{principalArn}, + EnableIAMEntityDetails: true, + }, + { + BoundIAMPrincipalARNs: []string{principalArn}, + EnableIAMEntityDetails: true, + GetEntityBodyHeader: "X-Test-Header", + }, + { + BoundIAMPrincipalARNs: []string{principalArn}, + EnableIAMEntityDetails: true, + GetEntityHeadersHeader: "X-Test-Header", + }, + { + BoundIAMPrincipalARNs: []string{principalArn}, + EnableIAMEntityDetails: true, + GetEntityURLHeader: "X-Test-Header", + }, + { + BoundIAMPrincipalARNs: []string{principalArn}, + EnableIAMEntityDetails: true, + GetEntityMethodHeader: "X-Test-Header", + }, + }, + }, + "wildcard principals require entity details": { + expError: "Must set EnableIAMEntityDetails=true to use wildcards in BoundIAMPrincipalARNs", + configs: []Config{ + {BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/*"}}, + {BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/path/*"}}, + }, + }, + "only one wildcard suffix is allowed": { + expError: "Only one wildcard is allowed at the end of the bound IAM principal ARN", + configs: []Config{ + { + BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/**"}, + EnableIAMEntityDetails: true, + }, + { + BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/*/*"}, + EnableIAMEntityDetails: true, + }, + { + BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/*/path"}, + EnableIAMEntityDetails: true, + }, + { + BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/*/path/*"}, + EnableIAMEntityDetails: true, + }, + }, + }, + "invalid principal arns are disallowed": { + expError: fmt.Sprintf("Invalid principal ARN"), + configs: []Config{ + {BoundIAMPrincipalARNs: []string{""}}, + {BoundIAMPrincipalARNs: []string{" "}}, + {BoundIAMPrincipalARNs: []string{"*"}, EnableIAMEntityDetails: true}, + {BoundIAMPrincipalARNs: []string{"arn:aws:iam:role/my-role"}}, + }, + }, + "valid principal arns are allowed": { + includeHeaderNames: true, + configs: []Config{ + {BoundIAMPrincipalARNs: []string{"arn:aws:sts::000000000000:assumed-role/my-role/some-session-name"}}, + {BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:user/my-user"}}, + {BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/my-role"}}, + {BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:*"}, EnableIAMEntityDetails: true}, + {BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/*"}, EnableIAMEntityDetails: true}, + {BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/path/*"}, EnableIAMEntityDetails: true}, + {BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:user/*"}, EnableIAMEntityDetails: true}, + {BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:user/path/*"}, EnableIAMEntityDetails: true}, + }, + }, + "server id header value requires service id header name": { + expError: "Must set ServerIDHeaderName to use a server ID value", + configs: []Config{ + { + BoundIAMPrincipalARNs: []string{principalArn}, + ServerIDHeaderValue: "consul.test.example.com", + }, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + for _, conf := range c.configs { + if c.includeHeaderNames { + setTestHeaderNames(&conf) + } + err := conf.Validate() + if c.expError != "" { + require.Error(t, err) + require.Contains(t, err.Error(), c.expError) + } else { + require.NoError(t, err) + } + } + }) + } +} + +func setTestHeaderNames(conf *Config) { + conf.GetEntityMethodHeader = "X-Test-Method" + conf.GetEntityURLHeader = "X-Test-URL" + conf.GetEntityHeadersHeader = "X-Test-Headers" + conf.GetEntityBodyHeader = "X-Test-Body" +} diff --git a/internal/iamauth/iamauthtest/testing.go b/internal/iamauth/iamauthtest/testing.go new file mode 100644 index 0000000000..4cb8519a92 --- /dev/null +++ b/internal/iamauth/iamauthtest/testing.go @@ -0,0 +1,187 @@ +package iamauthtest + +import ( + "encoding/xml" + "fmt" + "io" + "net/http" + "net/http/httptest" + "sort" + "strings" + "testing" + + "github.com/hashicorp/consul/internal/iamauth/responses" + "github.com/hashicorp/consul/internal/iamauth/responsestest" +) + +// NewTestServer returns a fake AWS API server for local tests: +// It supports the following paths: +// /sts returns STS API responses +// /iam returns IAM API responses +func NewTestServer(t *testing.T, s *Server) *httptest.Server { + server := httptest.NewUnstartedServer(s) + t.Cleanup(server.Close) + server.Start() + return server +} + +// Server contains configuration for the fake AWS API server. +type Server struct { + GetCallerIdentityResponse responses.GetCallerIdentityResponse + GetRoleResponse responses.GetRoleResponse + GetUserResponse responses.GetUserResponse +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + writeError(w, http.StatusBadRequest, r) + return + } + + switch { + case strings.HasPrefix(r.URL.Path, "/sts"): + writeXML(w, s.GetCallerIdentityResponse) + case strings.HasPrefix(r.URL.Path, "/iam"): + if bodyBytes, err := io.ReadAll(r.Body); err == nil { + body := string(bodyBytes) + switch { + case strings.Contains(body, "Action=GetRole"): + writeXML(w, s.GetRoleResponse) + return + case strings.Contains(body, "Action=GetUser"): + writeXML(w, s.GetUserResponse) + return + } + } + writeError(w, http.StatusBadRequest, r) + default: + writeError(w, http.StatusNotFound, r) + } +} + +func writeXML(w http.ResponseWriter, val interface{}) { + str, err := xml.MarshalIndent(val, "", " ") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, err.Error()) + return + } + w.Header().Add("Content-Type", "text/xml") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, string(str)) +} + +func writeError(w http.ResponseWriter, code int, r *http.Request) { + w.WriteHeader(code) + msg := fmt.Sprintf("%s %s", r.Method, r.URL) + fmt.Fprintf(w, ` + + Fake AWS Server Error: %s + +`, msg) +} + +type Fixture struct { + AssumedRoleARN string + CanonicalRoleARN string + RoleARN string + RoleARNWildcard string + RoleName string + RolePath string + RoleTags map[string]string + + EntityID string + EntityIDWithSession string + AccountID string + + UserARN string + UserARNWildcard string + UserName string + UserPath string + UserTags map[string]string + + ServerForRole *Server + ServerForUser *Server +} + +func MakeFixture() Fixture { + f := Fixture{ + AssumedRoleARN: "arn:aws:sts::1234567890:assumed-role/my-role/some-session", + CanonicalRoleARN: "arn:aws:iam::1234567890:role/my-role", + RoleARN: "arn:aws:iam::1234567890:role/some/path/my-role", + RoleARNWildcard: "arn:aws:iam::1234567890:role/some/path/*", + RoleName: "my-role", + RolePath: "some/path", + RoleTags: map[string]string{ + "service-name": "my-service", + "env": "my-env", + }, + + EntityID: "AAAsomeuniqueid", + EntityIDWithSession: "AAAsomeuniqueid:some-session", + AccountID: "1234567890", + + UserARN: "arn:aws:iam::1234567890:user/my-user", + UserARNWildcard: "arn:aws:iam::1234567890:user/*", + UserName: "my-user", + UserPath: "", + UserTags: map[string]string{"user-group": "my-group"}, + } + + f.ServerForRole = &Server{ + GetCallerIdentityResponse: responsestest.MakeGetCallerIdentityResponse( + f.AssumedRoleARN, f.EntityIDWithSession, f.AccountID, + ), + GetRoleResponse: responsestest.MakeGetRoleResponse( + f.RoleARN, f.EntityID, toTags(f.RoleTags)..., + ), + } + + f.ServerForUser = &Server{ + GetCallerIdentityResponse: responsestest.MakeGetCallerIdentityResponse( + f.UserARN, f.EntityID, f.AccountID, + ), + GetUserResponse: responsestest.MakeGetUserResponse( + f.UserARN, f.EntityID, toTags(f.UserTags)..., + ), + } + + return f +} + +func (f *Fixture) RoleTagKeys() []string { return keys(f.RoleTags) } +func (f *Fixture) UserTagKeys() []string { return keys(f.UserTags) } +func (f *Fixture) RoleTagValues() []string { return values(f.RoleTags) } +func (f *Fixture) UserTagValues() []string { return values(f.UserTags) } + +// toTags converts the map to a slice of responses.Tag +func toTags(tags map[string]string) []responses.Tag { + result := []responses.Tag{} + for k, v := range tags { + result = append(result, responses.Tag{ + Key: k, + Value: v, + }) + } + return result + +} + +// keys returns the keys in sorted order +func keys(tags map[string]string) []string { + result := []string{} + for k := range tags { + result = append(result, k) + } + sort.Strings(result) + return result +} + +// values returns values in tags, ordered by sorted keys +func values(tags map[string]string) []string { + result := []string{} + for _, k := range keys(tags) { // ensures sorted by key + result = append(result, tags[k]) + } + return result +} diff --git a/internal/iamauth/responses/arn.go b/internal/iamauth/responses/arn.go new file mode 100644 index 0000000000..ea5e541d30 --- /dev/null +++ b/internal/iamauth/responses/arn.go @@ -0,0 +1,94 @@ +package responses + +import ( + "fmt" + "strings" +) + +// https://github.com/hashicorp/vault/blob/ba533d006f2244103648785ebfe8a9a9763d2b6e/builtin/credential/aws/path_login.go#L1722-L1744 +type ParsedArn struct { + Partition string + AccountNumber string + Type string + Path string + FriendlyName string + SessionInfo string +} + +// https://github.com/hashicorp/vault/blob/ba533d006f2244103648785ebfe8a9a9763d2b6e/builtin/credential/aws/path_login.go#L1482-L1530 +// However, instance profiles are not support in Consul. +func ParseArn(iamArn string) (*ParsedArn, error) { + // iamArn should look like one of the following: + // 1. arn:aws:iam:::/ + // 2. arn:aws:sts:::assumed-role// + // if we get something like 2, then we want to transform that back to what + // most people would expect, which is arn:aws:iam:::role/ + var entity ParsedArn + fullParts := strings.Split(iamArn, ":") + if len(fullParts) != 6 { + return nil, fmt.Errorf("unrecognized arn: contains %d colon-separated parts, expected 6", len(fullParts)) + } + if fullParts[0] != "arn" { + return nil, fmt.Errorf("unrecognized arn: does not begin with \"arn:\"") + } + // normally aws, but could be aws-cn or aws-us-gov + entity.Partition = fullParts[1] + if entity.Partition == "" { + return nil, fmt.Errorf("unrecognized arn: %q is missing the partition", iamArn) + } + if fullParts[2] != "iam" && fullParts[2] != "sts" { + return nil, fmt.Errorf("unrecognized service: %v, not one of iam or sts", fullParts[2]) + } + // fullParts[3] is the region, which doesn't matter for AWS IAM entities + entity.AccountNumber = fullParts[4] + if entity.AccountNumber == "" { + return nil, fmt.Errorf("unrecognized arn: %q is missing the account number", iamArn) + } + // fullParts[5] would now be something like user/ or assumed-role// + parts := strings.Split(fullParts[5], "/") + if len(parts) < 2 { + return nil, fmt.Errorf("unrecognized arn: %q contains fewer than 2 slash-separated parts", fullParts[5]) + } + entity.Type = parts[0] + entity.Path = strings.Join(parts[1:len(parts)-1], "/") + entity.FriendlyName = parts[len(parts)-1] + // now, entity.FriendlyName should either be or + switch entity.Type { + case "assumed-role": + // Check for three parts for assumed role ARNs + if len(parts) < 3 { + return nil, fmt.Errorf("unrecognized arn: %q contains fewer than 3 slash-separated parts", fullParts[5]) + } + // Assumed roles don't have paths and have a slightly different format + // parts[2] is + entity.Path = "" + entity.FriendlyName = parts[1] + entity.SessionInfo = parts[2] + case "user": + case "role": + // case "instance-profile": + default: + return nil, fmt.Errorf("unrecognized principal type: %q", entity.Type) + } + + if entity.FriendlyName == "" { + return nil, fmt.Errorf("unrecognized arn: %q is missing the resource name", iamArn) + } + + return &entity, nil +} + +// CanonicalArn returns the canonical ARN for referring to an IAM entity +func (p *ParsedArn) CanonicalArn() string { + entityType := p.Type + // canonicalize "assumed-role" into "role" + if entityType == "assumed-role" { + entityType = "role" + } + // Annoyingly, the assumed-role entity type doesn't have the Path of the role which was assumed + // So, we "canonicalize" it by just completely dropping the path. The other option would be to + // make an AWS API call to look up the role by FriendlyName, which introduces more complexity to + // code and test, and it also breaks backwards compatibility in an area where we would really want + // it + return fmt.Sprintf("arn:%s:iam::%s:%s/%s", p.Partition, p.AccountNumber, entityType, p.FriendlyName) +} diff --git a/internal/iamauth/responses/responses.go b/internal/iamauth/responses/responses.go new file mode 100644 index 0000000000..e050b77342 --- /dev/null +++ b/internal/iamauth/responses/responses.go @@ -0,0 +1,92 @@ +package responses + +import "encoding/xml" + +type GetCallerIdentityResponse struct { + XMLName xml.Name `xml:"GetCallerIdentityResponse"` + GetCallerIdentityResult []GetCallerIdentityResult `xml:"GetCallerIdentityResult"` + ResponseMetadata []ResponseMetadata `xml:"ResponseMetadata"` +} + +type GetCallerIdentityResult struct { + Arn string `xml:"Arn"` + UserId string `xml:"UserId"` + Account string `xml:"Account"` +} + +type ResponseMetadata struct { + RequestId string `xml:"RequestId"` +} + +// IAMEntity is an interface for getting details from an IAM Role or User. +type IAMEntity interface { + EntityPath() string + EntityArn() string + EntityName() string + EntityId() string + EntityTags() map[string]string +} + +var _ IAMEntity = (*Role)(nil) +var _ IAMEntity = (*User)(nil) + +type GetRoleResponse struct { + XMLName xml.Name `xml:"GetRoleResponse"` + GetRoleResult []GetRoleResult `xml:"GetRoleResult"` + ResponseMetadata []ResponseMetadata `xml:"ResponseMetadata"` +} + +type GetRoleResult struct { + Role Role `xml:"Role"` +} + +type Role struct { + Arn string `xml:"Arn"` + Path string `xml:"Path"` + RoleId string `xml:"RoleId"` + RoleName string `xml:"RoleName"` + Tags []Tag `xml:"Tags"` +} + +func (r *Role) EntityPath() string { return r.Path } +func (r *Role) EntityArn() string { return r.Arn } +func (r *Role) EntityName() string { return r.RoleName } +func (r *Role) EntityId() string { return r.RoleId } +func (r *Role) EntityTags() map[string]string { return tagsToMap(r.Tags) } + +type GetUserResponse struct { + XMLName xml.Name `xml:"GetUserResponse"` + GetUserResult []GetUserResult `xml:"GetUserResult"` + ResponseMetadata []ResponseMetadata `xml:"ResponseMetadata"` +} + +type GetUserResult struct { + User User `xml:"User"` +} + +type User struct { + Arn string `xml:"Arn"` + Path string `xml:"Path"` + UserId string `xml:"UserId"` + UserName string `xml:"UserName"` + Tags []Tag `xml:"Tags"` +} + +func (u *User) EntityPath() string { return u.Path } +func (u *User) EntityArn() string { return u.Arn } +func (u *User) EntityName() string { return u.UserName } +func (u *User) EntityId() string { return u.UserId } +func (u *User) EntityTags() map[string]string { return tagsToMap(u.Tags) } + +type Tag struct { + Key string `xml:"Key"` + Value string `xml:"Value"` +} + +func tagsToMap(tags []Tag) map[string]string { + result := map[string]string{} + for _, tag := range tags { + result[tag.Key] = tag.Value + } + return result +} diff --git a/internal/iamauth/responses/responses_test.go b/internal/iamauth/responses/responses_test.go new file mode 100644 index 0000000000..df4a9c1e33 --- /dev/null +++ b/internal/iamauth/responses/responses_test.go @@ -0,0 +1,157 @@ +package responses + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseArn(t *testing.T) { + cases := map[string]struct { + arn string + expArn *ParsedArn + }{ + "assumed-role": { + arn: "arn:aws:sts::000000000000:assumed-role/my-role/session-name", + expArn: &ParsedArn{ + Partition: "aws", + AccountNumber: "000000000000", + Type: "assumed-role", + Path: "", + FriendlyName: "my-role", + SessionInfo: "session-name", + }, + }, + "role": { + arn: "arn:aws:iam::000000000000:role/my-role", + expArn: &ParsedArn{ + Partition: "aws", + AccountNumber: "000000000000", + Type: "role", + Path: "", + FriendlyName: "my-role", + SessionInfo: "", + }, + }, + "user": { + arn: "arn:aws:iam::000000000000:user/my-user", + expArn: &ParsedArn{ + Partition: "aws", + AccountNumber: "000000000000", + Type: "user", + Path: "", + FriendlyName: "my-user", + SessionInfo: "", + }, + }, + "role with path": { + arn: "arn:aws:iam::000000000000:role/path/my-role", + expArn: &ParsedArn{ + Partition: "aws", + AccountNumber: "000000000000", + Type: "role", + Path: "path", + FriendlyName: "my-role", + SessionInfo: "", + }, + }, + "role with path 2": { + arn: "arn:aws:iam::000000000000:role/path/to/my-role", + expArn: &ParsedArn{ + Partition: "aws", + AccountNumber: "000000000000", + Type: "role", + Path: "path/to", + FriendlyName: "my-role", + SessionInfo: "", + }, + }, + "role with path 3": { + arn: "arn:aws:iam::000000000000:role/some/path/to/my-role", + expArn: &ParsedArn{ + Partition: "aws", + AccountNumber: "000000000000", + Type: "role", + Path: "some/path/to", + FriendlyName: "my-role", + SessionInfo: "", + }, + }, + "user with path": { + arn: "arn:aws:iam::000000000000:user/path/my-user", + expArn: &ParsedArn{ + Partition: "aws", + AccountNumber: "000000000000", + Type: "user", + Path: "path", + FriendlyName: "my-user", + SessionInfo: "", + }, + }, + + // Invalid cases + "empty string": {arn: ""}, + "wildcard": {arn: "*"}, + "missing prefix": {arn: ":aws:sts::000000000000:assumed-role/my-role/session-name"}, + "missing partition": {arn: "arn::sts::000000000000:assumed-role/my-role/session-name"}, + "missing service": {arn: "arn:aws:::000000000000:assumed-role/my-role/session-name"}, + "missing separator": {arn: "arn:aws:sts:000000000000:assumed-role/my-role/session-name"}, + "missing account id": {arn: "arn:aws:sts:::assumed-role/my-role/session-name"}, + "missing resource": {arn: "arn:aws:sts::000000000000:"}, + "assumed-role missing parts": {arn: "arn:aws:sts::000000000000:assumed-role/my-role"}, + "role missing parts": {arn: "arn:aws:sts::000000000000:role"}, + "role missing parts 2": {arn: "arn:aws:sts::000000000000:role/"}, + "user missing parts": {arn: "arn:aws:sts::000000000000:user"}, + "user missing parts 2": {arn: "arn:aws:sts::000000000000:user/"}, + "unsupported service": {arn: "arn:aws:ecs:us-east-1:000000000000:task/my-task/00000000000000000000000000000000"}, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + parsed, err := ParseArn(c.arn) + if c.expArn != nil { + require.NoError(t, err) + require.Equal(t, c.expArn, parsed) + } else { + require.Error(t, err) + require.Nil(t, parsed) + } + }) + } +} + +func TestCanonicalArn(t *testing.T) { + cases := map[string]struct { + arn string + expArn string + }{ + "assumed-role arn": { + arn: "arn:aws:sts::000000000000:assumed-role/my-role/session-name", + expArn: "arn:aws:iam::000000000000:role/my-role", + }, + "role arn": { + arn: "arn:aws:iam::000000000000:role/my-role", + expArn: "arn:aws:iam::000000000000:role/my-role", + }, + "role arn with path": { + arn: "arn:aws:iam::000000000000:role/path/to/my-role", + expArn: "arn:aws:iam::000000000000:role/my-role", + }, + "user arn": { + arn: "arn:aws:iam::000000000000:user/my-user", + expArn: "arn:aws:iam::000000000000:user/my-user", + }, + "user arn with path": { + arn: "arn:aws:iam::000000000000:user/path/to/my-user", + expArn: "arn:aws:iam::000000000000:user/my-user", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + parsed, err := ParseArn(c.arn) + require.NoError(t, err) + require.Equal(t, c.expArn, parsed.CanonicalArn()) + }) + } +} diff --git a/internal/iamauth/responsestest/testing.go b/internal/iamauth/responsestest/testing.go new file mode 100644 index 0000000000..683308677b --- /dev/null +++ b/internal/iamauth/responsestest/testing.go @@ -0,0 +1,81 @@ +package responsestest + +import ( + "strings" + + "github.com/hashicorp/consul/internal/iamauth/responses" +) + +func MakeGetCallerIdentityResponse(arn, userId, accountId string) responses.GetCallerIdentityResponse { + // Sanity check the UserId for unit tests. + parsed := parseArn(arn) + switch parsed.Type { + case "assumed-role": + if !strings.Contains(userId, ":") { + panic("UserId for assumed-role in GetCallerIdentity response must be ':'") + } + default: + if strings.Contains(userId, ":") { + panic("UserId in GetCallerIdentity must not contain ':'") + } + } + + return responses.GetCallerIdentityResponse{ + GetCallerIdentityResult: []responses.GetCallerIdentityResult{ + { + Arn: arn, + UserId: userId, + Account: accountId, + }, + }, + } +} + +func MakeGetRoleResponse(arn, id string, tags ...responses.Tag) responses.GetRoleResponse { + if strings.Contains(id, ":") { + panic("RoleId in GetRole response must not contain ':'") + } + parsed := parseArn(arn) + return responses.GetRoleResponse{ + GetRoleResult: []responses.GetRoleResult{ + { + Role: responses.Role{ + Arn: arn, + Path: parsed.Path, + RoleId: id, + RoleName: parsed.FriendlyName, + Tags: tags, + }, + }, + }, + } +} + +func MakeGetUserResponse(arn, id string, tags ...responses.Tag) responses.GetUserResponse { + if strings.Contains(id, ":") { + panic("UserId in GetUser resposne must not contain ':'") + } + parsed := parseArn(arn) + return responses.GetUserResponse{ + GetUserResult: []responses.GetUserResult{ + { + User: responses.User{ + Arn: arn, + Path: parsed.Path, + UserId: id, + UserName: parsed.FriendlyName, + Tags: tags, + }, + }, + }, + } +} + +func parseArn(arn string) *responses.ParsedArn { + parsed, err := responses.ParseArn(arn) + if err != nil { + // For testing, just fail immediately. + panic(err) + } + return parsed +} diff --git a/internal/iamauth/token.go b/internal/iamauth/token.go new file mode 100644 index 0000000000..91994b5103 --- /dev/null +++ b/internal/iamauth/token.go @@ -0,0 +1,343 @@ +package iamauth + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/textproto" + "net/url" + "strings" + + "github.com/hashicorp/consul/lib/stringslice" +) + +const ( + amzHeaderPrefix = "X-Amz-" + defaultIAMEndpoint = "https://iam.amazonaws.com" + defaultSTSEndpoint = "https://sts.amazonaws.com" +) + +var defaultAllowedSTSRequestHeaders = []string{ + "X-Amz-Algorithm", + "X-Amz-Content-Sha256", + "X-Amz-Credential", + "X-Amz-Date", + "X-Amz-Security-Token", + "X-Amz-Signature", + "X-Amz-SignedHeaders", +} + +// BearerToken is a login "token" for an IAM auth method. It is a signed +// sts:GetCallerIdentity request in JSON format. Optionally, it can include a +// signed embedded iam:GetRole or iam:GetUser request in the headers. +type BearerToken struct { + config *Config + + getCallerIdentityMethod string + getCallerIdentityURL string + getCallerIdentityHeader http.Header + getCallerIdentityBody string + + getIAMEntityMethod string + getIAMEntityURL string + getIAMEntityHeader http.Header + getIAMEntityBody string + + entityRequestType string + parsedCallerIdentityURL *url.URL + parsedIAMEntityURL *url.URL +} + +var _ json.Unmarshaler = (*BearerToken)(nil) + +func NewBearerToken(loginToken string, config *Config) (*BearerToken, error) { + token := &BearerToken{config: config} + if err := json.Unmarshal([]byte(loginToken), &token); err != nil { + return nil, fmt.Errorf("invalid token: %s", err) + } + + if err := token.validate(); err != nil { + return nil, err + } + + if config.EnableIAMEntityDetails { + method, err := token.getHeader(token.config.GetEntityMethodHeader) + if err != nil { + return nil, err + } + + rawUrl, err := token.getHeader(token.config.GetEntityURLHeader) + if err != nil { + return nil, err + } + + headerJson, err := token.getHeader(token.config.GetEntityHeadersHeader) + if err != nil { + return nil, err + } + + var header http.Header + if err := json.Unmarshal([]byte(headerJson), &header); err != nil { + return nil, err + } + + body, err := token.getHeader(token.config.GetEntityBodyHeader) + if err != nil { + return nil, err + } + + parsedUrl, err := parseUrl(rawUrl) + if err != nil { + return nil, err + } + + token.getIAMEntityMethod = method + token.getIAMEntityBody = body + token.getIAMEntityURL = rawUrl + token.getIAMEntityHeader = header + token.parsedIAMEntityURL = parsedUrl + + reqType, err := token.validateIAMEntityBody() + if err != nil { + return nil, err + } + token.entityRequestType = reqType + } + return token, nil +} + +// https://github.com/hashicorp/vault/blob/b17e3256dde937a6248c9a2fa56206aac93d07de/builtin/credential/aws/path_login.go#L1178 +func (t *BearerToken) validate() error { + if t.getCallerIdentityMethod != "POST" { + return fmt.Errorf("iam_http_request_method must be POST") + } + if err := t.validateGetCallerIdentityBody(); err != nil { + return err + } + if err := t.validateAllowedSTSHeaderValues(); err != nil { + return err + } + return nil +} + +// https://github.com/hashicorp/vault/blob/b17e3256dde937a6248c9a2fa56206aac93d07de/builtin/credential/aws/path_login.go#L1439 +func (t *BearerToken) validateGetCallerIdentityBody() error { + allowedValues := url.Values{ + "Action": []string{"GetCallerIdentity"}, + // Will assume for now that future versions don't change + // the semantics + "Version": nil, // any value is allowed + } + if _, err := parseRequestBody(t.getCallerIdentityBody, allowedValues); err != nil { + return fmt.Errorf("iam_request_body error: %s", err) + } + + return nil +} + +func (t *BearerToken) validateIAMEntityBody() (string, error) { + allowedValues := url.Values{ + "Action": []string{"GetRole", "GetUser"}, + "RoleName": nil, // any value is allowed + "UserName": nil, + "Version": nil, + } + body, err := parseRequestBody(t.getIAMEntityBody, allowedValues) + if err != nil { + return "", fmt.Errorf("iam_request_headers[%s] error: %s", t.config.GetEntityBodyHeader, err) + } + + // Disallow GetRole+UserName and GetUser+RoleName. + action := body["Action"][0] + _, hasRoleName := body["RoleName"] + _, hasUserName := body["UserName"] + if action == "GetUser" && hasUserName && !hasRoleName { + return action, nil + } else if action == "GetRole" && hasRoleName && !hasUserName { + return action, nil + } + return "", fmt.Errorf("iam_request_headers[%q] error: invalid request body %q", t.config.GetEntityBodyHeader, t.getIAMEntityBody) +} + +// parseRequestBody parses the AWS STS or IAM request body, such as 'Action=GetRole&RoleName=my-role'. +// It returns the parsed values, or an error if there are unexpected fields based on allowedValues. +// +// A key-value pair in the body is allowed if: +// - It is a single value (i.e. no bodies like 'Action=1&Action=2') +// - allowedValues[key] is an empty slice or nil (any value is allowed for the key) +// - allowedValues[key] is non-empty and contains the exact value +// This always requires an 'Action' field is present and non-empty. +func parseRequestBody(body string, allowedValues url.Values) (url.Values, error) { + qs, err := url.ParseQuery(body) + if err != nil { + return nil, err + } + + // Action field is always required. + if _, ok := qs["Action"]; !ok || len(qs["Action"]) == 0 || qs["Action"][0] == "" { + return nil, fmt.Errorf(`missing field "Action"`) + } + + // Ensure the body does not have extra fields and each + // field in the body matches the allowed values. + for k, v := range qs { + exp, ok := allowedValues[k] + if k != "Action" && !ok { + return nil, fmt.Errorf("unexpected field %q", k) + } + + if len(exp) == 0 { + // empty indicates any value is okay + continue + } else if len(v) != 1 || !stringslice.Contains(exp, v[0]) { + return nil, fmt.Errorf("unexpected value %s=%v", k, v) + } + } + + return qs, nil +} + +// https://github.com/hashicorp/vault/blob/861454e0ed1390d67ddaf1a53c1798e5e291728c/builtin/credential/aws/path_config_client.go#L349 +func (t *BearerToken) validateAllowedSTSHeaderValues() error { + for k := range t.getCallerIdentityHeader { + h := textproto.CanonicalMIMEHeaderKey(k) + if strings.HasPrefix(h, amzHeaderPrefix) && + !stringslice.Contains(defaultAllowedSTSRequestHeaders, h) && + !stringslice.Contains(t.config.AllowedSTSHeaderValues, h) { + return fmt.Errorf("invalid request header: %s", h) + } + } + return nil +} + +// UnmarshalJSON unmarshals the bearer token details which contains an HTTP +// request (a signed sts:GetCallerIdentity request). +func (t *BearerToken) UnmarshalJSON(data []byte) error { + var rawData struct { + Method string `json:"iam_http_request_method"` + UrlBase64 string `json:"iam_request_url"` + HeadersBase64 string `json:"iam_request_headers"` + BodyBase64 string `json:"iam_request_body"` + } + + if err := json.Unmarshal(data, &rawData); err != nil { + return err + } + + rawUrl, err := base64.StdEncoding.DecodeString(rawData.UrlBase64) + if err != nil { + return err + } + + headersJson, err := base64.StdEncoding.DecodeString(rawData.HeadersBase64) + if err != nil { + return err + } + + var headers http.Header + // This is a JSON-string in JSON + if err := json.Unmarshal(headersJson, &headers); err != nil { + return err + } + + body, err := base64.StdEncoding.DecodeString(rawData.BodyBase64) + if err != nil { + return err + } + + t.getCallerIdentityMethod = rawData.Method + t.getCallerIdentityBody = string(body) + t.getCallerIdentityHeader = headers + t.getCallerIdentityURL = string(rawUrl) + + parsedUrl, err := parseUrl(t.getCallerIdentityURL) + if err != nil { + return err + } + t.parsedCallerIdentityURL = parsedUrl + return nil +} + +func parseUrl(s string) (*url.URL, error) { + u, err := url.Parse(s) + if err != nil { + return nil, err + } + // url.Parse doesn't error on empty string + if u == nil || u.Scheme == "" || u.Host == "" || u.Path == "" { + return nil, fmt.Errorf("url is invalid: %q", s) + } + return u, nil +} + +// GetCallerIdentityRequest returns the sts:GetCallerIdentity request decoded +// from the bearer token. +func (t *BearerToken) GetCallerIdentityRequest() (*http.Request, error) { + // NOTE: We need to ensure we're calling STS, instead of acting as an unintended network proxy + // The protection against this is that this method will only call the endpoint specified in the + // client config (defaulting to sts.amazonaws.com), so it would require an admin to override + // the endpoint to talk to alternate web addresses + endpoint := defaultSTSEndpoint + if t.config.STSEndpoint != "" { + endpoint = t.config.STSEndpoint + } + + return buildHttpRequest( + t.getCallerIdentityMethod, + endpoint, + t.parsedCallerIdentityURL, + t.getCallerIdentityBody, + t.getCallerIdentityHeader, + ) +} + +// GetEntityRequest returns the iam:GetUser or iam:GetRole request from the request details, +// if present, embedded in the headers of the sts:GetCallerIdentity request. +func (t *BearerToken) GetEntityRequest() (*http.Request, error) { + endpoint := defaultIAMEndpoint + if t.config.IAMEndpoint != "" { + endpoint = t.config.IAMEndpoint + } + + return buildHttpRequest( + t.getIAMEntityMethod, + endpoint, + t.parsedIAMEntityURL, + t.getIAMEntityBody, + t.getIAMEntityHeader, + ) +} + +// getHeader returns the header from s.GetCallerIdentityHeader, or an error if +// the header is not found or is not a single value. +func (t *BearerToken) getHeader(name string) (string, error) { + values := t.getCallerIdentityHeader.Values(name) + if len(values) == 0 { + return "", fmt.Errorf("missing header %q", name) + } + if len(values) != 1 { + return "", fmt.Errorf("invalid value for header %q (expected 1 item)", name) + } + return values[0], nil +} + +// buildHttpRequest returns an HTTP request from the given details. +// This supports sending to a custom endpoint, but always preserves the +// Host header and URI path, which are signed and cannot be modified. +// There's a deeper explanation of this in the Vault source code. +// https://github.com/hashicorp/vault/blob/b17e3256dde937a6248c9a2fa56206aac93d07de/builtin/credential/aws/path_login.go#L1569 +func buildHttpRequest(method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*http.Request, error) { + targetUrl := fmt.Sprintf("%s%s", endpoint, parsedUrl.RequestURI()) + request, err := http.NewRequest(method, targetUrl, strings.NewReader(body)) + if err != nil { + return nil, err + } + request.Host = parsedUrl.Host + for k, vals := range headers { + for _, val := range vals { + request.Header.Add(k, val) + } + } + return request, nil +} diff --git a/internal/iamauth/token_test.go b/internal/iamauth/token_test.go new file mode 100644 index 0000000000..4de7ba7157 --- /dev/null +++ b/internal/iamauth/token_test.go @@ -0,0 +1,364 @@ +package iamauth + +import ( + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewBearerToken(t *testing.T) { + cases := map[string]struct { + tokenStr string + config Config + expToken BearerToken + expError string + }{ + "valid token": { + tokenStr: validBearerTokenJson, + expToken: validBearerTokenParsed, + }, + "valid token with role": { + tokenStr: validBearerTokenWithRoleJson, + config: Config{ + EnableIAMEntityDetails: true, + GetEntityMethodHeader: "X-Consul-IAM-GetEntity-Method", + GetEntityURLHeader: "X-Consul-IAM-GetEntity-URL", + GetEntityHeadersHeader: "X-Consul-IAM-GetEntity-Headers", + GetEntityBodyHeader: "X-Consul-IAM-GetEntity-Body", + }, + expToken: validBearerTokenWithRoleParsed, + }, + + "empty json": { + tokenStr: `{}`, + expError: "unexpected end of JSON input", + }, + "missing iam_request_method field": { + tokenStr: tokenJsonMissingMethodField, + expError: "iam_http_request_method must be POST", + }, + "missing iam_request_url field": { + tokenStr: tokenJsonMissingUrlField, + expError: "url is invalid", + }, + "missing iam_request_headers field": { + tokenStr: tokenJsonMissingHeadersField, + expError: "unexpected end of JSON input", + }, + "missing iam_request_body field": { + tokenStr: tokenJsonMissingBodyField, + expError: "iam_request_body error", + }, + "invalid json": { + tokenStr: `{`, + expError: "unexpected end of JSON input", + }, + } + for name, c := range cases { + t.Run(name, func(t *testing.T) { + token, err := NewBearerToken(c.tokenStr, &c.config) + t.Logf("token = %+v", token) + if c.expError != "" { + require.Error(t, err) + require.Contains(t, err.Error(), c.expError) + require.Nil(t, token) + } else { + require.NoError(t, err) + c.expToken.config = &c.config + require.Equal(t, &c.expToken, token) + } + }) + } +} + +func TestParseRequestBody(t *testing.T) { + cases := map[string]struct { + body string + allowedValues url.Values + expValues url.Values + expError string + }{ + "one allowed field": { + body: "Action=GetCallerIdentity&Version=1234", + allowedValues: url.Values{"Version": []string{"1234"}}, + expValues: url.Values{ + "Action": []string{"GetCallerIdentity"}, + "Version": []string{"1234"}, + }, + }, + "many allowed fields": { + body: "Action=GetRole&RoleName=my-role&Version=1234", + allowedValues: url.Values{ + "Action": []string{"GetUser", "GetRole"}, + "UserName": nil, + "RoleName": nil, + "Version": nil, + }, + expValues: url.Values{ + "Action": []string{"GetRole"}, + "RoleName": []string{"my-role"}, + "Version": []string{"1234"}, + }, + }, + "action only": { + body: "Action=GetRole", + allowedValues: nil, + expValues: url.Values{"Action": []string{"GetRole"}}, + }, + + "empty body": { + expValues: url.Values{}, + expError: `missing field "Action"`, + }, + "disallowed field": { + body: "Action=GetRole&Version=1234&Extra=Abc", + allowedValues: url.Values{"Action": nil, "Version": nil}, + expError: `unexpected field "Extra"`, + }, + "mismatched action": { + body: "Action=GetRole", + allowedValues: url.Values{"Action": []string{"GetUser"}}, + expError: `unexpected value Action=[GetRole]`, + }, + "mismatched field": { + body: "Action=GetRole&Extra=1234", + allowedValues: url.Values{"Action": nil, "Extra": []string{"abc"}}, + expError: `unexpected value Extra=[1234]`, + }, + "multi-valued field": { + body: "Action=GetRole&Action=GetUser", + allowedValues: url.Values{"Action": []string{"GetRole", "GetUser"}}, + // only one value is allowed. + expError: `unexpected value Action=[GetRole GetUser]`, + }, + "empty action": { + body: "Action=", + allowedValues: nil, + expError: `missing field "Action"`, + }, + "missing action": { + body: "Version=1234", + allowedValues: url.Values{"Action": []string{"GetRole"}}, + expError: `missing field "Action"`, + }, + } + for name, c := range cases { + t.Run(name, func(t *testing.T) { + values, err := parseRequestBody(c.body, c.allowedValues) + if c.expError != "" { + require.Error(t, err) + require.Contains(t, err.Error(), c.expError) + require.Nil(t, values) + } else { + require.NoError(t, err) + require.Equal(t, c.expValues, values) + } + }) + } +} + +func TestValidateGetCallerIdentityBody(t *testing.T) { + cases := map[string]struct { + body string + expError string + }{ + "valid": {"Action=GetCallerIdentity&Version=1234", ""}, + "valid 2": {"Action=GetCallerIdentity", ""}, + "empty action": { + "Action=", + `iam_request_body error: missing field "Action"`, + }, + "invalid action": { + "Action=GetRole", + `iam_request_body error: unexpected value Action=[GetRole]`, + }, + "missing action": { + "Version=1234", + `iam_request_body error: missing field "Action"`, + }, + "empty": { + "", + `iam_request_body error: missing field "Action"`, + }, + } + for name, c := range cases { + t.Run(name, func(t *testing.T) { + token := &BearerToken{getCallerIdentityBody: c.body} + err := token.validateGetCallerIdentityBody() + if c.expError != "" { + require.Error(t, err) + require.Contains(t, err.Error(), c.expError) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidateIAMEntityBody(t *testing.T) { + cases := map[string]struct { + body string + expReqType string + expError string + }{ + "valid role": { + body: "Action=GetRole&RoleName=my-role&Version=1234", + expReqType: "GetRole", + }, + "valid role without version": { + body: "Action=GetRole&RoleName=my-role", + expReqType: "GetRole", + }, + "valid user": { + body: "Action=GetUser&UserName=my-role&Version=1234", + expReqType: "GetUser", + }, + "valid user without version": { + body: "Action=GetUser&UserName=my-role", + expReqType: "GetUser", + }, + + "invalid action": { + body: "Action=GetCallerIdentity", + expError: `unexpected value Action=[GetCallerIdentity]`, + }, + "role missing action": { + body: "RoleName=my-role&Version=1234", + expError: `missing field "Action"`, + }, + "user missing action": { + body: "UserName=my-role&Version=1234", + expError: `missing field "Action"`, + }, + "empty": { + body: "", + expError: `missing field "Action"`, + }, + "empty action": { + body: "Action=", + expError: `missing field "Action"`, + }, + "role with user name": { + body: "Action=GetRole&UserName=my-role&Version=1234", + expError: `invalid request body`, + }, + "user with role name": { + body: "Action=GetUser&RoleName=my-role&Version=1234", + expError: `invalid request body`, + }, + } + for name, c := range cases { + t.Run(name, func(t *testing.T) { + token := &BearerToken{ + config: &Config{}, + getIAMEntityBody: c.body, + } + reqType, err := token.validateIAMEntityBody() + if c.expError != "" { + require.Error(t, err) + require.Contains(t, err.Error(), c.expError) + require.Equal(t, "", reqType) + } else { + require.NoError(t, err) + require.Equal(t, c.expReqType, reqType) + } + }) + } +} + +var ( + validBearerTokenJson = `{ + "iam_http_request_method":"POST", + "iam_request_body":"QWN0aW9uPUdldENhbGxlcklkZW50aXR5JlZlcnNpb249MjAxMS0wNi0xNQ==", + "iam_request_headers":"eyJBdXRob3JpemF0aW9uIjpbIkFXUzQtSE1BQy1TSEEyNTYgQ3JlZGVudGlhbD1mYWtlLzIwMjIwMzIyL3VzLWVhc3QtMS9zdHMvYXdzNF9yZXF1ZXN0LCBTaWduZWRIZWFkZXJzPWNvbnRlbnQtbGVuZ3RoO2NvbnRlbnQtdHlwZTtob3N0O3gtYW16LWRhdGU7eC1hbXotc2VjdXJpdHktdG9rZW4sIFNpZ25hdHVyZT1lZmMzMjBiOTcyZDA3YjM4YjY1ZWIyNDI1NjgwNWUwMzE0OWRhNTg2ZDgwNGY4YzYzNjRjZTk4ZGViZTA4MGIxIl0sIkNvbnRlbnQtTGVuZ3RoIjpbIjQzIl0sIkNvbnRlbnQtVHlwZSI6WyJhcHBsaWNhdGlvbi94LXd3dy1mb3JtLXVybGVuY29kZWQ7IGNoYXJzZXQ9dXRmLTgiXSwiVXNlci1BZ2VudCI6WyJhd3Mtc2RrLWdvLzEuNDIuMzQgKGdvMS4xNy41OyBkYXJ3aW47IGFtZDY0KSJdLCJYLUFtei1EYXRlIjpbIjIwMjIwMzIyVDIxMTEwM1oiXSwiWC1BbXotU2VjdXJpdHktVG9rZW4iOlsiZmFrZSJdfQ==", + "iam_request_url":"aHR0cHM6Ly9zdHMuYW1hem9uYXdzLmNvbS8=" +}` + + validBearerTokenParsed = BearerToken{ + getCallerIdentityMethod: "POST", + getCallerIdentityURL: "https://sts.amazonaws.com/", + getCallerIdentityHeader: http.Header{ + "Authorization": []string{"AWS4-HMAC-SHA256 Credential=fake/20220322/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-security-token, Signature=efc320b972d07b38b65eb24256805e03149da586d804f8c6364ce98debe080b1"}, + "Content-Length": []string{"43"}, + "Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"}, + "User-Agent": []string{"aws-sdk-go/1.42.34 (go1.17.5; darwin; amd64)"}, + "X-Amz-Date": []string{"20220322T211103Z"}, + "X-Amz-Security-Token": []string{"fake"}, + }, + getCallerIdentityBody: "Action=GetCallerIdentity&Version=2011-06-15", + parsedCallerIdentityURL: &url.URL{ + Scheme: "https", + Host: "sts.amazonaws.com", + Path: "/", + }, + } + + validBearerTokenWithRoleJson = `{"iam_http_request_method":"POST","iam_request_body":"QWN0aW9uPUdldENhbGxlcklkZW50aXR5JlZlcnNpb249MjAxMS0wNi0xNQ==","iam_request_headers":"eyJBdXRob3JpemF0aW9uIjpbIkFXUzQtSE1BQy1TSEEyNTYgQ3JlZGVudGlhbD1mYWtlLWtleS1pZC8yMDIyMDMyMi9mYWtlLXJlZ2lvbi9zdHMvYXdzNF9yZXF1ZXN0LCBTaWduZWRIZWFkZXJzPWNvbnRlbnQtbGVuZ3RoO2NvbnRlbnQtdHlwZTtob3N0O3gtYW16LWRhdGU7eC1jb25zdWwtaWFtLWdldGVudGl0eS1ib2R5O3gtY29uc3VsLWlhbS1nZXRlbnRpdHktaGVhZGVyczt4LWNvbnN1bC1pYW0tZ2V0ZW50aXR5LW1ldGhvZDt4LWNvbnN1bC1pYW0tZ2V0ZW50aXR5LXVybCwgU2lnbmF0dXJlPTU2MWFjMzFiNWFkMDFjMTI0YzU0YzE2OGY3NmVhNmJmZDY0NWI4ZWM1MzQ1ZjgzNTc3MjljOWFhMGI0NzEzMzciXSwiQ29udGVudC1MZW5ndGgiOlsiNDMiXSwiQ29udGVudC1UeXBlIjpbImFwcGxpY2F0aW9uL3gtd3d3LWZvcm0tdXJsZW5jb2RlZDsgY2hhcnNldD11dGYtOCJdLCJVc2VyLUFnZW50IjpbImF3cy1zZGstZ28vMS40Mi4zNCAoZ28xLjE3LjU7IGRhcndpbjsgYW1kNjQpIl0sIlgtQW16LURhdGUiOlsiMjAyMjAzMjJUMjI1NzQyWiJdLCJYLUNvbnN1bC1JYW0tR2V0ZW50aXR5LUJvZHkiOlsiQWN0aW9uPUdldFJvbGVcdTAwMjZSb2xlTmFtZT1teS1yb2xlXHUwMDI2VmVyc2lvbj0yMDEwLTA1LTA4Il0sIlgtQ29uc3VsLUlhbS1HZXRlbnRpdHktSGVhZGVycyI6WyJ7XCJBdXRob3JpemF0aW9uXCI6W1wiQVdTNC1ITUFDLVNIQTI1NiBDcmVkZW50aWFsPWZha2Uta2V5LWlkLzIwMjIwMzIyL3VzLWVhc3QtMS9pYW0vYXdzNF9yZXF1ZXN0LCBTaWduZWRIZWFkZXJzPWNvbnRlbnQtbGVuZ3RoO2NvbnRlbnQtdHlwZTtob3N0O3gtYW16LWRhdGUsIFNpZ25hdHVyZT1hYTJhMTlkMGEzMDVkNzRiYmQwMDk3NzZiY2E4ODBlNTNjZmE5OTFlNDgzZTQwMzk0NzE4MWE0MWNjNDgyOTQwXCJdLFwiQ29udGVudC1MZW5ndGhcIjpbXCI1MFwiXSxcIkNvbnRlbnQtVHlwZVwiOltcImFwcGxpY2F0aW9uL3gtd3d3LWZvcm0tdXJsZW5jb2RlZDsgY2hhcnNldD11dGYtOFwiXSxcIlVzZXItQWdlbnRcIjpbXCJhd3Mtc2RrLWdvLzEuNDIuMzQgKGdvMS4xNy41OyBkYXJ3aW47IGFtZDY0KVwiXSxcIlgtQW16LURhdGVcIjpbXCIyMDIyMDMyMlQyMjU3NDJaXCJdfSJdLCJYLUNvbnN1bC1JYW0tR2V0ZW50aXR5LU1ldGhvZCI6WyJQT1NUIl0sIlgtQ29uc3VsLUlhbS1HZXRlbnRpdHktVXJsIjpbImh0dHBzOi8vaWFtLmFtYXpvbmF3cy5jb20vIl19","iam_request_url":"aHR0cDovLzEyNy4wLjAuMTo2MzY5Ni9zdHMv"}` + + validBearerTokenWithRoleParsed = BearerToken{ + getCallerIdentityMethod: "POST", + getCallerIdentityURL: "http://127.0.0.1:63696/sts/", + getCallerIdentityHeader: http.Header{ + "Authorization": []string{"AWS4-HMAC-SHA256 Credential=fake-key-id/20220322/fake-region/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-consul-iam-getentity-body;x-consul-iam-getentity-headers;x-consul-iam-getentity-method;x-consul-iam-getentity-url, Signature=561ac31b5ad01c124c54c168f76ea6bfd645b8ec5345f8357729c9aa0b471337"}, + "Content-Length": []string{"43"}, + "Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"}, + "User-Agent": []string{"aws-sdk-go/1.42.34 (go1.17.5; darwin; amd64)"}, + "X-Amz-Date": []string{"20220322T225742Z"}, + "X-Consul-Iam-Getentity-Body": []string{"Action=GetRole&RoleName=my-role&Version=2010-05-08"}, + "X-Consul-Iam-Getentity-Headers": []string{`{"Authorization":["AWS4-HMAC-SHA256 Credential=fake-key-id/20220322/us-east-1/iam/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aa2a19d0a305d74bbd009776bca880e53cfa991e483e403947181a41cc482940"],"Content-Length":["50"],"Content-Type":["application/x-www-form-urlencoded; charset=utf-8"],"User-Agent":["aws-sdk-go/1.42.34 (go1.17.5; darwin; amd64)"],"X-Amz-Date":["20220322T225742Z"]}`}, + "X-Consul-Iam-Getentity-Method": []string{"POST"}, + "X-Consul-Iam-Getentity-Url": []string{"https://iam.amazonaws.com/"}, + }, + getCallerIdentityBody: "Action=GetCallerIdentity&Version=2011-06-15", + + // Fields parsed from headers above + getIAMEntityMethod: "POST", + getIAMEntityURL: "https://iam.amazonaws.com/", + getIAMEntityHeader: http.Header{ + "Authorization": []string{"AWS4-HMAC-SHA256 Credential=fake-key-id/20220322/us-east-1/iam/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aa2a19d0a305d74bbd009776bca880e53cfa991e483e403947181a41cc482940"}, + "Content-Length": []string{"50"}, + "Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"}, + "User-Agent": []string{"aws-sdk-go/1.42.34 (go1.17.5; darwin; amd64)"}, + "X-Amz-Date": []string{"20220322T225742Z"}, + }, + getIAMEntityBody: "Action=GetRole&RoleName=my-role&Version=2010-05-08", + entityRequestType: "GetRole", + + parsedCallerIdentityURL: &url.URL{ + Scheme: "http", + Host: "127.0.0.1:63696", + Path: "/sts/", + }, + parsedIAMEntityURL: &url.URL{ + Scheme: "https", + Host: "iam.amazonaws.com", + Path: "/", + }, + } + + tokenJsonMissingMethodField = `{ + "iam_request_body":"QWN0aW9uPUdldENhbGxlcklkZW50aXR5JlZlcnNpb249MjAxMS0wNi0xNQ==", + "iam_request_headers":"eyJBdXRob3JpemF0aW9uIjpbIkFXUzQtSE1BQy1TSEEyNTYgQ3JlZGVudGlhbD1mYWtlLzIwMjIwMzIyL3VzLWVhc3QtMS9zdHMvYXdzNF9yZXF1ZXN0LCBTaWduZWRIZWFkZXJzPWNvbnRlbnQtbGVuZ3RoO2NvbnRlbnQtdHlwZTtob3N0O3gtYW16LWRhdGU7eC1hbXotc2VjdXJpdHktdG9rZW4sIFNpZ25hdHVyZT1lZmMzMjBiOTcyZDA3YjM4YjY1ZWIyNDI1NjgwNWUwMzE0OWRhNTg2ZDgwNGY4YzYzNjRjZTk4ZGViZTA4MGIxIl0sIkNvbnRlbnQtTGVuZ3RoIjpbIjQzIl0sIkNvbnRlbnQtVHlwZSI6WyJhcHBsaWNhdGlvbi94LXd3dy1mb3JtLXVybGVuY29kZWQ7IGNoYXJzZXQ9dXRmLTgiXSwiVXNlci1BZ2VudCI6WyJhd3Mtc2RrLWdvLzEuNDIuMzQgKGdvMS4xNy41OyBkYXJ3aW47IGFtZDY0KSJdLCJYLUFtei1EYXRlIjpbIjIwMjIwMzIyVDIxMTEwM1oiXSwiWC1BbXotU2VjdXJpdHktVG9rZW4iOlsiZmFrZSJdfQ==", + "iam_request_url":"aHR0cHM6Ly9zdHMuYW1hem9uYXdzLmNvbS8=" +}` + + tokenJsonMissingBodyField = `{ + "iam_http_request_method":"POST", + "iam_request_headers":"eyJBdXRob3JpemF0aW9uIjpbIkFXUzQtSE1BQy1TSEEyNTYgQ3JlZGVudGlhbD1mYWtlLzIwMjIwMzIyL3VzLWVhc3QtMS9zdHMvYXdzNF9yZXF1ZXN0LCBTaWduZWRIZWFkZXJzPWNvbnRlbnQtbGVuZ3RoO2NvbnRlbnQtdHlwZTtob3N0O3gtYW16LWRhdGU7eC1hbXotc2VjdXJpdHktdG9rZW4sIFNpZ25hdHVyZT1lZmMzMjBiOTcyZDA3YjM4YjY1ZWIyNDI1NjgwNWUwMzE0OWRhNTg2ZDgwNGY4YzYzNjRjZTk4ZGViZTA4MGIxIl0sIkNvbnRlbnQtTGVuZ3RoIjpbIjQzIl0sIkNvbnRlbnQtVHlwZSI6WyJhcHBsaWNhdGlvbi94LXd3dy1mb3JtLXVybGVuY29kZWQ7IGNoYXJzZXQ9dXRmLTgiXSwiVXNlci1BZ2VudCI6WyJhd3Mtc2RrLWdvLzEuNDIuMzQgKGdvMS4xNy41OyBkYXJ3aW47IGFtZDY0KSJdLCJYLUFtei1EYXRlIjpbIjIwMjIwMzIyVDIxMTEwM1oiXSwiWC1BbXotU2VjdXJpdHktVG9rZW4iOlsiZmFrZSJdfQ==", + "iam_request_url":"aHR0cHM6Ly9zdHMuYW1hem9uYXdzLmNvbS8=" +}` + + tokenJsonMissingHeadersField = `{ + "iam_http_request_method":"POST", + "iam_request_body":"QWN0aW9uPUdldENhbGxlcklkZW50aXR5JlZlcnNpb249MjAxMS0wNi0xNQ==", + "iam_request_url":"aHR0cHM6Ly9zdHMuYW1hem9uYXdzLmNvbS8=" +}` + + tokenJsonMissingUrlField = `{ + "iam_http_request_method":"POST", + "iam_request_body":"QWN0aW9uPUdldENhbGxlcklkZW50aXR5JlZlcnNpb249MjAxMS0wNi0xNQ==", + "iam_request_headers":"eyJBdXRob3JpemF0aW9uIjpbIkFXUzQtSE1BQy1TSEEyNTYgQ3JlZGVudGlhbD1mYWtlLzIwMjIwMzIyL3VzLWVhc3QtMS9zdHMvYXdzNF9yZXF1ZXN0LCBTaWduZWRIZWFkZXJzPWNvbnRlbnQtbGVuZ3RoO2NvbnRlbnQtdHlwZTtob3N0O3gtYW16LWRhdGU7eC1hbXotc2VjdXJpdHktdG9rZW4sIFNpZ25hdHVyZT1lZmMzMjBiOTcyZDA3YjM4YjY1ZWIyNDI1NjgwNWUwMzE0OWRhNTg2ZDgwNGY4YzYzNjRjZTk4ZGViZTA4MGIxIl0sIkNvbnRlbnQtTGVuZ3RoIjpbIjQzIl0sIkNvbnRlbnQtVHlwZSI6WyJhcHBsaWNhdGlvbi94LXd3dy1mb3JtLXVybGVuY29kZWQ7IGNoYXJzZXQ9dXRmLTgiXSwiVXNlci1BZ2VudCI6WyJhd3Mtc2RrLWdvLzEuNDIuMzQgKGdvMS4xNy41OyBkYXJ3aW47IGFtZDY0KSJdLCJYLUFtei1EYXRlIjpbIjIwMjIwMzIyVDIxMTEwM1oiXSwiWC1BbXotU2VjdXJpdHktVG9rZW4iOlsiZmFrZSJdfQ==" +}` +) diff --git a/internal/iamauth/util.go b/internal/iamauth/util.go new file mode 100644 index 0000000000..bfd5f22d77 --- /dev/null +++ b/internal/iamauth/util.go @@ -0,0 +1,158 @@ +package iamauth + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/hashicorp/consul/internal/iamauth/responses" + "github.com/hashicorp/go-hclog" +) + +type LoginInput struct { + Creds *credentials.Credentials + IncludeIAMEntity bool + STSEndpoint string + STSRegion string + + Logger hclog.Logger + + ServerIDHeaderValue string + // Customizable header names + ServerIDHeaderName string + GetEntityMethodHeader string + GetEntityURLHeader string + GetEntityHeadersHeader string + GetEntityBodyHeader string +} + +// GenerateLoginData populates the necessary data to send for the bearer token. +// https://github.com/hashicorp/go-secure-stdlib/blob/main/awsutil/generate_credentials.go#L232-L301 +func GenerateLoginData(in *LoginInput) (map[string]interface{}, error) { + cfg := aws.Config{ + Credentials: in.Creds, + Region: aws.String(in.STSRegion), + } + if in.STSEndpoint != "" { + cfg.Endpoint = aws.String(in.STSEndpoint) + } else { + cfg.EndpointResolver = endpoints.ResolverFunc(stsSigningResolver) + } + + stsSession, err := session.NewSessionWithOptions(session.Options{Config: cfg}) + if err != nil { + return nil, err + } + + svc := sts.New(stsSession) + stsRequest, _ := svc.GetCallerIdentityRequest(nil) + + // Include the iam:GetRole or iam:GetUser request in headers. + if in.IncludeIAMEntity { + entityRequest, err := formatSignedEntityRequest(svc, in) + if err != nil { + return nil, err + } + + headersJson, err := json.Marshal(entityRequest.HTTPRequest.Header) + if err != nil { + return nil, err + } + requestBody, err := ioutil.ReadAll(entityRequest.HTTPRequest.Body) + if err != nil { + return nil, err + } + + stsRequest.HTTPRequest.Header.Add(in.GetEntityMethodHeader, entityRequest.HTTPRequest.Method) + stsRequest.HTTPRequest.Header.Add(in.GetEntityURLHeader, entityRequest.HTTPRequest.URL.String()) + stsRequest.HTTPRequest.Header.Add(in.GetEntityHeadersHeader, string(headersJson)) + stsRequest.HTTPRequest.Header.Add(in.GetEntityBodyHeader, string(requestBody)) + } + + // Inject the required auth header value, if supplied, and then sign the request including that header + if in.ServerIDHeaderValue != "" { + stsRequest.HTTPRequest.Header.Add(in.ServerIDHeaderName, in.ServerIDHeaderValue) + } + + stsRequest.Sign() + + // Now extract out the relevant parts of the request + headersJson, err := json.Marshal(stsRequest.HTTPRequest.Header) + if err != nil { + return nil, err + } + requestBody, err := ioutil.ReadAll(stsRequest.HTTPRequest.Body) + if err != nil { + return nil, err + } + + return map[string]interface{}{ + "iam_http_request_method": stsRequest.HTTPRequest.Method, + "iam_request_url": base64.StdEncoding.EncodeToString([]byte(stsRequest.HTTPRequest.URL.String())), + "iam_request_headers": base64.StdEncoding.EncodeToString(headersJson), + "iam_request_body": base64.StdEncoding.EncodeToString(requestBody), + }, nil +} + +// STS is a really weird service that used to only have global endpoints but now has regional endpoints as well. +// For backwards compatibility, even if you request a region other than us-east-1, it'll still sign for us-east-1. +// See, e.g., https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_temp_enable-regions_writing_code +// So we have to shim in this EndpointResolver to force it to sign for the right region +func stsSigningResolver(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + defaultEndpoint, err := endpoints.DefaultResolver().EndpointFor(service, region, optFns...) + if err != nil { + return defaultEndpoint, err + } + defaultEndpoint.SigningRegion = region + return defaultEndpoint, nil +} + +func formatSignedEntityRequest(svc *sts.STS, in *LoginInput) (*request.Request, error) { + // We need to retrieve the IAM user or role for the iam:GetRole or iam:GetUser request. + // GetCallerIdentity returns this and requires no permissions. + resp, err := svc.GetCallerIdentity(nil) + if err != nil { + return nil, err + } + + arn, err := responses.ParseArn(*resp.Arn) + if err != nil { + return nil, err + } + + iamSession, err := session.NewSessionWithOptions(session.Options{ + Config: aws.Config{ + Credentials: svc.Config.Credentials, + }, + }) + if err != nil { + return nil, err + } + iamSvc := iam.New(iamSession) + + var req *request.Request + switch arn.Type { + case "role", "assumed-role": + req, _ = iamSvc.GetRoleRequest(&iam.GetRoleInput{RoleName: &arn.FriendlyName}) + case "user": + req, _ = iamSvc.GetUserRequest(&iam.GetUserInput{UserName: &arn.FriendlyName}) + default: + return nil, fmt.Errorf("entity %s is not an IAM role or IAM user", arn.Type) + } + + // Inject the required auth header value, if supplied, and then sign the request including that header + if in.ServerIDHeaderValue != "" { + req.HTTPRequest.Header.Add(in.ServerIDHeaderName, in.ServerIDHeaderValue) + } + + req.Sign() + return req, nil +} diff --git a/lib/glob.go b/lib/glob.go new file mode 100644 index 0000000000..969e3ab25c --- /dev/null +++ b/lib/glob.go @@ -0,0 +1,24 @@ +package lib + +import "strings" + +// GlobbedStringsMatch compares item to val with support for a leading and/or +// trailing wildcard '*' in item. +func GlobbedStringsMatch(item, val string) bool { + if len(item) < 2 { + return val == item + } + + hasPrefix := strings.HasPrefix(item, "*") + hasSuffix := strings.HasSuffix(item, "*") + + if hasPrefix && hasSuffix { + return strings.Contains(val, item[1:len(item)-1]) + } else if hasPrefix { + return strings.HasSuffix(val, item[1:]) + } else if hasSuffix { + return strings.HasPrefix(val, item[:len(item)-1]) + } + + return val == item +} diff --git a/lib/glob_test.go b/lib/glob_test.go new file mode 100644 index 0000000000..6c29f5ef19 --- /dev/null +++ b/lib/glob_test.go @@ -0,0 +1,37 @@ +package lib + +import "testing" + +func TestGlobbedStringsMatch(t *testing.T) { + tests := []struct { + item string + val string + expect bool + }{ + {"", "", true}, + {"*", "*", true}, + {"**", "**", true}, + {"*t", "t", true}, + {"*t", "test", true}, + {"t*", "test", true}, + {"*test", "test", true}, + {"*test", "a test", true}, + {"test", "a test", false}, + {"*test", "tests", false}, + {"test*", "test", true}, + {"test*", "testsss", true}, + {"test**", "testsss", false}, + {"test**", "test*", true}, + {"**test", "*test", true}, + {"TEST", "test", false}, + {"test", "test", true}, + } + + for _, tt := range tests { + actual := GlobbedStringsMatch(tt.item, tt.val) + + if actual != tt.expect { + t.Fatalf("Bad testcase %#v, expected %t, got %t", tt, tt.expect, actual) + } + } +}