mirror of https://github.com/portainer/portainer
				
				
				
			
		
			
				
	
	
		
			187 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			187 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Go
		
	
	
package oauth
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"io"
 | 
						|
	"mime"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"strings"
 | 
						|
 | 
						|
	portainer "github.com/portainer/portainer/api"
 | 
						|
 | 
						|
	"github.com/golang-jwt/jwt/v4"
 | 
						|
	"github.com/pkg/errors"
 | 
						|
	"github.com/rs/zerolog/log"
 | 
						|
	"github.com/segmentio/encoding/json"
 | 
						|
	"golang.org/x/oauth2"
 | 
						|
)
 | 
						|
 | 
						|
// Service represents a service used to authenticate users against an authorization server
 | 
						|
type Service struct{}
 | 
						|
 | 
						|
// NewService returns a pointer to a new instance of this service
 | 
						|
func NewService() *Service {
 | 
						|
	return &Service{}
 | 
						|
}
 | 
						|
 | 
						|
// Authenticate takes an access code and exchanges it for an access token from portainer OAuthSettings token environment(endpoint).
 | 
						|
// On success, it will then return the username and token expiry time associated to authenticated user by fetching this information
 | 
						|
// from the resource server and matching it with the user identifier setting.
 | 
						|
func (*Service) Authenticate(code string, configuration *portainer.OAuthSettings) (string, error) {
 | 
						|
	token, err := getOAuthToken(code, configuration)
 | 
						|
	if err != nil {
 | 
						|
		log.Debug().Err(err).Msg("failed retrieving oauth token")
 | 
						|
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	idToken, err := getIdToken(token)
 | 
						|
	if err != nil {
 | 
						|
		log.Debug().Err(err).Msg("failed parsing id_token")
 | 
						|
	}
 | 
						|
 | 
						|
	resource, err := getResource(token.AccessToken, configuration)
 | 
						|
	if err != nil {
 | 
						|
		log.Debug().Err(err).Msg("failed retrieving resource")
 | 
						|
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	resource = mergeSecondIntoFirst(idToken, resource)
 | 
						|
 | 
						|
	username, err := getUsername(resource, configuration)
 | 
						|
	if err != nil {
 | 
						|
		log.Debug().Err(err).Msg("failed retrieving username")
 | 
						|
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	return username, nil
 | 
						|
}
 | 
						|
 | 
						|
// mergeSecondIntoFirst merges the overlap map into the base overwriting any existing values.
 | 
						|
func mergeSecondIntoFirst(base map[string]interface{}, overlap map[string]interface{}) map[string]interface{} {
 | 
						|
	for k, v := range overlap {
 | 
						|
		base[k] = v
 | 
						|
	}
 | 
						|
 | 
						|
	return base
 | 
						|
}
 | 
						|
 | 
						|
func getOAuthToken(code string, configuration *portainer.OAuthSettings) (*oauth2.Token, error) {
 | 
						|
	unescapedCode, err := url.QueryUnescape(code)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	config := buildConfig(configuration)
 | 
						|
	token, err := config.Exchange(context.Background(), unescapedCode)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return token, nil
 | 
						|
}
 | 
						|
 | 
						|
// getIdToken retrieves parsed id_token from the OAuth token response.
 | 
						|
// This is necessary for OAuth providers like Azure
 | 
						|
// that do not provide information about user groups on the user resource endpoint.
 | 
						|
func getIdToken(token *oauth2.Token) (map[string]interface{}, error) {
 | 
						|
	tokenData := make(map[string]interface{})
 | 
						|
 | 
						|
	idToken := token.Extra("id_token")
 | 
						|
	if idToken == nil {
 | 
						|
		return tokenData, nil
 | 
						|
	}
 | 
						|
 | 
						|
	jwtParser := jwt.Parser{
 | 
						|
		SkipClaimsValidation: true,
 | 
						|
	}
 | 
						|
 | 
						|
	t, _, err := jwtParser.ParseUnverified(idToken.(string), jwt.MapClaims{})
 | 
						|
	if err != nil {
 | 
						|
		return tokenData, errors.Wrap(err, "failed to parse id_token")
 | 
						|
	}
 | 
						|
 | 
						|
	if claims, ok := t.Claims.(jwt.MapClaims); ok {
 | 
						|
		for k, v := range claims {
 | 
						|
			tokenData[k] = v
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return tokenData, nil
 | 
						|
}
 | 
						|
 | 
						|
func getResource(token string, configuration *portainer.OAuthSettings) (map[string]interface{}, error) {
 | 
						|
	req, err := http.NewRequest("GET", configuration.ResourceURI, nil)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	client := &http.Client{}
 | 
						|
	req.Header.Set("Authorization", "Bearer "+token)
 | 
						|
 | 
						|
	resp, err := client.Do(req)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	defer resp.Body.Close()
 | 
						|
 | 
						|
	body, err := io.ReadAll(resp.Body)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if resp.StatusCode != http.StatusOK {
 | 
						|
		return nil, &oauth2.RetrieveError{
 | 
						|
			Response: resp,
 | 
						|
			Body:     body,
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	content, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if content == "application/x-www-form-urlencoded" || content == "text/plain" {
 | 
						|
		values, err := url.ParseQuery(string(body))
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		datamap := make(map[string]interface{})
 | 
						|
		for k, v := range values {
 | 
						|
			if len(v) == 0 {
 | 
						|
				datamap[k] = ""
 | 
						|
			} else {
 | 
						|
				datamap[k] = v[0]
 | 
						|
			}
 | 
						|
		}
 | 
						|
		return datamap, nil
 | 
						|
	}
 | 
						|
 | 
						|
	var datamap map[string]interface{}
 | 
						|
	if err = json.Unmarshal(body, &datamap); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return datamap, nil
 | 
						|
}
 | 
						|
 | 
						|
func buildConfig(configuration *portainer.OAuthSettings) *oauth2.Config {
 | 
						|
	endpoint := oauth2.Endpoint{
 | 
						|
		AuthURL:  configuration.AuthorizationURI,
 | 
						|
		TokenURL: configuration.AccessTokenURI,
 | 
						|
	}
 | 
						|
 | 
						|
	return &oauth2.Config{
 | 
						|
		ClientID:     configuration.ClientID,
 | 
						|
		ClientSecret: configuration.ClientSecret,
 | 
						|
		Endpoint:     endpoint,
 | 
						|
		RedirectURL:  configuration.RedirectURI,
 | 
						|
		Scopes:       strings.Split(configuration.Scopes, ","),
 | 
						|
	}
 | 
						|
}
 |