mirror of https://github.com/portainer/portainer
fix(users): fix data-race in userCreate() BE-11209 (#12194)
parent
280ca22aeb
commit
2d5c834590
|
@ -5,6 +5,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
portainer "github.com/portainer/portainer/api"
|
portainer "github.com/portainer/portainer/api"
|
||||||
|
"github.com/portainer/portainer/api/dataservices"
|
||||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||||
|
@ -54,12 +55,33 @@ func (handler *Handler) userCreate(w http.ResponseWriter, r *http.Request) *http
|
||||||
return httperror.BadRequest("Invalid request payload", err)
|
return httperror.BadRequest("Invalid request payload", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := handler.DataStore.User().UserByUsername(payload.Username)
|
var user *portainer.User
|
||||||
if err != nil && !handler.DataStore.IsErrObjectNotFound(err) {
|
|
||||||
return httperror.InternalServerError("Unable to retrieve users from the database", err)
|
if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||||
|
var err error
|
||||||
|
user, err = handler.createUser(tx, payload)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}); err != nil {
|
||||||
|
var httpErr *httperror.HandlerError
|
||||||
|
if errors.As(err, &httpErr) {
|
||||||
|
return httpErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return httperror.InternalServerError("Unexpected error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return response.JSON(w, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (handler *Handler) createUser(tx dataservices.DataStoreTx, payload userCreatePayload) (*portainer.User, error) {
|
||||||
|
user, err := tx.User().UserByUsername(payload.Username)
|
||||||
|
if err != nil && !tx.IsErrObjectNotFound(err) {
|
||||||
|
return nil, httperror.InternalServerError("Unable to retrieve users from the database", err)
|
||||||
|
}
|
||||||
|
|
||||||
if user != nil {
|
if user != nil {
|
||||||
return httperror.Conflict("Another user with the same username already exists", errUserAlreadyExists)
|
return nil, httperror.Conflict("Another user with the same username already exists", errUserAlreadyExists)
|
||||||
}
|
}
|
||||||
|
|
||||||
user = &portainer.User{
|
user = &portainer.User{
|
||||||
|
@ -67,33 +89,33 @@ func (handler *Handler) userCreate(w http.ResponseWriter, r *http.Request) *http
|
||||||
Role: portainer.UserRole(payload.Role),
|
Role: portainer.UserRole(payload.Role),
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := handler.DataStore.Settings().Settings()
|
settings, err := tx.Settings().Settings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httperror.InternalServerError("Unable to retrieve settings from the database", err)
|
return nil, httperror.InternalServerError("Unable to retrieve settings from the database", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// when ldap/oauth is on, can only add users without password
|
// When LDAP/OAuth is on, can only add users without password
|
||||||
if (settings.AuthenticationMethod == portainer.AuthenticationLDAP || settings.AuthenticationMethod == portainer.AuthenticationOAuth) && payload.Password != "" {
|
if (settings.AuthenticationMethod == portainer.AuthenticationLDAP || settings.AuthenticationMethod == portainer.AuthenticationOAuth) && payload.Password != "" {
|
||||||
errMsg := "A user with password can not be created when authentication method is Oauth or LDAP"
|
errMsg := "a user with password can not be created when authentication method is Oauth or LDAP"
|
||||||
return httperror.BadRequest(errMsg, errors.New(errMsg))
|
return nil, httperror.BadRequest(errMsg, errors.New(errMsg))
|
||||||
}
|
}
|
||||||
|
|
||||||
if settings.AuthenticationMethod == portainer.AuthenticationInternal {
|
if settings.AuthenticationMethod == portainer.AuthenticationInternal {
|
||||||
if !handler.passwordStrengthChecker.Check(payload.Password) {
|
if !handler.passwordStrengthChecker.Check(payload.Password) {
|
||||||
return httperror.BadRequest("Password does not meet the requirements", nil)
|
return nil, httperror.BadRequest("Password does not meet the requirements", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
user.Password, err = handler.CryptoService.Hash(payload.Password)
|
user.Password, err = handler.CryptoService.Hash(payload.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httperror.InternalServerError("Unable to hash user password", errCryptoHashFailure)
|
return nil, httperror.InternalServerError("Unable to hash user password", errCryptoHashFailure)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = handler.DataStore.User().Create(user)
|
if err := tx.User().Create(user); err != nil {
|
||||||
if err != nil {
|
return nil, httperror.InternalServerError("Unable to persist user inside the database", err)
|
||||||
return httperror.InternalServerError("Unable to persist user inside the database", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
hideFields(user)
|
hideFields(user)
|
||||||
return response.JSON(w, user)
|
|
||||||
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
package users
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
portainer "github.com/portainer/portainer/api"
|
||||||
|
"github.com/portainer/portainer/api/crypto"
|
||||||
|
"github.com/portainer/portainer/api/datastore"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockPasswordStrengthChecker struct{}
|
||||||
|
|
||||||
|
func (m *mockPasswordStrengthChecker) Check(string) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentUserCreation(t *testing.T) {
|
||||||
|
_, store := datastore.MustNewTestStore(t, true, false)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
passwordStrengthChecker: &mockPasswordStrengthChecker{},
|
||||||
|
CryptoService: &crypto.Service{},
|
||||||
|
DataStore: store,
|
||||||
|
}
|
||||||
|
|
||||||
|
ucp := userCreatePayload{
|
||||||
|
Username: "portainer",
|
||||||
|
Password: "password",
|
||||||
|
Role: int(portainer.AdministratorRole),
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := json.Marshal(ucp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
errGroup := &errgroup.Group{}
|
||||||
|
|
||||||
|
n := 100
|
||||||
|
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "/users", bytes.NewReader(m))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.userCreate(httptest.NewRecorder(), req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
err = errGroup.Wait()
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
users, err := store.User().ReadAll()
|
||||||
|
require.NotEmpty(t, users)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
userCreated := false
|
||||||
|
for _, u := range users {
|
||||||
|
if u.Username == ucp.Username {
|
||||||
|
require.False(t, userCreated)
|
||||||
|
userCreated = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, userCreated)
|
||||||
|
}
|
Loading…
Reference in New Issue