mirror of https://github.com/portainer/portainer
fix(oauth): analyze id_token for Azure [EE-2984] (#7000)
parent
0cd2a4558b
commit
fd4b515350
@ -0,0 +1,24 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
)
|
||||
|
||||
func getUsername(datamap map[string]interface{}, configuration *portainer.OAuthSettings) (string, error) {
|
||||
username, ok := datamap[configuration.UserIdentifier].(string)
|
||||
if ok && username != "" {
|
||||
return username, nil
|
||||
}
|
||||
|
||||
if !ok {
|
||||
username, ok := datamap[configuration.UserIdentifier].(float64)
|
||||
if ok && username != 0 {
|
||||
return fmt.Sprint(int(username)), nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("failed to extract username from oauth resource")
|
||||
}
|
@ -0,0 +1,80 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portaineree "github.com/portainer/portainer/api"
|
||||
)
|
||||
|
||||
func Test_getUsername(t *testing.T) {
|
||||
t.Run("fails for non-matching user identifier", func(t *testing.T) {
|
||||
oauthSettings := &portaineree.OAuthSettings{UserIdentifier: "username"}
|
||||
datamap := map[string]interface{}{"name": "john"}
|
||||
|
||||
_, err := getUsername(datamap, oauthSettings)
|
||||
if err == nil {
|
||||
t.Errorf("getUsername should fail if user identifier doesn't exist as key in oauth userinfo object")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fails if username is empty string", func(t *testing.T) {
|
||||
oauthSettings := &portaineree.OAuthSettings{UserIdentifier: "username"}
|
||||
datamap := map[string]interface{}{"username": ""}
|
||||
|
||||
_, err := getUsername(datamap, oauthSettings)
|
||||
if err == nil {
|
||||
t.Errorf("getUsername should fail if username from oauth userinfo object is empty string")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fails if username is 0 int", func(t *testing.T) {
|
||||
oauthSettings := &portaineree.OAuthSettings{UserIdentifier: "username"}
|
||||
datamap := map[string]interface{}{"username": 0}
|
||||
|
||||
_, err := getUsername(datamap, oauthSettings)
|
||||
if err == nil {
|
||||
t.Errorf("getUsername should fail if username from oauth userinfo object is 0 val int")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fails if username is negative int", func(t *testing.T) {
|
||||
oauthSettings := &portaineree.OAuthSettings{UserIdentifier: "username"}
|
||||
datamap := map[string]interface{}{"username": -1}
|
||||
|
||||
_, err := getUsername(datamap, oauthSettings)
|
||||
if err == nil {
|
||||
t.Errorf("getUsername should fail if username from oauth userinfo object is -1 (negative) int")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("succeeds if username is matched and is not empty", func(t *testing.T) {
|
||||
oauthSettings := &portaineree.OAuthSettings{UserIdentifier: "username"}
|
||||
datamap := map[string]interface{}{"username": "john"}
|
||||
|
||||
_, err := getUsername(datamap, oauthSettings)
|
||||
if err != nil {
|
||||
t.Errorf("getUsername should succeed if username from oauth userinfo object matched and non-empty")
|
||||
}
|
||||
})
|
||||
|
||||
// looks like a bug!?
|
||||
t.Run("fails if username is matched and is positive int", func(t *testing.T) {
|
||||
oauthSettings := &portaineree.OAuthSettings{UserIdentifier: "username"}
|
||||
datamap := map[string]interface{}{"username": 1}
|
||||
|
||||
_, err := getUsername(datamap, oauthSettings)
|
||||
if err == nil {
|
||||
t.Errorf("getUsername should fail if username from oauth userinfo object matched is positive int")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("succeeds if username is matched and is non-zero (or negative) float", func(t *testing.T) {
|
||||
oauthSettings := &portaineree.OAuthSettings{UserIdentifier: "username"}
|
||||
datamap := map[string]interface{}{"username": 1.1}
|
||||
|
||||
_, err := getUsername(datamap, oauthSettings)
|
||||
if err != nil {
|
||||
t.Errorf("getUsername should succeed if username from oauth userinfo object matched and non-zero (or negative)")
|
||||
}
|
||||
})
|
||||
}
|
@ -0,0 +1,145 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/oauth/oauthtest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func Test_getOAuthToken(t *testing.T) {
|
||||
validCode := "valid-code"
|
||||
srv, config := oauthtest.RunOAuthServer(validCode, &portainer.OAuthSettings{})
|
||||
defer srv.Close()
|
||||
|
||||
t.Run("getOAuthToken fails upon invalid code", func(t *testing.T) {
|
||||
code := ""
|
||||
_, err := getOAuthToken(code, config)
|
||||
if err == nil {
|
||||
t.Errorf("getOAuthToken should fail upon providing invalid code; code=%v", code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("getOAuthToken succeeds upon providing valid code", func(t *testing.T) {
|
||||
code := validCode
|
||||
token, err := getOAuthToken(code, config)
|
||||
|
||||
if token == nil || err != nil {
|
||||
t.Errorf("getOAuthToken should successfully return access token upon providing valid code")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_getIdToken(t *testing.T) {
|
||||
verifiedToken := `eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJPbmxpbmUgSldUIEJ1aWxkZXIiLCJpYXQiOjE2NTM1NDA3MjksImV4cCI6MTY4NTA3NjcyOSwiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoiam9obi5kb2VAZXhhbXBsZS5jb20iLCJHaXZlbk5hbWUiOiJKb2huIiwiU3VybmFtZSI6IkRvZSIsIkdyb3VwcyI6WyJGaXJzdCIsIlNlY29uZCJdfQ.GeU8XCV4Y4p5Vm-i63Aj7UP5zpb_0Zxb7-DjM2_z-s8`
|
||||
nonVerifiedToken := `eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJPbmxpbmUgSldUIEJ1aWxkZXIiLCJpYXQiOjE2NTM1NDA3MjksImV4cCI6MTY4NTA3NjcyOSwiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoiam9obi5kb2VAZXhhbXBsZS5jb20iLCJHaXZlbk5hbWUiOiJKb2huIiwiU3VybmFtZSI6IkRvZSIsIkdyb3VwcyI6WyJGaXJzdCIsIlNlY29uZCJdfQ.`
|
||||
claims := map[string]interface{}{
|
||||
"iss": "Online JWT Builder",
|
||||
"iat": float64(1653540729),
|
||||
"exp": float64(1685076729),
|
||||
"aud": "www.example.com",
|
||||
"sub": "john.doe@example.com",
|
||||
"GivenName": "John",
|
||||
"Surname": "Doe",
|
||||
"Groups": []interface{}{"First", "Second"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
testName string
|
||||
idToken string
|
||||
expectedResult map[string]interface{}
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
testName: "should return claims if token exists and is verified",
|
||||
idToken: verifiedToken,
|
||||
expectedResult: claims,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
testName: "should return claims if token exists but is not verified",
|
||||
idToken: nonVerifiedToken,
|
||||
expectedResult: claims,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
testName: "should return empty map if token does not exist",
|
||||
idToken: "",
|
||||
expectedResult: make(map[string]interface{}),
|
||||
expectedError: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.testName, func(t *testing.T) {
|
||||
token := &oauth2.Token{}
|
||||
if tc.idToken != "" {
|
||||
token = token.WithExtra(map[string]interface{}{"id_token": tc.idToken})
|
||||
}
|
||||
|
||||
result, err := getIdToken(token)
|
||||
assert.Equal(t, err, tc.expectedError)
|
||||
assert.Equal(t, result, tc.expectedResult)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getResource(t *testing.T) {
|
||||
srv, config := oauthtest.RunOAuthServer("", &portainer.OAuthSettings{})
|
||||
defer srv.Close()
|
||||
|
||||
t.Run("should fail upon missing Authorization Bearer header", func(t *testing.T) {
|
||||
_, err := getResource("", config)
|
||||
if err == nil {
|
||||
t.Errorf("getResource should fail if access token is not provided in auth bearer header")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should fail upon providing incorrect Authorization Bearer header", func(t *testing.T) {
|
||||
_, err := getResource("incorrect-token", config)
|
||||
if err == nil {
|
||||
t.Errorf("getResource should fail if incorrect access token provided in auth bearer header")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should succeed upon providing correct Authorization Bearer header", func(t *testing.T) {
|
||||
_, err := getResource(oauthtest.AccessToken, config)
|
||||
if err != nil {
|
||||
t.Errorf("getResource should succeed if correct access token provided in auth bearer header")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Authenticate(t *testing.T) {
|
||||
code := "valid-code"
|
||||
authService := NewService()
|
||||
|
||||
t.Run("should fail if user identifier does not get matched in resource", func(t *testing.T) {
|
||||
srv, config := oauthtest.RunOAuthServer(code, &portainer.OAuthSettings{})
|
||||
defer srv.Close()
|
||||
|
||||
_, err := authService.Authenticate(code, config)
|
||||
if err == nil {
|
||||
t.Error("Authenticate should fail to extract username from resource if incorrect UserIdentifier provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should succeed if user identifier does get matched in resource", func(t *testing.T) {
|
||||
config := &portainer.OAuthSettings{UserIdentifier: "username"}
|
||||
srv, config := oauthtest.RunOAuthServer(code, config)
|
||||
defer srv.Close()
|
||||
|
||||
username, err := authService.Authenticate(code, config)
|
||||
if err != nil {
|
||||
t.Errorf("Authenticate should succeed to extract username from resource if correct UserIdentifier provided; UserIdentifier=%s", config.UserIdentifier)
|
||||
}
|
||||
|
||||
want := "test-oauth-user"
|
||||
if username != want {
|
||||
t.Errorf("Authenticate should return correct username; got=%s, want=%s", username, want)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
@ -0,0 +1,96 @@
|
||||
package oauthtest
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
)
|
||||
|
||||
const (
|
||||
AccessToken = "test-token"
|
||||
)
|
||||
|
||||
// OAuthRoutes is an OAuth 2.0 compliant handler
|
||||
func OAuthRoutes(code string, config *portainer.OAuthSettings) http.Handler {
|
||||
router := mux.NewRouter()
|
||||
|
||||
router.HandleFunc(
|
||||
"/authorize",
|
||||
func(w http.ResponseWriter, req *http.Request) {
|
||||
location := fmt.Sprintf("%s?code=%s&state=%s", config.RedirectURI, code, "anything")
|
||||
// w.Header().Set("Location", location)
|
||||
// w.WriteHeader(http.StatusFound)
|
||||
http.Redirect(w, req, location, http.StatusFound)
|
||||
},
|
||||
).Methods(http.MethodGet)
|
||||
|
||||
router.HandleFunc(
|
||||
"/access_token",
|
||||
func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := req.ParseForm(); err != nil {
|
||||
fmt.Fprintf(w, "ParseForm() err: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
reqCode := req.FormValue("code")
|
||||
if reqCode != code {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
"access_token": AccessToken,
|
||||
"scope": "groups",
|
||||
})
|
||||
},
|
||||
).Methods(http.MethodPost)
|
||||
|
||||
router.HandleFunc(
|
||||
"/user",
|
||||
func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
authHeader := req.Header.Get("Authorization")
|
||||
splitToken := strings.Split(authHeader, "Bearer ")
|
||||
if len(splitToken) < 2 || splitToken[1] != AccessToken {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"username": "test-oauth-user",
|
||||
"groups": "testing",
|
||||
})
|
||||
},
|
||||
).Methods(http.MethodGet)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
// RunOAuthServer is a barebones OAuth 2.0 compliant test server which can be used to test OAuth 2 functionality
|
||||
func RunOAuthServer(code string, config *portainer.OAuthSettings) (*httptest.Server, *portainer.OAuthSettings) {
|
||||
srv := httptest.NewUnstartedServer(http.DefaultServeMux)
|
||||
|
||||
addr := srv.Listener.Addr()
|
||||
|
||||
config.AuthorizationURI = fmt.Sprintf("http://%s/authorize", addr)
|
||||
config.AccessTokenURI = fmt.Sprintf("http://%s/access_token", addr)
|
||||
config.ResourceURI = fmt.Sprintf("http://%s/user", addr)
|
||||
config.RedirectURI = fmt.Sprintf("http://%s/", addr)
|
||||
|
||||
srv.Config.Handler = OAuthRoutes(code, config)
|
||||
srv.Start()
|
||||
|
||||
return srv, config
|
||||
}
|
Loading…
Reference in new issue