Merge pull request #66314 from jlowdermilk/cmdtokensource-reset

gcp client auth plugin: persist default cache on unauthorized
pull/8/head
k8s-ci-robot 2018-09-14 00:49:21 -07:00 committed by GitHub
commit c04fe8c27c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 23 deletions

View File

@ -174,7 +174,13 @@ func parseScopes(gcpConfig map[string]string) []string {
} }
func (g *gcpAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper { func (g *gcpAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}, g.persister} var resetCache map[string]string
if cts, ok := g.tokenSource.(*cachedTokenSource); ok {
resetCache = cts.baseCache()
} else {
resetCache = make(map[string]string)
}
return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}, g.persister, resetCache}
} }
func (g *gcpAuthProvider) Login() error { return nil } func (g *gcpAuthProvider) Login() error { return nil }
@ -247,6 +253,19 @@ func (t *cachedTokenSource) update(tok *oauth2.Token) map[string]string {
return ret return ret
} }
// baseCache is the base configuration value for this TokenSource, without any cached ephemeral tokens.
func (t *cachedTokenSource) baseCache() map[string]string {
t.lk.Lock()
defer t.lk.Unlock()
ret := map[string]string{}
for k, v := range t.cache {
ret[k] = v
}
delete(ret, "access-token")
delete(ret, "expiry")
return ret
}
type commandTokenSource struct { type commandTokenSource struct {
cmd string cmd string
args []string args []string
@ -337,6 +356,7 @@ func parseJSONPath(input interface{}, name, template string) (string, error) {
type conditionalTransport struct { type conditionalTransport struct {
oauthTransport *oauth2.Transport oauthTransport *oauth2.Transport
persister restclient.AuthProviderConfigPersister persister restclient.AuthProviderConfigPersister
resetCache map[string]string
} }
var _ net.RoundTripperWrapper = &conditionalTransport{} var _ net.RoundTripperWrapper = &conditionalTransport{}
@ -354,8 +374,7 @@ func (t *conditionalTransport) RoundTrip(req *http.Request) (*http.Response, err
if res.StatusCode == 401 { if res.StatusCode == 401 {
glog.V(4).Infof("The credentials that were supplied are invalid for the target cluster") glog.V(4).Infof("The credentials that were supplied are invalid for the target cluster")
emptyCache := make(map[string]string) t.persister.Persist(t.resetCache)
t.persister.Persist(emptyCache)
} }
return res, nil return res, nil

View File

@ -442,37 +442,61 @@ func (t *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return t.res, nil return t.res, nil
} }
func TestClearingCredentials(t *testing.T) { func Test_cmdTokenSource_roundTrip(t *testing.T) {
accessToken := "fakeToken"
fakeExpiry := time.Now().Add(time.Hour) fakeExpiry := time.Now().Add(time.Hour)
fakeExpiryStr := fakeExpiry.Format(time.RFC3339Nano)
cache := map[string]string{ fs := &fakeTokenSource{
"access-token": "fakeToken", token: &oauth2.Token{
"expiry": fakeExpiry.String(), AccessToken: accessToken,
Expiry: fakeExpiry,
},
} }
cts := cachedTokenSource{ cmdCache := map[string]string{
source: nil, "cmd-path": "/path/to/tokensource/cmd",
accessToken: cache["access-token"], "cmd-args": "--output=json",
expiry: fakeExpiry, }
persister: nil, cmdCacheUpdated := map[string]string{
cache: nil, "cmd-path": "/path/to/tokensource/cmd",
"cmd-args": "--output=json",
"access-token": accessToken,
"expiry": fakeExpiryStr,
}
simpleCacheUpdated := map[string]string{
"access-token": accessToken,
"expiry": fakeExpiryStr,
} }
tests := []struct { tests := []struct {
name string name string
res http.Response res http.Response
cache map[string]string baseCache, expectedCache map[string]string
}{ }{
{ {
"Unauthorized", "Unauthorized",
http.Response{StatusCode: 401}, http.Response{StatusCode: 401},
make(map[string]string), make(map[string]string),
make(map[string]string),
},
{
"Unauthorized, nonempty defaultCache",
http.Response{StatusCode: 401},
cmdCache,
cmdCache,
}, },
{ {
"Authorized", "Authorized",
http.Response{StatusCode: 200}, http.Response{StatusCode: 200},
cache, make(map[string]string),
simpleCacheUpdated,
},
{
"Authorized, nonempty defaultCache",
http.Response{StatusCode: 200},
cmdCache,
cmdCacheUpdated,
}, },
} }
@ -480,17 +504,23 @@ func TestClearingCredentials(t *testing.T) {
req := http.Request{Header: http.Header{}} req := http.Request{Header: http.Header{}}
for _, tc := range tests { for _, tc := range tests {
authProvider := gcpAuthProvider{&cts, persister} cts, err := newCachedTokenSource(accessToken, fakeExpiry.String(), persister, fs, tc.baseCache)
if err != nil {
t.Fatalf("unexpected error from newCachedTokenSource: %v", err)
}
authProvider := gcpAuthProvider{cts, persister}
fakeTransport := MockTransport{&tc.res} fakeTransport := MockTransport{&tc.res}
transport := (authProvider.WrapTransport(&fakeTransport)) transport := (authProvider.WrapTransport(&fakeTransport))
persister.Persist(cache) // call Token to persist/update cache
if _, err := cts.Token(); err != nil {
t.Fatalf("unexpected error from cachedTokenSource.Token(): %v", err)
}
transport.RoundTrip(&req) transport.RoundTrip(&req)
if got := persister.read(); !reflect.DeepEqual(got, tc.cache) { if got := persister.read(); !reflect.DeepEqual(got, tc.expectedCache) {
t.Errorf("got cache %v, want %v", got, tc.cache) t.Errorf("got cache %v, want %v", got, tc.expectedCache)
} }
} }