diff --git a/api/exec/kubernetes_deploy.go b/api/exec/kubernetes_deploy.go index acadef038..ff51a8bbe 100644 --- a/api/exec/kubernetes_deploy.go +++ b/api/exec/kubernetes_deploy.go @@ -49,7 +49,7 @@ func (deployer *KubernetesDeployer) getToken(userID portainer.UserID, endpoint * return "", err } - tokenCache := deployer.kubernetesTokenCacheManager.GetOrCreateTokenCache(int(endpoint.ID)) + tokenCache := deployer.kubernetesTokenCacheManager.GetOrCreateTokenCache(endpoint.ID) tokenManager, err := kubernetes.NewTokenManager(kubeCLI, deployer.dataStore, tokenCache, setLocalAdminToken) if err != nil { diff --git a/api/http/handler/auth/logout.go b/api/http/handler/auth/logout.go index ec3e0b9ac..afadc05c0 100644 --- a/api/http/handler/auth/logout.go +++ b/api/http/handler/auth/logout.go @@ -23,7 +23,7 @@ func (handler *Handler) logout(w http.ResponseWriter, r *http.Request) *httperro return httperror.InternalServerError("Unable to retrieve user details from authentication token", err) } - handler.KubernetesTokenCacheManager.RemoveUserFromCache(int(tokenData.ID)) + handler.KubernetesTokenCacheManager.RemoveUserFromCache(tokenData.ID) return response.Empty(w) } diff --git a/api/http/handler/websocket/pod.go b/api/http/handler/websocket/pod.go index 45fb85271..8520a5a3f 100644 --- a/api/http/handler/websocket/pod.go +++ b/api/http/handler/websocket/pod.go @@ -170,7 +170,7 @@ func (handler *Handler) getToken(request *http.Request, endpoint *portainer.Endp return "", false, err } - tokenCache := handler.kubernetesTokenCacheManager.GetOrCreateTokenCache(int(endpoint.ID)) + tokenCache := handler.kubernetesTokenCacheManager.GetOrCreateTokenCache(endpoint.ID) tokenManager, err := kubernetes.NewTokenManager(kubecli, handler.DataStore, tokenCache, setLocalAdminToken) if err != nil { diff --git a/api/http/proxy/factory/kubernetes.go b/api/http/proxy/factory/kubernetes.go index f56a8f857..e676b9ea4 100644 --- a/api/http/proxy/factory/kubernetes.go +++ b/api/http/proxy/factory/kubernetes.go @@ -33,7 +33,7 @@ func (factory *ProxyFactory) newKubernetesLocalProxy(endpoint *portainer.Endpoin return nil, err } - tokenCache := factory.kubernetesTokenCacheManager.CreateTokenCache(int(endpoint.ID)) + tokenCache := factory.kubernetesTokenCacheManager.GetOrCreateTokenCache(endpoint.ID) tokenManager, err := kubernetes.NewTokenManager(kubecli, factory.dataStore, tokenCache, true) if err != nil { return nil, err @@ -64,7 +64,7 @@ func (factory *ProxyFactory) newKubernetesEdgeHTTPProxy(endpoint *portainer.Endp return nil, err } - tokenCache := factory.kubernetesTokenCacheManager.CreateTokenCache(int(endpoint.ID)) + tokenCache := factory.kubernetesTokenCacheManager.GetOrCreateTokenCache(endpoint.ID) tokenManager, err := kubernetes.NewTokenManager(kubecli, factory.dataStore, tokenCache, false) if err != nil { return nil, err @@ -96,7 +96,7 @@ func (factory *ProxyFactory) newKubernetesAgentHTTPSProxy(endpoint *portainer.En return nil, err } - tokenCache := factory.kubernetesTokenCacheManager.CreateTokenCache(int(endpoint.ID)) + tokenCache := factory.kubernetesTokenCacheManager.GetOrCreateTokenCache(endpoint.ID) tokenManager, err := kubernetes.NewTokenManager(kubecli, factory.dataStore, tokenCache, false) if err != nil { return nil, err diff --git a/api/http/proxy/factory/kubernetes/token.go b/api/http/proxy/factory/kubernetes/token.go index f2d620630..e794f4189 100644 --- a/api/http/proxy/factory/kubernetes/token.go +++ b/api/http/proxy/factory/kubernetes/token.go @@ -43,18 +43,15 @@ func (manager *tokenManager) GetAdminServiceAccountToken() string { return manager.adminToken } +// GetUserServiceAccountToken setup a user's service account if it does not exist, then retrieve its token func (manager *tokenManager) GetUserServiceAccountToken(userID int, endpointID portainer.EndpointID) (string, error) { - manager.tokenCache.mutex.Lock() - defer manager.tokenCache.mutex.Unlock() - - token, ok := manager.tokenCache.getToken(userID) - if !ok { + tokenFunc := func() (string, error) { memberships, err := manager.dataStore.TeamMembership().TeamMembershipsByUserID(portainer.UserID(userID)) if err != nil { return "", err } - teamIds := make([]int, 0) + teamIds := make([]int, 0, len(memberships)) for _, membership := range memberships { teamIds = append(teamIds, int(membership.TeamID)) } @@ -70,14 +67,8 @@ func (manager *tokenManager) GetUserServiceAccountToken(userID int, endpointID p return "", err } - serviceAccountToken, err := manager.kubecli.GetServiceAccountBearerToken(userID) - if err != nil { - return "", err - } - - manager.tokenCache.addToken(userID, serviceAccountToken) - token = serviceAccountToken + return manager.kubecli.GetServiceAccountBearerToken(userID) } - return token, nil + return manager.tokenCache.getOrAddToken(portainer.UserID(userID), tokenFunc) } diff --git a/api/http/proxy/factory/kubernetes/token_cache.go b/api/http/proxy/factory/kubernetes/token_cache.go index fd701000d..f00d5befe 100644 --- a/api/http/proxy/factory/kubernetes/token_cache.go +++ b/api/http/proxy/factory/kubernetes/token_cache.go @@ -1,84 +1,78 @@ package kubernetes import ( - "strconv" "sync" - cmap "github.com/orcaman/concurrent-map" + portainer "github.com/portainer/portainer/api" ) -type ( - // TokenCacheManager represents a service used to manage multiple tokenCache objects. - TokenCacheManager struct { - tokenCaches cmap.ConcurrentMap - } +// TokenCacheManager represents a service used to manage multiple tokenCache objects. +type TokenCacheManager struct { + tokenCaches map[portainer.EndpointID]*tokenCache + mu sync.Mutex +} - tokenCache struct { - userTokenCache cmap.ConcurrentMap - mutex sync.Mutex - } -) +type tokenCache struct { + userTokenCache map[portainer.UserID]string + mu sync.Mutex +} // NewTokenCacheManager returns a pointer to a new instance of TokenCacheManager func NewTokenCacheManager() *TokenCacheManager { return &TokenCacheManager{ - tokenCaches: cmap.New(), + tokenCaches: make(map[portainer.EndpointID]*tokenCache), } } -// CreateTokenCache will create a new tokenCache object, associate it to the manager map of caches -// and return a pointer to that tokenCache instance. -func (manager *TokenCacheManager) CreateTokenCache(endpointID int) *tokenCache { - tokenCache := newTokenCache() - - key := strconv.Itoa(endpointID) - manager.tokenCaches.Set(key, tokenCache) - - return tokenCache -} - // GetOrCreateTokenCache will get the tokenCache from the manager map of caches if it exists, // otherwise it will create a new tokenCache object, associate it to the manager map of caches // and return a pointer to that tokenCache instance. -func (manager *TokenCacheManager) GetOrCreateTokenCache(endpointID int) *tokenCache { - key := strconv.Itoa(endpointID) - if epCache, ok := manager.tokenCaches.Get(key); ok { - return epCache.(*tokenCache) +func (manager *TokenCacheManager) GetOrCreateTokenCache(endpointID portainer.EndpointID) *tokenCache { + manager.mu.Lock() + defer manager.mu.Unlock() + + if tc, ok := manager.tokenCaches[endpointID]; ok { + return tc } - return manager.CreateTokenCache(endpointID) + tc := &tokenCache{ + userTokenCache: make(map[portainer.UserID]string), + } + + manager.tokenCaches[endpointID] = tc + + return tc } // RemoveUserFromCache will ensure that the specific userID is removed from all registered caches. -func (manager *TokenCacheManager) RemoveUserFromCache(userID int) { - for cache := range manager.tokenCaches.IterBuffered() { - cache.Val.(*tokenCache).removeToken(userID) +func (manager *TokenCacheManager) RemoveUserFromCache(userID portainer.UserID) { + manager.mu.Lock() + for _, tc := range manager.tokenCaches { + tc.removeToken(userID) } + manager.mu.Unlock() } -func newTokenCache() *tokenCache { - return &tokenCache{ - userTokenCache: cmap.New(), - mutex: sync.Mutex{}, - } -} +func (cache *tokenCache) getOrAddToken(userID portainer.UserID, tokenGetFunc func() (string, error)) (string, error) { + cache.mu.Lock() + defer cache.mu.Unlock() -func (cache *tokenCache) getToken(userID int) (string, bool) { - key := strconv.Itoa(userID) - token, ok := cache.userTokenCache.Get(key) - if ok { - return token.(string), true + if tok, ok := cache.userTokenCache[userID]; ok { + return tok, nil } - return "", false + tok, err := tokenGetFunc() + if err != nil { + return "", err + } + + cache.userTokenCache[userID] = tok + + return tok, nil } -func (cache *tokenCache) addToken(userID int, token string) { - key := strconv.Itoa(userID) - cache.userTokenCache.Set(key, token) -} - -func (cache *tokenCache) removeToken(userID int) { - key := strconv.Itoa(userID) - cache.userTokenCache.Remove(key) +func (cache *tokenCache) removeToken(userID portainer.UserID) { + cache.mu.Lock() + delete(cache.userTokenCache, userID) + cache.mu.Unlock() } diff --git a/api/http/proxy/factory/kubernetes/token_cache_test.go b/api/http/proxy/factory/kubernetes/token_cache_test.go new file mode 100644 index 000000000..525576d64 --- /dev/null +++ b/api/http/proxy/factory/kubernetes/token_cache_test.go @@ -0,0 +1,102 @@ +package kubernetes + +import ( + "errors" + "testing" + + portainer "github.com/portainer/portainer/api" +) + +func noTokFunc() (string, error) { + return "", errors.New("no token found") +} + +func stringTok(tok string) func() (string, error) { + return func() (string, error) { + return tok, nil + } +} + +func failFunc(t *testing.T) func() (string, error) { + return func() (string, error) { + t.FailNow() + return noTokFunc() + } +} + +func TestTokenCacheDataRace(t *testing.T) { + ch := make(chan struct{}) + + for i := 0; i < 1000; i++ { + var tokenCache1, tokenCache2 *tokenCache + + mgr := NewTokenCacheManager() + + go func() { + tokenCache1 = mgr.GetOrCreateTokenCache(1) + ch <- struct{}{} + }() + + go func() { + tokenCache2 = mgr.GetOrCreateTokenCache(1) + ch <- struct{}{} + }() + + <-ch + <-ch + + if tokenCache1 != tokenCache2 { + t.FailNow() + } + } +} + +func TestTokenCache(t *testing.T) { + mgr := NewTokenCacheManager() + tc1 := mgr.GetOrCreateTokenCache(1) + tc2 := mgr.GetOrCreateTokenCache(2) + tc3 := mgr.GetOrCreateTokenCache(3) + + uid := portainer.UserID(2) + tokenString1 := "token-string-1" + tokenString2 := "token-string-2" + + tok, err := tc1.getOrAddToken(uid, stringTok(tokenString1)) + if err != nil || tok != tokenString1 { + t.FailNow() + } + + tok, err = tc1.getOrAddToken(uid, failFunc(t)) + if err != nil || tok != tokenString1 { + t.FailNow() + } + + tok, err = tc2.getOrAddToken(uid, stringTok(tokenString2)) + if err != nil || tok != tokenString2 { + t.FailNow() + } + + _, err = tc3.getOrAddToken(uid, noTokFunc) + if err == nil { + t.FailNow() + } + + // Remove one user from all the caches + + mgr.RemoveUserFromCache(uid) + + _, err = tc1.getOrAddToken(uid, noTokFunc) + if err == nil { + t.FailNow() + } + + _, err = tc2.getOrAddToken(uid, noTokFunc) + if err == nil { + t.FailNow() + } + + _, err = tc3.getOrAddToken(uid, noTokFunc) + if err == nil { + t.FailNow() + } +}