diff --git a/api/dataservices/team/team.go b/api/dataservices/team/team.go index e2c9089f2..cc40feffc 100644 --- a/api/dataservices/team/team.go +++ b/api/dataservices/team/team.go @@ -19,8 +19,7 @@ type Service struct { // NewService creates a new instance of a service. func NewService(connection portainer.Connection) (*Service, error) { - err := connection.SetServiceName(BucketName) - if err != nil { + if err := connection.SetServiceName(BucketName); err != nil { return nil, err } @@ -32,6 +31,16 @@ func NewService(connection portainer.Connection) (*Service, error) { }, nil } +func (service *Service) Tx(tx portainer.Transaction) ServiceTx { + return ServiceTx{ + BaseDataServiceTx: dataservices.BaseDataServiceTx[portainer.Team, portainer.TeamID]{ + Bucket: BucketName, + Connection: service.Connection, + Tx: tx, + }, + } +} + // TeamByName returns a team by name. func (service *Service) TeamByName(name string) (*portainer.Team, error) { var t portainer.Team diff --git a/api/dataservices/team/tx.go b/api/dataservices/team/tx.go new file mode 100644 index 000000000..4887aa4b1 --- /dev/null +++ b/api/dataservices/team/tx.go @@ -0,0 +1,48 @@ +package team + +import ( + "errors" + "strings" + + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" + dserrors "github.com/portainer/portainer/api/dataservices/errors" +) + +type ServiceTx struct { + dataservices.BaseDataServiceTx[portainer.Team, portainer.TeamID] +} + +// TeamByName returns a team by name. +func (service ServiceTx) TeamByName(name string) (*portainer.Team, error) { + var t portainer.Team + + err := service.Tx.GetAll( + BucketName, + &portainer.Team{}, + dataservices.FirstFn(&t, func(e portainer.Team) bool { + return strings.EqualFold(e.Name, name) + }), + ) + + if errors.Is(err, dataservices.ErrStop) { + return &t, nil + } + + if err == nil { + return nil, dserrors.ErrObjectNotFound + } + + return nil, err +} + +// CreateTeam creates a new Team. +func (service ServiceTx) Create(team *portainer.Team) error { + return service.Tx.CreateObject( + BucketName, + func(id uint64) (int, any) { + team.ID = portainer.TeamID(id) + return int(team.ID), team + }, + ) +} diff --git a/api/datastore/services.go b/api/datastore/services.go index 802989d3d..9670a61be 100644 --- a/api/datastore/services.go +++ b/api/datastore/services.go @@ -402,7 +402,6 @@ type storeExport struct { } func (store *Store) Export(filename string) (err error) { - backup := storeExport{} if c, err := store.CustomTemplate().ReadAll(); err != nil { @@ -606,6 +605,7 @@ func (store *Store) Export(filename string) (err error) { if err != nil { return err } + return os.WriteFile(filename, b, 0600) } diff --git a/api/datastore/services_tx.go b/api/datastore/services_tx.go index 0968d18cc..a5646a550 100644 --- a/api/datastore/services_tx.go +++ b/api/datastore/services_tx.go @@ -80,7 +80,10 @@ func (tx *StoreTx) TeamMembership() dataservices.TeamMembershipService { return tx.store.TeamMembershipService.Tx(tx.tx) } -func (tx *StoreTx) Team() dataservices.TeamService { return nil } +func (tx *StoreTx) Team() dataservices.TeamService { + return tx.store.TeamService.Tx(tx.tx) +} + func (tx *StoreTx) TunnelServer() dataservices.TunnelServerService { return nil } func (tx *StoreTx) User() dataservices.UserService { diff --git a/api/http/handler/teams/team_create.go b/api/http/handler/teams/team_create.go index 50f4be323..7ab5e8a37 100644 --- a/api/http/handler/teams/team_create.go +++ b/api/http/handler/teams/team_create.go @@ -5,6 +5,7 @@ import ( "net/http" portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" @@ -23,6 +24,7 @@ func (payload *teamCreatePayload) Validate(r *http.Request) error { if govalidator.IsNull(payload.Name) { return errors.New("Invalid team name") } + return nil } @@ -43,26 +45,42 @@ func (payload *teamCreatePayload) Validate(r *http.Request) error { // @router /teams [post] func (handler *Handler) teamCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { var payload teamCreatePayload - err := request.DecodeAndValidateJSONPayload(r, &payload) - if err != nil { + if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil { return httperror.BadRequest("Invalid request payload", err) } - team, err := handler.DataStore.Team().TeamByName(payload.Name) - if err != nil && !handler.DataStore.IsErrObjectNotFound(err) { - return httperror.InternalServerError("Unable to retrieve teams from the database", err) + var team *portainer.Team + + if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + var err error + team, err = createTeam(tx, payload) + + return err + }); err != nil { + var httpErr *httperror.HandlerError + if errors.As(err, &httpErr) { + return httpErr + } + + return httperror.InternalServerError("Unexpected error", err) + } + + return response.JSON(w, team) +} + +func createTeam(tx dataservices.DataStoreTx, payload teamCreatePayload) (*portainer.Team, error) { + team, err := tx.Team().TeamByName(payload.Name) + if err != nil && !tx.IsErrObjectNotFound(err) { + return nil, httperror.InternalServerError("Unable to retrieve teams from the database", err) } if team != nil { - return httperror.Conflict("A team with the same name already exists", errors.New("Team already exists")) + return nil, httperror.Conflict("A team with the same name already exists", errors.New("Team already exists")) } - team = &portainer.Team{ - Name: payload.Name, - } + team = &portainer.Team{Name: payload.Name} - err = handler.DataStore.Team().Create(team) - if err != nil { - return httperror.InternalServerError("Unable to persist the team inside the database", err) + if err := tx.Team().Create(team); err != nil { + return nil, httperror.InternalServerError("Unable to persist the team inside the database", err) } for _, teamLeader := range payload.TeamLeaders { @@ -72,11 +90,10 @@ func (handler *Handler) teamCreate(w http.ResponseWriter, r *http.Request) *http Role: portainer.TeamLeader, } - err = handler.DataStore.TeamMembership().Create(membership) - if err != nil { - return httperror.InternalServerError("Unable to persist team leadership inside the database", err) + if err := tx.TeamMembership().Create(membership); err != nil { + return nil, httperror.InternalServerError("Unable to persist team leadership inside the database", err) } } - return response.JSON(w, team) + return team, nil } diff --git a/api/http/handler/teams/team_create_test.go b/api/http/handler/teams/team_create_test.go new file mode 100644 index 000000000..e56b7de5d --- /dev/null +++ b/api/http/handler/teams/team_create_test.go @@ -0,0 +1,65 @@ +package teams + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/portainer/portainer/api/datastore" + + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +func TestConcurrentTeamCreation(t *testing.T) { + _, store := datastore.MustNewTestStore(t, true, false) + + h := &Handler{ + DataStore: store, + } + + tcp := teamCreatePayload{ + Name: "portainer", + } + + m, err := json.Marshal(tcp) + require.NoError(t, err) + + errGroup := &errgroup.Group{} + + n := 100 + + for i := 0; i < n; i++ { + errGroup.Go(func() error { + req, err := http.NewRequest(http.MethodPost, "/teams", bytes.NewReader(m)) + if err != nil { + return err + } + + if err := h.teamCreate(httptest.NewRecorder(), req); err != nil { + return err + } + + return nil + }) + } + + err = errGroup.Wait() + require.Error(t, err) + + teams, err := store.Team().ReadAll() + require.NotEmpty(t, teams) + require.NoError(t, err) + + teamCreated := false + for _, team := range teams { + if team.Name == tcp.Name { + require.False(t, teamCreated) + teamCreated = true + } + } + + require.True(t, teamCreated) +}