refactor(ouath): use oauth2 library to get token

pull/2749/head
Chaim Lev Ari 2019-01-18 10:56:16 +02:00
parent 60040e90d0
commit 46e8f10aea
3 changed files with 18 additions and 93 deletions

View File

@ -49,17 +49,17 @@ func (handler *Handler) validateOAuth(w http.ResponseWriter, r *http.Request) *h
return &httperror.HandlerError{http.StatusForbidden, "Unable to acquire username", portainer.ErrUnauthorized} return &httperror.HandlerError{http.StatusForbidden, "Unable to acquire username", portainer.ErrUnauthorized}
} }
u, err := handler.UserService.UserByUsername(username) user, err := handler.UserService.UserByUsername(username)
if err != nil && err != portainer.ErrObjectNotFound { if err != nil && err != portainer.ErrObjectNotFound {
return &httperror.HandlerError{http.StatusInternalServerError, "Unable to retrieve a user with the specified username from the database", err} return &httperror.HandlerError{http.StatusInternalServerError, "Unable to retrieve a user with the specified username from the database", err}
} }
if u == nil && !settings.OAuthSettings.OAuthAutoCreateUsers { if user == nil && !settings.OAuthSettings.OAuthAutoCreateUsers {
return &httperror.HandlerError{http.StatusForbidden, "Unregistered account", portainer.ErrUnauthorized} return &httperror.HandlerError{http.StatusForbidden, "Unregistered account", portainer.ErrUnauthorized}
} }
if u == nil { if user == nil {
user := &portainer.User{ user = &portainer.User{
Username: username, Username: username,
Role: portainer.StandardUserRole, Role: portainer.StandardUserRole,
} }
@ -69,10 +69,9 @@ func (handler *Handler) validateOAuth(w http.ResponseWriter, r *http.Request) *h
return &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist user inside the database", err} return &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist user inside the database", err}
} }
return handler.writeToken(w, user)
} }
return handler.writeToken(w, u) return handler.writeToken(w, user)
} }
func (handler *Handler) loginOAuth(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { func (handler *Handler) loginOAuth(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
@ -85,7 +84,7 @@ func (handler *Handler) loginOAuth(w http.ResponseWriter, r *http.Request) *http
return &httperror.HandlerError{http.StatusForbidden, "OAuth authentication is disabled", err} return &httperror.HandlerError{http.StatusForbidden, "OAuth authentication is disabled", err}
} }
url := handler.OAuthService.BuildLoginURL(settings.OAuthSettings) url := handler.OAuthService.BuildLoginURL(&settings.OAuthSettings)
http.Redirect(w, r, url, http.StatusTemporaryRedirect) http.Redirect(w, r, url, http.StatusTemporaryRedirect)
return nil return nil
} }

View File

@ -1,12 +1,10 @@
package oauth package oauth
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"log"
"mime" "mime"
"net/http" "net/http"
"net/url" "net/url"
@ -26,84 +24,9 @@ type Service struct{}
// GetAccessToken takes an access code and exchanges it for an access token from portainer OAuthSettings token endpoint // GetAccessToken takes an access code and exchanges it for an access token from portainer OAuthSettings token endpoint
func (*Service) GetAccessToken(code string, settings *portainer.OAuthSettings) (string, error) { func (*Service) GetAccessToken(code string, settings *portainer.OAuthSettings) (string, error) {
v := url.Values{} config := buildConfig(settings)
v.Set("client_id", settings.ClientID) token, err := config.Exchange(context.Background(), code)
v.Set("client_secret", settings.ClientSecret) return token.AccessToken, err
v.Set("redirect_uri", settings.RedirectURI)
v.Set("code", code)
v.Set("grant_type", "authorization_code")
req, err := http.NewRequest("POST", settings.AccessTokenURI, strings.NewReader(v.Encode()))
if err != nil {
return "", err
}
client := &http.Client{}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r, err := client.Do(req)
if err != nil {
return "", err
}
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
return "", fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
if r.StatusCode != http.StatusOK {
type ErrorMessage struct {
Message string
Type string
Code int
}
type ErrorResponse struct {
Error ErrorMessage
}
var response ErrorResponse
if err = json.Unmarshal(body, &response); err != nil {
// report also error
log.Printf("[Error] - Failed parsing error body: %v", err)
return "", errors.New("oauth2: cannot fetch token")
}
return "", errors.New(response.Error.Message)
}
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
if content == "application/x-www-form-urlencoded" || content == "text/plain" {
values, err := url.ParseQuery(string(body))
if err != nil {
return "", err
}
token := values.Get("access_token")
log.Printf("[DEBUG] - returned body %v", values)
if token == "" {
log.Printf("[DEBUG] - access token returned empty - %v", values)
return "", errors.New("oauth2: cannot fetch token")
}
return token, nil
}
type tokenJSON struct {
AccessToken string `json:"access_token"`
}
var tj tokenJSON
if err = json.Unmarshal(body, &tj); err != nil {
return "", err
}
token := tj.AccessToken
if token == "" {
log.Printf("[DEBUG] - access token returned empty - %v with status code", string(body), r.StatusCode)
return "", errors.New("oauth2: cannot fetch token")
}
return token, nil
} }
// GetUsername takes a token and retrieves the portainer OAuthSettings user identifier from resource server. // GetUsername takes a token and retrieves the portainer OAuthSettings user identifier from resource server.
@ -167,19 +90,22 @@ func (*Service) GetUsername(token string, settings *portainer.OAuthSettings) (st
} }
// BuildLoginURL creates a login url for the oauth provider // BuildLoginURL creates a login url for the oauth provider
func (*Service) BuildLoginURL(oauthSettings portainer.OAuthSettings) string { func (*Service) BuildLoginURL(oauthSettings *portainer.OAuthSettings) string {
oauthConfig := buildConfig(oauthSettings)
return oauthConfig.AuthCodeURL("portainer")
}
func buildConfig(oauthSettings *portainer.OAuthSettings) *oauth2.Config {
endpoint := oauth2.Endpoint{ endpoint := oauth2.Endpoint{
AuthURL: oauthSettings.AuthorizationURI, AuthURL: oauthSettings.AuthorizationURI,
TokenURL: oauthSettings.AccessTokenURI, TokenURL: oauthSettings.AccessTokenURI,
} }
oauthConfig := &oauth2.Config{ return &oauth2.Config{
ClientID: oauthSettings.ClientID, ClientID: oauthSettings.ClientID,
ClientSecret: oauthSettings.ClientSecret, ClientSecret: oauthSettings.ClientSecret,
Endpoint: endpoint, Endpoint: endpoint,
RedirectURL: oauthSettings.RedirectURI, RedirectURL: oauthSettings.RedirectURI,
Scopes: strings.Split(oauthSettings.Scopes, ","), Scopes: strings.Split(oauthSettings.Scopes, ","),
} }
return oauthConfig.AuthCodeURL("portainer")
} }

View File

@ -766,7 +766,7 @@ type (
OAuthService interface { OAuthService interface {
GetAccessToken(code string, settings *OAuthSettings) (string, error) GetAccessToken(code string, settings *OAuthSettings) (string, error)
GetUsername(token string, settings *OAuthSettings) (string, error) GetUsername(token string, settings *OAuthSettings) (string, error)
BuildLoginURL(oauthSettings OAuthSettings) string BuildLoginURL(oauthSettings *OAuthSettings) string
} }
// SwarmStackManager represents a service to manage Swarm stacks // SwarmStackManager represents a service to manage Swarm stacks