mirror of https://github.com/prometheus/prometheus
Added Azure OAuth support (#12572)
* Added Azure OAuth support Signed-off-by: rakshith210 <rakshith.me@gmail.com> * Added missing comment Signed-off-by: rakshith210 <rakshith.me@gmail.com> * Addressing comment Signed-off-by: rakshith210 <rakshith.me@gmail.com> * Fixed lint issue Signed-off-by: rakshith210 <rakshith.me@gmail.com> * Fix test Signed-off-by: rakshith210 <rakshith.me@gmail.com> * Addressing comments Signed-off-by: rakshith210 <rakshith.me@gmail.com> * Added documentation and updated unit tests Signed-off-by: rakshith210 <rakshith.me@gmail.com> * Addressing comments Signed-off-by: rakshith210 <rakshith.me@gmail.com> --------- Signed-off-by: rakshith210 <rakshith.me@gmail.com>pull/12937/head
parent
0331bcc7c9
commit
cdad64002a
|
@ -3537,7 +3537,13 @@ azuread:
|
|||
|
||||
# Azure User-assigned Managed identity.
|
||||
[ managed_identity:
|
||||
[ client_id: <string> ]
|
||||
[ client_id: <string> ] ]
|
||||
|
||||
# Azure OAuth.
|
||||
[ oauth:
|
||||
[ client_id: <string> ]
|
||||
[ client_secret: <string> ]
|
||||
[ tenant_id: <string> ] ]
|
||||
|
||||
# Configures the remote write request's TLS settings.
|
||||
tls_config:
|
||||
|
|
|
@ -22,7 +22,10 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/regexp"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
"github.com/google/uuid"
|
||||
|
@ -46,11 +49,26 @@ type ManagedIdentityConfig struct {
|
|||
ClientID string `yaml:"client_id,omitempty"`
|
||||
}
|
||||
|
||||
// OAuthConfig is used to store azure oauth config values.
|
||||
type OAuthConfig struct {
|
||||
// ClientID is the clientId of the azure active directory application that is being used to authenticate.
|
||||
ClientID string `yaml:"client_id,omitempty"`
|
||||
|
||||
// ClientSecret is the clientSecret of the azure active directory application that is being used to authenticate.
|
||||
ClientSecret string `yaml:"client_secret,omitempty"`
|
||||
|
||||
// TenantID is the tenantId of the azure active directory application that is being used to authenticate.
|
||||
TenantID string `yaml:"tenant_id,omitempty"`
|
||||
}
|
||||
|
||||
// AzureADConfig is used to store the config values.
|
||||
type AzureADConfig struct { // nolint:revive
|
||||
// ManagedIdentity is the managed identity that is being used to authenticate.
|
||||
ManagedIdentity *ManagedIdentityConfig `yaml:"managed_identity,omitempty"`
|
||||
|
||||
// OAuth is the oauth config that is being used to authenticate.
|
||||
OAuth *OAuthConfig `yaml:"oauth,omitempty"`
|
||||
|
||||
// Cloud is the Azure cloud in which the service is running. Example: AzurePublic/AzureGovernment/AzureChina.
|
||||
Cloud string `yaml:"cloud,omitempty"`
|
||||
}
|
||||
|
@ -84,18 +102,47 @@ func (c *AzureADConfig) Validate() error {
|
|||
return fmt.Errorf("must provide a cloud in the Azure AD config")
|
||||
}
|
||||
|
||||
if c.ManagedIdentity == nil {
|
||||
return fmt.Errorf("must provide an Azure Managed Identity in the Azure AD config")
|
||||
if c.ManagedIdentity == nil && c.OAuth == nil {
|
||||
return fmt.Errorf("must provide an Azure Managed Identity or Azure OAuth in the Azure AD config")
|
||||
}
|
||||
|
||||
if c.ManagedIdentity.ClientID == "" {
|
||||
return fmt.Errorf("must provide an Azure Managed Identity client_id in the Azure AD config")
|
||||
if c.ManagedIdentity != nil && c.OAuth != nil {
|
||||
return fmt.Errorf("cannot provide both Azure Managed Identity and Azure OAuth in the Azure AD config")
|
||||
}
|
||||
|
||||
_, err := uuid.Parse(c.ManagedIdentity.ClientID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("the provided Azure Managed Identity client_id provided is invalid")
|
||||
if c.ManagedIdentity != nil {
|
||||
if c.ManagedIdentity.ClientID == "" {
|
||||
return fmt.Errorf("must provide an Azure Managed Identity client_id in the Azure AD config")
|
||||
}
|
||||
|
||||
_, err := uuid.Parse(c.ManagedIdentity.ClientID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("the provided Azure Managed Identity client_id is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
if c.OAuth != nil {
|
||||
if c.OAuth.ClientID == "" {
|
||||
return fmt.Errorf("must provide an Azure OAuth client_id in the Azure AD config")
|
||||
}
|
||||
if c.OAuth.ClientSecret == "" {
|
||||
return fmt.Errorf("must provide an Azure OAuth client_secret in the Azure AD config")
|
||||
}
|
||||
if c.OAuth.TenantID == "" {
|
||||
return fmt.Errorf("must provide an Azure OAuth tenant_id in the Azure AD config")
|
||||
}
|
||||
|
||||
var err error
|
||||
_, err = uuid.Parse(c.OAuth.ClientID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("the provided Azure OAuth client_id is invalid")
|
||||
}
|
||||
_, err = regexp.MatchString("^[0-9a-zA-Z-.]+$", c.OAuth.TenantID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("the provided Azure OAuth tenant_id is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -146,21 +193,54 @@ func (rt *azureADRoundTripper) RoundTrip(req *http.Request) (*http.Response, err
|
|||
|
||||
// newTokenCredential returns a TokenCredential of different kinds like Azure Managed Identity and Azure AD application.
|
||||
func newTokenCredential(cfg *AzureADConfig) (azcore.TokenCredential, error) {
|
||||
cred, err := newManagedIdentityTokenCredential(cfg.ManagedIdentity.ClientID)
|
||||
var cred azcore.TokenCredential
|
||||
var err error
|
||||
cloudConfiguration, err := getCloudConfiguration(cfg.Cloud)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientOpts := &azcore.ClientOptions{
|
||||
Cloud: cloudConfiguration,
|
||||
}
|
||||
|
||||
if cfg.ManagedIdentity != nil {
|
||||
managedIdentityConfig := &ManagedIdentityConfig{
|
||||
ClientID: cfg.ManagedIdentity.ClientID,
|
||||
}
|
||||
cred, err = newManagedIdentityTokenCredential(clientOpts, managedIdentityConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.OAuth != nil {
|
||||
oAuthConfig := &OAuthConfig{
|
||||
ClientID: cfg.OAuth.ClientID,
|
||||
ClientSecret: cfg.OAuth.ClientSecret,
|
||||
TenantID: cfg.OAuth.TenantID,
|
||||
}
|
||||
cred, err = newOAuthTokenCredential(clientOpts, oAuthConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
// newManagedIdentityTokenCredential returns new Managed Identity token credential.
|
||||
func newManagedIdentityTokenCredential(managedIdentityClientID string) (azcore.TokenCredential, error) {
|
||||
clientID := azidentity.ClientID(managedIdentityClientID)
|
||||
opts := &azidentity.ManagedIdentityCredentialOptions{ID: clientID}
|
||||
func newManagedIdentityTokenCredential(clientOpts *azcore.ClientOptions, managedIdentityConfig *ManagedIdentityConfig) (azcore.TokenCredential, error) {
|
||||
clientID := azidentity.ClientID(managedIdentityConfig.ClientID)
|
||||
opts := &azidentity.ManagedIdentityCredentialOptions{ClientOptions: *clientOpts, ID: clientID}
|
||||
return azidentity.NewManagedIdentityCredential(opts)
|
||||
}
|
||||
|
||||
// newOAuthTokenCredential returns new OAuth token credential
|
||||
func newOAuthTokenCredential(clientOpts *azcore.ClientOptions, oAuthConfig *OAuthConfig) (azcore.TokenCredential, error) {
|
||||
opts := &azidentity.ClientSecretCredentialOptions{ClientOptions: *clientOpts}
|
||||
return azidentity.NewClientSecretCredential(oAuthConfig.TenantID, oAuthConfig.ClientID, oAuthConfig.ClientSecret, opts)
|
||||
}
|
||||
|
||||
// newTokenProvider helps to fetch accessToken for different types of credential. This also takes care of
|
||||
// refreshing the accessToken before expiry. This accessToken is attached to the Authorization header while making requests.
|
||||
func newTokenProvider(cfg *AzureADConfig, cred azcore.TokenCredential) (*tokenProvider, error) {
|
||||
|
@ -245,3 +325,17 @@ func getAudience(cloud string) (string, error) {
|
|||
return "", errors.New("Cloud is not specified or is incorrect: " + cloud)
|
||||
}
|
||||
}
|
||||
|
||||
// getCloudConfiguration returns the cloud Configuration which contains AAD endpoint for different clouds
|
||||
func getCloudConfiguration(c string) (cloud.Configuration, error) {
|
||||
switch strings.ToLower(c) {
|
||||
case strings.ToLower(AzureChina):
|
||||
return cloud.AzureChina, nil
|
||||
case strings.ToLower(AzureGovernment):
|
||||
return cloud.AzureGovernment, nil
|
||||
case strings.ToLower(AzurePublic):
|
||||
return cloud.AzurePublic, nil
|
||||
default:
|
||||
return cloud.Configuration{}, errors.New("Cloud is not specified or is incorrect: " + c)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,17 +26,20 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
dummyAudience = "dummyAudience"
|
||||
dummyClientID = "00000000-0000-0000-0000-000000000000"
|
||||
testTokenString = "testTokenString"
|
||||
dummyAudience = "dummyAudience"
|
||||
dummyClientID = "00000000-0000-0000-0000-000000000000"
|
||||
dummyClientSecret = "Cl1ent$ecret!"
|
||||
dummyTenantID = "00000000-a12b-3cd4-e56f-000000000000"
|
||||
testTokenString = "testTokenString"
|
||||
)
|
||||
|
||||
var testTokenExpiry = time.Now().Add(10 * time.Second)
|
||||
var testTokenExpiry = time.Now().Add(5 * time.Second)
|
||||
|
||||
type AzureAdTestSuite struct {
|
||||
suite.Suite
|
||||
|
@ -62,47 +65,64 @@ func TestAzureAd(t *testing.T) {
|
|||
}
|
||||
|
||||
func (ad *AzureAdTestSuite) TestAzureAdRoundTripper() {
|
||||
var gotReq *http.Request
|
||||
|
||||
testToken := &azcore.AccessToken{
|
||||
Token: testTokenString,
|
||||
ExpiresOn: testTokenExpiry,
|
||||
cases := []struct {
|
||||
cfg *AzureADConfig
|
||||
}{
|
||||
// AzureAd roundtripper with Managedidentity.
|
||||
{
|
||||
cfg: &AzureADConfig{
|
||||
Cloud: "AzurePublic",
|
||||
ManagedIdentity: &ManagedIdentityConfig{
|
||||
ClientID: dummyClientID,
|
||||
},
|
||||
},
|
||||
},
|
||||
// AzureAd roundtripper with OAuth.
|
||||
{
|
||||
cfg: &AzureADConfig{
|
||||
Cloud: "AzurePublic",
|
||||
OAuth: &OAuthConfig{
|
||||
ClientID: dummyClientID,
|
||||
ClientSecret: dummyClientSecret,
|
||||
TenantID: dummyTenantID,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
var gotReq *http.Request
|
||||
|
||||
managedIdentityConfig := &ManagedIdentityConfig{
|
||||
ClientID: dummyClientID,
|
||||
testToken := &azcore.AccessToken{
|
||||
Token: testTokenString,
|
||||
ExpiresOn: testTokenExpiry,
|
||||
}
|
||||
|
||||
ad.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil)
|
||||
|
||||
tokenProvider, err := newTokenProvider(c.cfg, ad.mockCredential)
|
||||
ad.Assert().NoError(err)
|
||||
|
||||
rt := &azureADRoundTripper{
|
||||
next: promhttp.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
gotReq = req
|
||||
return &http.Response{StatusCode: http.StatusOK}, nil
|
||||
}),
|
||||
tokenProvider: tokenProvider,
|
||||
}
|
||||
|
||||
cli := &http.Client{Transport: rt}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!"))
|
||||
ad.Assert().NoError(err)
|
||||
|
||||
_, err = cli.Do(req)
|
||||
ad.Assert().NoError(err)
|
||||
ad.Assert().NotNil(gotReq)
|
||||
|
||||
origReq := gotReq
|
||||
ad.Assert().NotEmpty(origReq.Header.Get("Authorization"))
|
||||
ad.Assert().Equal("Bearer "+testTokenString, origReq.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
azureAdConfig := &AzureADConfig{
|
||||
Cloud: "AzurePublic",
|
||||
ManagedIdentity: managedIdentityConfig,
|
||||
}
|
||||
|
||||
ad.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil)
|
||||
|
||||
tokenProvider, err := newTokenProvider(azureAdConfig, ad.mockCredential)
|
||||
ad.Assert().NoError(err)
|
||||
|
||||
rt := &azureADRoundTripper{
|
||||
next: promhttp.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
gotReq = req
|
||||
return &http.Response{StatusCode: http.StatusOK}, nil
|
||||
}),
|
||||
tokenProvider: tokenProvider,
|
||||
}
|
||||
|
||||
cli := &http.Client{Transport: rt}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!"))
|
||||
ad.Assert().NoError(err)
|
||||
|
||||
_, err = cli.Do(req)
|
||||
ad.Assert().NoError(err)
|
||||
ad.Assert().NotNil(gotReq)
|
||||
|
||||
origReq := gotReq
|
||||
ad.Assert().NotEmpty(origReq.Header.Get("Authorization"))
|
||||
ad.Assert().Equal("Bearer "+testTokenString, origReq.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
func loadAzureAdConfig(filename string) (*AzureADConfig, error) {
|
||||
|
@ -117,42 +137,54 @@ func loadAzureAdConfig(filename string) (*AzureADConfig, error) {
|
|||
return &cfg, nil
|
||||
}
|
||||
|
||||
func testGoodConfig(t *testing.T, filename string) {
|
||||
_, err := loadAzureAdConfig(filename)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error parsing %s: %s", filename, err)
|
||||
func TestAzureAdConfig(t *testing.T) {
|
||||
cases := []struct {
|
||||
filename string
|
||||
err string
|
||||
}{
|
||||
// Missing managedidentiy or oauth field.
|
||||
{
|
||||
filename: "testdata/azuread_bad_configmissing.yaml",
|
||||
err: "must provide an Azure Managed Identity or Azure OAuth in the Azure AD config",
|
||||
},
|
||||
// Invalid managedidentity client id.
|
||||
{
|
||||
filename: "testdata/azuread_bad_invalidclientid.yaml",
|
||||
err: "the provided Azure Managed Identity client_id is invalid",
|
||||
},
|
||||
// Missing tenant id in oauth config.
|
||||
{
|
||||
filename: "testdata/azuread_bad_invalidoauthconfig.yaml",
|
||||
err: "must provide an Azure OAuth tenant_id in the Azure AD config",
|
||||
},
|
||||
// Invalid config when both managedidentity and oauth is provided.
|
||||
{
|
||||
filename: "testdata/azuread_bad_twoconfig.yaml",
|
||||
err: "cannot provide both Azure Managed Identity and Azure OAuth in the Azure AD config",
|
||||
},
|
||||
// Valid config with missing optionally cloud field.
|
||||
{
|
||||
filename: "testdata/azuread_good_cloudmissing.yaml",
|
||||
},
|
||||
// Valid managed identity config.
|
||||
{
|
||||
filename: "testdata/azuread_good_managedidentity.yaml",
|
||||
},
|
||||
// Valid Oauth config.
|
||||
{
|
||||
filename: "testdata/azuread_good_oauth.yaml",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoodAzureAdConfig(t *testing.T) {
|
||||
filename := "testdata/azuread_good.yaml"
|
||||
testGoodConfig(t, filename)
|
||||
}
|
||||
|
||||
func TestGoodCloudMissingAzureAdConfig(t *testing.T) {
|
||||
filename := "testdata/azuread_good_cloudmissing.yaml"
|
||||
testGoodConfig(t, filename)
|
||||
}
|
||||
|
||||
func TestBadClientIdMissingAzureAdConfig(t *testing.T) {
|
||||
filename := "testdata/azuread_bad_clientidmissing.yaml"
|
||||
_, err := loadAzureAdConfig(filename)
|
||||
if err == nil {
|
||||
t.Fatalf("Did not receive expected error unmarshaling bad azuread config")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "must provide an Azure Managed Identity in the Azure AD config") {
|
||||
t.Errorf("Received unexpected error from unmarshal of %s: %s", filename, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadInvalidClientIdAzureAdConfig(t *testing.T) {
|
||||
filename := "testdata/azuread_bad_invalidclientid.yaml"
|
||||
_, err := loadAzureAdConfig(filename)
|
||||
if err == nil {
|
||||
t.Fatalf("Did not receive expected error unmarshaling bad azuread config")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "the provided Azure Managed Identity client_id provided is invalid") {
|
||||
t.Errorf("Received unexpected error from unmarshal of %s: %s", filename, err.Error())
|
||||
for _, c := range cases {
|
||||
_, err := loadAzureAdConfig(c.filename)
|
||||
if c.err != "" {
|
||||
if err == nil {
|
||||
t.Fatalf("Did not receive expected error unmarshaling bad azuread config")
|
||||
}
|
||||
require.EqualError(t, err, c.err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -173,75 +205,90 @@ func TestTokenProvider(t *testing.T) {
|
|||
suite.Run(t, new(TokenProviderTestSuite))
|
||||
}
|
||||
|
||||
func (s *TokenProviderTestSuite) TestNewTokenProvider_NilAudience_Fail() {
|
||||
managedIdentityConfig := &ManagedIdentityConfig{
|
||||
ClientID: dummyClientID,
|
||||
func (s *TokenProviderTestSuite) TestNewTokenProvider() {
|
||||
cases := []struct {
|
||||
cfg *AzureADConfig
|
||||
err string
|
||||
}{
|
||||
// Invalid tokenProvider for managedidentity.
|
||||
{
|
||||
cfg: &AzureADConfig{
|
||||
Cloud: "PublicAzure",
|
||||
ManagedIdentity: &ManagedIdentityConfig{
|
||||
ClientID: dummyClientID,
|
||||
},
|
||||
},
|
||||
err: "Cloud is not specified or is incorrect: ",
|
||||
},
|
||||
// Invalid tokenProvider for oauth.
|
||||
{
|
||||
cfg: &AzureADConfig{
|
||||
Cloud: "PublicAzure",
|
||||
OAuth: &OAuthConfig{
|
||||
ClientID: dummyClientID,
|
||||
ClientSecret: dummyClientSecret,
|
||||
TenantID: dummyTenantID,
|
||||
},
|
||||
},
|
||||
err: "Cloud is not specified or is incorrect: ",
|
||||
},
|
||||
// Valid tokenProvider for managedidentity.
|
||||
{
|
||||
cfg: &AzureADConfig{
|
||||
Cloud: "AzurePublic",
|
||||
ManagedIdentity: &ManagedIdentityConfig{
|
||||
ClientID: dummyClientID,
|
||||
},
|
||||
},
|
||||
},
|
||||
// Valid tokenProvider for oauth.
|
||||
{
|
||||
cfg: &AzureADConfig{
|
||||
Cloud: "AzurePublic",
|
||||
OAuth: &OAuthConfig{
|
||||
ClientID: dummyClientID,
|
||||
ClientSecret: dummyClientSecret,
|
||||
TenantID: dummyTenantID,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
mockGetTokenCallCounter := 1
|
||||
for _, c := range cases {
|
||||
if c.err != "" {
|
||||
actualTokenProvider, actualErr := newTokenProvider(c.cfg, s.mockCredential)
|
||||
|
||||
azureAdConfig := &AzureADConfig{
|
||||
Cloud: "PublicAzure",
|
||||
ManagedIdentity: managedIdentityConfig,
|
||||
s.Assert().Nil(actualTokenProvider)
|
||||
s.Assert().NotNil(actualErr)
|
||||
s.Assert().ErrorContains(actualErr, c.err)
|
||||
} else {
|
||||
testToken := &azcore.AccessToken{
|
||||
Token: testTokenString,
|
||||
ExpiresOn: testTokenExpiry,
|
||||
}
|
||||
|
||||
s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil).Once().
|
||||
On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil)
|
||||
|
||||
actualTokenProvider, actualErr := newTokenProvider(c.cfg, s.mockCredential)
|
||||
|
||||
s.Assert().NotNil(actualTokenProvider)
|
||||
s.Assert().Nil(actualErr)
|
||||
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
|
||||
|
||||
// Token set to refresh at half of the expiry time. The test tokens are set to expiry in 5s.
|
||||
// Hence, the 4 seconds wait to check if the token is refreshed.
|
||||
time.Sleep(4 * time.Second)
|
||||
|
||||
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
|
||||
|
||||
s.mockCredential.AssertNumberOfCalls(s.T(), "GetToken", 2*mockGetTokenCallCounter)
|
||||
mockGetTokenCallCounter += 1
|
||||
accessToken, err := actualTokenProvider.getAccessToken(context.Background())
|
||||
s.Assert().Nil(err)
|
||||
s.Assert().NotEqual(accessToken, testTokenString)
|
||||
}
|
||||
}
|
||||
|
||||
actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential)
|
||||
|
||||
s.Assert().Nil(actualTokenProvider)
|
||||
s.Assert().NotNil(actualErr)
|
||||
s.Assert().Equal("Cloud is not specified or is incorrect: "+azureAdConfig.Cloud, actualErr.Error())
|
||||
}
|
||||
|
||||
func (s *TokenProviderTestSuite) TestNewTokenProvider_Success() {
|
||||
managedIdentityConfig := &ManagedIdentityConfig{
|
||||
ClientID: dummyClientID,
|
||||
}
|
||||
|
||||
azureAdConfig := &AzureADConfig{
|
||||
Cloud: "AzurePublic",
|
||||
ManagedIdentity: managedIdentityConfig,
|
||||
}
|
||||
s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil)
|
||||
|
||||
actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential)
|
||||
|
||||
s.Assert().NotNil(actualTokenProvider)
|
||||
s.Assert().Nil(actualErr)
|
||||
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
|
||||
}
|
||||
|
||||
func (s *TokenProviderTestSuite) TestPeriodicTokenRefresh_Success() {
|
||||
// setup
|
||||
managedIdentityConfig := &ManagedIdentityConfig{
|
||||
ClientID: dummyClientID,
|
||||
}
|
||||
|
||||
azureAdConfig := &AzureADConfig{
|
||||
Cloud: "AzurePublic",
|
||||
ManagedIdentity: managedIdentityConfig,
|
||||
}
|
||||
testToken := &azcore.AccessToken{
|
||||
Token: testTokenString,
|
||||
ExpiresOn: testTokenExpiry,
|
||||
}
|
||||
|
||||
s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil).Once().
|
||||
On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil)
|
||||
|
||||
actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential)
|
||||
|
||||
s.Assert().NotNil(actualTokenProvider)
|
||||
s.Assert().Nil(actualErr)
|
||||
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
|
||||
|
||||
// Token set to refresh at half of the expiry time. The test tokens are set to expiry in 10s.
|
||||
// Hence, the 6 seconds wait to check if the token is refreshed.
|
||||
time.Sleep(6 * time.Second)
|
||||
|
||||
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
|
||||
|
||||
s.mockCredential.AssertNumberOfCalls(s.T(), "GetToken", 2)
|
||||
accessToken, err := actualTokenProvider.getAccessToken(context.Background())
|
||||
s.Assert().Nil(err)
|
||||
s.Assert().NotEqual(accessToken, testTokenString)
|
||||
}
|
||||
|
||||
func getToken() azcore.AccessToken {
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
cloud: AzurePublic
|
||||
oauth:
|
||||
client_id: 00000000-0000-0000-0000-000000000000
|
||||
client_secret: Cl1ent$ecret!
|
|
@ -0,0 +1,7 @@
|
|||
cloud: AzurePublic
|
||||
managed_identity:
|
||||
client_id: 00000000-0000-0000-0000-000000000000
|
||||
oauth:
|
||||
client_id: 00000000-0000-0000-0000-000000000000
|
||||
client_secret: Cl1ent$ecret!
|
||||
tenant_id: 00000000-a12b-3cd4-e56f-000000000000
|
|
@ -0,0 +1,5 @@
|
|||
cloud: AzurePublic
|
||||
oauth:
|
||||
client_id: 00000000-0000-0000-0000-000000000000
|
||||
client_secret: Cl1ent$ecret!
|
||||
tenant_id: 00000000-a12b-3cd4-e56f-000000000000
|
Loading…
Reference in New Issue