Allow the use of bearer_token or bearer_token_file for MarathonSD

pull/2462/head
Michael Kraus 2017-03-02 09:44:20 +01:00
parent 0a7fb56b16
commit 47bdcf0f67
3 changed files with 39 additions and 9 deletions

View File

@ -241,6 +241,7 @@ func resolveFilepaths(baseDir string, cfg *Config) {
kcfg.TLSConfig.KeyFile = join(kcfg.TLSConfig.KeyFile)
}
for _, mcfg := range cfg.MarathonSDConfigs {
mcfg.BearerTokenFile = join(mcfg.BearerTokenFile)
mcfg.TLSConfig.CAFile = join(mcfg.TLSConfig.CAFile)
mcfg.TLSConfig.CertFile = join(mcfg.TLSConfig.CertFile)
mcfg.TLSConfig.KeyFile = join(mcfg.TLSConfig.KeyFile)
@ -920,6 +921,8 @@ type MarathonSDConfig struct {
Timeout model.Duration `yaml:"timeout,omitempty"`
RefreshInterval model.Duration `yaml:"refresh_interval,omitempty"`
TLSConfig TLSConfig `yaml:"tls_config,omitempty"`
BearerToken string `yaml:"bearer_token,omitempty"`
BearerTokenFile string `yaml:"bearer_token_file,omitempty"`
// Catches all undefined fields and must be empty after parsing.
XXX map[string]interface{} `yaml:",inline"`
@ -939,6 +942,12 @@ func (c *MarathonSDConfig) UnmarshalYAML(unmarshal func(interface{}) error) erro
if len(c.Servers) == 0 {
return fmt.Errorf("Marathon SD config must contain at least one Marathon server")
}
if len(c.BearerToken) > 0 && len(c.BearerTokenFile) > 0 {
return fmt.Errorf("at most one of bearer_token & bearer_token_file must be configured")
}
if len(c.BearerToken) == 0 && len(c.BearerTokenFile) == 0 {
return fmt.Errorf("at most one of bearer_token & bearer_token_file must be configured")
}
return nil
}

View File

@ -20,6 +20,7 @@ import (
"math/rand"
"net"
"net/http"
"strings"
"time"
"golang.org/x/net/context"
@ -77,6 +78,7 @@ type Discovery struct {
refreshInterval time.Duration
lastRefresh map[string]*config.TargetGroup
appsClient AppListClient
token string
}
// Initialize sets up the discovery for usage.
@ -86,6 +88,15 @@ func NewDiscovery(conf *config.MarathonSDConfig) (*Discovery, error) {
return nil, err
}
token := conf.BearerToken
if conf.BearerTokenFile != "" {
bf, err := ioutil.ReadFile(conf.BearerTokenFile)
if err != nil {
return nil, err
}
token = strings.TrimSpace(string(bf))
}
client := &http.Client{
Timeout: time.Duration(conf.Timeout),
Transport: &http.Transport{
@ -98,6 +109,7 @@ func NewDiscovery(conf *config.MarathonSDConfig) (*Discovery, error) {
servers: conf.Servers,
refreshInterval: time.Duration(conf.RefreshInterval),
appsClient: fetchApps,
token: token,
}, nil
}
@ -160,7 +172,7 @@ func (md *Discovery) updateServices(ctx context.Context, ch chan<- []*config.Tar
func (md *Discovery) fetchTargetGroups() (map[string]*config.TargetGroup, error) {
url := RandomAppsURL(md.servers)
apps, err := md.appsClient(md.client, url)
apps, err := md.appsClient(md.client, url, md.token)
if err != nil {
return nil, err
}
@ -201,11 +213,20 @@ type AppList struct {
}
// AppListClient defines a function that can be used to get an application list from marathon.
type AppListClient func(client *http.Client, url string) (*AppList, error)
type AppListClient func(client *http.Client, url, token string) (*AppList, error)
// fetchApps requests a list of applications from a marathon server.
func fetchApps(client *http.Client, url string) (*AppList, error) {
resp, err := client.Get(url)
func fetchApps(client *http.Client, url, token string) (*AppList, error) {
request, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
if token != "" {
request.Header.Set("Authorization", "token="+token)
}
resp, err := client.Do(request)
if err != nil {
return nil, err
}

View File

@ -44,7 +44,7 @@ func TestMarathonSDHandleError(t *testing.T) {
var (
errTesting = errors.New("testing failure")
ch = make(chan []*config.TargetGroup, 1)
client = func(client *http.Client, url string) (*AppList, error) { return nil, errTesting }
client = func(client *http.Client, url, token string) (*AppList, error) { return nil, errTesting }
)
if err := testUpdateServices(client, ch); err != errTesting {
t.Fatalf("Expected error: %s", err)
@ -59,7 +59,7 @@ func TestMarathonSDHandleError(t *testing.T) {
func TestMarathonSDEmptyList(t *testing.T) {
var (
ch = make(chan []*config.TargetGroup, 1)
client = func(client *http.Client, url string) (*AppList, error) { return &AppList{}, nil }
client = func(client *http.Client, url, token string) (*AppList, error) { return &AppList{}, nil }
)
if err := testUpdateServices(client, ch); err != nil {
t.Fatalf("Got error: %s", err)
@ -130,7 +130,7 @@ func TestMarathonSDRemoveApp(t *testing.T) {
if err != nil {
t.Fatalf("%s", err)
}
md.appsClient = func(client *http.Client, url string) (*AppList, error) {
md.appsClient = func(client *http.Client, url, token string) (*AppList, error) {
return marathonTestAppList(marathonValidLabel, 1), nil
}
go func() {
@ -165,7 +165,7 @@ func TestMarathonSDRunAndStop(t *testing.T) {
if err != nil {
t.Fatalf("%s", err)
}
md.appsClient = func(client *http.Client, url string) (*AppList, error) {
md.appsClient = func(client *http.Client, url, token string) (*AppList, error) {
return marathonTestAppList(marathonValidLabel, 1), nil
}
ctx, cancel := context.WithCancel(context.Background())
@ -213,7 +213,7 @@ func marathonTestZeroTaskPortAppList(labels map[string]string, runningTasks int)
func TestMarathonZeroTaskPorts(t *testing.T) {
var (
ch = make(chan []*config.TargetGroup, 1)
client = func(client *http.Client, url string) (*AppList, error) {
client = func(client *http.Client, url, token string) (*AppList, error) {
return marathonTestZeroTaskPortAppList(marathonValidLabel, 1), nil
}
)