mirror of https://github.com/k3s-io/k3s
Merge pull request #66314 from jlowdermilk/cmdtokensource-reset
gcp client auth plugin: persist default cache on unauthorizedpull/8/head
commit
c04fe8c27c
|
@ -174,7 +174,13 @@ func parseScopes(gcpConfig map[string]string) []string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *gcpAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
|
func (g *gcpAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
|
||||||
return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}, g.persister}
|
var resetCache map[string]string
|
||||||
|
if cts, ok := g.tokenSource.(*cachedTokenSource); ok {
|
||||||
|
resetCache = cts.baseCache()
|
||||||
|
} else {
|
||||||
|
resetCache = make(map[string]string)
|
||||||
|
}
|
||||||
|
return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}, g.persister, resetCache}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *gcpAuthProvider) Login() error { return nil }
|
func (g *gcpAuthProvider) Login() error { return nil }
|
||||||
|
@ -247,6 +253,19 @@ func (t *cachedTokenSource) update(tok *oauth2.Token) map[string]string {
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// baseCache is the base configuration value for this TokenSource, without any cached ephemeral tokens.
|
||||||
|
func (t *cachedTokenSource) baseCache() map[string]string {
|
||||||
|
t.lk.Lock()
|
||||||
|
defer t.lk.Unlock()
|
||||||
|
ret := map[string]string{}
|
||||||
|
for k, v := range t.cache {
|
||||||
|
ret[k] = v
|
||||||
|
}
|
||||||
|
delete(ret, "access-token")
|
||||||
|
delete(ret, "expiry")
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
type commandTokenSource struct {
|
type commandTokenSource struct {
|
||||||
cmd string
|
cmd string
|
||||||
args []string
|
args []string
|
||||||
|
@ -337,6 +356,7 @@ func parseJSONPath(input interface{}, name, template string) (string, error) {
|
||||||
type conditionalTransport struct {
|
type conditionalTransport struct {
|
||||||
oauthTransport *oauth2.Transport
|
oauthTransport *oauth2.Transport
|
||||||
persister restclient.AuthProviderConfigPersister
|
persister restclient.AuthProviderConfigPersister
|
||||||
|
resetCache map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ net.RoundTripperWrapper = &conditionalTransport{}
|
var _ net.RoundTripperWrapper = &conditionalTransport{}
|
||||||
|
@ -354,8 +374,7 @@ func (t *conditionalTransport) RoundTrip(req *http.Request) (*http.Response, err
|
||||||
|
|
||||||
if res.StatusCode == 401 {
|
if res.StatusCode == 401 {
|
||||||
glog.V(4).Infof("The credentials that were supplied are invalid for the target cluster")
|
glog.V(4).Infof("The credentials that were supplied are invalid for the target cluster")
|
||||||
emptyCache := make(map[string]string)
|
t.persister.Persist(t.resetCache)
|
||||||
t.persister.Persist(emptyCache)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
|
|
|
@ -442,37 +442,61 @@ func (t *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
return t.res, nil
|
return t.res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClearingCredentials(t *testing.T) {
|
func Test_cmdTokenSource_roundTrip(t *testing.T) {
|
||||||
|
|
||||||
|
accessToken := "fakeToken"
|
||||||
fakeExpiry := time.Now().Add(time.Hour)
|
fakeExpiry := time.Now().Add(time.Hour)
|
||||||
|
fakeExpiryStr := fakeExpiry.Format(time.RFC3339Nano)
|
||||||
cache := map[string]string{
|
fs := &fakeTokenSource{
|
||||||
"access-token": "fakeToken",
|
token: &oauth2.Token{
|
||||||
"expiry": fakeExpiry.String(),
|
AccessToken: accessToken,
|
||||||
|
Expiry: fakeExpiry,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
cts := cachedTokenSource{
|
cmdCache := map[string]string{
|
||||||
source: nil,
|
"cmd-path": "/path/to/tokensource/cmd",
|
||||||
accessToken: cache["access-token"],
|
"cmd-args": "--output=json",
|
||||||
expiry: fakeExpiry,
|
}
|
||||||
persister: nil,
|
cmdCacheUpdated := map[string]string{
|
||||||
cache: nil,
|
"cmd-path": "/path/to/tokensource/cmd",
|
||||||
|
"cmd-args": "--output=json",
|
||||||
|
"access-token": accessToken,
|
||||||
|
"expiry": fakeExpiryStr,
|
||||||
|
}
|
||||||
|
simpleCacheUpdated := map[string]string{
|
||||||
|
"access-token": accessToken,
|
||||||
|
"expiry": fakeExpiryStr,
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
res http.Response
|
res http.Response
|
||||||
cache map[string]string
|
baseCache, expectedCache map[string]string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"Unauthorized",
|
"Unauthorized",
|
||||||
http.Response{StatusCode: 401},
|
http.Response{StatusCode: 401},
|
||||||
make(map[string]string),
|
make(map[string]string),
|
||||||
|
make(map[string]string),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Unauthorized, nonempty defaultCache",
|
||||||
|
http.Response{StatusCode: 401},
|
||||||
|
cmdCache,
|
||||||
|
cmdCache,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"Authorized",
|
"Authorized",
|
||||||
http.Response{StatusCode: 200},
|
http.Response{StatusCode: 200},
|
||||||
cache,
|
make(map[string]string),
|
||||||
|
simpleCacheUpdated,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Authorized, nonempty defaultCache",
|
||||||
|
http.Response{StatusCode: 200},
|
||||||
|
cmdCache,
|
||||||
|
cmdCacheUpdated,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -480,17 +504,23 @@ func TestClearingCredentials(t *testing.T) {
|
||||||
req := http.Request{Header: http.Header{}}
|
req := http.Request{Header: http.Header{}}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
authProvider := gcpAuthProvider{&cts, persister}
|
cts, err := newCachedTokenSource(accessToken, fakeExpiry.String(), persister, fs, tc.baseCache)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error from newCachedTokenSource: %v", err)
|
||||||
|
}
|
||||||
|
authProvider := gcpAuthProvider{cts, persister}
|
||||||
|
|
||||||
fakeTransport := MockTransport{&tc.res}
|
fakeTransport := MockTransport{&tc.res}
|
||||||
|
|
||||||
transport := (authProvider.WrapTransport(&fakeTransport))
|
transport := (authProvider.WrapTransport(&fakeTransport))
|
||||||
persister.Persist(cache)
|
// call Token to persist/update cache
|
||||||
|
if _, err := cts.Token(); err != nil {
|
||||||
|
t.Fatalf("unexpected error from cachedTokenSource.Token(): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
transport.RoundTrip(&req)
|
transport.RoundTrip(&req)
|
||||||
|
|
||||||
if got := persister.read(); !reflect.DeepEqual(got, tc.cache) {
|
if got := persister.read(); !reflect.DeepEqual(got, tc.expectedCache) {
|
||||||
t.Errorf("got cache %v, want %v", got, tc.cache)
|
t.Errorf("got cache %v, want %v", got, tc.expectedCache)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue