From 92f338e0cd0de66a21bb856945b46e4b56cab911 Mon Sep 17 00:00:00 2001 From: andres-portainer <91705312+andres-portainer@users.noreply.github.com> Date: Thu, 5 Sep 2024 22:28:04 -0300 Subject: [PATCH] fix(users): fix data-race in userCreate() BE-11209 (#12193) --- api/http/handler/users/user_create.go | 50 ++++++++++---- api/http/handler/users/user_create_test.go | 77 ++++++++++++++++++++++ 2 files changed, 113 insertions(+), 14 deletions(-) create mode 100644 api/http/handler/users/user_create_test.go diff --git a/api/http/handler/users/user_create.go b/api/http/handler/users/user_create.go index 627e00d15..a5afb1b8b 100644 --- a/api/http/handler/users/user_create.go +++ b/api/http/handler/users/user_create.go @@ -6,6 +6,7 @@ import ( "strings" portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" @@ -53,12 +54,33 @@ func (handler *Handler) userCreate(w http.ResponseWriter, r *http.Request) *http return httperror.BadRequest("Invalid request payload", err) } - user, err := handler.DataStore.User().UserByUsername(payload.Username) - if err != nil && !handler.DataStore.IsErrObjectNotFound(err) { - return httperror.InternalServerError("Unable to retrieve users from the database", err) + var user *portainer.User + + if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + var err error + user, err = handler.createUser(tx, payload) + + return err + }); err != nil { + var httpErr *httperror.HandlerError + if errors.As(err, &httpErr) { + return httpErr + } + + return httperror.InternalServerError("Unexpected error", err) } + + return response.JSON(w, user) +} + +func (handler *Handler) createUser(tx dataservices.DataStoreTx, payload userCreatePayload) (*portainer.User, error) { + user, err := tx.User().UserByUsername(payload.Username) + if err != nil && !tx.IsErrObjectNotFound(err) { + return nil, httperror.InternalServerError("Unable to retrieve users from the database", err) + } + if user != nil { - return httperror.Conflict("Another user with the same username already exists", errUserAlreadyExists) + return nil, httperror.Conflict("Another user with the same username already exists", errUserAlreadyExists) } user = &portainer.User{ @@ -66,33 +88,33 @@ func (handler *Handler) userCreate(w http.ResponseWriter, r *http.Request) *http Role: portainer.UserRole(payload.Role), } - settings, err := handler.DataStore.Settings().Settings() + settings, err := tx.Settings().Settings() if err != nil { - return httperror.InternalServerError("Unable to retrieve settings from the database", err) + return nil, httperror.InternalServerError("Unable to retrieve settings from the database", err) } - // when ldap/oauth is on, can only add users without password + // When LDAP/OAuth is on, can only add users without password if (settings.AuthenticationMethod == portainer.AuthenticationLDAP || settings.AuthenticationMethod == portainer.AuthenticationOAuth) && payload.Password != "" { - errMsg := "A user with password can not be created when authentication method is Oauth or LDAP" - return httperror.BadRequest(errMsg, errors.New(errMsg)) + errMsg := "a user with password can not be created when authentication method is Oauth or LDAP" + return nil, httperror.BadRequest(errMsg, errors.New(errMsg)) } if settings.AuthenticationMethod == portainer.AuthenticationInternal { if !handler.passwordStrengthChecker.Check(payload.Password) { - return httperror.BadRequest("Password does not meet the requirements", nil) + return nil, httperror.BadRequest("Password does not meet the requirements", nil) } user.Password, err = handler.CryptoService.Hash(payload.Password) if err != nil { - return httperror.InternalServerError("Unable to hash user password", errCryptoHashFailure) + return nil, httperror.InternalServerError("Unable to hash user password", errCryptoHashFailure) } } - if err := handler.DataStore.User().Create(user); err != nil { - return httperror.InternalServerError("Unable to persist user inside the database", err) + if err := tx.User().Create(user); err != nil { + return nil, httperror.InternalServerError("Unable to persist user inside the database", err) } hideFields(user) - return response.JSON(w, user) + return user, nil } diff --git a/api/http/handler/users/user_create_test.go b/api/http/handler/users/user_create_test.go new file mode 100644 index 000000000..98eb933bf --- /dev/null +++ b/api/http/handler/users/user_create_test.go @@ -0,0 +1,77 @@ +package users + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/crypto" + "github.com/portainer/portainer/api/datastore" + + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +type mockPasswordStrengthChecker struct{} + +func (m *mockPasswordStrengthChecker) Check(string) bool { + return true +} + +func TestConcurrentUserCreation(t *testing.T) { + _, store := datastore.MustNewTestStore(t, true, false) + + h := &Handler{ + passwordStrengthChecker: &mockPasswordStrengthChecker{}, + CryptoService: &crypto.Service{}, + DataStore: store, + } + + ucp := userCreatePayload{ + Username: "portainer", + Password: "password", + Role: int(portainer.AdministratorRole), + } + + m, err := json.Marshal(ucp) + require.NoError(t, err) + + errGroup := &errgroup.Group{} + + n := 100 + + for range n { + errGroup.Go(func() error { + req, err := http.NewRequest(http.MethodPost, "/users", bytes.NewReader(m)) + if err != nil { + return err + } + + if err := h.userCreate(httptest.NewRecorder(), req); err != nil { + return err + } + + return nil + }) + } + + err = errGroup.Wait() + require.Error(t, err) + + users, err := store.User().ReadAll() + require.NotEmpty(t, users) + require.NoError(t, err) + + userCreated := false + for _, u := range users { + if u.Username == ucp.Username { + require.False(t, userCreated) + userCreated = true + } + } + + require.True(t, userCreated) +}