Browse Source

Added Azure OAuth support (#12572)

* Added Azure OAuth support

Signed-off-by: rakshith210 <rakshith.me@gmail.com>

* Added missing comment

Signed-off-by: rakshith210 <rakshith.me@gmail.com>

* Addressing comment

Signed-off-by: rakshith210 <rakshith.me@gmail.com>

* Fixed lint issue

Signed-off-by: rakshith210 <rakshith.me@gmail.com>

* Fix test

Signed-off-by: rakshith210 <rakshith.me@gmail.com>

* Addressing comments

Signed-off-by: rakshith210 <rakshith.me@gmail.com>

* Added documentation and updated unit tests

Signed-off-by: rakshith210 <rakshith.me@gmail.com>

* Addressing comments

Signed-off-by: rakshith210 <rakshith.me@gmail.com>

---------

Signed-off-by: rakshith210 <rakshith.me@gmail.com>
pull/12937/head
rakshith210 1 year ago committed by GitHub
parent
commit
cdad64002a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      docs/configuration/configuration.md
  2. 116
      storage/remote/azuread/azuread.go
  3. 323
      storage/remote/azuread/azuread_test.go
  4. 0
      storage/remote/azuread/testdata/azuread_bad_configmissing.yaml
  5. 4
      storage/remote/azuread/testdata/azuread_bad_invalidoauthconfig.yaml
  6. 7
      storage/remote/azuread/testdata/azuread_bad_twoconfig.yaml
  7. 0
      storage/remote/azuread/testdata/azuread_good_managedidentity.yaml
  8. 5
      storage/remote/azuread/testdata/azuread_good_oauth.yaml

8
docs/configuration/configuration.md

@ -3537,7 +3537,13 @@ azuread:
# Azure User-assigned Managed identity.
[ managed_identity:
[ client_id: <string> ]
[ client_id: <string> ] ]
# Azure OAuth.
[ oauth:
[ client_id: <string> ]
[ client_secret: <string> ]
[ tenant_id: <string> ] ]
# Configures the remote write request's TLS settings.
tls_config:

116
storage/remote/azuread/azuread.go

@ -22,7 +22,10 @@ import (
"sync"
"time"
"github.com/grafana/regexp"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/google/uuid"
@ -46,11 +49,26 @@ type ManagedIdentityConfig struct {
ClientID string `yaml:"client_id,omitempty"`
}
// OAuthConfig is used to store azure oauth config values.
type OAuthConfig struct {
// ClientID is the clientId of the azure active directory application that is being used to authenticate.
ClientID string `yaml:"client_id,omitempty"`
// ClientSecret is the clientSecret of the azure active directory application that is being used to authenticate.
ClientSecret string `yaml:"client_secret,omitempty"`
// TenantID is the tenantId of the azure active directory application that is being used to authenticate.
TenantID string `yaml:"tenant_id,omitempty"`
}
// AzureADConfig is used to store the config values.
type AzureADConfig struct { // nolint:revive
// ManagedIdentity is the managed identity that is being used to authenticate.
ManagedIdentity *ManagedIdentityConfig `yaml:"managed_identity,omitempty"`
// OAuth is the oauth config that is being used to authenticate.
OAuth *OAuthConfig `yaml:"oauth,omitempty"`
// Cloud is the Azure cloud in which the service is running. Example: AzurePublic/AzureGovernment/AzureChina.
Cloud string `yaml:"cloud,omitempty"`
}
@ -84,18 +102,47 @@ func (c *AzureADConfig) Validate() error {
return fmt.Errorf("must provide a cloud in the Azure AD config")
}
if c.ManagedIdentity == nil {
return fmt.Errorf("must provide an Azure Managed Identity in the Azure AD config")
if c.ManagedIdentity == nil && c.OAuth == nil {
return fmt.Errorf("must provide an Azure Managed Identity or Azure OAuth in the Azure AD config")
}
if c.ManagedIdentity.ClientID == "" {
return fmt.Errorf("must provide an Azure Managed Identity client_id in the Azure AD config")
if c.ManagedIdentity != nil && c.OAuth != nil {
return fmt.Errorf("cannot provide both Azure Managed Identity and Azure OAuth in the Azure AD config")
}
_, err := uuid.Parse(c.ManagedIdentity.ClientID)
if err != nil {
return fmt.Errorf("the provided Azure Managed Identity client_id provided is invalid")
if c.ManagedIdentity != nil {
if c.ManagedIdentity.ClientID == "" {
return fmt.Errorf("must provide an Azure Managed Identity client_id in the Azure AD config")
}
_, err := uuid.Parse(c.ManagedIdentity.ClientID)
if err != nil {
return fmt.Errorf("the provided Azure Managed Identity client_id is invalid")
}
}
if c.OAuth != nil {
if c.OAuth.ClientID == "" {
return fmt.Errorf("must provide an Azure OAuth client_id in the Azure AD config")
}
if c.OAuth.ClientSecret == "" {
return fmt.Errorf("must provide an Azure OAuth client_secret in the Azure AD config")
}
if c.OAuth.TenantID == "" {
return fmt.Errorf("must provide an Azure OAuth tenant_id in the Azure AD config")
}
var err error
_, err = uuid.Parse(c.OAuth.ClientID)
if err != nil {
return fmt.Errorf("the provided Azure OAuth client_id is invalid")
}
_, err = regexp.MatchString("^[0-9a-zA-Z-.]+$", c.OAuth.TenantID)
if err != nil {
return fmt.Errorf("the provided Azure OAuth tenant_id is invalid")
}
}
return nil
}
@ -146,21 +193,54 @@ func (rt *azureADRoundTripper) RoundTrip(req *http.Request) (*http.Response, err
// newTokenCredential returns a TokenCredential of different kinds like Azure Managed Identity and Azure AD application.
func newTokenCredential(cfg *AzureADConfig) (azcore.TokenCredential, error) {
cred, err := newManagedIdentityTokenCredential(cfg.ManagedIdentity.ClientID)
var cred azcore.TokenCredential
var err error
cloudConfiguration, err := getCloudConfiguration(cfg.Cloud)
if err != nil {
return nil, err
}
clientOpts := &azcore.ClientOptions{
Cloud: cloudConfiguration,
}
if cfg.ManagedIdentity != nil {
managedIdentityConfig := &ManagedIdentityConfig{
ClientID: cfg.ManagedIdentity.ClientID,
}
cred, err = newManagedIdentityTokenCredential(clientOpts, managedIdentityConfig)
if err != nil {
return nil, err
}
}
if cfg.OAuth != nil {
oAuthConfig := &OAuthConfig{
ClientID: cfg.OAuth.ClientID,
ClientSecret: cfg.OAuth.ClientSecret,
TenantID: cfg.OAuth.TenantID,
}
cred, err = newOAuthTokenCredential(clientOpts, oAuthConfig)
if err != nil {
return nil, err
}
}
return cred, nil
}
// newManagedIdentityTokenCredential returns new Managed Identity token credential.
func newManagedIdentityTokenCredential(managedIdentityClientID string) (azcore.TokenCredential, error) {
clientID := azidentity.ClientID(managedIdentityClientID)
opts := &azidentity.ManagedIdentityCredentialOptions{ID: clientID}
func newManagedIdentityTokenCredential(clientOpts *azcore.ClientOptions, managedIdentityConfig *ManagedIdentityConfig) (azcore.TokenCredential, error) {
clientID := azidentity.ClientID(managedIdentityConfig.ClientID)
opts := &azidentity.ManagedIdentityCredentialOptions{ClientOptions: *clientOpts, ID: clientID}
return azidentity.NewManagedIdentityCredential(opts)
}
// newOAuthTokenCredential returns new OAuth token credential
func newOAuthTokenCredential(clientOpts *azcore.ClientOptions, oAuthConfig *OAuthConfig) (azcore.TokenCredential, error) {
opts := &azidentity.ClientSecretCredentialOptions{ClientOptions: *clientOpts}
return azidentity.NewClientSecretCredential(oAuthConfig.TenantID, oAuthConfig.ClientID, oAuthConfig.ClientSecret, opts)
}
// newTokenProvider helps to fetch accessToken for different types of credential. This also takes care of
// refreshing the accessToken before expiry. This accessToken is attached to the Authorization header while making requests.
func newTokenProvider(cfg *AzureADConfig, cred azcore.TokenCredential) (*tokenProvider, error) {
@ -245,3 +325,17 @@ func getAudience(cloud string) (string, error) {
return "", errors.New("Cloud is not specified or is incorrect: " + cloud)
}
}
// getCloudConfiguration returns the cloud Configuration which contains AAD endpoint for different clouds
func getCloudConfiguration(c string) (cloud.Configuration, error) {
switch strings.ToLower(c) {
case strings.ToLower(AzureChina):
return cloud.AzureChina, nil
case strings.ToLower(AzureGovernment):
return cloud.AzureGovernment, nil
case strings.ToLower(AzurePublic):
return cloud.AzurePublic, nil
default:
return cloud.Configuration{}, errors.New("Cloud is not specified or is incorrect: " + c)
}
}

323
storage/remote/azuread/azuread_test.go

@ -26,17 +26,20 @@ import (
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"gopkg.in/yaml.v2"
)
const (
dummyAudience = "dummyAudience"
dummyClientID = "00000000-0000-0000-0000-000000000000"
testTokenString = "testTokenString"
dummyAudience = "dummyAudience"
dummyClientID = "00000000-0000-0000-0000-000000000000"
dummyClientSecret = "Cl1ent$ecret!"
dummyTenantID = "00000000-a12b-3cd4-e56f-000000000000"
testTokenString = "testTokenString"
)
var testTokenExpiry = time.Now().Add(10 * time.Second)
var testTokenExpiry = time.Now().Add(5 * time.Second)
type AzureAdTestSuite struct {
suite.Suite
@ -62,47 +65,64 @@ func TestAzureAd(t *testing.T) {
}
func (ad *AzureAdTestSuite) TestAzureAdRoundTripper() {
var gotReq *http.Request
testToken := &azcore.AccessToken{
Token: testTokenString,
ExpiresOn: testTokenExpiry,
cases := []struct {
cfg *AzureADConfig
}{
// AzureAd roundtripper with Managedidentity.
{
cfg: &AzureADConfig{
Cloud: "AzurePublic",
ManagedIdentity: &ManagedIdentityConfig{
ClientID: dummyClientID,
},
},
},
// AzureAd roundtripper with OAuth.
{
cfg: &AzureADConfig{
Cloud: "AzurePublic",
OAuth: &OAuthConfig{
ClientID: dummyClientID,
ClientSecret: dummyClientSecret,
TenantID: dummyTenantID,
},
},
},
}
for _, c := range cases {
var gotReq *http.Request
managedIdentityConfig := &ManagedIdentityConfig{
ClientID: dummyClientID,
}
testToken := &azcore.AccessToken{
Token: testTokenString,
ExpiresOn: testTokenExpiry,
}
azureAdConfig := &AzureADConfig{
Cloud: "AzurePublic",
ManagedIdentity: managedIdentityConfig,
}
ad.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil)
ad.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil)
tokenProvider, err := newTokenProvider(c.cfg, ad.mockCredential)
ad.Assert().NoError(err)
tokenProvider, err := newTokenProvider(azureAdConfig, ad.mockCredential)
ad.Assert().NoError(err)
rt := &azureADRoundTripper{
next: promhttp.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
gotReq = req
return &http.Response{StatusCode: http.StatusOK}, nil
}),
tokenProvider: tokenProvider,
}
rt := &azureADRoundTripper{
next: promhttp.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
gotReq = req
return &http.Response{StatusCode: http.StatusOK}, nil
}),
tokenProvider: tokenProvider,
}
cli := &http.Client{Transport: rt}
cli := &http.Client{Transport: rt}
req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!"))
ad.Assert().NoError(err)
req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!"))
ad.Assert().NoError(err)
_, err = cli.Do(req)
ad.Assert().NoError(err)
ad.Assert().NotNil(gotReq)
_, err = cli.Do(req)
ad.Assert().NoError(err)
ad.Assert().NotNil(gotReq)
origReq := gotReq
ad.Assert().NotEmpty(origReq.Header.Get("Authorization"))
ad.Assert().Equal("Bearer "+testTokenString, origReq.Header.Get("Authorization"))
origReq := gotReq
ad.Assert().NotEmpty(origReq.Header.Get("Authorization"))
ad.Assert().Equal("Bearer "+testTokenString, origReq.Header.Get("Authorization"))
}
}
func loadAzureAdConfig(filename string) (*AzureADConfig, error) {
@ -117,42 +137,54 @@ func loadAzureAdConfig(filename string) (*AzureADConfig, error) {
return &cfg, nil
}
func testGoodConfig(t *testing.T, filename string) {
_, err := loadAzureAdConfig(filename)
if err != nil {
t.Fatalf("Unexpected error parsing %s: %s", filename, err)
}
}
func TestGoodAzureAdConfig(t *testing.T) {
filename := "testdata/azuread_good.yaml"
testGoodConfig(t, filename)
}
func TestGoodCloudMissingAzureAdConfig(t *testing.T) {
filename := "testdata/azuread_good_cloudmissing.yaml"
testGoodConfig(t, filename)
}
func TestBadClientIdMissingAzureAdConfig(t *testing.T) {
filename := "testdata/azuread_bad_clientidmissing.yaml"
_, err := loadAzureAdConfig(filename)
if err == nil {
t.Fatalf("Did not receive expected error unmarshaling bad azuread config")
}
if !strings.Contains(err.Error(), "must provide an Azure Managed Identity in the Azure AD config") {
t.Errorf("Received unexpected error from unmarshal of %s: %s", filename, err.Error())
func TestAzureAdConfig(t *testing.T) {
cases := []struct {
filename string
err string
}{
// Missing managedidentiy or oauth field.
{
filename: "testdata/azuread_bad_configmissing.yaml",
err: "must provide an Azure Managed Identity or Azure OAuth in the Azure AD config",
},
// Invalid managedidentity client id.
{
filename: "testdata/azuread_bad_invalidclientid.yaml",
err: "the provided Azure Managed Identity client_id is invalid",
},
// Missing tenant id in oauth config.
{
filename: "testdata/azuread_bad_invalidoauthconfig.yaml",
err: "must provide an Azure OAuth tenant_id in the Azure AD config",
},
// Invalid config when both managedidentity and oauth is provided.
{
filename: "testdata/azuread_bad_twoconfig.yaml",
err: "cannot provide both Azure Managed Identity and Azure OAuth in the Azure AD config",
},
// Valid config with missing optionally cloud field.
{
filename: "testdata/azuread_good_cloudmissing.yaml",
},
// Valid managed identity config.
{
filename: "testdata/azuread_good_managedidentity.yaml",
},
// Valid Oauth config.
{
filename: "testdata/azuread_good_oauth.yaml",
},
}
}
func TestBadInvalidClientIdAzureAdConfig(t *testing.T) {
filename := "testdata/azuread_bad_invalidclientid.yaml"
_, err := loadAzureAdConfig(filename)
if err == nil {
t.Fatalf("Did not receive expected error unmarshaling bad azuread config")
}
if !strings.Contains(err.Error(), "the provided Azure Managed Identity client_id provided is invalid") {
t.Errorf("Received unexpected error from unmarshal of %s: %s", filename, err.Error())
for _, c := range cases {
_, err := loadAzureAdConfig(c.filename)
if c.err != "" {
if err == nil {
t.Fatalf("Did not receive expected error unmarshaling bad azuread config")
}
require.EqualError(t, err, c.err)
} else {
require.NoError(t, err)
}
}
}
@ -173,75 +205,90 @@ func TestTokenProvider(t *testing.T) {
suite.Run(t, new(TokenProviderTestSuite))
}
func (s *TokenProviderTestSuite) TestNewTokenProvider_NilAudience_Fail() {
managedIdentityConfig := &ManagedIdentityConfig{
ClientID: dummyClientID,
}
azureAdConfig := &AzureADConfig{
Cloud: "PublicAzure",
ManagedIdentity: managedIdentityConfig,
}
actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential)
s.Assert().Nil(actualTokenProvider)
s.Assert().NotNil(actualErr)
s.Assert().Equal("Cloud is not specified or is incorrect: "+azureAdConfig.Cloud, actualErr.Error())
}
func (s *TokenProviderTestSuite) TestNewTokenProvider_Success() {
managedIdentityConfig := &ManagedIdentityConfig{
ClientID: dummyClientID,
}
azureAdConfig := &AzureADConfig{
Cloud: "AzurePublic",
ManagedIdentity: managedIdentityConfig,
}
s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil)
actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential)
s.Assert().NotNil(actualTokenProvider)
s.Assert().Nil(actualErr)
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
}
func (s *TokenProviderTestSuite) TestPeriodicTokenRefresh_Success() {
// setup
managedIdentityConfig := &ManagedIdentityConfig{
ClientID: dummyClientID,
func (s *TokenProviderTestSuite) TestNewTokenProvider() {
cases := []struct {
cfg *AzureADConfig
err string
}{
// Invalid tokenProvider for managedidentity.
{
cfg: &AzureADConfig{
Cloud: "PublicAzure",
ManagedIdentity: &ManagedIdentityConfig{
ClientID: dummyClientID,
},
},
err: "Cloud is not specified or is incorrect: ",
},
// Invalid tokenProvider for oauth.
{
cfg: &AzureADConfig{
Cloud: "PublicAzure",
OAuth: &OAuthConfig{
ClientID: dummyClientID,
ClientSecret: dummyClientSecret,
TenantID: dummyTenantID,
},
},
err: "Cloud is not specified or is incorrect: ",
},
// Valid tokenProvider for managedidentity.
{
cfg: &AzureADConfig{
Cloud: "AzurePublic",
ManagedIdentity: &ManagedIdentityConfig{
ClientID: dummyClientID,
},
},
},
// Valid tokenProvider for oauth.
{
cfg: &AzureADConfig{
Cloud: "AzurePublic",
OAuth: &OAuthConfig{
ClientID: dummyClientID,
ClientSecret: dummyClientSecret,
TenantID: dummyTenantID,
},
},
},
}
azureAdConfig := &AzureADConfig{
Cloud: "AzurePublic",
ManagedIdentity: managedIdentityConfig,
mockGetTokenCallCounter := 1
for _, c := range cases {
if c.err != "" {
actualTokenProvider, actualErr := newTokenProvider(c.cfg, s.mockCredential)
s.Assert().Nil(actualTokenProvider)
s.Assert().NotNil(actualErr)
s.Assert().ErrorContains(actualErr, c.err)
} else {
testToken := &azcore.AccessToken{
Token: testTokenString,
ExpiresOn: testTokenExpiry,
}
s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil).Once().
On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil)
actualTokenProvider, actualErr := newTokenProvider(c.cfg, s.mockCredential)
s.Assert().NotNil(actualTokenProvider)
s.Assert().Nil(actualErr)
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
// Token set to refresh at half of the expiry time. The test tokens are set to expiry in 5s.
// Hence, the 4 seconds wait to check if the token is refreshed.
time.Sleep(4 * time.Second)
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
s.mockCredential.AssertNumberOfCalls(s.T(), "GetToken", 2*mockGetTokenCallCounter)
mockGetTokenCallCounter += 1
accessToken, err := actualTokenProvider.getAccessToken(context.Background())
s.Assert().Nil(err)
s.Assert().NotEqual(accessToken, testTokenString)
}
}
testToken := &azcore.AccessToken{
Token: testTokenString,
ExpiresOn: testTokenExpiry,
}
s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil).Once().
On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil)
actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential)
s.Assert().NotNil(actualTokenProvider)
s.Assert().Nil(actualErr)
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
// Token set to refresh at half of the expiry time. The test tokens are set to expiry in 10s.
// Hence, the 6 seconds wait to check if the token is refreshed.
time.Sleep(6 * time.Second)
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
s.mockCredential.AssertNumberOfCalls(s.T(), "GetToken", 2)
accessToken, err := actualTokenProvider.getAccessToken(context.Background())
s.Assert().Nil(err)
s.Assert().NotEqual(accessToken, testTokenString)
}
func getToken() azcore.AccessToken {

0
storage/remote/azuread/testdata/azuread_bad_clientidmissing.yaml → storage/remote/azuread/testdata/azuread_bad_configmissing.yaml vendored

4
storage/remote/azuread/testdata/azuread_bad_invalidoauthconfig.yaml vendored

@ -0,0 +1,4 @@
cloud: AzurePublic
oauth:
client_id: 00000000-0000-0000-0000-000000000000
client_secret: Cl1ent$ecret!

7
storage/remote/azuread/testdata/azuread_bad_twoconfig.yaml vendored

@ -0,0 +1,7 @@
cloud: AzurePublic
managed_identity:
client_id: 00000000-0000-0000-0000-000000000000
oauth:
client_id: 00000000-0000-0000-0000-000000000000
client_secret: Cl1ent$ecret!
tenant_id: 00000000-a12b-3cd4-e56f-000000000000

0
storage/remote/azuread/testdata/azuread_good.yaml → storage/remote/azuread/testdata/azuread_good_managedidentity.yaml vendored

5
storage/remote/azuread/testdata/azuread_good_oauth.yaml vendored

@ -0,0 +1,5 @@
cloud: AzurePublic
oauth:
client_id: 00000000-0000-0000-0000-000000000000
client_secret: Cl1ent$ecret!
tenant_id: 00000000-a12b-3cd4-e56f-000000000000
Loading…
Cancel
Save