|
|
|
@ -26,17 +26,20 @@ import (
|
|
|
|
|
"github.com/google/uuid" |
|
|
|
|
"github.com/prometheus/client_golang/prometheus/promhttp" |
|
|
|
|
"github.com/stretchr/testify/mock" |
|
|
|
|
"github.com/stretchr/testify/require" |
|
|
|
|
"github.com/stretchr/testify/suite" |
|
|
|
|
"gopkg.in/yaml.v2" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
const ( |
|
|
|
|
dummyAudience = "dummyAudience" |
|
|
|
|
dummyClientID = "00000000-0000-0000-0000-000000000000" |
|
|
|
|
testTokenString = "testTokenString" |
|
|
|
|
dummyAudience = "dummyAudience" |
|
|
|
|
dummyClientID = "00000000-0000-0000-0000-000000000000" |
|
|
|
|
dummyClientSecret = "Cl1ent$ecret!" |
|
|
|
|
dummyTenantID = "00000000-a12b-3cd4-e56f-000000000000" |
|
|
|
|
testTokenString = "testTokenString" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
var testTokenExpiry = time.Now().Add(10 * time.Second) |
|
|
|
|
var testTokenExpiry = time.Now().Add(5 * time.Second) |
|
|
|
|
|
|
|
|
|
type AzureAdTestSuite struct { |
|
|
|
|
suite.Suite |
|
|
|
@ -62,47 +65,64 @@ func TestAzureAd(t *testing.T) {
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (ad *AzureAdTestSuite) TestAzureAdRoundTripper() { |
|
|
|
|
var gotReq *http.Request |
|
|
|
|
|
|
|
|
|
testToken := &azcore.AccessToken{ |
|
|
|
|
Token: testTokenString, |
|
|
|
|
ExpiresOn: testTokenExpiry, |
|
|
|
|
cases := []struct { |
|
|
|
|
cfg *AzureADConfig |
|
|
|
|
}{ |
|
|
|
|
// AzureAd roundtripper with Managedidentity.
|
|
|
|
|
{ |
|
|
|
|
cfg: &AzureADConfig{ |
|
|
|
|
Cloud: "AzurePublic", |
|
|
|
|
ManagedIdentity: &ManagedIdentityConfig{ |
|
|
|
|
ClientID: dummyClientID, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
// AzureAd roundtripper with OAuth.
|
|
|
|
|
{ |
|
|
|
|
cfg: &AzureADConfig{ |
|
|
|
|
Cloud: "AzurePublic", |
|
|
|
|
OAuth: &OAuthConfig{ |
|
|
|
|
ClientID: dummyClientID, |
|
|
|
|
ClientSecret: dummyClientSecret, |
|
|
|
|
TenantID: dummyTenantID, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
for _, c := range cases { |
|
|
|
|
var gotReq *http.Request |
|
|
|
|
|
|
|
|
|
managedIdentityConfig := &ManagedIdentityConfig{ |
|
|
|
|
ClientID: dummyClientID, |
|
|
|
|
} |
|
|
|
|
testToken := &azcore.AccessToken{ |
|
|
|
|
Token: testTokenString, |
|
|
|
|
ExpiresOn: testTokenExpiry, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
azureAdConfig := &AzureADConfig{ |
|
|
|
|
Cloud: "AzurePublic", |
|
|
|
|
ManagedIdentity: managedIdentityConfig, |
|
|
|
|
} |
|
|
|
|
ad.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil) |
|
|
|
|
|
|
|
|
|
ad.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil) |
|
|
|
|
tokenProvider, err := newTokenProvider(c.cfg, ad.mockCredential) |
|
|
|
|
ad.Assert().NoError(err) |
|
|
|
|
|
|
|
|
|
tokenProvider, err := newTokenProvider(azureAdConfig, ad.mockCredential) |
|
|
|
|
ad.Assert().NoError(err) |
|
|
|
|
|
|
|
|
|
rt := &azureADRoundTripper{ |
|
|
|
|
next: promhttp.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { |
|
|
|
|
gotReq = req |
|
|
|
|
return &http.Response{StatusCode: http.StatusOK}, nil |
|
|
|
|
}), |
|
|
|
|
tokenProvider: tokenProvider, |
|
|
|
|
} |
|
|
|
|
rt := &azureADRoundTripper{ |
|
|
|
|
next: promhttp.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { |
|
|
|
|
gotReq = req |
|
|
|
|
return &http.Response{StatusCode: http.StatusOK}, nil |
|
|
|
|
}), |
|
|
|
|
tokenProvider: tokenProvider, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
cli := &http.Client{Transport: rt} |
|
|
|
|
cli := &http.Client{Transport: rt} |
|
|
|
|
|
|
|
|
|
req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!")) |
|
|
|
|
ad.Assert().NoError(err) |
|
|
|
|
req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!")) |
|
|
|
|
ad.Assert().NoError(err) |
|
|
|
|
|
|
|
|
|
_, err = cli.Do(req) |
|
|
|
|
ad.Assert().NoError(err) |
|
|
|
|
ad.Assert().NotNil(gotReq) |
|
|
|
|
_, err = cli.Do(req) |
|
|
|
|
ad.Assert().NoError(err) |
|
|
|
|
ad.Assert().NotNil(gotReq) |
|
|
|
|
|
|
|
|
|
origReq := gotReq |
|
|
|
|
ad.Assert().NotEmpty(origReq.Header.Get("Authorization")) |
|
|
|
|
ad.Assert().Equal("Bearer "+testTokenString, origReq.Header.Get("Authorization")) |
|
|
|
|
origReq := gotReq |
|
|
|
|
ad.Assert().NotEmpty(origReq.Header.Get("Authorization")) |
|
|
|
|
ad.Assert().Equal("Bearer "+testTokenString, origReq.Header.Get("Authorization")) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func loadAzureAdConfig(filename string) (*AzureADConfig, error) { |
|
|
|
@ -117,42 +137,54 @@ func loadAzureAdConfig(filename string) (*AzureADConfig, error) {
|
|
|
|
|
return &cfg, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func testGoodConfig(t *testing.T, filename string) { |
|
|
|
|
_, err := loadAzureAdConfig(filename) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Fatalf("Unexpected error parsing %s: %s", filename, err) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestGoodAzureAdConfig(t *testing.T) { |
|
|
|
|
filename := "testdata/azuread_good.yaml" |
|
|
|
|
testGoodConfig(t, filename) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestGoodCloudMissingAzureAdConfig(t *testing.T) { |
|
|
|
|
filename := "testdata/azuread_good_cloudmissing.yaml" |
|
|
|
|
testGoodConfig(t, filename) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestBadClientIdMissingAzureAdConfig(t *testing.T) { |
|
|
|
|
filename := "testdata/azuread_bad_clientidmissing.yaml" |
|
|
|
|
_, err := loadAzureAdConfig(filename) |
|
|
|
|
if err == nil { |
|
|
|
|
t.Fatalf("Did not receive expected error unmarshaling bad azuread config") |
|
|
|
|
} |
|
|
|
|
if !strings.Contains(err.Error(), "must provide an Azure Managed Identity in the Azure AD config") { |
|
|
|
|
t.Errorf("Received unexpected error from unmarshal of %s: %s", filename, err.Error()) |
|
|
|
|
func TestAzureAdConfig(t *testing.T) { |
|
|
|
|
cases := []struct { |
|
|
|
|
filename string |
|
|
|
|
err string |
|
|
|
|
}{ |
|
|
|
|
// Missing managedidentiy or oauth field.
|
|
|
|
|
{ |
|
|
|
|
filename: "testdata/azuread_bad_configmissing.yaml", |
|
|
|
|
err: "must provide an Azure Managed Identity or Azure OAuth in the Azure AD config", |
|
|
|
|
}, |
|
|
|
|
// Invalid managedidentity client id.
|
|
|
|
|
{ |
|
|
|
|
filename: "testdata/azuread_bad_invalidclientid.yaml", |
|
|
|
|
err: "the provided Azure Managed Identity client_id is invalid", |
|
|
|
|
}, |
|
|
|
|
// Missing tenant id in oauth config.
|
|
|
|
|
{ |
|
|
|
|
filename: "testdata/azuread_bad_invalidoauthconfig.yaml", |
|
|
|
|
err: "must provide an Azure OAuth tenant_id in the Azure AD config", |
|
|
|
|
}, |
|
|
|
|
// Invalid config when both managedidentity and oauth is provided.
|
|
|
|
|
{ |
|
|
|
|
filename: "testdata/azuread_bad_twoconfig.yaml", |
|
|
|
|
err: "cannot provide both Azure Managed Identity and Azure OAuth in the Azure AD config", |
|
|
|
|
}, |
|
|
|
|
// Valid config with missing optionally cloud field.
|
|
|
|
|
{ |
|
|
|
|
filename: "testdata/azuread_good_cloudmissing.yaml", |
|
|
|
|
}, |
|
|
|
|
// Valid managed identity config.
|
|
|
|
|
{ |
|
|
|
|
filename: "testdata/azuread_good_managedidentity.yaml", |
|
|
|
|
}, |
|
|
|
|
// Valid Oauth config.
|
|
|
|
|
{ |
|
|
|
|
filename: "testdata/azuread_good_oauth.yaml", |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestBadInvalidClientIdAzureAdConfig(t *testing.T) { |
|
|
|
|
filename := "testdata/azuread_bad_invalidclientid.yaml" |
|
|
|
|
_, err := loadAzureAdConfig(filename) |
|
|
|
|
if err == nil { |
|
|
|
|
t.Fatalf("Did not receive expected error unmarshaling bad azuread config") |
|
|
|
|
} |
|
|
|
|
if !strings.Contains(err.Error(), "the provided Azure Managed Identity client_id provided is invalid") { |
|
|
|
|
t.Errorf("Received unexpected error from unmarshal of %s: %s", filename, err.Error()) |
|
|
|
|
for _, c := range cases { |
|
|
|
|
_, err := loadAzureAdConfig(c.filename) |
|
|
|
|
if c.err != "" { |
|
|
|
|
if err == nil { |
|
|
|
|
t.Fatalf("Did not receive expected error unmarshaling bad azuread config") |
|
|
|
|
} |
|
|
|
|
require.EqualError(t, err, c.err) |
|
|
|
|
} else { |
|
|
|
|
require.NoError(t, err) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -173,75 +205,90 @@ func TestTokenProvider(t *testing.T) {
|
|
|
|
|
suite.Run(t, new(TokenProviderTestSuite)) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (s *TokenProviderTestSuite) TestNewTokenProvider_NilAudience_Fail() { |
|
|
|
|
managedIdentityConfig := &ManagedIdentityConfig{ |
|
|
|
|
ClientID: dummyClientID, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
azureAdConfig := &AzureADConfig{ |
|
|
|
|
Cloud: "PublicAzure", |
|
|
|
|
ManagedIdentity: managedIdentityConfig, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential) |
|
|
|
|
|
|
|
|
|
s.Assert().Nil(actualTokenProvider) |
|
|
|
|
s.Assert().NotNil(actualErr) |
|
|
|
|
s.Assert().Equal("Cloud is not specified or is incorrect: "+azureAdConfig.Cloud, actualErr.Error()) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (s *TokenProviderTestSuite) TestNewTokenProvider_Success() { |
|
|
|
|
managedIdentityConfig := &ManagedIdentityConfig{ |
|
|
|
|
ClientID: dummyClientID, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
azureAdConfig := &AzureADConfig{ |
|
|
|
|
Cloud: "AzurePublic", |
|
|
|
|
ManagedIdentity: managedIdentityConfig, |
|
|
|
|
} |
|
|
|
|
s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil) |
|
|
|
|
|
|
|
|
|
actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential) |
|
|
|
|
|
|
|
|
|
s.Assert().NotNil(actualTokenProvider) |
|
|
|
|
s.Assert().Nil(actualErr) |
|
|
|
|
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background())) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (s *TokenProviderTestSuite) TestPeriodicTokenRefresh_Success() { |
|
|
|
|
// setup
|
|
|
|
|
managedIdentityConfig := &ManagedIdentityConfig{ |
|
|
|
|
ClientID: dummyClientID, |
|
|
|
|
func (s *TokenProviderTestSuite) TestNewTokenProvider() { |
|
|
|
|
cases := []struct { |
|
|
|
|
cfg *AzureADConfig |
|
|
|
|
err string |
|
|
|
|
}{ |
|
|
|
|
// Invalid tokenProvider for managedidentity.
|
|
|
|
|
{ |
|
|
|
|
cfg: &AzureADConfig{ |
|
|
|
|
Cloud: "PublicAzure", |
|
|
|
|
ManagedIdentity: &ManagedIdentityConfig{ |
|
|
|
|
ClientID: dummyClientID, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
err: "Cloud is not specified or is incorrect: ", |
|
|
|
|
}, |
|
|
|
|
// Invalid tokenProvider for oauth.
|
|
|
|
|
{ |
|
|
|
|
cfg: &AzureADConfig{ |
|
|
|
|
Cloud: "PublicAzure", |
|
|
|
|
OAuth: &OAuthConfig{ |
|
|
|
|
ClientID: dummyClientID, |
|
|
|
|
ClientSecret: dummyClientSecret, |
|
|
|
|
TenantID: dummyTenantID, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
err: "Cloud is not specified or is incorrect: ", |
|
|
|
|
}, |
|
|
|
|
// Valid tokenProvider for managedidentity.
|
|
|
|
|
{ |
|
|
|
|
cfg: &AzureADConfig{ |
|
|
|
|
Cloud: "AzurePublic", |
|
|
|
|
ManagedIdentity: &ManagedIdentityConfig{ |
|
|
|
|
ClientID: dummyClientID, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
// Valid tokenProvider for oauth.
|
|
|
|
|
{ |
|
|
|
|
cfg: &AzureADConfig{ |
|
|
|
|
Cloud: "AzurePublic", |
|
|
|
|
OAuth: &OAuthConfig{ |
|
|
|
|
ClientID: dummyClientID, |
|
|
|
|
ClientSecret: dummyClientSecret, |
|
|
|
|
TenantID: dummyTenantID, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
azureAdConfig := &AzureADConfig{ |
|
|
|
|
Cloud: "AzurePublic", |
|
|
|
|
ManagedIdentity: managedIdentityConfig, |
|
|
|
|
mockGetTokenCallCounter := 1 |
|
|
|
|
for _, c := range cases { |
|
|
|
|
if c.err != "" { |
|
|
|
|
actualTokenProvider, actualErr := newTokenProvider(c.cfg, s.mockCredential) |
|
|
|
|
|
|
|
|
|
s.Assert().Nil(actualTokenProvider) |
|
|
|
|
s.Assert().NotNil(actualErr) |
|
|
|
|
s.Assert().ErrorContains(actualErr, c.err) |
|
|
|
|
} else { |
|
|
|
|
testToken := &azcore.AccessToken{ |
|
|
|
|
Token: testTokenString, |
|
|
|
|
ExpiresOn: testTokenExpiry, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil).Once(). |
|
|
|
|
On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil) |
|
|
|
|
|
|
|
|
|
actualTokenProvider, actualErr := newTokenProvider(c.cfg, s.mockCredential) |
|
|
|
|
|
|
|
|
|
s.Assert().NotNil(actualTokenProvider) |
|
|
|
|
s.Assert().Nil(actualErr) |
|
|
|
|
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background())) |
|
|
|
|
|
|
|
|
|
// Token set to refresh at half of the expiry time. The test tokens are set to expiry in 5s.
|
|
|
|
|
// Hence, the 4 seconds wait to check if the token is refreshed.
|
|
|
|
|
time.Sleep(4 * time.Second) |
|
|
|
|
|
|
|
|
|
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background())) |
|
|
|
|
|
|
|
|
|
s.mockCredential.AssertNumberOfCalls(s.T(), "GetToken", 2*mockGetTokenCallCounter) |
|
|
|
|
mockGetTokenCallCounter += 1 |
|
|
|
|
accessToken, err := actualTokenProvider.getAccessToken(context.Background()) |
|
|
|
|
s.Assert().Nil(err) |
|
|
|
|
s.Assert().NotEqual(accessToken, testTokenString) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
testToken := &azcore.AccessToken{ |
|
|
|
|
Token: testTokenString, |
|
|
|
|
ExpiresOn: testTokenExpiry, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil).Once(). |
|
|
|
|
On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil) |
|
|
|
|
|
|
|
|
|
actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential) |
|
|
|
|
|
|
|
|
|
s.Assert().NotNil(actualTokenProvider) |
|
|
|
|
s.Assert().Nil(actualErr) |
|
|
|
|
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background())) |
|
|
|
|
|
|
|
|
|
// Token set to refresh at half of the expiry time. The test tokens are set to expiry in 10s.
|
|
|
|
|
// Hence, the 6 seconds wait to check if the token is refreshed.
|
|
|
|
|
time.Sleep(6 * time.Second) |
|
|
|
|
|
|
|
|
|
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background())) |
|
|
|
|
|
|
|
|
|
s.mockCredential.AssertNumberOfCalls(s.T(), "GetToken", 2) |
|
|
|
|
accessToken, err := actualTokenProvider.getAccessToken(context.Background()) |
|
|
|
|
s.Assert().Nil(err) |
|
|
|
|
s.Assert().NotEqual(accessToken, testTokenString) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func getToken() azcore.AccessToken { |
|
|
|
|