mirror of https://github.com/prometheus/prometheus
Add Azure AD package for remote write (#11944)
* Add Azure AD package for remote write * Made AzurePublic default and updated configuration.md * Updated config structure and removed getToken at initialization * Changed passing context from request Signed-off-by: Rakshith Padmanabha <rapadman@microsoft.com> Signed-off-by: rakshith210 <rakshith.me@gmail.com>pull/12420/head
parent
a8772a4178
commit
b1675e23af
|
@ -34,6 +34,7 @@ import (
|
|||
"github.com/prometheus/prometheus/discovery"
|
||||
"github.com/prometheus/prometheus/model/labels"
|
||||
"github.com/prometheus/prometheus/model/relabel"
|
||||
"github.com/prometheus/prometheus/storage/remote/azuread"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -907,6 +908,7 @@ type RemoteWriteConfig struct {
|
|||
QueueConfig QueueConfig `yaml:"queue_config,omitempty"`
|
||||
MetadataConfig MetadataConfig `yaml:"metadata_config,omitempty"`
|
||||
SigV4Config *sigv4.SigV4Config `yaml:"sigv4,omitempty"`
|
||||
AzureADConfig *azuread.AzureADConfig `yaml:"azuread,omitempty"`
|
||||
}
|
||||
|
||||
// SetDirectory joins any relative file paths with dir.
|
||||
|
@ -943,8 +945,12 @@ func (c *RemoteWriteConfig) UnmarshalYAML(unmarshal func(interface{}) error) err
|
|||
httpClientConfigAuthEnabled := c.HTTPClientConfig.BasicAuth != nil ||
|
||||
c.HTTPClientConfig.Authorization != nil || c.HTTPClientConfig.OAuth2 != nil
|
||||
|
||||
if httpClientConfigAuthEnabled && c.SigV4Config != nil {
|
||||
return fmt.Errorf("at most one of basic_auth, authorization, oauth2, & sigv4 must be configured")
|
||||
if httpClientConfigAuthEnabled && (c.SigV4Config != nil || c.AzureADConfig != nil) {
|
||||
return fmt.Errorf("at most one of basic_auth, authorization, oauth2, sigv4, & azuread must be configured")
|
||||
}
|
||||
|
||||
if c.SigV4Config != nil && c.AzureADConfig != nil {
|
||||
return fmt.Errorf("at most one of basic_auth, authorization, oauth2, sigv4, & azuread must be configured")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -965,7 +971,7 @@ func validateHeadersForTracing(headers map[string]string) error {
|
|||
func validateHeaders(headers map[string]string) error {
|
||||
for header := range headers {
|
||||
if strings.ToLower(header) == "authorization" {
|
||||
return errors.New("authorization header must be changed via the basic_auth, authorization, oauth2, or sigv4 parameter")
|
||||
return errors.New("authorization header must be changed via the basic_auth, authorization, oauth2, sigv4, or azuread parameter")
|
||||
}
|
||||
if _, ok := reservedHeaders[strings.ToLower(header)]; ok {
|
||||
return fmt.Errorf("%s is a reserved header. It must not be changed", header)
|
||||
|
|
|
@ -1727,7 +1727,7 @@ var expectedErrors = []struct {
|
|||
},
|
||||
{
|
||||
filename: "remote_write_authorization_header.bad.yml",
|
||||
errMsg: `authorization header must be changed via the basic_auth, authorization, oauth2, or sigv4 parameter`,
|
||||
errMsg: `authorization header must be changed via the basic_auth, authorization, oauth2, sigv4, or azuread parameter`,
|
||||
},
|
||||
{
|
||||
filename: "remote_write_url_missing.bad.yml",
|
||||
|
|
|
@ -3466,7 +3466,7 @@ authorization:
|
|||
[ credentials_file: <filename> ]
|
||||
|
||||
# Optionally configures AWS's Signature Verification 4 signing process to
|
||||
# sign requests. Cannot be set at the same time as basic_auth, authorization, or oauth2.
|
||||
# sign requests. Cannot be set at the same time as basic_auth, authorization, oauth2, or azuread.
|
||||
# To use the default credentials from the AWS SDK, use `sigv4: {}`.
|
||||
sigv4:
|
||||
# The AWS region. If blank, the region from the default credentials chain
|
||||
|
@ -3485,10 +3485,20 @@ sigv4:
|
|||
[ role_arn: <string> ]
|
||||
|
||||
# Optional OAuth 2.0 configuration.
|
||||
# Cannot be used at the same time as basic_auth, authorization, or sigv4.
|
||||
# Cannot be used at the same time as basic_auth, authorization, sigv4, or azuread.
|
||||
oauth2:
|
||||
[ <oauth2> ]
|
||||
|
||||
# Optional AzureAD configuration.
|
||||
# Cannot be used at the same time as basic_auth, authorization, oauth2, or sigv4.
|
||||
azuread:
|
||||
# The Azure Cloud. Options are 'AzurePublic', 'AzureChina', or 'AzureGovernment'.
|
||||
[ cloud: <string> | default = AzurePublic ]
|
||||
|
||||
# Azure User-assigned Managed identity.
|
||||
[ managed_identity:
|
||||
[ client_id: <string> ]
|
||||
|
||||
# Configures the remote write request's TLS settings.
|
||||
tls_config:
|
||||
[ <tls_config> ]
|
||||
|
|
9
go.mod
9
go.mod
|
@ -4,6 +4,8 @@ go 1.19
|
|||
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go v65.0.0+incompatible
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1
|
||||
github.com/Azure/go-autorest/autorest v0.11.28
|
||||
github.com/Azure/go-autorest/autorest/adal v0.9.23
|
||||
github.com/alecthomas/kingpin/v2 v2.3.2
|
||||
|
@ -83,10 +85,15 @@ require (
|
|||
|
||||
require (
|
||||
cloud.google.com/go/compute/metadata v0.2.3 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 // indirect
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v0.8.1 // indirect
|
||||
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||
github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 // indirect
|
||||
github.com/rogpeppe/go-internal v1.10.0 // indirect
|
||||
github.com/stretchr/objx v0.5.0 // indirect
|
||||
github.com/xhit/go-str2duration/v2 v2.1.0 // indirect
|
||||
google.golang.org/genproto v0.0.0-20230526203410-71b5a4ffd15e // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230526203410-71b5a4ffd15e // indirect
|
||||
|
@ -135,7 +142,7 @@ require (
|
|||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/gofuzz v1.2.0 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.7.1 // indirect
|
||||
github.com/gorilla/websocket v1.5.0 // indirect
|
||||
|
|
12
go.sum
12
go.sum
|
@ -38,6 +38,12 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9
|
|||
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
||||
github.com/Azure/azure-sdk-for-go v65.0.0+incompatible h1:HzKLt3kIwMm4KeJYTdx9EbjRYTySD/t8i1Ee/W5EGXw=
|
||||
github.com/Azure/azure-sdk-for-go v65.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1 h1:gVXuXcWd1i4C2Ruxe321aU+IKGaStvGB/S90PUPB/W8=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1/go.mod h1:DffdKW9RFqa5VgmsjUOsS7UE7eiA5iAvYUs63bhKQ0M=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1 h1:T8quHYlUGyb/oqtSTwqlCr1ilJHrDv+ZtpSfo+hm1BU=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1/go.mod h1:gLa1CL2RNE4s7M3yopJ/p0iq5DdY6Yv5ZUt9MTRZOQM=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 h1:+5VZ72z0Qan5Bog5C+ZkgSqUbeVUd9wgtHOrIKuc5b8=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs=
|
||||
|
@ -60,6 +66,8 @@ github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+Z
|
|||
github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8=
|
||||
github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo=
|
||||
github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v0.8.1 h1:oPdPEZFSbl7oSPEAIPMPBMUmiL+mqgzBJwM/9qYcwNg=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v0.8.1/go.mod h1:4qFor3D/HDsvBME35Xy9rwW9DecL+M2sNw1ybjPtwA0=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||
github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
|
||||
|
@ -515,6 +523,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
|||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM=
|
||||
github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4=
|
||||
github.com/linode/linodego v1.16.1 h1:5otq57M4PdHycPERRfSFZ0s1yz1ETVWGjCp3hh7+F9w=
|
||||
|
@ -630,6 +640,8 @@ github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAv
|
|||
github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac=
|
||||
github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc=
|
||||
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
|
||||
github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 h1:Qj1ukM4GlMWXNdMBuXcXfz/Kw9s1qm0CLY32QxuSImI=
|
||||
github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4/go.mod h1:N6UoU20jOqggOuDwUaBQpluzLNDqif3kq9z2wpdYEfQ=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
azuread package
|
||||
=========================================
|
||||
|
||||
azuread provides an http.RoundTripper that attaches an Azure AD accessToken
|
||||
to remote write requests.
|
||||
|
||||
This module is considered internal to Prometheus, without any stability
|
||||
guarantees for external usage.
|
|
@ -0,0 +1,247 @@
|
|||
// Copyright 2023 The Prometheus Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package azuread
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
// Clouds.
|
||||
AzureChina = "AzureChina"
|
||||
AzureGovernment = "AzureGovernment"
|
||||
AzurePublic = "AzurePublic"
|
||||
|
||||
// Audiences.
|
||||
IngestionChinaAudience = "https://monitor.azure.cn//.default"
|
||||
IngestionGovernmentAudience = "https://monitor.azure.us//.default"
|
||||
IngestionPublicAudience = "https://monitor.azure.com//.default"
|
||||
)
|
||||
|
||||
// ManagedIdentityConfig is used to store managed identity config values
|
||||
type ManagedIdentityConfig struct {
|
||||
// ClientID is the clientId of the managed identity that is being used to authenticate.
|
||||
ClientID string `yaml:"client_id,omitempty"`
|
||||
}
|
||||
|
||||
// AzureADConfig is used to store the config values.
|
||||
type AzureADConfig struct { // nolint:revive
|
||||
// ManagedIdentity is the managed identity that is being used to authenticate.
|
||||
ManagedIdentity *ManagedIdentityConfig `yaml:"managed_identity,omitempty"`
|
||||
|
||||
// Cloud is the Azure cloud in which the service is running. Example: AzurePublic/AzureGovernment/AzureChina.
|
||||
Cloud string `yaml:"cloud,omitempty"`
|
||||
}
|
||||
|
||||
// azureADRoundTripper is used to store the roundtripper and the tokenprovider.
|
||||
type azureADRoundTripper struct {
|
||||
next http.RoundTripper
|
||||
tokenProvider *tokenProvider
|
||||
}
|
||||
|
||||
// tokenProvider is used to store and retrieve Azure AD accessToken.
|
||||
type tokenProvider struct {
|
||||
// token is member used to store the current valid accessToken.
|
||||
token string
|
||||
// mu guards access to token.
|
||||
mu sync.Mutex
|
||||
// refreshTime is used to store the refresh time of the current valid accessToken.
|
||||
refreshTime time.Time
|
||||
// credentialClient is the Azure AD credential client that is being used to retrieve accessToken.
|
||||
credentialClient azcore.TokenCredential
|
||||
options *policy.TokenRequestOptions
|
||||
}
|
||||
|
||||
// Validate validates config values provided.
|
||||
func (c *AzureADConfig) Validate() error {
|
||||
if c.Cloud == "" {
|
||||
c.Cloud = AzurePublic
|
||||
}
|
||||
|
||||
if c.Cloud != AzureChina && c.Cloud != AzureGovernment && c.Cloud != AzurePublic {
|
||||
return fmt.Errorf("must provide a cloud in the Azure AD config")
|
||||
}
|
||||
|
||||
if c.ManagedIdentity == nil {
|
||||
return fmt.Errorf("must provide an Azure Managed Identity in the Azure AD config")
|
||||
}
|
||||
|
||||
if c.ManagedIdentity.ClientID == "" {
|
||||
return fmt.Errorf("must provide an Azure Managed Identity client_id in the Azure AD config")
|
||||
}
|
||||
|
||||
_, err := uuid.Parse(c.ManagedIdentity.ClientID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("the provided Azure Managed Identity client_id provided is invalid")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalYAML unmarshal the Azure AD config yaml.
|
||||
func (c *AzureADConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type plain AzureADConfig
|
||||
*c = AzureADConfig{}
|
||||
if err := unmarshal((*plain)(c)); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Validate()
|
||||
}
|
||||
|
||||
// NewAzureADRoundTripper creates round tripper adding Azure AD authorization to calls.
|
||||
func NewAzureADRoundTripper(cfg *AzureADConfig, next http.RoundTripper) (http.RoundTripper, error) {
|
||||
if next == nil {
|
||||
next = http.DefaultTransport
|
||||
}
|
||||
|
||||
cred, err := newTokenCredential(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenProvider, err := newTokenProvider(cfg, cred)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rt := &azureADRoundTripper{
|
||||
next: next,
|
||||
tokenProvider: tokenProvider,
|
||||
}
|
||||
return rt, nil
|
||||
}
|
||||
|
||||
// RoundTrip sets Authorization header for requests.
|
||||
func (rt *azureADRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
accessToken, err := rt.tokenProvider.getAccessToken(req.Context())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bearerAccessToken := "Bearer " + accessToken
|
||||
req.Header.Set("Authorization", bearerAccessToken)
|
||||
|
||||
return rt.next.RoundTrip(req)
|
||||
}
|
||||
|
||||
// newTokenCredential returns a TokenCredential of different kinds like Azure Managed Identity and Azure AD application.
|
||||
func newTokenCredential(cfg *AzureADConfig) (azcore.TokenCredential, error) {
|
||||
cred, err := newManagedIdentityTokenCredential(cfg.ManagedIdentity.ClientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
// newManagedIdentityTokenCredential returns new Managed Identity token credential.
|
||||
func newManagedIdentityTokenCredential(managedIdentityClientID string) (azcore.TokenCredential, error) {
|
||||
clientID := azidentity.ClientID(managedIdentityClientID)
|
||||
opts := &azidentity.ManagedIdentityCredentialOptions{ID: clientID}
|
||||
return azidentity.NewManagedIdentityCredential(opts)
|
||||
}
|
||||
|
||||
// newTokenProvider helps to fetch accessToken for different types of credential. This also takes care of
|
||||
// refreshing the accessToken before expiry. This accessToken is attached to the Authorization header while making requests.
|
||||
func newTokenProvider(cfg *AzureADConfig, cred azcore.TokenCredential) (*tokenProvider, error) {
|
||||
audience, err := getAudience(cfg.Cloud)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenProvider := &tokenProvider{
|
||||
credentialClient: cred,
|
||||
options: &policy.TokenRequestOptions{Scopes: []string{audience}},
|
||||
}
|
||||
|
||||
return tokenProvider, nil
|
||||
}
|
||||
|
||||
// getAccessToken returns the current valid accessToken.
|
||||
func (tokenProvider *tokenProvider) getAccessToken(ctx context.Context) (string, error) {
|
||||
tokenProvider.mu.Lock()
|
||||
defer tokenProvider.mu.Unlock()
|
||||
if tokenProvider.valid() {
|
||||
return tokenProvider.token, nil
|
||||
}
|
||||
err := tokenProvider.getToken(ctx)
|
||||
if err != nil {
|
||||
return "", errors.New("Failed to get access token: " + err.Error())
|
||||
}
|
||||
return tokenProvider.token, nil
|
||||
}
|
||||
|
||||
// valid checks if the token in the token provider is valid and not expired.
|
||||
func (tokenProvider *tokenProvider) valid() bool {
|
||||
if len(tokenProvider.token) == 0 {
|
||||
return false
|
||||
}
|
||||
if tokenProvider.refreshTime.After(time.Now().UTC()) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getToken retrieves a new accessToken and stores the newly retrieved token in the tokenProvider.
|
||||
func (tokenProvider *tokenProvider) getToken(ctx context.Context) error {
|
||||
accessToken, err := tokenProvider.credentialClient.GetToken(ctx, *tokenProvider.options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(accessToken.Token) == 0 {
|
||||
return errors.New("access token is empty")
|
||||
}
|
||||
|
||||
tokenProvider.token = accessToken.Token
|
||||
err = tokenProvider.updateRefreshTime(accessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateRefreshTime handles logic to set refreshTime. The refreshTime is set at half the duration of the actual token expiry.
|
||||
func (tokenProvider *tokenProvider) updateRefreshTime(accessToken azcore.AccessToken) error {
|
||||
tokenExpiryTimestamp := accessToken.ExpiresOn.UTC()
|
||||
deltaExpirytime := time.Now().Add(time.Until(tokenExpiryTimestamp) / 2)
|
||||
if deltaExpirytime.After(time.Now().UTC()) {
|
||||
tokenProvider.refreshTime = deltaExpirytime
|
||||
} else {
|
||||
return errors.New("access token expiry is less than the current time")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAudience returns audiences for different clouds.
|
||||
func getAudience(cloud string) (string, error) {
|
||||
switch strings.ToLower(cloud) {
|
||||
case strings.ToLower(AzureChina):
|
||||
return IngestionChinaAudience, nil
|
||||
case strings.ToLower(AzureGovernment):
|
||||
return IngestionGovernmentAudience, nil
|
||||
case strings.ToLower(AzurePublic):
|
||||
return IngestionPublicAudience, nil
|
||||
default:
|
||||
return "", errors.New("Cloud is not specified or is incorrect: " + cloud)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,252 @@
|
|||
// Copyright 2023 The Prometheus Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package azuread
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
dummyAudience = "dummyAudience"
|
||||
dummyClientID = "00000000-0000-0000-0000-000000000000"
|
||||
testTokenString = "testTokenString"
|
||||
)
|
||||
|
||||
var testTokenExpiry = time.Now().Add(10 * time.Second)
|
||||
|
||||
type AzureAdTestSuite struct {
|
||||
suite.Suite
|
||||
mockCredential *mockCredential
|
||||
}
|
||||
|
||||
type TokenProviderTestSuite struct {
|
||||
suite.Suite
|
||||
mockCredential *mockCredential
|
||||
}
|
||||
|
||||
// mockCredential mocks azidentity TokenCredential interface.
|
||||
type mockCredential struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (ad *AzureAdTestSuite) BeforeTest(_, _ string) {
|
||||
ad.mockCredential = new(mockCredential)
|
||||
}
|
||||
|
||||
func TestAzureAd(t *testing.T) {
|
||||
suite.Run(t, new(AzureAdTestSuite))
|
||||
}
|
||||
|
||||
func (ad *AzureAdTestSuite) TestAzureAdRoundTripper() {
|
||||
var gotReq *http.Request
|
||||
|
||||
testToken := &azcore.AccessToken{
|
||||
Token: testTokenString,
|
||||
ExpiresOn: testTokenExpiry,
|
||||
}
|
||||
|
||||
managedIdentityConfig := &ManagedIdentityConfig{
|
||||
ClientID: dummyClientID,
|
||||
}
|
||||
|
||||
azureAdConfig := &AzureADConfig{
|
||||
Cloud: "AzurePublic",
|
||||
ManagedIdentity: managedIdentityConfig,
|
||||
}
|
||||
|
||||
ad.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil)
|
||||
|
||||
tokenProvider, err := newTokenProvider(azureAdConfig, ad.mockCredential)
|
||||
ad.Assert().NoError(err)
|
||||
|
||||
rt := &azureADRoundTripper{
|
||||
next: promhttp.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
gotReq = req
|
||||
return &http.Response{StatusCode: http.StatusOK}, nil
|
||||
}),
|
||||
tokenProvider: tokenProvider,
|
||||
}
|
||||
|
||||
cli := &http.Client{Transport: rt}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!"))
|
||||
ad.Assert().NoError(err)
|
||||
|
||||
_, err = cli.Do(req)
|
||||
ad.Assert().NoError(err)
|
||||
ad.Assert().NotNil(gotReq)
|
||||
|
||||
origReq := gotReq
|
||||
ad.Assert().NotEmpty(origReq.Header.Get("Authorization"))
|
||||
ad.Assert().Equal("Bearer "+testTokenString, origReq.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
func loadAzureAdConfig(filename string) (*AzureADConfig, error) {
|
||||
content, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg := AzureADConfig{}
|
||||
if err = yaml.UnmarshalStrict(content, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func testGoodConfig(t *testing.T, filename string) {
|
||||
_, err := loadAzureAdConfig(filename)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error parsing %s: %s", filename, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoodAzureAdConfig(t *testing.T) {
|
||||
filename := "testdata/azuread_good.yaml"
|
||||
testGoodConfig(t, filename)
|
||||
}
|
||||
|
||||
func TestGoodCloudMissingAzureAdConfig(t *testing.T) {
|
||||
filename := "testdata/azuread_good_cloudmissing.yaml"
|
||||
testGoodConfig(t, filename)
|
||||
}
|
||||
|
||||
func TestBadClientIdMissingAzureAdConfig(t *testing.T) {
|
||||
filename := "testdata/azuread_bad_clientidmissing.yaml"
|
||||
_, err := loadAzureAdConfig(filename)
|
||||
if err == nil {
|
||||
t.Fatalf("Did not receive expected error unmarshaling bad azuread config")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "must provide an Azure Managed Identity in the Azure AD config") {
|
||||
t.Errorf("Received unexpected error from unmarshal of %s: %s", filename, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadInvalidClientIdAzureAdConfig(t *testing.T) {
|
||||
filename := "testdata/azuread_bad_invalidclientid.yaml"
|
||||
_, err := loadAzureAdConfig(filename)
|
||||
if err == nil {
|
||||
t.Fatalf("Did not receive expected error unmarshaling bad azuread config")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "the provided Azure Managed Identity client_id provided is invalid") {
|
||||
t.Errorf("Received unexpected error from unmarshal of %s: %s", filename, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
|
||||
args := m.MethodCalled("GetToken", ctx, options)
|
||||
if args.Get(0) == nil {
|
||||
return azcore.AccessToken{}, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).(azcore.AccessToken), nil
|
||||
}
|
||||
|
||||
func (s *TokenProviderTestSuite) BeforeTest(_, _ string) {
|
||||
s.mockCredential = new(mockCredential)
|
||||
}
|
||||
|
||||
func TestTokenProvider(t *testing.T) {
|
||||
suite.Run(t, new(TokenProviderTestSuite))
|
||||
}
|
||||
|
||||
func (s *TokenProviderTestSuite) TestNewTokenProvider_NilAudience_Fail() {
|
||||
managedIdentityConfig := &ManagedIdentityConfig{
|
||||
ClientID: dummyClientID,
|
||||
}
|
||||
|
||||
azureAdConfig := &AzureADConfig{
|
||||
Cloud: "PublicAzure",
|
||||
ManagedIdentity: managedIdentityConfig,
|
||||
}
|
||||
|
||||
actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential)
|
||||
|
||||
s.Assert().Nil(actualTokenProvider)
|
||||
s.Assert().NotNil(actualErr)
|
||||
s.Assert().Equal("Cloud is not specified or is incorrect: "+azureAdConfig.Cloud, actualErr.Error())
|
||||
}
|
||||
|
||||
func (s *TokenProviderTestSuite) TestNewTokenProvider_Success() {
|
||||
managedIdentityConfig := &ManagedIdentityConfig{
|
||||
ClientID: dummyClientID,
|
||||
}
|
||||
|
||||
azureAdConfig := &AzureADConfig{
|
||||
Cloud: "AzurePublic",
|
||||
ManagedIdentity: managedIdentityConfig,
|
||||
}
|
||||
s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil)
|
||||
|
||||
actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential)
|
||||
|
||||
s.Assert().NotNil(actualTokenProvider)
|
||||
s.Assert().Nil(actualErr)
|
||||
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
|
||||
}
|
||||
|
||||
func (s *TokenProviderTestSuite) TestPeriodicTokenRefresh_Success() {
|
||||
// setup
|
||||
managedIdentityConfig := &ManagedIdentityConfig{
|
||||
ClientID: dummyClientID,
|
||||
}
|
||||
|
||||
azureAdConfig := &AzureADConfig{
|
||||
Cloud: "AzurePublic",
|
||||
ManagedIdentity: managedIdentityConfig,
|
||||
}
|
||||
testToken := &azcore.AccessToken{
|
||||
Token: testTokenString,
|
||||
ExpiresOn: testTokenExpiry,
|
||||
}
|
||||
|
||||
s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil).Once().
|
||||
On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil)
|
||||
|
||||
actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential)
|
||||
|
||||
s.Assert().NotNil(actualTokenProvider)
|
||||
s.Assert().Nil(actualErr)
|
||||
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
|
||||
|
||||
// Token set to refresh at half of the expiry time. The test tokens are set to expiry in 10s.
|
||||
// Hence, the 6 seconds wait to check if the token is refreshed.
|
||||
time.Sleep(6 * time.Second)
|
||||
|
||||
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
|
||||
|
||||
s.mockCredential.AssertNumberOfCalls(s.T(), "GetToken", 2)
|
||||
accessToken, err := actualTokenProvider.getAccessToken(context.Background())
|
||||
s.Assert().Nil(err)
|
||||
s.Assert().NotEqual(accessToken, testTokenString)
|
||||
}
|
||||
|
||||
func getToken() azcore.AccessToken {
|
||||
return azcore.AccessToken{
|
||||
Token: uuid.New().String(),
|
||||
ExpiresOn: time.Now().Add(10 * time.Second),
|
||||
}
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
cloud: AzurePublic
|
|
@ -0,0 +1,3 @@
|
|||
cloud: AzurePublic
|
||||
managed_identity:
|
||||
client_id: foo-foobar-bar-foo-00000000
|
|
@ -0,0 +1,3 @@
|
|||
cloud: AzurePublic
|
||||
managed_identity:
|
||||
client_id: 00000000-0000-0000-0000-000000000000
|
|
@ -0,0 +1,2 @@
|
|||
managed_identity:
|
||||
client_id: 00000000-0000-0000-0000-000000000000
|
|
@ -36,6 +36,7 @@ import (
|
|||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/prometheus/prometheus/prompb"
|
||||
"github.com/prometheus/prometheus/storage/remote/azuread"
|
||||
)
|
||||
|
||||
const maxErrMsgLen = 1024
|
||||
|
@ -97,6 +98,7 @@ type ClientConfig struct {
|
|||
Timeout model.Duration
|
||||
HTTPClientConfig config_util.HTTPClientConfig
|
||||
SigV4Config *sigv4.SigV4Config
|
||||
AzureADConfig *azuread.AzureADConfig
|
||||
Headers map[string]string
|
||||
RetryOnRateLimit bool
|
||||
}
|
||||
|
@ -146,6 +148,13 @@ func NewWriteClient(name string, conf *ClientConfig) (WriteClient, error) {
|
|||
}
|
||||
}
|
||||
|
||||
if conf.AzureADConfig != nil {
|
||||
t, err = azuread.NewAzureADRoundTripper(conf.AzureADConfig, httpClient.Transport)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(conf.Headers) > 0 {
|
||||
t = newInjectHeadersRoundTripper(conf.Headers, t)
|
||||
}
|
||||
|
|
|
@ -158,6 +158,7 @@ func (rws *WriteStorage) ApplyConfig(conf *config.Config) error {
|
|||
Timeout: rwConf.RemoteTimeout,
|
||||
HTTPClientConfig: rwConf.HTTPClientConfig,
|
||||
SigV4Config: rwConf.SigV4Config,
|
||||
AzureADConfig: rwConf.AzureADConfig,
|
||||
Headers: rwConf.Headers,
|
||||
RetryOnRateLimit: rwConf.QueueConfig.RetryOnRateLimit,
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue