mirror of https://github.com/prometheus/prometheus
300 lines
7.9 KiB
Go
300 lines
7.9 KiB
Go
// Copyright 2023 The Prometheus Authors
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package azuread
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
|
"github.com/google/uuid"
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/stretchr/testify/suite"
|
|
"gopkg.in/yaml.v2"
|
|
)
|
|
|
|
const (
|
|
dummyAudience = "dummyAudience"
|
|
dummyClientID = "00000000-0000-0000-0000-000000000000"
|
|
dummyClientSecret = "Cl1ent$ecret!"
|
|
dummyTenantID = "00000000-a12b-3cd4-e56f-000000000000"
|
|
testTokenString = "testTokenString"
|
|
)
|
|
|
|
var testTokenExpiry = time.Now().Add(5 * time.Second)
|
|
|
|
type AzureAdTestSuite struct {
|
|
suite.Suite
|
|
mockCredential *mockCredential
|
|
}
|
|
|
|
type TokenProviderTestSuite struct {
|
|
suite.Suite
|
|
mockCredential *mockCredential
|
|
}
|
|
|
|
// mockCredential mocks azidentity TokenCredential interface.
|
|
type mockCredential struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (ad *AzureAdTestSuite) BeforeTest(_, _ string) {
|
|
ad.mockCredential = new(mockCredential)
|
|
}
|
|
|
|
func TestAzureAd(t *testing.T) {
|
|
suite.Run(t, new(AzureAdTestSuite))
|
|
}
|
|
|
|
func (ad *AzureAdTestSuite) TestAzureAdRoundTripper() {
|
|
cases := []struct {
|
|
cfg *AzureADConfig
|
|
}{
|
|
// AzureAd roundtripper with Managedidentity.
|
|
{
|
|
cfg: &AzureADConfig{
|
|
Cloud: "AzurePublic",
|
|
ManagedIdentity: &ManagedIdentityConfig{
|
|
ClientID: dummyClientID,
|
|
},
|
|
},
|
|
},
|
|
// AzureAd roundtripper with OAuth.
|
|
{
|
|
cfg: &AzureADConfig{
|
|
Cloud: "AzurePublic",
|
|
OAuth: &OAuthConfig{
|
|
ClientID: dummyClientID,
|
|
ClientSecret: dummyClientSecret,
|
|
TenantID: dummyTenantID,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
for _, c := range cases {
|
|
var gotReq *http.Request
|
|
|
|
testToken := &azcore.AccessToken{
|
|
Token: testTokenString,
|
|
ExpiresOn: testTokenExpiry,
|
|
}
|
|
|
|
ad.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil)
|
|
|
|
tokenProvider, err := newTokenProvider(c.cfg, ad.mockCredential)
|
|
ad.Assert().NoError(err)
|
|
|
|
rt := &azureADRoundTripper{
|
|
next: promhttp.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
|
gotReq = req
|
|
return &http.Response{StatusCode: http.StatusOK}, nil
|
|
}),
|
|
tokenProvider: tokenProvider,
|
|
}
|
|
|
|
cli := &http.Client{Transport: rt}
|
|
|
|
req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!"))
|
|
ad.Assert().NoError(err)
|
|
|
|
_, err = cli.Do(req)
|
|
ad.Assert().NoError(err)
|
|
ad.Assert().NotNil(gotReq)
|
|
|
|
origReq := gotReq
|
|
ad.Assert().NotEmpty(origReq.Header.Get("Authorization"))
|
|
ad.Assert().Equal("Bearer "+testTokenString, origReq.Header.Get("Authorization"))
|
|
}
|
|
}
|
|
|
|
func loadAzureAdConfig(filename string) (*AzureADConfig, error) {
|
|
content, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cfg := AzureADConfig{}
|
|
if err = yaml.UnmarshalStrict(content, &cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
return &cfg, nil
|
|
}
|
|
|
|
func TestAzureAdConfig(t *testing.T) {
|
|
cases := []struct {
|
|
filename string
|
|
err string
|
|
}{
|
|
// Missing managedidentiy or oauth field.
|
|
{
|
|
filename: "testdata/azuread_bad_configmissing.yaml",
|
|
err: "must provide an Azure Managed Identity or Azure OAuth in the Azure AD config",
|
|
},
|
|
// Invalid managedidentity client id.
|
|
{
|
|
filename: "testdata/azuread_bad_invalidclientid.yaml",
|
|
err: "the provided Azure Managed Identity client_id is invalid",
|
|
},
|
|
// Missing tenant id in oauth config.
|
|
{
|
|
filename: "testdata/azuread_bad_invalidoauthconfig.yaml",
|
|
err: "must provide an Azure OAuth tenant_id in the Azure AD config",
|
|
},
|
|
// Invalid config when both managedidentity and oauth is provided.
|
|
{
|
|
filename: "testdata/azuread_bad_twoconfig.yaml",
|
|
err: "cannot provide both Azure Managed Identity and Azure OAuth in the Azure AD config",
|
|
},
|
|
// Valid config with missing optionally cloud field.
|
|
{
|
|
filename: "testdata/azuread_good_cloudmissing.yaml",
|
|
},
|
|
// Valid managed identity config.
|
|
{
|
|
filename: "testdata/azuread_good_managedidentity.yaml",
|
|
},
|
|
// Valid Oauth config.
|
|
{
|
|
filename: "testdata/azuread_good_oauth.yaml",
|
|
},
|
|
}
|
|
for _, c := range cases {
|
|
_, err := loadAzureAdConfig(c.filename)
|
|
if c.err != "" {
|
|
if err == nil {
|
|
t.Fatalf("Did not receive expected error unmarshaling bad azuread config")
|
|
}
|
|
require.EqualError(t, err, c.err)
|
|
} else {
|
|
require.NoError(t, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
|
|
args := m.MethodCalled("GetToken", ctx, options)
|
|
if args.Get(0) == nil {
|
|
return azcore.AccessToken{}, args.Error(1)
|
|
}
|
|
|
|
return args.Get(0).(azcore.AccessToken), nil
|
|
}
|
|
|
|
func (s *TokenProviderTestSuite) BeforeTest(_, _ string) {
|
|
s.mockCredential = new(mockCredential)
|
|
}
|
|
|
|
func TestTokenProvider(t *testing.T) {
|
|
suite.Run(t, new(TokenProviderTestSuite))
|
|
}
|
|
|
|
func (s *TokenProviderTestSuite) TestNewTokenProvider() {
|
|
cases := []struct {
|
|
cfg *AzureADConfig
|
|
err string
|
|
}{
|
|
// Invalid tokenProvider for managedidentity.
|
|
{
|
|
cfg: &AzureADConfig{
|
|
Cloud: "PublicAzure",
|
|
ManagedIdentity: &ManagedIdentityConfig{
|
|
ClientID: dummyClientID,
|
|
},
|
|
},
|
|
err: "Cloud is not specified or is incorrect: ",
|
|
},
|
|
// Invalid tokenProvider for oauth.
|
|
{
|
|
cfg: &AzureADConfig{
|
|
Cloud: "PublicAzure",
|
|
OAuth: &OAuthConfig{
|
|
ClientID: dummyClientID,
|
|
ClientSecret: dummyClientSecret,
|
|
TenantID: dummyTenantID,
|
|
},
|
|
},
|
|
err: "Cloud is not specified or is incorrect: ",
|
|
},
|
|
// Valid tokenProvider for managedidentity.
|
|
{
|
|
cfg: &AzureADConfig{
|
|
Cloud: "AzurePublic",
|
|
ManagedIdentity: &ManagedIdentityConfig{
|
|
ClientID: dummyClientID,
|
|
},
|
|
},
|
|
},
|
|
// Valid tokenProvider for oauth.
|
|
{
|
|
cfg: &AzureADConfig{
|
|
Cloud: "AzurePublic",
|
|
OAuth: &OAuthConfig{
|
|
ClientID: dummyClientID,
|
|
ClientSecret: dummyClientSecret,
|
|
TenantID: dummyTenantID,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
mockGetTokenCallCounter := 1
|
|
for _, c := range cases {
|
|
if c.err != "" {
|
|
actualTokenProvider, actualErr := newTokenProvider(c.cfg, s.mockCredential)
|
|
|
|
s.Assert().Nil(actualTokenProvider)
|
|
s.Assert().NotNil(actualErr)
|
|
s.Assert().ErrorContains(actualErr, c.err)
|
|
} else {
|
|
testToken := &azcore.AccessToken{
|
|
Token: testTokenString,
|
|
ExpiresOn: testTokenExpiry,
|
|
}
|
|
|
|
s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil).Once().
|
|
On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil)
|
|
|
|
actualTokenProvider, actualErr := newTokenProvider(c.cfg, s.mockCredential)
|
|
|
|
s.Assert().NotNil(actualTokenProvider)
|
|
s.Assert().Nil(actualErr)
|
|
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
|
|
|
|
// Token set to refresh at half of the expiry time. The test tokens are set to expiry in 5s.
|
|
// Hence, the 4 seconds wait to check if the token is refreshed.
|
|
time.Sleep(4 * time.Second)
|
|
|
|
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
|
|
|
|
s.mockCredential.AssertNumberOfCalls(s.T(), "GetToken", 2*mockGetTokenCallCounter)
|
|
mockGetTokenCallCounter += 1
|
|
accessToken, err := actualTokenProvider.getAccessToken(context.Background())
|
|
s.Assert().Nil(err)
|
|
s.Assert().NotEqual(accessToken, testTokenString)
|
|
}
|
|
}
|
|
}
|
|
|
|
func getToken() azcore.AccessToken {
|
|
return azcore.AccessToken{
|
|
Token: uuid.New().String(),
|
|
ExpiresOn: time.Now().Add(10 * time.Second),
|
|
}
|
|
}
|