fix(teams): fix data-race in teamCreate() BE-11210 (#12196)

pull/12203/head
andres-portainer 3 months ago committed by GitHub
parent 753150e03c
commit 280ca22aeb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -19,8 +19,7 @@ type Service struct {
// NewService creates a new instance of a service.
func NewService(connection portainer.Connection) (*Service, error) {
err := connection.SetServiceName(BucketName)
if err != nil {
if err := connection.SetServiceName(BucketName); err != nil {
return nil, err
}
@ -32,6 +31,16 @@ func NewService(connection portainer.Connection) (*Service, error) {
}, nil
}
func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
return ServiceTx{
BaseDataServiceTx: dataservices.BaseDataServiceTx[portainer.Team, portainer.TeamID]{
Bucket: BucketName,
Connection: service.Connection,
Tx: tx,
},
}
}
// TeamByName returns a team by name.
func (service *Service) TeamByName(name string) (*portainer.Team, error) {
var t portainer.Team

@ -0,0 +1,48 @@
package team
import (
"errors"
"strings"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
dserrors "github.com/portainer/portainer/api/dataservices/errors"
)
type ServiceTx struct {
dataservices.BaseDataServiceTx[portainer.Team, portainer.TeamID]
}
// TeamByName returns a team by name.
func (service ServiceTx) TeamByName(name string) (*portainer.Team, error) {
var t portainer.Team
err := service.Tx.GetAll(
BucketName,
&portainer.Team{},
dataservices.FirstFn(&t, func(e portainer.Team) bool {
return strings.EqualFold(e.Name, name)
}),
)
if errors.Is(err, dataservices.ErrStop) {
return &t, nil
}
if err == nil {
return nil, dserrors.ErrObjectNotFound
}
return nil, err
}
// CreateTeam creates a new Team.
func (service ServiceTx) Create(team *portainer.Team) error {
return service.Tx.CreateObject(
BucketName,
func(id uint64) (int, any) {
team.ID = portainer.TeamID(id)
return int(team.ID), team
},
)
}

@ -402,7 +402,6 @@ type storeExport struct {
}
func (store *Store) Export(filename string) (err error) {
backup := storeExport{}
if c, err := store.CustomTemplate().ReadAll(); err != nil {
@ -606,6 +605,7 @@ func (store *Store) Export(filename string) (err error) {
if err != nil {
return err
}
return os.WriteFile(filename, b, 0600)
}

@ -80,7 +80,10 @@ func (tx *StoreTx) TeamMembership() dataservices.TeamMembershipService {
return tx.store.TeamMembershipService.Tx(tx.tx)
}
func (tx *StoreTx) Team() dataservices.TeamService { return nil }
func (tx *StoreTx) Team() dataservices.TeamService {
return tx.store.TeamService.Tx(tx.tx)
}
func (tx *StoreTx) TunnelServer() dataservices.TunnelServerService { return nil }
func (tx *StoreTx) User() dataservices.UserService {

@ -5,6 +5,7 @@ import (
"net/http"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
@ -23,6 +24,7 @@ func (payload *teamCreatePayload) Validate(r *http.Request) error {
if govalidator.IsNull(payload.Name) {
return errors.New("Invalid team name")
}
return nil
}
@ -43,26 +45,42 @@ func (payload *teamCreatePayload) Validate(r *http.Request) error {
// @router /teams [post]
func (handler *Handler) teamCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
var payload teamCreatePayload
err := request.DecodeAndValidateJSONPayload(r, &payload)
if err != nil {
if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
return httperror.BadRequest("Invalid request payload", err)
}
team, err := handler.DataStore.Team().TeamByName(payload.Name)
if err != nil && !handler.DataStore.IsErrObjectNotFound(err) {
return httperror.InternalServerError("Unable to retrieve teams from the database", err)
var team *portainer.Team
if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
var err error
team, err = createTeam(tx, payload)
return err
}); err != nil {
var httpErr *httperror.HandlerError
if errors.As(err, &httpErr) {
return httpErr
}
return httperror.InternalServerError("Unexpected error", err)
}
return response.JSON(w, team)
}
func createTeam(tx dataservices.DataStoreTx, payload teamCreatePayload) (*portainer.Team, error) {
team, err := tx.Team().TeamByName(payload.Name)
if err != nil && !tx.IsErrObjectNotFound(err) {
return nil, httperror.InternalServerError("Unable to retrieve teams from the database", err)
}
if team != nil {
return httperror.Conflict("A team with the same name already exists", errors.New("Team already exists"))
return nil, httperror.Conflict("A team with the same name already exists", errors.New("Team already exists"))
}
team = &portainer.Team{
Name: payload.Name,
}
team = &portainer.Team{Name: payload.Name}
err = handler.DataStore.Team().Create(team)
if err != nil {
return httperror.InternalServerError("Unable to persist the team inside the database", err)
if err := tx.Team().Create(team); err != nil {
return nil, httperror.InternalServerError("Unable to persist the team inside the database", err)
}
for _, teamLeader := range payload.TeamLeaders {
@ -72,11 +90,10 @@ func (handler *Handler) teamCreate(w http.ResponseWriter, r *http.Request) *http
Role: portainer.TeamLeader,
}
err = handler.DataStore.TeamMembership().Create(membership)
if err != nil {
return httperror.InternalServerError("Unable to persist team leadership inside the database", err)
if err := tx.TeamMembership().Create(membership); err != nil {
return nil, httperror.InternalServerError("Unable to persist team leadership inside the database", err)
}
}
return response.JSON(w, team)
return team, nil
}

@ -0,0 +1,65 @@
package teams
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/portainer/portainer/api/datastore"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
func TestConcurrentTeamCreation(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, false)
h := &Handler{
DataStore: store,
}
tcp := teamCreatePayload{
Name: "portainer",
}
m, err := json.Marshal(tcp)
require.NoError(t, err)
errGroup := &errgroup.Group{}
n := 100
for i := 0; i < n; i++ {
errGroup.Go(func() error {
req, err := http.NewRequest(http.MethodPost, "/teams", bytes.NewReader(m))
if err != nil {
return err
}
if err := h.teamCreate(httptest.NewRecorder(), req); err != nil {
return err
}
return nil
})
}
err = errGroup.Wait()
require.Error(t, err)
teams, err := store.Team().ReadAll()
require.NotEmpty(t, teams)
require.NoError(t, err)
teamCreated := false
for _, team := range teams {
if team.Name == tcp.Name {
require.False(t, teamCreated)
teamCreated = true
}
}
require.True(t, teamCreated)
}
Loading…
Cancel
Save