mirror of https://github.com/portainer/portainer
fix(teams): fix data-race in teamCreate() BE-11210 (#12195)
parent
80e607ab30
commit
7a176cf284
|
@ -19,8 +19,7 @@ type Service struct {
|
|||
|
||||
// NewService creates a new instance of a service.
|
||||
func NewService(connection portainer.Connection) (*Service, error) {
|
||||
err := connection.SetServiceName(BucketName)
|
||||
if err != nil {
|
||||
if err := connection.SetServiceName(BucketName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -32,6 +31,16 @@ func NewService(connection portainer.Connection) (*Service, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
|
||||
return ServiceTx{
|
||||
BaseDataServiceTx: dataservices.BaseDataServiceTx[portainer.Team, portainer.TeamID]{
|
||||
Bucket: BucketName,
|
||||
Connection: service.Connection,
|
||||
Tx: tx,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TeamByName returns a team by name.
|
||||
func (service *Service) TeamByName(name string) (*portainer.Team, error) {
|
||||
var t portainer.Team
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
package team
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
dserrors "github.com/portainer/portainer/api/dataservices/errors"
|
||||
)
|
||||
|
||||
type ServiceTx struct {
|
||||
dataservices.BaseDataServiceTx[portainer.Team, portainer.TeamID]
|
||||
}
|
||||
|
||||
// TeamByName returns a team by name.
|
||||
func (service ServiceTx) TeamByName(name string) (*portainer.Team, error) {
|
||||
var t portainer.Team
|
||||
|
||||
err := service.Tx.GetAll(
|
||||
BucketName,
|
||||
&portainer.Team{},
|
||||
dataservices.FirstFn(&t, func(e portainer.Team) bool {
|
||||
return strings.EqualFold(e.Name, name)
|
||||
}),
|
||||
)
|
||||
|
||||
if errors.Is(err, dataservices.ErrStop) {
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
return nil, dserrors.ErrObjectNotFound
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// CreateTeam creates a new Team.
|
||||
func (service ServiceTx) Create(team *portainer.Team) error {
|
||||
return service.Tx.CreateObject(
|
||||
BucketName,
|
||||
func(id uint64) (int, any) {
|
||||
team.ID = portainer.TeamID(id)
|
||||
return int(team.ID), team
|
||||
},
|
||||
)
|
||||
}
|
|
@ -389,7 +389,6 @@ type storeExport struct {
|
|||
}
|
||||
|
||||
func (store *Store) Export(filename string) (err error) {
|
||||
|
||||
backup := storeExport{}
|
||||
|
||||
if c, err := store.CustomTemplate().ReadAll(); err != nil {
|
||||
|
@ -593,6 +592,7 @@ func (store *Store) Export(filename string) (err error) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(filename, b, 0600)
|
||||
}
|
||||
|
||||
|
|
|
@ -82,7 +82,9 @@ func (tx *StoreTx) TeamMembership() dataservices.TeamMembershipService {
|
|||
return tx.store.TeamMembershipService.Tx(tx.tx)
|
||||
}
|
||||
|
||||
func (tx *StoreTx) Team() dataservices.TeamService { return nil }
|
||||
func (tx *StoreTx) Team() dataservices.TeamService {
|
||||
return tx.store.TeamService.Tx(tx.tx)
|
||||
}
|
||||
|
||||
func (tx *StoreTx) TunnelServer() dataservices.TunnelServerService { return nil }
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
|
@ -21,6 +22,7 @@ func (payload *teamCreatePayload) Validate(r *http.Request) error {
|
|||
if len(payload.Name) == 0 {
|
||||
return errors.New("Invalid team name")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -41,26 +43,42 @@ func (payload *teamCreatePayload) Validate(r *http.Request) error {
|
|||
// @router /teams [post]
|
||||
func (handler *Handler) teamCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
|
||||
var payload teamCreatePayload
|
||||
err := request.DecodeAndValidateJSONPayload(r, &payload)
|
||||
if err != nil {
|
||||
if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
|
||||
return httperror.BadRequest("Invalid request payload", err)
|
||||
}
|
||||
|
||||
team, err := handler.DataStore.Team().TeamByName(payload.Name)
|
||||
if err != nil && !handler.DataStore.IsErrObjectNotFound(err) {
|
||||
return httperror.InternalServerError("Unable to retrieve teams from the database", err)
|
||||
var team *portainer.Team
|
||||
|
||||
if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
var err error
|
||||
team, err = createTeam(tx, payload)
|
||||
|
||||
return err
|
||||
}); err != nil {
|
||||
var httpErr *httperror.HandlerError
|
||||
if errors.As(err, &httpErr) {
|
||||
return httpErr
|
||||
}
|
||||
|
||||
return httperror.InternalServerError("Unexpected error", err)
|
||||
}
|
||||
|
||||
return response.JSON(w, team)
|
||||
}
|
||||
|
||||
func createTeam(tx dataservices.DataStoreTx, payload teamCreatePayload) (*portainer.Team, error) {
|
||||
team, err := tx.Team().TeamByName(payload.Name)
|
||||
if err != nil && !tx.IsErrObjectNotFound(err) {
|
||||
return nil, httperror.InternalServerError("Unable to retrieve teams from the database", err)
|
||||
}
|
||||
if team != nil {
|
||||
return httperror.Conflict("A team with the same name already exists", errors.New("Team already exists"))
|
||||
return nil, httperror.Conflict("A team with the same name already exists", errors.New("Team already exists"))
|
||||
}
|
||||
|
||||
team = &portainer.Team{
|
||||
Name: payload.Name,
|
||||
}
|
||||
team = &portainer.Team{Name: payload.Name}
|
||||
|
||||
err = handler.DataStore.Team().Create(team)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to persist the team inside the database", err)
|
||||
if err := tx.Team().Create(team); err != nil {
|
||||
return nil, httperror.InternalServerError("Unable to persist the team inside the database", err)
|
||||
}
|
||||
|
||||
for _, teamLeader := range payload.TeamLeaders {
|
||||
|
@ -70,11 +88,10 @@ func (handler *Handler) teamCreate(w http.ResponseWriter, r *http.Request) *http
|
|||
Role: portainer.TeamLeader,
|
||||
}
|
||||
|
||||
err = handler.DataStore.TeamMembership().Create(membership)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to persist team leadership inside the database", err)
|
||||
if err := tx.TeamMembership().Create(membership); err != nil {
|
||||
return nil, httperror.InternalServerError("Unable to persist team leadership inside the database", err)
|
||||
}
|
||||
}
|
||||
|
||||
return response.JSON(w, team)
|
||||
return team, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
package teams
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func TestConcurrentTeamCreation(t *testing.T) {
|
||||
_, store := datastore.MustNewTestStore(t, true, false)
|
||||
|
||||
h := &Handler{
|
||||
DataStore: store,
|
||||
}
|
||||
|
||||
tcp := teamCreatePayload{
|
||||
Name: "portainer",
|
||||
}
|
||||
|
||||
m, err := json.Marshal(tcp)
|
||||
require.NoError(t, err)
|
||||
|
||||
errGroup := &errgroup.Group{}
|
||||
|
||||
n := 100
|
||||
|
||||
for range n {
|
||||
errGroup.Go(func() error {
|
||||
req, err := http.NewRequest(http.MethodPost, "/teams", bytes.NewReader(m))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.teamCreate(httptest.NewRecorder(), req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
err = errGroup.Wait()
|
||||
require.Error(t, err)
|
||||
|
||||
teams, err := store.Team().ReadAll()
|
||||
require.NotEmpty(t, teams)
|
||||
require.NoError(t, err)
|
||||
|
||||
teamCreated := false
|
||||
for _, team := range teams {
|
||||
if team.Name == tcp.Name {
|
||||
require.False(t, teamCreated)
|
||||
teamCreated = true
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, teamCreated)
|
||||
}
|
Loading…
Reference in New Issue