mirror of https://github.com/portainer/portainer
fix(oauth): add a timeout to getOAuthToken() BE-11283 (#63)
parent
e528cff615
commit
966fca950b
|
@ -3,10 +3,12 @@ package oauth
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
|
"maps"
|
||||||
"mime"
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
portainer "github.com/portainer/portainer/api"
|
portainer "github.com/portainer/portainer/api"
|
||||||
|
|
||||||
|
@ -29,28 +31,28 @@ func NewService() *Service {
|
||||||
// On success, it will then return the username and token expiry time associated to authenticated user by fetching this information
|
// On success, it will then return the username and token expiry time associated to authenticated user by fetching this information
|
||||||
// from the resource server and matching it with the user identifier setting.
|
// from the resource server and matching it with the user identifier setting.
|
||||||
func (*Service) Authenticate(code string, configuration *portainer.OAuthSettings) (string, error) {
|
func (*Service) Authenticate(code string, configuration *portainer.OAuthSettings) (string, error) {
|
||||||
token, err := getOAuthToken(code, configuration)
|
token, err := GetOAuthToken(code, configuration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("failed retrieving oauth token")
|
log.Error().Err(err).Msg("failed retrieving oauth token")
|
||||||
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, err := getIdToken(token)
|
idToken, err := GetIdToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("failed parsing id_token")
|
log.Error().Err(err).Msg("failed parsing id_token")
|
||||||
}
|
}
|
||||||
|
|
||||||
resource, err := getResource(token.AccessToken, configuration)
|
resource, err := GetResource(token.AccessToken, configuration.ResourceURI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("failed retrieving resource")
|
log.Error().Err(err).Msg("failed retrieving resource")
|
||||||
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
resource = mergeSecondIntoFirst(idToken, resource)
|
maps.Copy(idToken, resource)
|
||||||
|
|
||||||
username, err := getUsername(resource, configuration)
|
username, err := GetUsername(resource, configuration.UserIdentifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("failed retrieving username")
|
log.Error().Err(err).Msg("failed retrieving username")
|
||||||
|
|
||||||
|
@ -60,34 +62,24 @@ func (*Service) Authenticate(code string, configuration *portainer.OAuthSettings
|
||||||
return username, nil
|
return username, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// mergeSecondIntoFirst merges the overlap map into the base overwriting any existing values.
|
func GetOAuthToken(code string, configuration *portainer.OAuthSettings) (*oauth2.Token, error) {
|
||||||
func mergeSecondIntoFirst(base map[string]any, overlap map[string]any) map[string]any {
|
|
||||||
for k, v := range overlap {
|
|
||||||
base[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
return base
|
|
||||||
}
|
|
||||||
|
|
||||||
func getOAuthToken(code string, configuration *portainer.OAuthSettings) (*oauth2.Token, error) {
|
|
||||||
unescapedCode, err := url.QueryUnescape(code)
|
unescapedCode, err := url.QueryUnescape(code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
config := buildConfig(configuration)
|
config := buildConfig(configuration)
|
||||||
token, err := config.Exchange(context.Background(), unescapedCode)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return token, nil
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
return config.Exchange(ctx, unescapedCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getIdToken retrieves parsed id_token from the OAuth token response.
|
// GetIdToken retrieves parsed id_token from the OAuth token response.
|
||||||
// This is necessary for OAuth providers like Azure
|
// This is necessary for OAuth providers like Azure
|
||||||
// that do not provide information about user groups on the user resource endpoint.
|
// that do not provide information about user groups on the user resource endpoint.
|
||||||
func getIdToken(token *oauth2.Token) (map[string]any, error) {
|
func GetIdToken(token *oauth2.Token) (map[string]any, error) {
|
||||||
tokenData := make(map[string]any)
|
tokenData := make(map[string]any)
|
||||||
|
|
||||||
idToken := token.Extra("id_token")
|
idToken := token.Extra("id_token")
|
||||||
|
@ -113,8 +105,8 @@ func getIdToken(token *oauth2.Token) (map[string]any, error) {
|
||||||
return tokenData, nil
|
return tokenData, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getResource(token string, configuration *portainer.OAuthSettings) (map[string]any, error) {
|
func GetResource(token string, resourceURI string) (map[string]any, error) {
|
||||||
req, err := http.NewRequest("GET", configuration.ResourceURI, nil)
|
req, err := http.NewRequest(http.MethodGet, resourceURI, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -159,6 +151,7 @@ func getResource(token string, configuration *portainer.OAuthSettings) (map[stri
|
||||||
datamap[k] = v[0]
|
datamap[k] = v[0]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return datamap, nil
|
return datamap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,18 +163,16 @@ func getResource(token string, configuration *portainer.OAuthSettings) (map[stri
|
||||||
return datamap, nil
|
return datamap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildConfig(configuration *portainer.OAuthSettings) *oauth2.Config {
|
func buildConfig(config *portainer.OAuthSettings) *oauth2.Config {
|
||||||
endpoint := oauth2.Endpoint{
|
|
||||||
AuthURL: configuration.AuthorizationURI,
|
|
||||||
TokenURL: configuration.AccessTokenURI,
|
|
||||||
AuthStyle: configuration.AuthStyle,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &oauth2.Config{
|
return &oauth2.Config{
|
||||||
ClientID: configuration.ClientID,
|
ClientID: config.ClientID,
|
||||||
ClientSecret: configuration.ClientSecret,
|
ClientSecret: config.ClientSecret,
|
||||||
Endpoint: endpoint,
|
RedirectURL: config.RedirectURI,
|
||||||
RedirectURL: configuration.RedirectURI,
|
Scopes: strings.Split(config.Scopes, ","),
|
||||||
Scopes: strings.Split(configuration.Scopes, ","),
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: config.AuthorizationURI,
|
||||||
|
TokenURL: config.AccessTokenURI,
|
||||||
|
AuthStyle: config.AuthStyle,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,18 +3,16 @@ package oauth
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
portainer "github.com/portainer/portainer/api"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func getUsername(datamap map[string]any, configuration *portainer.OAuthSettings) (string, error) {
|
func GetUsername(datamap map[string]any, userIdentifier string) (string, error) {
|
||||||
username, ok := datamap[configuration.UserIdentifier].(string)
|
username, ok := datamap[userIdentifier].(string)
|
||||||
if ok && username != "" {
|
if ok && username != "" {
|
||||||
return username, nil
|
return username, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
username, ok := datamap[configuration.UserIdentifier].(float64)
|
username, ok := datamap[userIdentifier].(float64)
|
||||||
if ok && username != 0 {
|
if ok && username != 0 {
|
||||||
return strconv.Itoa(int(username)), nil
|
return strconv.Itoa(int(username)), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,8 +11,7 @@ func Test_getUsername(t *testing.T) {
|
||||||
oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"}
|
oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"}
|
||||||
datamap := map[string]any{"name": "john"}
|
datamap := map[string]any{"name": "john"}
|
||||||
|
|
||||||
_, err := getUsername(datamap, oauthSettings)
|
if _, err := GetUsername(datamap, oauthSettings.UserIdentifier); err == nil {
|
||||||
if err == nil {
|
|
||||||
t.Errorf("getUsername should fail if user identifier doesn't exist as key in oauth userinfo object")
|
t.Errorf("getUsername should fail if user identifier doesn't exist as key in oauth userinfo object")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -21,8 +20,7 @@ func Test_getUsername(t *testing.T) {
|
||||||
oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"}
|
oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"}
|
||||||
datamap := map[string]any{"username": ""}
|
datamap := map[string]any{"username": ""}
|
||||||
|
|
||||||
_, err := getUsername(datamap, oauthSettings)
|
if _, err := GetUsername(datamap, oauthSettings.UserIdentifier); err == nil {
|
||||||
if err == nil {
|
|
||||||
t.Errorf("getUsername should fail if username from oauth userinfo object is empty string")
|
t.Errorf("getUsername should fail if username from oauth userinfo object is empty string")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -31,8 +29,7 @@ func Test_getUsername(t *testing.T) {
|
||||||
oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"}
|
oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"}
|
||||||
datamap := map[string]any{"username": 0}
|
datamap := map[string]any{"username": 0}
|
||||||
|
|
||||||
_, err := getUsername(datamap, oauthSettings)
|
if _, err := GetUsername(datamap, oauthSettings.UserIdentifier); err == nil {
|
||||||
if err == nil {
|
|
||||||
t.Errorf("getUsername should fail if username from oauth userinfo object is 0 val int")
|
t.Errorf("getUsername should fail if username from oauth userinfo object is 0 val int")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -41,8 +38,7 @@ func Test_getUsername(t *testing.T) {
|
||||||
oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"}
|
oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"}
|
||||||
datamap := map[string]any{"username": -1}
|
datamap := map[string]any{"username": -1}
|
||||||
|
|
||||||
_, err := getUsername(datamap, oauthSettings)
|
if _, err := GetUsername(datamap, oauthSettings.UserIdentifier); err == nil {
|
||||||
if err == nil {
|
|
||||||
t.Errorf("getUsername should fail if username from oauth userinfo object is -1 (negative) int")
|
t.Errorf("getUsername should fail if username from oauth userinfo object is -1 (negative) int")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -51,8 +47,7 @@ func Test_getUsername(t *testing.T) {
|
||||||
oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"}
|
oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"}
|
||||||
datamap := map[string]any{"username": "john"}
|
datamap := map[string]any{"username": "john"}
|
||||||
|
|
||||||
_, err := getUsername(datamap, oauthSettings)
|
if _, err := GetUsername(datamap, oauthSettings.UserIdentifier); err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Errorf("getUsername should succeed if username from oauth userinfo object matched and non-empty")
|
t.Errorf("getUsername should succeed if username from oauth userinfo object matched and non-empty")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -62,8 +57,7 @@ func Test_getUsername(t *testing.T) {
|
||||||
oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"}
|
oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"}
|
||||||
datamap := map[string]any{"username": 1}
|
datamap := map[string]any{"username": 1}
|
||||||
|
|
||||||
_, err := getUsername(datamap, oauthSettings)
|
if _, err := GetUsername(datamap, oauthSettings.UserIdentifier); err == nil {
|
||||||
if err == nil {
|
|
||||||
t.Errorf("getUsername should fail if username from oauth userinfo object matched is positive int")
|
t.Errorf("getUsername should fail if username from oauth userinfo object matched is positive int")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -72,8 +66,7 @@ func Test_getUsername(t *testing.T) {
|
||||||
oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"}
|
oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"}
|
||||||
datamap := map[string]any{"username": 1.1}
|
datamap := map[string]any{"username": 1.1}
|
||||||
|
|
||||||
_, err := getUsername(datamap, oauthSettings)
|
if _, err := GetUsername(datamap, oauthSettings.UserIdentifier); err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Errorf("getUsername should succeed if username from oauth userinfo object matched and non-zero (or negative)")
|
t.Errorf("getUsername should succeed if username from oauth userinfo object matched and non-zero (or negative)")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
|
|
||||||
portainer "github.com/portainer/portainer/api"
|
portainer "github.com/portainer/portainer/api"
|
||||||
"github.com/portainer/portainer/api/oauth/oauthtest"
|
"github.com/portainer/portainer/api/oauth/oauthtest"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
@ -16,14 +17,14 @@ func Test_getOAuthToken(t *testing.T) {
|
||||||
|
|
||||||
t.Run("getOAuthToken fails upon invalid code", func(t *testing.T) {
|
t.Run("getOAuthToken fails upon invalid code", func(t *testing.T) {
|
||||||
code := ""
|
code := ""
|
||||||
if _, err := getOAuthToken(code, config); err == nil {
|
if _, err := GetOAuthToken(code, config); err == nil {
|
||||||
t.Errorf("getOAuthToken should fail upon providing invalid code; code=%v", code)
|
t.Errorf("getOAuthToken should fail upon providing invalid code; code=%v", code)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("getOAuthToken succeeds upon providing valid code", func(t *testing.T) {
|
t.Run("getOAuthToken succeeds upon providing valid code", func(t *testing.T) {
|
||||||
code := validCode
|
code := validCode
|
||||||
token, err := getOAuthToken(code, config)
|
token, err := GetOAuthToken(code, config)
|
||||||
|
|
||||||
if token == nil || err != nil {
|
if token == nil || err != nil {
|
||||||
t.Errorf("getOAuthToken should successfully return access token upon providing valid code")
|
t.Errorf("getOAuthToken should successfully return access token upon providing valid code")
|
||||||
|
@ -78,7 +79,7 @@ func Test_getIdToken(t *testing.T) {
|
||||||
token = token.WithExtra(map[string]any{"id_token": tc.idToken})
|
token = token.WithExtra(map[string]any{"id_token": tc.idToken})
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := getIdToken(token)
|
result, err := GetIdToken(token)
|
||||||
assert.Equal(t, err, tc.expectedError)
|
assert.Equal(t, err, tc.expectedError)
|
||||||
assert.Equal(t, result, tc.expectedResult)
|
assert.Equal(t, result, tc.expectedResult)
|
||||||
})
|
})
|
||||||
|
@ -90,19 +91,19 @@ func Test_getResource(t *testing.T) {
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
t.Run("should fail upon missing Authorization Bearer header", func(t *testing.T) {
|
t.Run("should fail upon missing Authorization Bearer header", func(t *testing.T) {
|
||||||
if _, err := getResource("", config); err == nil {
|
if _, err := GetResource("", config.ResourceURI); err == nil {
|
||||||
t.Errorf("getResource should fail if access token is not provided in auth bearer header")
|
t.Errorf("getResource should fail if access token is not provided in auth bearer header")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("should fail upon providing incorrect Authorization Bearer header", func(t *testing.T) {
|
t.Run("should fail upon providing incorrect Authorization Bearer header", func(t *testing.T) {
|
||||||
if _, err := getResource("incorrect-token", config); err == nil {
|
if _, err := GetResource("incorrect-token", config.ResourceURI); err == nil {
|
||||||
t.Errorf("getResource should fail if incorrect access token provided in auth bearer header")
|
t.Errorf("getResource should fail if incorrect access token provided in auth bearer header")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("should succeed upon providing correct Authorization Bearer header", func(t *testing.T) {
|
t.Run("should succeed upon providing correct Authorization Bearer header", func(t *testing.T) {
|
||||||
if _, err := getResource(oauthtest.AccessToken, config); err != nil {
|
if _, err := GetResource(oauthtest.AccessToken, config.ResourceURI); err != nil {
|
||||||
t.Errorf("getResource should succeed if correct access token provided in auth bearer header")
|
t.Errorf("getResource should succeed if correct access token provided in auth bearer header")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
Loading…
Reference in New Issue