mirror of https://github.com/portainer/portainer
refactor(ouath): use oauth2 library to get token
parent
60040e90d0
commit
46e8f10aea
|
@ -49,17 +49,17 @@ func (handler *Handler) validateOAuth(w http.ResponseWriter, r *http.Request) *h
|
||||||
return &httperror.HandlerError{http.StatusForbidden, "Unable to acquire username", portainer.ErrUnauthorized}
|
return &httperror.HandlerError{http.StatusForbidden, "Unable to acquire username", portainer.ErrUnauthorized}
|
||||||
}
|
}
|
||||||
|
|
||||||
u, err := handler.UserService.UserByUsername(username)
|
user, err := handler.UserService.UserByUsername(username)
|
||||||
if err != nil && err != portainer.ErrObjectNotFound {
|
if err != nil && err != portainer.ErrObjectNotFound {
|
||||||
return &httperror.HandlerError{http.StatusInternalServerError, "Unable to retrieve a user with the specified username from the database", err}
|
return &httperror.HandlerError{http.StatusInternalServerError, "Unable to retrieve a user with the specified username from the database", err}
|
||||||
}
|
}
|
||||||
|
|
||||||
if u == nil && !settings.OAuthSettings.OAuthAutoCreateUsers {
|
if user == nil && !settings.OAuthSettings.OAuthAutoCreateUsers {
|
||||||
return &httperror.HandlerError{http.StatusForbidden, "Unregistered account", portainer.ErrUnauthorized}
|
return &httperror.HandlerError{http.StatusForbidden, "Unregistered account", portainer.ErrUnauthorized}
|
||||||
}
|
}
|
||||||
|
|
||||||
if u == nil {
|
if user == nil {
|
||||||
user := &portainer.User{
|
user = &portainer.User{
|
||||||
Username: username,
|
Username: username,
|
||||||
Role: portainer.StandardUserRole,
|
Role: portainer.StandardUserRole,
|
||||||
}
|
}
|
||||||
|
@ -69,10 +69,9 @@ func (handler *Handler) validateOAuth(w http.ResponseWriter, r *http.Request) *h
|
||||||
return &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist user inside the database", err}
|
return &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist user inside the database", err}
|
||||||
}
|
}
|
||||||
|
|
||||||
return handler.writeToken(w, user)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return handler.writeToken(w, u)
|
return handler.writeToken(w, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler *Handler) loginOAuth(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
|
func (handler *Handler) loginOAuth(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
|
||||||
|
@ -85,7 +84,7 @@ func (handler *Handler) loginOAuth(w http.ResponseWriter, r *http.Request) *http
|
||||||
return &httperror.HandlerError{http.StatusForbidden, "OAuth authentication is disabled", err}
|
return &httperror.HandlerError{http.StatusForbidden, "OAuth authentication is disabled", err}
|
||||||
}
|
}
|
||||||
|
|
||||||
url := handler.OAuthService.BuildLoginURL(settings.OAuthSettings)
|
url := handler.OAuthService.BuildLoginURL(&settings.OAuthSettings)
|
||||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,10 @@
|
||||||
package oauth
|
package oauth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
|
||||||
"mime"
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -26,84 +24,9 @@ type Service struct{}
|
||||||
|
|
||||||
// GetAccessToken takes an access code and exchanges it for an access token from portainer OAuthSettings token endpoint
|
// GetAccessToken takes an access code and exchanges it for an access token from portainer OAuthSettings token endpoint
|
||||||
func (*Service) GetAccessToken(code string, settings *portainer.OAuthSettings) (string, error) {
|
func (*Service) GetAccessToken(code string, settings *portainer.OAuthSettings) (string, error) {
|
||||||
v := url.Values{}
|
config := buildConfig(settings)
|
||||||
v.Set("client_id", settings.ClientID)
|
token, err := config.Exchange(context.Background(), code)
|
||||||
v.Set("client_secret", settings.ClientSecret)
|
return token.AccessToken, err
|
||||||
v.Set("redirect_uri", settings.RedirectURI)
|
|
||||||
v.Set("code", code)
|
|
||||||
v.Set("grant_type", "authorization_code")
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", settings.AccessTokenURI, strings.NewReader(v.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &http.Client{}
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
r, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.StatusCode != http.StatusOK {
|
|
||||||
type ErrorMessage struct {
|
|
||||||
Message string
|
|
||||||
Type string
|
|
||||||
Code int
|
|
||||||
}
|
|
||||||
type ErrorResponse struct {
|
|
||||||
Error ErrorMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
var response ErrorResponse
|
|
||||||
if err = json.Unmarshal(body, &response); err != nil {
|
|
||||||
// report also error
|
|
||||||
log.Printf("[Error] - Failed parsing error body: %v", err)
|
|
||||||
return "", errors.New("oauth2: cannot fetch token")
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", errors.New(response.Error.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
|
||||||
if content == "application/x-www-form-urlencoded" || content == "text/plain" {
|
|
||||||
values, err := url.ParseQuery(string(body))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
token := values.Get("access_token")
|
|
||||||
log.Printf("[DEBUG] - returned body %v", values)
|
|
||||||
|
|
||||||
if token == "" {
|
|
||||||
log.Printf("[DEBUG] - access token returned empty - %v", values)
|
|
||||||
return "", errors.New("oauth2: cannot fetch token")
|
|
||||||
}
|
|
||||||
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type tokenJSON struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var tj tokenJSON
|
|
||||||
if err = json.Unmarshal(body, &tj); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
token := tj.AccessToken
|
|
||||||
|
|
||||||
if token == "" {
|
|
||||||
log.Printf("[DEBUG] - access token returned empty - %v with status code", string(body), r.StatusCode)
|
|
||||||
return "", errors.New("oauth2: cannot fetch token")
|
|
||||||
}
|
|
||||||
return token, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUsername takes a token and retrieves the portainer OAuthSettings user identifier from resource server.
|
// GetUsername takes a token and retrieves the portainer OAuthSettings user identifier from resource server.
|
||||||
|
@ -167,19 +90,22 @@ func (*Service) GetUsername(token string, settings *portainer.OAuthSettings) (st
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildLoginURL creates a login url for the oauth provider
|
// BuildLoginURL creates a login url for the oauth provider
|
||||||
func (*Service) BuildLoginURL(oauthSettings portainer.OAuthSettings) string {
|
func (*Service) BuildLoginURL(oauthSettings *portainer.OAuthSettings) string {
|
||||||
|
oauthConfig := buildConfig(oauthSettings)
|
||||||
|
return oauthConfig.AuthCodeURL("portainer")
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildConfig(oauthSettings *portainer.OAuthSettings) *oauth2.Config {
|
||||||
endpoint := oauth2.Endpoint{
|
endpoint := oauth2.Endpoint{
|
||||||
AuthURL: oauthSettings.AuthorizationURI,
|
AuthURL: oauthSettings.AuthorizationURI,
|
||||||
TokenURL: oauthSettings.AccessTokenURI,
|
TokenURL: oauthSettings.AccessTokenURI,
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthConfig := &oauth2.Config{
|
return &oauth2.Config{
|
||||||
ClientID: oauthSettings.ClientID,
|
ClientID: oauthSettings.ClientID,
|
||||||
ClientSecret: oauthSettings.ClientSecret,
|
ClientSecret: oauthSettings.ClientSecret,
|
||||||
Endpoint: endpoint,
|
Endpoint: endpoint,
|
||||||
RedirectURL: oauthSettings.RedirectURI,
|
RedirectURL: oauthSettings.RedirectURI,
|
||||||
Scopes: strings.Split(oauthSettings.Scopes, ","),
|
Scopes: strings.Split(oauthSettings.Scopes, ","),
|
||||||
}
|
}
|
||||||
|
|
||||||
return oauthConfig.AuthCodeURL("portainer")
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -766,7 +766,7 @@ type (
|
||||||
OAuthService interface {
|
OAuthService interface {
|
||||||
GetAccessToken(code string, settings *OAuthSettings) (string, error)
|
GetAccessToken(code string, settings *OAuthSettings) (string, error)
|
||||||
GetUsername(token string, settings *OAuthSettings) (string, error)
|
GetUsername(token string, settings *OAuthSettings) (string, error)
|
||||||
BuildLoginURL(oauthSettings OAuthSettings) string
|
BuildLoginURL(oauthSettings *OAuthSettings) string
|
||||||
}
|
}
|
||||||
|
|
||||||
// SwarmStackManager represents a service to manage Swarm stacks
|
// SwarmStackManager represents a service to manage Swarm stacks
|
||||||
|
|
Loading…
Reference in New Issue