Clear auth config when gcp credentials fail

Specific use case is when utilizing multiple
gcp accounts, the user may provide credentials
for the wrong account.

This change ensures the incorrect credentials
are not cached in auth config, and logs an
appropriate message.
pull/6/head
Matt Tyler 2017-05-29 16:58:09 +08:00
parent a57c33bd28
commit b92016769e
2 changed files with 79 additions and 2 deletions

View File

@ -124,7 +124,7 @@ func newGCPAuthProvider(_ string, gcpConfig map[string]string, persister restcli
}
func (g *gcpAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}}
return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}, g.persister}
}
func (g *gcpAuthProvider) Login() error { return nil }
@ -284,11 +284,25 @@ func parseJSONPath(input interface{}, name, template string) (string, error) {
type conditionalTransport struct {
oauthTransport *oauth2.Transport
persister restclient.AuthProviderConfigPersister
}
func (t *conditionalTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if len(req.Header.Get("Authorization")) != 0 {
return t.oauthTransport.Base.RoundTrip(req)
}
return t.oauthTransport.RoundTrip(req)
res, err := t.oauthTransport.RoundTrip(req)
if err != nil {
return nil, err
}
if res.StatusCode == 401 {
glog.V(4).Infof("The credentials that were supplied are invalid for the target cluster")
emptyCache := make(map[string]string)
t.persister.Persist(emptyCache)
}
return res, nil
}

View File

@ -18,6 +18,7 @@ package gcp
import (
"fmt"
"net/http"
"os"
"os/exec"
"reflect"
@ -323,3 +324,65 @@ func TestCachedTokenSource(t *testing.T) {
t.Errorf("got cache %v, want %v", got, cache)
}
}
type MockTransport struct {
res *http.Response
}
func (t *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return t.res, nil
}
func TestClearingCredentials(t *testing.T) {
fakeExpiry := time.Now().Add(time.Hour)
cache := map[string]string{
"access-token": "fakeToken",
"expiry": fakeExpiry.String(),
}
cts := cachedTokenSource{
source: nil,
accessToken: cache["access-token"],
expiry: fakeExpiry,
persister: nil,
cache: nil,
}
tests := []struct {
name string
res http.Response
cache map[string]string
}{
{
"Unauthorized",
http.Response{StatusCode: 401},
make(map[string]string),
},
{
"Authorized",
http.Response{StatusCode: 200},
cache,
},
}
persister := &fakePersister{}
req := http.Request{Header: http.Header{}}
for _, tc := range tests {
authProvider := gcpAuthProvider{&cts, persister}
fakeTransport := MockTransport{&tc.res}
transport := (authProvider.WrapTransport(&fakeTransport))
persister.Persist(cache)
transport.RoundTrip(&req)
if got := persister.read(); !reflect.DeepEqual(got, tc.cache) {
t.Errorf("got cache %v, want %v", got, tc.cache)
}
}
}