diff --git a/api/oauth/oauth.go b/api/oauth/oauth.go index 8ff5afeee..59a5b3b36 100644 --- a/api/oauth/oauth.go +++ b/api/oauth/oauth.go @@ -3,16 +3,18 @@ package oauth import ( "context" "encoding/json" - "fmt" "io/ioutil" - "log" "mime" "net/http" "net/url" + "strings" "golang.org/x/oauth2" + "github.com/golang-jwt/jwt" + "github.com/pkg/errors" portainer "github.com/portainer/portainer/api" + log "github.com/sirupsen/logrus" ) // Service represents a service used to authenticate users against an authorization server @@ -29,17 +31,39 @@ func NewService() *Service { func (*Service) Authenticate(code string, configuration *portainer.OAuthSettings) (string, error) { token, err := getOAuthToken(code, configuration) if err != nil { - log.Printf("[DEBUG] - Failed retrieving access token: %v", err) + log.Debugf("[internal,oauth] [message: failed retrieving oauth token: %v]", err) return "", err } - username, err := getUsername(token.AccessToken, configuration) + + idToken, err := getIdToken(token) + if err != nil { + log.Debugf("[internal,oauth] [message: failed parsing id_token: %v]", err) + } + + resource, err := getResource(token.AccessToken, configuration) if err != nil { - log.Printf("[DEBUG] - Failed retrieving oauth user name: %v", err) + log.Debugf("[internal,oauth] [message: failed retrieving resource: %v]", err) + return "", err + } + + resource = mergeSecondIntoFirst(idToken, resource) + + username, err := getUsername(resource, configuration) + if err != nil { + log.Debugf("[internal,oauth] [message: failed retrieving username: %v]", err) return "", err } return username, nil } +// mergeSecondIntoFirst merges the overlap map into the base overwriting any existing values. +func mergeSecondIntoFirst(base map[string]interface{}, overlap map[string]interface{}) map[string]interface{} { + for k, v := range overlap { + base[k] = v + } + return base +} + func getOAuthToken(code string, configuration *portainer.OAuthSettings) (*oauth2.Token, error) { unescapedCode, err := url.QueryUnescape(code) if err != nil { @@ -55,27 +79,55 @@ func getOAuthToken(code string, configuration *portainer.OAuthSettings) (*oauth2 return token, nil } -func getUsername(token string, configuration *portainer.OAuthSettings) (string, error) { +// getIdToken retrieves parsed id_token from the OAuth token response. +// This is necessary for OAuth providers like Azure +// that do not provide information about user groups on the user resource endpoint. +func getIdToken(token *oauth2.Token) (map[string]interface{}, error) { + tokenData := make(map[string]interface{}) + + idToken := token.Extra("id_token") + if idToken == nil { + return tokenData, nil + } + + jwtParser := jwt.Parser{ + SkipClaimsValidation: true, + } + + t, _, err := jwtParser.ParseUnverified(idToken.(string), jwt.MapClaims{}) + if err != nil { + return tokenData, errors.Wrap(err, "failed to parse id_token") + } + + if claims, ok := t.Claims.(jwt.MapClaims); ok { + for k, v := range claims { + tokenData[k] = v + } + } + return tokenData, nil +} + +func getResource(token string, configuration *portainer.OAuthSettings) (map[string]interface{}, error) { req, err := http.NewRequest("GET", configuration.ResourceURI, nil) if err != nil { - return "", err + return nil, err } client := &http.Client{} req.Header.Set("Authorization", "Bearer "+token) resp, err := client.Do(req) if err != nil { - return "", err + return nil, err } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { - return "", err + return nil, err } if resp.StatusCode != http.StatusOK { - return "", &oauth2.RetrieveError{ + return nil, &oauth2.RetrieveError{ Response: resp, Body: body, } @@ -83,47 +135,32 @@ func getUsername(token string, configuration *portainer.OAuthSettings) (string, content, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) if err != nil { - return "", err + return nil, err } if content == "application/x-www-form-urlencoded" || content == "text/plain" { values, err := url.ParseQuery(string(body)) if err != nil { - return "", err + return nil, err } - username := values.Get(configuration.UserIdentifier) - if username == "" { - return username, &oauth2.RetrieveError{ - Response: resp, - Body: body, + datamap := make(map[string]interface{}) + for k, v := range values { + if len(v) == 0 { + datamap[k] = "" + } else { + datamap[k] = v[0] } } - - return username, nil + return datamap, nil } var datamap map[string]interface{} if err = json.Unmarshal(body, &datamap); err != nil { - return "", err - } - - username, ok := datamap[configuration.UserIdentifier].(string) - if ok && username != "" { - return username, nil - } - - if !ok { - username, ok := datamap[configuration.UserIdentifier].(float64) - if ok && username != 0 { - return fmt.Sprint(int(username)), nil - } + return nil, err } - return "", &oauth2.RetrieveError{ - Response: resp, - Body: body, - } + return datamap, nil } func buildConfig(configuration *portainer.OAuthSettings) *oauth2.Config { @@ -137,6 +174,6 @@ func buildConfig(configuration *portainer.OAuthSettings) *oauth2.Config { ClientSecret: configuration.ClientSecret, Endpoint: endpoint, RedirectURL: configuration.RedirectURI, - Scopes: []string{configuration.Scopes}, + Scopes: strings.Split(configuration.Scopes, ","), } } diff --git a/api/oauth/oauth_resource.go b/api/oauth/oauth_resource.go new file mode 100644 index 000000000..3c5124070 --- /dev/null +++ b/api/oauth/oauth_resource.go @@ -0,0 +1,24 @@ +package oauth + +import ( + "errors" + "fmt" + + portainer "github.com/portainer/portainer/api" +) + +func getUsername(datamap map[string]interface{}, configuration *portainer.OAuthSettings) (string, error) { + username, ok := datamap[configuration.UserIdentifier].(string) + if ok && username != "" { + return username, nil + } + + if !ok { + username, ok := datamap[configuration.UserIdentifier].(float64) + if ok && username != 0 { + return fmt.Sprint(int(username)), nil + } + } + + return "", errors.New("failed to extract username from oauth resource") +} diff --git a/api/oauth/oauth_resource_test.go b/api/oauth/oauth_resource_test.go new file mode 100644 index 000000000..3051171e5 --- /dev/null +++ b/api/oauth/oauth_resource_test.go @@ -0,0 +1,80 @@ +package oauth + +import ( + "testing" + + portaineree "github.com/portainer/portainer/api" +) + +func Test_getUsername(t *testing.T) { + t.Run("fails for non-matching user identifier", func(t *testing.T) { + oauthSettings := &portaineree.OAuthSettings{UserIdentifier: "username"} + datamap := map[string]interface{}{"name": "john"} + + _, err := getUsername(datamap, oauthSettings) + if err == nil { + t.Errorf("getUsername should fail if user identifier doesn't exist as key in oauth userinfo object") + } + }) + + t.Run("fails if username is empty string", func(t *testing.T) { + oauthSettings := &portaineree.OAuthSettings{UserIdentifier: "username"} + datamap := map[string]interface{}{"username": ""} + + _, err := getUsername(datamap, oauthSettings) + if err == nil { + t.Errorf("getUsername should fail if username from oauth userinfo object is empty string") + } + }) + + t.Run("fails if username is 0 int", func(t *testing.T) { + oauthSettings := &portaineree.OAuthSettings{UserIdentifier: "username"} + datamap := map[string]interface{}{"username": 0} + + _, err := getUsername(datamap, oauthSettings) + if err == nil { + t.Errorf("getUsername should fail if username from oauth userinfo object is 0 val int") + } + }) + + t.Run("fails if username is negative int", func(t *testing.T) { + oauthSettings := &portaineree.OAuthSettings{UserIdentifier: "username"} + datamap := map[string]interface{}{"username": -1} + + _, err := getUsername(datamap, oauthSettings) + if err == nil { + t.Errorf("getUsername should fail if username from oauth userinfo object is -1 (negative) int") + } + }) + + t.Run("succeeds if username is matched and is not empty", func(t *testing.T) { + oauthSettings := &portaineree.OAuthSettings{UserIdentifier: "username"} + datamap := map[string]interface{}{"username": "john"} + + _, err := getUsername(datamap, oauthSettings) + if err != nil { + t.Errorf("getUsername should succeed if username from oauth userinfo object matched and non-empty") + } + }) + + // looks like a bug!? + t.Run("fails if username is matched and is positive int", func(t *testing.T) { + oauthSettings := &portaineree.OAuthSettings{UserIdentifier: "username"} + datamap := map[string]interface{}{"username": 1} + + _, err := getUsername(datamap, oauthSettings) + if err == nil { + t.Errorf("getUsername should fail if username from oauth userinfo object matched is positive int") + } + }) + + t.Run("succeeds if username is matched and is non-zero (or negative) float", func(t *testing.T) { + oauthSettings := &portaineree.OAuthSettings{UserIdentifier: "username"} + datamap := map[string]interface{}{"username": 1.1} + + _, err := getUsername(datamap, oauthSettings) + if err != nil { + t.Errorf("getUsername should succeed if username from oauth userinfo object matched and non-zero (or negative)") + } + }) +} diff --git a/api/oauth/oauth_test.go b/api/oauth/oauth_test.go new file mode 100644 index 000000000..0fb10e587 --- /dev/null +++ b/api/oauth/oauth_test.go @@ -0,0 +1,145 @@ +package oauth + +import ( + "testing" + + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/oauth/oauthtest" + "github.com/stretchr/testify/assert" + "golang.org/x/oauth2" +) + +func Test_getOAuthToken(t *testing.T) { + validCode := "valid-code" + srv, config := oauthtest.RunOAuthServer(validCode, &portainer.OAuthSettings{}) + defer srv.Close() + + t.Run("getOAuthToken fails upon invalid code", func(t *testing.T) { + code := "" + _, err := getOAuthToken(code, config) + if err == nil { + t.Errorf("getOAuthToken should fail upon providing invalid code; code=%v", code) + } + }) + + t.Run("getOAuthToken succeeds upon providing valid code", func(t *testing.T) { + code := validCode + token, err := getOAuthToken(code, config) + + if token == nil || err != nil { + t.Errorf("getOAuthToken should successfully return access token upon providing valid code") + } + }) +} + +func Test_getIdToken(t *testing.T) { + verifiedToken := `eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJPbmxpbmUgSldUIEJ1aWxkZXIiLCJpYXQiOjE2NTM1NDA3MjksImV4cCI6MTY4NTA3NjcyOSwiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoiam9obi5kb2VAZXhhbXBsZS5jb20iLCJHaXZlbk5hbWUiOiJKb2huIiwiU3VybmFtZSI6IkRvZSIsIkdyb3VwcyI6WyJGaXJzdCIsIlNlY29uZCJdfQ.GeU8XCV4Y4p5Vm-i63Aj7UP5zpb_0Zxb7-DjM2_z-s8` + nonVerifiedToken := `eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJPbmxpbmUgSldUIEJ1aWxkZXIiLCJpYXQiOjE2NTM1NDA3MjksImV4cCI6MTY4NTA3NjcyOSwiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoiam9obi5kb2VAZXhhbXBsZS5jb20iLCJHaXZlbk5hbWUiOiJKb2huIiwiU3VybmFtZSI6IkRvZSIsIkdyb3VwcyI6WyJGaXJzdCIsIlNlY29uZCJdfQ.` + claims := map[string]interface{}{ + "iss": "Online JWT Builder", + "iat": float64(1653540729), + "exp": float64(1685076729), + "aud": "www.example.com", + "sub": "john.doe@example.com", + "GivenName": "John", + "Surname": "Doe", + "Groups": []interface{}{"First", "Second"}, + } + + tests := []struct { + testName string + idToken string + expectedResult map[string]interface{} + expectedError error + }{ + { + testName: "should return claims if token exists and is verified", + idToken: verifiedToken, + expectedResult: claims, + expectedError: nil, + }, + { + testName: "should return claims if token exists but is not verified", + idToken: nonVerifiedToken, + expectedResult: claims, + expectedError: nil, + }, + { + testName: "should return empty map if token does not exist", + idToken: "", + expectedResult: make(map[string]interface{}), + expectedError: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.testName, func(t *testing.T) { + token := &oauth2.Token{} + if tc.idToken != "" { + token = token.WithExtra(map[string]interface{}{"id_token": tc.idToken}) + } + + result, err := getIdToken(token) + assert.Equal(t, err, tc.expectedError) + assert.Equal(t, result, tc.expectedResult) + }) + } +} + +func Test_getResource(t *testing.T) { + srv, config := oauthtest.RunOAuthServer("", &portainer.OAuthSettings{}) + defer srv.Close() + + t.Run("should fail upon missing Authorization Bearer header", func(t *testing.T) { + _, err := getResource("", config) + if err == nil { + t.Errorf("getResource should fail if access token is not provided in auth bearer header") + } + }) + + t.Run("should fail upon providing incorrect Authorization Bearer header", func(t *testing.T) { + _, err := getResource("incorrect-token", config) + if err == nil { + t.Errorf("getResource should fail if incorrect access token provided in auth bearer header") + } + }) + + t.Run("should succeed upon providing correct Authorization Bearer header", func(t *testing.T) { + _, err := getResource(oauthtest.AccessToken, config) + if err != nil { + t.Errorf("getResource should succeed if correct access token provided in auth bearer header") + } + }) +} + +func Test_Authenticate(t *testing.T) { + code := "valid-code" + authService := NewService() + + t.Run("should fail if user identifier does not get matched in resource", func(t *testing.T) { + srv, config := oauthtest.RunOAuthServer(code, &portainer.OAuthSettings{}) + defer srv.Close() + + _, err := authService.Authenticate(code, config) + if err == nil { + t.Error("Authenticate should fail to extract username from resource if incorrect UserIdentifier provided") + } + }) + + t.Run("should succeed if user identifier does get matched in resource", func(t *testing.T) { + config := &portainer.OAuthSettings{UserIdentifier: "username"} + srv, config := oauthtest.RunOAuthServer(code, config) + defer srv.Close() + + username, err := authService.Authenticate(code, config) + if err != nil { + t.Errorf("Authenticate should succeed to extract username from resource if correct UserIdentifier provided; UserIdentifier=%s", config.UserIdentifier) + } + + want := "test-oauth-user" + if username != want { + t.Errorf("Authenticate should return correct username; got=%s, want=%s", username, want) + } + }) + +} diff --git a/api/oauth/oauthtest/oauth_server.go b/api/oauth/oauthtest/oauth_server.go new file mode 100644 index 000000000..62ae0823e --- /dev/null +++ b/api/oauth/oauthtest/oauth_server.go @@ -0,0 +1,96 @@ +package oauthtest + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + + "github.com/gorilla/mux" + portainer "github.com/portainer/portainer/api" +) + +const ( + AccessToken = "test-token" +) + +// OAuthRoutes is an OAuth 2.0 compliant handler +func OAuthRoutes(code string, config *portainer.OAuthSettings) http.Handler { + router := mux.NewRouter() + + router.HandleFunc( + "/authorize", + func(w http.ResponseWriter, req *http.Request) { + location := fmt.Sprintf("%s?code=%s&state=%s", config.RedirectURI, code, "anything") + // w.Header().Set("Location", location) + // w.WriteHeader(http.StatusFound) + http.Redirect(w, req, location, http.StatusFound) + }, + ).Methods(http.MethodGet) + + router.HandleFunc( + "/access_token", + func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + + if err := req.ParseForm(); err != nil { + fmt.Fprintf(w, "ParseForm() err: %v", err) + return + } + + reqCode := req.FormValue("code") + if reqCode != code { + w.WriteHeader(http.StatusUnauthorized) + return + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "token_type": "Bearer", + "expires_in": 86400, + "access_token": AccessToken, + "scope": "groups", + }) + }, + ).Methods(http.MethodPost) + + router.HandleFunc( + "/user", + func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + + authHeader := req.Header.Get("Authorization") + splitToken := strings.Split(authHeader, "Bearer ") + if len(splitToken) < 2 || splitToken[1] != AccessToken { + w.WriteHeader(http.StatusUnauthorized) + return + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "username": "test-oauth-user", + "groups": "testing", + }) + }, + ).Methods(http.MethodGet) + + return router +} + +// RunOAuthServer is a barebones OAuth 2.0 compliant test server which can be used to test OAuth 2 functionality +func RunOAuthServer(code string, config *portainer.OAuthSettings) (*httptest.Server, *portainer.OAuthSettings) { + srv := httptest.NewUnstartedServer(http.DefaultServeMux) + + addr := srv.Listener.Addr() + + config.AuthorizationURI = fmt.Sprintf("http://%s/authorize", addr) + config.AccessTokenURI = fmt.Sprintf("http://%s/access_token", addr) + config.ResourceURI = fmt.Sprintf("http://%s/user", addr) + config.RedirectURI = fmt.Sprintf("http://%s/", addr) + + srv.Config.Handler = OAuthRoutes(code, config) + srv.Start() + + return srv, config +} diff --git a/app/portainer/oauth/components/oauth-settings/oauth-settings.controller.js b/app/portainer/oauth/components/oauth-settings/oauth-settings.controller.js index 875737f5e..19ac94a50 100644 --- a/app/portainer/oauth/components/oauth-settings/oauth-settings.controller.js +++ b/app/portainer/oauth/components/oauth-settings/oauth-settings.controller.js @@ -32,10 +32,10 @@ export default class OAuthSettingsController { onMicrosoftTenantIDChange() { const tenantID = this.state.microsoftTenantID || MS_TENANT_ID_PLACEHOLDER; - this.settings.AuthorizationURI = `https://login.microsoftonline.com/${tenantID}/oauth2/authorize`; - this.settings.AccessTokenURI = `https://login.microsoftonline.com/${tenantID}/oauth2/token`; - this.settings.ResourceURI = `https://graph.windows.net/${tenantID}/me?api-version=2013-11-08`; - this.settings.LogoutURI = `https://login.microsoftonline.com/${tenantID}/oauth2/logout`; + this.settings.AuthorizationURI = `https://login.microsoftonline.com/${tenantID}/oauth2/v2.0/authorize`; + this.settings.AccessTokenURI = `https://login.microsoftonline.com/${tenantID}/oauth2/v2.0/token`; + this.settings.ResourceURI = `https://graph.microsoft.com/v1.0/me`; + this.settings.LogoutURI = `https://login.microsoftonline.com/${tenantID}/oauth2/v2.0/logout`; } useDefaultProviderConfiguration(providerId) { diff --git a/app/portainer/oauth/components/oauth-settings/providers.js b/app/portainer/oauth/components/oauth-settings/providers.js index ef868758b..ff6ec4491 100644 --- a/app/portainer/oauth/components/oauth-settings/providers.js +++ b/app/portainer/oauth/components/oauth-settings/providers.js @@ -2,12 +2,12 @@ import { baseHref } from '@/portainer/helpers/pathHelper'; export default { microsoft: { - authUrl: 'https://login.microsoftonline.com/TENANT_ID/oauth2/authorize', - accessTokenUrl: 'https://login.microsoftonline.com/TENANT_ID/oauth2/token', - resourceUrl: 'https://graph.windows.net/TENANT_ID/me?api-version=2013-11-08', - logoutUrl: `https://login.microsoftonline.com/TENANT_ID/oauth2/logout`, + authUrl: 'https://login.microsoftonline.com/TENANT_ID/oauth2/v2.0/authorize', + accessTokenUrl: 'https://login.microsoftonline.com/TENANT_ID/oauth2/v2.0/token', + resourceUrl: 'https://graph.microsoft.com/v1.0/me', + logoutUrl: `https://login.microsoftonline.com/TENANT_ID/oauth2/v2.0/logout`, userIdentifier: 'userPrincipalName', - scopes: 'id,email,name', + scopes: 'profile openid', }, google: { authUrl: 'https://accounts.google.com/o/oauth2/auth',