mirror of https://github.com/portainer/portainer
fix(teams): fix data-race in teamCreate() BE-11210 (#12195)
parent
80e607ab30
commit
7a176cf284
|
@ -19,8 +19,7 @@ type Service struct {
|
||||||
|
|
||||||
// NewService creates a new instance of a service.
|
// NewService creates a new instance of a service.
|
||||||
func NewService(connection portainer.Connection) (*Service, error) {
|
func NewService(connection portainer.Connection) (*Service, error) {
|
||||||
err := connection.SetServiceName(BucketName)
|
if err := connection.SetServiceName(BucketName); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,6 +31,16 @@ func NewService(connection portainer.Connection) (*Service, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
|
||||||
|
return ServiceTx{
|
||||||
|
BaseDataServiceTx: dataservices.BaseDataServiceTx[portainer.Team, portainer.TeamID]{
|
||||||
|
Bucket: BucketName,
|
||||||
|
Connection: service.Connection,
|
||||||
|
Tx: tx,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TeamByName returns a team by name.
|
// TeamByName returns a team by name.
|
||||||
func (service *Service) TeamByName(name string) (*portainer.Team, error) {
|
func (service *Service) TeamByName(name string) (*portainer.Team, error) {
|
||||||
var t portainer.Team
|
var t portainer.Team
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
package team
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
portainer "github.com/portainer/portainer/api"
|
||||||
|
"github.com/portainer/portainer/api/dataservices"
|
||||||
|
dserrors "github.com/portainer/portainer/api/dataservices/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ServiceTx struct {
|
||||||
|
dataservices.BaseDataServiceTx[portainer.Team, portainer.TeamID]
|
||||||
|
}
|
||||||
|
|
||||||
|
// TeamByName returns a team by name.
|
||||||
|
func (service ServiceTx) TeamByName(name string) (*portainer.Team, error) {
|
||||||
|
var t portainer.Team
|
||||||
|
|
||||||
|
err := service.Tx.GetAll(
|
||||||
|
BucketName,
|
||||||
|
&portainer.Team{},
|
||||||
|
dataservices.FirstFn(&t, func(e portainer.Team) bool {
|
||||||
|
return strings.EqualFold(e.Name, name)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
if errors.Is(err, dataservices.ErrStop) {
|
||||||
|
return &t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
return nil, dserrors.ErrObjectNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTeam creates a new Team.
|
||||||
|
func (service ServiceTx) Create(team *portainer.Team) error {
|
||||||
|
return service.Tx.CreateObject(
|
||||||
|
BucketName,
|
||||||
|
func(id uint64) (int, any) {
|
||||||
|
team.ID = portainer.TeamID(id)
|
||||||
|
return int(team.ID), team
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
|
@ -389,7 +389,6 @@ type storeExport struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (store *Store) Export(filename string) (err error) {
|
func (store *Store) Export(filename string) (err error) {
|
||||||
|
|
||||||
backup := storeExport{}
|
backup := storeExport{}
|
||||||
|
|
||||||
if c, err := store.CustomTemplate().ReadAll(); err != nil {
|
if c, err := store.CustomTemplate().ReadAll(); err != nil {
|
||||||
|
@ -593,6 +592,7 @@ func (store *Store) Export(filename string) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return os.WriteFile(filename, b, 0600)
|
return os.WriteFile(filename, b, 0600)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -82,7 +82,9 @@ func (tx *StoreTx) TeamMembership() dataservices.TeamMembershipService {
|
||||||
return tx.store.TeamMembershipService.Tx(tx.tx)
|
return tx.store.TeamMembershipService.Tx(tx.tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *StoreTx) Team() dataservices.TeamService { return nil }
|
func (tx *StoreTx) Team() dataservices.TeamService {
|
||||||
|
return tx.store.TeamService.Tx(tx.tx)
|
||||||
|
}
|
||||||
|
|
||||||
func (tx *StoreTx) TunnelServer() dataservices.TunnelServerService { return nil }
|
func (tx *StoreTx) TunnelServer() dataservices.TunnelServerService { return nil }
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
portainer "github.com/portainer/portainer/api"
|
portainer "github.com/portainer/portainer/api"
|
||||||
|
"github.com/portainer/portainer/api/dataservices"
|
||||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||||
|
@ -21,6 +22,7 @@ func (payload *teamCreatePayload) Validate(r *http.Request) error {
|
||||||
if len(payload.Name) == 0 {
|
if len(payload.Name) == 0 {
|
||||||
return errors.New("Invalid team name")
|
return errors.New("Invalid team name")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,26 +43,42 @@ func (payload *teamCreatePayload) Validate(r *http.Request) error {
|
||||||
// @router /teams [post]
|
// @router /teams [post]
|
||||||
func (handler *Handler) teamCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
|
func (handler *Handler) teamCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
|
||||||
var payload teamCreatePayload
|
var payload teamCreatePayload
|
||||||
err := request.DecodeAndValidateJSONPayload(r, &payload)
|
if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
|
||||||
if err != nil {
|
|
||||||
return httperror.BadRequest("Invalid request payload", err)
|
return httperror.BadRequest("Invalid request payload", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
team, err := handler.DataStore.Team().TeamByName(payload.Name)
|
var team *portainer.Team
|
||||||
if err != nil && !handler.DataStore.IsErrObjectNotFound(err) {
|
|
||||||
return httperror.InternalServerError("Unable to retrieve teams from the database", err)
|
if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||||
|
var err error
|
||||||
|
team, err = createTeam(tx, payload)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}); err != nil {
|
||||||
|
var httpErr *httperror.HandlerError
|
||||||
|
if errors.As(err, &httpErr) {
|
||||||
|
return httpErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return httperror.InternalServerError("Unexpected error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return response.JSON(w, team)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTeam(tx dataservices.DataStoreTx, payload teamCreatePayload) (*portainer.Team, error) {
|
||||||
|
team, err := tx.Team().TeamByName(payload.Name)
|
||||||
|
if err != nil && !tx.IsErrObjectNotFound(err) {
|
||||||
|
return nil, httperror.InternalServerError("Unable to retrieve teams from the database", err)
|
||||||
}
|
}
|
||||||
if team != nil {
|
if team != nil {
|
||||||
return httperror.Conflict("A team with the same name already exists", errors.New("Team already exists"))
|
return nil, httperror.Conflict("A team with the same name already exists", errors.New("Team already exists"))
|
||||||
}
|
}
|
||||||
|
|
||||||
team = &portainer.Team{
|
team = &portainer.Team{Name: payload.Name}
|
||||||
Name: payload.Name,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = handler.DataStore.Team().Create(team)
|
if err := tx.Team().Create(team); err != nil {
|
||||||
if err != nil {
|
return nil, httperror.InternalServerError("Unable to persist the team inside the database", err)
|
||||||
return httperror.InternalServerError("Unable to persist the team inside the database", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, teamLeader := range payload.TeamLeaders {
|
for _, teamLeader := range payload.TeamLeaders {
|
||||||
|
@ -70,11 +88,10 @@ func (handler *Handler) teamCreate(w http.ResponseWriter, r *http.Request) *http
|
||||||
Role: portainer.TeamLeader,
|
Role: portainer.TeamLeader,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = handler.DataStore.TeamMembership().Create(membership)
|
if err := tx.TeamMembership().Create(membership); err != nil {
|
||||||
if err != nil {
|
return nil, httperror.InternalServerError("Unable to persist team leadership inside the database", err)
|
||||||
return httperror.InternalServerError("Unable to persist team leadership inside the database", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return response.JSON(w, team)
|
return team, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
package teams
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/portainer/portainer/api/datastore"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConcurrentTeamCreation(t *testing.T) {
|
||||||
|
_, store := datastore.MustNewTestStore(t, true, false)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
DataStore: store,
|
||||||
|
}
|
||||||
|
|
||||||
|
tcp := teamCreatePayload{
|
||||||
|
Name: "portainer",
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := json.Marshal(tcp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
errGroup := &errgroup.Group{}
|
||||||
|
|
||||||
|
n := 100
|
||||||
|
|
||||||
|
for range n {
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "/teams", bytes.NewReader(m))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.teamCreate(httptest.NewRecorder(), req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
err = errGroup.Wait()
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
teams, err := store.Team().ReadAll()
|
||||||
|
require.NotEmpty(t, teams)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
teamCreated := false
|
||||||
|
for _, team := range teams {
|
||||||
|
if team.Name == tcp.Name {
|
||||||
|
require.False(t, teamCreated)
|
||||||
|
teamCreated = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, teamCreated)
|
||||||
|
}
|
Loading…
Reference in New Issue