fix(jwt): upgrade jwt to remove deprecated jwt.StandardClaims [EE-6469] (#10850)

pull/11550/head
Matt Hook 2024-04-23 17:33:36 +12:00 committed by GitHub
parent 2463648161
commit 505a2d5523
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 16 deletions

View File

@ -13,6 +13,8 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
const year = time.Hour * 24 * 365
// scope represents JWT scopes that are supported in JWT claims. // scope represents JWT scopes that are supported in JWT claims.
type scope string type scope string
@ -29,7 +31,7 @@ type claims struct {
Role int `json:"role"` Role int `json:"role"`
Scope scope `json:"scope"` Scope scope `json:"scope"`
ForceChangePassword bool `json:"forceChangePassword"` ForceChangePassword bool `json:"forceChangePassword"`
jwt.StandardClaims jwt.RegisteredClaims
} }
var ( var (
@ -98,7 +100,7 @@ func (service *Service) defaultExpireAt() time.Time {
// GenerateToken generates a new JWT token. // GenerateToken generates a new JWT token.
func (service *Service) GenerateToken(data *portainer.TokenData) (string, time.Time, error) { func (service *Service) GenerateToken(data *portainer.TokenData) (string, time.Time, error) {
expiryTime := service.defaultExpireAt() expiryTime := service.defaultExpireAt()
token, err := service.generateSignedToken(data, expiryTime.Unix(), defaultScope) token, err := service.generateSignedToken(data, expiryTime, defaultScope)
return token, expiryTime, err return token, expiryTime, err
} }
@ -121,7 +123,7 @@ func (service *Service) ParseAndVerifyToken(token string) (*portainer.TokenData,
if err != nil { if err != nil {
return nil, errInvalidJWTToken return nil, errInvalidJWTToken
} }
if user.TokenIssueAt > cl.StandardClaims.IssuedAt { if user.TokenIssueAt > cl.RegisteredClaims.ExpiresAt.Unix() {
return nil, errInvalidJWTToken return nil, errInvalidJWTToken
} }
@ -156,7 +158,7 @@ func (service *Service) SetUserSessionDuration(userSessionDuration time.Duration
service.userSessionTimeout = userSessionDuration service.userSessionTimeout = userSessionDuration
} }
func (service *Service) generateSignedToken(data *portainer.TokenData, expiresAt int64, scope scope) (string, error) { func (service *Service) generateSignedToken(data *portainer.TokenData, expiresAt time.Time, scope scope) (string, error) {
secret, found := service.secrets[scope] secret, found := service.secrets[scope]
if !found { if !found {
return "", fmt.Errorf("invalid scope: %v", scope) return "", fmt.Errorf("invalid scope: %v", scope)
@ -170,7 +172,7 @@ func (service *Service) generateSignedToken(data *portainer.TokenData, expiresAt
if settings.IsDockerDesktopExtension { if settings.IsDockerDesktopExtension {
// Set expiration to 99 years for docker desktop extension. // Set expiration to 99 years for docker desktop extension.
log.Info().Msg("detected docker desktop extension mode") log.Info().Msg("detected docker desktop extension mode")
expiresAt = time.Now().Add(time.Hour * 8760 * 99).Unix() expiresAt = time.Now().Add(year * 99)
} }
cl := claims{ cl := claims{
@ -179,10 +181,13 @@ func (service *Service) generateSignedToken(data *portainer.TokenData, expiresAt
Role: int(data.Role), Role: int(data.Role),
Scope: scope, Scope: scope,
ForceChangePassword: data.ForceChangePassword, ForceChangePassword: data.ForceChangePassword,
StandardClaims: jwt.StandardClaims{ }
ExpiresAt: expiresAt,
IssuedAt: time.Now().Unix(), if !expiresAt.IsZero() {
}, cl.RegisteredClaims = jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(time.Now()),
}
} }
token := jwt.NewWithClaims(jwt.SigningMethodHS256, cl) token := jwt.NewWithClaims(jwt.SigningMethodHS256, cl)

View File

@ -18,9 +18,9 @@ func (service *Service) GenerateTokenForKubeconfig(data *portainer.TokenData) (s
return "", err return "", err
} }
expiryAt := time.Now().Add(expiryDuration).Unix() expiryAt := time.Now().Add(expiryDuration)
if expiryDuration == time.Duration(0) { if expiryDuration == time.Duration(0) {
expiryAt = 0 expiryAt = time.Time{}
} }
return service.generateSignedToken(data, expiryAt, kubeConfigScope) return service.generateSignedToken(data, expiryAt, kubeConfigScope)

View File

@ -43,14 +43,14 @@ func TestService_GenerateTokenForKubeconfig(t *testing.T) {
name string name string
fields fields fields fields
args args args args
wantExpiresAt int64 wantExpiresAt *jwt.NumericDate
wantErr bool wantErr bool
}{ }{
{ {
name: "kubeconfig no expiry", name: "kubeconfig no expiry",
fields: myFields, fields: myFields,
args: myArgs, args: myArgs,
wantExpiresAt: 0, wantExpiresAt: nil,
wantErr: false, wantErr: false,
}, },
} }

View File

@ -20,7 +20,7 @@ func TestGenerateSignedToken(t *testing.T) {
ID: 1, ID: 1,
Role: 1, Role: 1,
} }
expiresAt := time.Now().Add(1 * time.Hour).Unix() expiresAt := time.Now().Add(1 * time.Hour)
generatedToken, err := svc.generateSignedToken(token, expiresAt, defaultScope) generatedToken, err := svc.generateSignedToken(token, expiresAt, defaultScope)
assert.NoError(t, err, "failed to generate a signed token") assert.NoError(t, err, "failed to generate a signed token")
@ -36,7 +36,7 @@ func TestGenerateSignedToken(t *testing.T) {
assert.Equal(t, token.Username, tokenClaims.Username) assert.Equal(t, token.Username, tokenClaims.Username)
assert.Equal(t, int(token.ID), tokenClaims.UserID) assert.Equal(t, int(token.ID), tokenClaims.UserID)
assert.Equal(t, int(token.Role), tokenClaims.Role) assert.Equal(t, int(token.Role), tokenClaims.Role)
assert.Equal(t, expiresAt, tokenClaims.ExpiresAt) assert.Equal(t, jwt.NewNumericDate(expiresAt), tokenClaims.ExpiresAt)
} }
func TestGenerateSignedToken_InvalidScope(t *testing.T) { func TestGenerateSignedToken_InvalidScope(t *testing.T) {
@ -49,7 +49,7 @@ func TestGenerateSignedToken_InvalidScope(t *testing.T) {
ID: 1, ID: 1,
Role: 1, Role: 1,
} }
expiresAt := time.Now().Add(1 * time.Hour).Unix() expiresAt := time.Now().Add(1 * time.Hour)
_, err = svc.generateSignedToken(token, expiresAt, "testing") _, err = svc.generateSignedToken(token, expiresAt, "testing")
assert.Error(t, err) assert.Error(t, err)