fix(edgegroups): fix data-race in edgeGroupCreate EE-4435 (#8477)

pull/8498/head
andres-portainer 2023-02-14 15:18:07 -03:00 committed by GitHub
parent e66dea44e3
commit f081631808
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 800 additions and 317 deletions

View File

@ -4,10 +4,35 @@ import (
"io" "io"
) )
type ReadTransaction interface {
GetObject(bucketName string, key []byte, object interface{}) error
GetAll(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error
GetAllWithJsoniter(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error
GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj interface{}, append func(o interface{}) (interface{}, error)) error
}
type Transaction interface {
ReadTransaction
SetServiceName(bucketName string) error
UpdateObject(bucketName string, key []byte, object interface{}) error
DeleteObject(bucketName string, key []byte) error
CreateObject(bucketName string, fn func(uint64) (int, interface{})) error
CreateObjectWithId(bucketName string, id int, obj interface{}) error
CreateObjectWithStringId(bucketName string, id []byte, obj interface{}) error
DeleteAllObjects(bucketName string, obj interface{}, matching func(o interface{}) (id int, ok bool)) error
GetNextIdentifier(bucketName string) int
}
type Connection interface { type Connection interface {
Transaction
Open() error Open() error
Close() error Close() error
UpdateTx(fn func(Transaction) error) error
ViewTx(fn func(Transaction) error) error
// write the db contents to filename as json (the schema needs defining) // write the db contents to filename as json (the schema needs defining)
ExportRaw(filename string) error ExportRaw(filename string) error
@ -21,20 +46,9 @@ type Connection interface {
NeedsEncryptionMigration() (bool, error) NeedsEncryptionMigration() (bool, error)
SetEncrypted(encrypted bool) SetEncrypted(encrypted bool)
SetServiceName(bucketName string) error
GetObject(bucketName string, key []byte, object interface{}) error
UpdateObject(bucketName string, key []byte, object interface{}) error
UpdateObjectFunc(bucketName string, key []byte, object any, updateFn func()) error
DeleteObject(bucketName string, key []byte) error
DeleteAllObjects(bucketName string, obj interface{}, matching func(o interface{}) (id int, ok bool)) error
GetNextIdentifier(bucketName string) int
CreateObject(bucketName string, fn func(uint64) (int, interface{})) error
CreateObjectWithId(bucketName string, id int, obj interface{}) error
CreateObjectWithStringId(bucketName string, id []byte, obj interface{}) error
GetAll(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error
GetAllWithJsoniter(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error
ConvertToKey(v int) []byte
BackupMetadata() (map[string]interface{}, error) BackupMetadata() (map[string]interface{}, error)
RestoreMetadata(s map[string]interface{}) error RestoreMetadata(s map[string]interface{}) error
UpdateObjectFunc(bucketName string, key []byte, object any, updateFn func()) error
ConvertToKey(v int) []byte
} }

View File

@ -1,7 +1,6 @@
package boltdb package boltdb
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -10,6 +9,7 @@ import (
"path" "path"
"time" "time"
portainer "github.com/portainer/portainer/api"
dserrors "github.com/portainer/portainer/api/dataservices/errors" dserrors "github.com/portainer/portainer/api/dataservices/errors"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -132,9 +132,11 @@ func (connection *DbConnection) Open() error {
if err != nil { if err != nil {
return err return err
} }
db.MaxBatchSize = connection.MaxBatchSize db.MaxBatchSize = connection.MaxBatchSize
db.MaxBatchDelay = connection.MaxBatchDelay db.MaxBatchDelay = connection.MaxBatchDelay
connection.DB = db connection.DB = db
return nil return nil
} }
@ -144,9 +146,30 @@ func (connection *DbConnection) Close() error {
if connection.DB != nil { if connection.DB != nil {
return connection.DB.Close() return connection.DB.Close()
} }
return nil return nil
} }
func (connection *DbConnection) txFn(fn func(portainer.Transaction) error) func(*bolt.Tx) error {
return func(tx *bolt.Tx) error {
return fn(&DbTransaction{conn: connection, tx: tx})
}
}
// UpdateTx executes the given function inside a read-write transaction
func (connection *DbConnection) UpdateTx(fn func(portainer.Transaction) error) error {
if connection.MaxBatchDelay > 0 && connection.MaxBatchSize > 1 {
return connection.Batch(connection.txFn(fn))
}
return connection.Update(connection.txFn(fn))
}
// ViewTx executes the given function inside a read-only transaction
func (connection *DbConnection) ViewTx(fn func(portainer.Transaction) error) error {
return connection.View(connection.txFn(fn))
}
// BackupTo backs up db to a provided writer. // BackupTo backs up db to a provided writer.
// It does hot backup and doesn't block other database reads and writes // It does hot backup and doesn't block other database reads and writes
func (connection *DbConnection) BackupTo(w io.Writer) error { func (connection *DbConnection) BackupTo(w io.Writer) error {
@ -180,34 +203,16 @@ func (connection *DbConnection) ConvertToKey(v int) []byte {
// CreateBucket is a generic function used to create a bucket inside a database. // CreateBucket is a generic function used to create a bucket inside a database.
func (connection *DbConnection) SetServiceName(bucketName string) error { func (connection *DbConnection) SetServiceName(bucketName string) error {
return connection.Batch(func(tx *bolt.Tx) error { return connection.UpdateTx(func(tx portainer.Transaction) error {
_, err := tx.CreateBucketIfNotExists([]byte(bucketName)) return tx.SetServiceName(bucketName)
return err
}) })
} }
// GetObject is a generic function used to retrieve an unmarshalled object from a database. // GetObject is a generic function used to retrieve an unmarshalled object from a database.
func (connection *DbConnection) GetObject(bucketName string, key []byte, object interface{}) error { func (connection *DbConnection) GetObject(bucketName string, key []byte, object interface{}) error {
var data []byte return connection.ViewTx(func(tx portainer.Transaction) error {
return tx.GetObject(bucketName, key, object)
err := connection.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(bucketName))
value := bucket.Get(key)
if value == nil {
return dserrors.ErrObjectNotFound
}
data = make([]byte, len(value))
copy(data, value)
return nil
}) })
if err != nil {
return err
}
return connection.UnmarshalObjectWithJsoniter(data, object)
} }
func (connection *DbConnection) getEncryptionKey() []byte { func (connection *DbConnection) getEncryptionKey() []byte {
@ -220,14 +225,8 @@ func (connection *DbConnection) getEncryptionKey() []byte {
// UpdateObject is a generic function used to update an object inside a database. // UpdateObject is a generic function used to update an object inside a database.
func (connection *DbConnection) UpdateObject(bucketName string, key []byte, object interface{}) error { func (connection *DbConnection) UpdateObject(bucketName string, key []byte, object interface{}) error {
data, err := connection.MarshalObject(object) return connection.UpdateTx(func(tx portainer.Transaction) error {
if err != nil { return tx.UpdateObject(bucketName, key, object)
return err
}
return connection.Batch(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(bucketName))
return bucket.Put(key, data)
}) })
} }
@ -259,34 +258,16 @@ func (connection *DbConnection) UpdateObjectFunc(bucketName string, key []byte,
// DeleteObject is a generic function used to delete an object inside a database. // DeleteObject is a generic function used to delete an object inside a database.
func (connection *DbConnection) DeleteObject(bucketName string, key []byte) error { func (connection *DbConnection) DeleteObject(bucketName string, key []byte) error {
return connection.Batch(func(tx *bolt.Tx) error { return connection.UpdateTx(func(tx portainer.Transaction) error {
bucket := tx.Bucket([]byte(bucketName)) return tx.DeleteObject(bucketName, key)
return bucket.Delete(key)
}) })
} }
// DeleteAllObjects delete all objects where matching() returns (id, ok). // DeleteAllObjects delete all objects where matching() returns (id, ok).
// TODO: think about how to return the error inside (maybe change ok to type err, and use "notfound"? // TODO: think about how to return the error inside (maybe change ok to type err, and use "notfound"?
func (connection *DbConnection) DeleteAllObjects(bucketName string, obj interface{}, matching func(o interface{}) (id int, ok bool)) error { func (connection *DbConnection) DeleteAllObjects(bucketName string, obj interface{}, matching func(o interface{}) (id int, ok bool)) error {
return connection.Batch(func(tx *bolt.Tx) error { return connection.UpdateTx(func(tx portainer.Transaction) error {
bucket := tx.Bucket([]byte(bucketName)) return tx.DeleteAllObjects(bucketName, obj, matching)
cursor := bucket.Cursor()
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
err := connection.UnmarshalObject(v, &obj)
if err != nil {
return err
}
if id, ok := matching(obj); ok {
err := bucket.Delete(connection.ConvertToKey(id))
if err != nil {
return err
}
}
}
return nil
}) })
} }
@ -294,13 +275,8 @@ func (connection *DbConnection) DeleteAllObjects(bucketName string, obj interfac
func (connection *DbConnection) GetNextIdentifier(bucketName string) int { func (connection *DbConnection) GetNextIdentifier(bucketName string) int {
var identifier int var identifier int
connection.Batch(func(tx *bolt.Tx) error { _ = connection.UpdateTx(func(tx portainer.Transaction) error {
bucket := tx.Bucket([]byte(bucketName)) identifier = tx.GetNextIdentifier(bucketName)
id, err := bucket.NextSequence()
if err != nil {
return err
}
identifier = int(id)
return nil return nil
}) })
@ -309,108 +285,41 @@ func (connection *DbConnection) GetNextIdentifier(bucketName string) int {
// CreateObject creates a new object in the bucket, using the next bucket sequence id // CreateObject creates a new object in the bucket, using the next bucket sequence id
func (connection *DbConnection) CreateObject(bucketName string, fn func(uint64) (int, interface{})) error { func (connection *DbConnection) CreateObject(bucketName string, fn func(uint64) (int, interface{})) error {
return connection.Batch(func(tx *bolt.Tx) error { return connection.UpdateTx(func(tx portainer.Transaction) error {
bucket := tx.Bucket([]byte(bucketName)) return tx.CreateObject(bucketName, fn)
seqId, _ := bucket.NextSequence()
id, obj := fn(seqId)
data, err := connection.MarshalObject(obj)
if err != nil {
return err
}
return bucket.Put(connection.ConvertToKey(int(id)), data)
}) })
} }
// CreateObjectWithId creates a new object in the bucket, using the specified id // CreateObjectWithId creates a new object in the bucket, using the specified id
func (connection *DbConnection) CreateObjectWithId(bucketName string, id int, obj interface{}) error { func (connection *DbConnection) CreateObjectWithId(bucketName string, id int, obj interface{}) error {
return connection.Batch(func(tx *bolt.Tx) error { return connection.UpdateTx(func(tx portainer.Transaction) error {
bucket := tx.Bucket([]byte(bucketName)) return tx.CreateObjectWithId(bucketName, id, obj)
data, err := connection.MarshalObject(obj)
if err != nil {
return err
}
return bucket.Put(connection.ConvertToKey(id), data)
}) })
} }
// CreateObjectWithStringId creates a new object in the bucket, using the specified id // CreateObjectWithStringId creates a new object in the bucket, using the specified id
func (connection *DbConnection) CreateObjectWithStringId(bucketName string, id []byte, obj interface{}) error { func (connection *DbConnection) CreateObjectWithStringId(bucketName string, id []byte, obj interface{}) error {
return connection.Batch(func(tx *bolt.Tx) error { return connection.UpdateTx(func(tx portainer.Transaction) error {
bucket := tx.Bucket([]byte(bucketName)) return tx.CreateObjectWithStringId(bucketName, id, obj)
data, err := connection.MarshalObject(obj)
if err != nil {
return err
}
return bucket.Put(id, data)
}) })
} }
func (connection *DbConnection) GetAll(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error { func (connection *DbConnection) GetAll(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error {
err := connection.View(func(tx *bolt.Tx) error { return connection.ViewTx(func(tx portainer.Transaction) error {
bucket := tx.Bucket([]byte(bucketName)) return tx.GetAll(bucketName, obj, append)
cursor := bucket.Cursor()
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
err := connection.UnmarshalObject(v, obj)
if err != nil {
return err
}
obj, err = append(obj)
if err != nil {
return err
}
}
return nil
}) })
return err
} }
// TODO: decide which Unmarshal to use, and use one... // TODO: decide which Unmarshal to use, and use one...
func (connection *DbConnection) GetAllWithJsoniter(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error { func (connection *DbConnection) GetAllWithJsoniter(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error {
err := connection.View(func(tx *bolt.Tx) error { return connection.ViewTx(func(tx portainer.Transaction) error {
bucket := tx.Bucket([]byte(bucketName)) return tx.GetAllWithJsoniter(bucketName, obj, append)
cursor := bucket.Cursor()
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
err := connection.UnmarshalObjectWithJsoniter(v, obj)
if err != nil {
return err
}
obj, err = append(obj)
if err != nil {
return err
}
}
return nil
}) })
return err
} }
func (connection *DbConnection) GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj interface{}, append func(o interface{}) (interface{}, error)) error { func (connection *DbConnection) GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj interface{}, append func(o interface{}) (interface{}, error)) error {
return connection.View(func(tx *bolt.Tx) error { return connection.ViewTx(func(tx portainer.Transaction) error {
cursor := tx.Bucket([]byte(bucketName)).Cursor() return tx.GetAllWithKeyPrefix(bucketName, keyPrefix, obj, append)
for k, v := cursor.Seek(keyPrefix); k != nil && bytes.HasPrefix(k, keyPrefix); k, v = cursor.Next() {
err := connection.UnmarshalObjectWithJsoniter(v, obj)
if err != nil {
return err
}
obj, err = append(obj)
if err != nil {
return err
}
}
return nil
}) })
} }

172
api/database/boltdb/tx.go Normal file
View File

@ -0,0 +1,172 @@
package boltdb
import (
"bytes"
dserrors "github.com/portainer/portainer/api/dataservices/errors"
"github.com/rs/zerolog/log"
bolt "go.etcd.io/bbolt"
)
type DbTransaction struct {
conn *DbConnection
tx *bolt.Tx
}
func (tx *DbTransaction) SetServiceName(bucketName string) error {
_, err := tx.tx.CreateBucketIfNotExists([]byte(bucketName))
return err
}
func (tx *DbTransaction) GetObject(bucketName string, key []byte, object interface{}) error {
bucket := tx.tx.Bucket([]byte(bucketName))
value := bucket.Get(key)
if value == nil {
return dserrors.ErrObjectNotFound
}
data := make([]byte, len(value))
copy(data, value)
return tx.conn.UnmarshalObjectWithJsoniter(data, object)
}
func (tx *DbTransaction) UpdateObject(bucketName string, key []byte, object interface{}) error {
data, err := tx.conn.MarshalObject(object)
if err != nil {
return err
}
bucket := tx.tx.Bucket([]byte(bucketName))
return bucket.Put(key, data)
}
func (tx *DbTransaction) DeleteObject(bucketName string, key []byte) error {
bucket := tx.tx.Bucket([]byte(bucketName))
return bucket.Delete(key)
}
func (tx *DbTransaction) DeleteAllObjects(bucketName string, obj interface{}, matching func(o interface{}) (id int, ok bool)) error {
bucket := tx.tx.Bucket([]byte(bucketName))
cursor := bucket.Cursor()
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
err := tx.conn.UnmarshalObject(v, &obj)
if err != nil {
return err
}
if id, ok := matching(obj); ok {
err := bucket.Delete(tx.conn.ConvertToKey(id))
if err != nil {
return err
}
}
}
return nil
}
func (tx *DbTransaction) GetNextIdentifier(bucketName string) int {
bucket := tx.tx.Bucket([]byte(bucketName))
id, err := bucket.NextSequence()
if err != nil {
log.Error().Err(err).Str("bucket", bucketName).Msg("failed to get the next identifer")
return 0
}
return int(id)
}
func (tx *DbTransaction) CreateObject(bucketName string, fn func(uint64) (int, interface{})) error {
bucket := tx.tx.Bucket([]byte(bucketName))
seqId, _ := bucket.NextSequence()
id, obj := fn(seqId)
data, err := tx.conn.MarshalObject(obj)
if err != nil {
return err
}
return bucket.Put(tx.conn.ConvertToKey(int(id)), data)
}
func (tx *DbTransaction) CreateObjectWithId(bucketName string, id int, obj interface{}) error {
bucket := tx.tx.Bucket([]byte(bucketName))
data, err := tx.conn.MarshalObject(obj)
if err != nil {
return err
}
return bucket.Put(tx.conn.ConvertToKey(id), data)
}
func (tx *DbTransaction) CreateObjectWithStringId(bucketName string, id []byte, obj interface{}) error {
bucket := tx.tx.Bucket([]byte(bucketName))
data, err := tx.conn.MarshalObject(obj)
if err != nil {
return err
}
return bucket.Put(id, data)
}
func (tx *DbTransaction) GetAll(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error {
bucket := tx.tx.Bucket([]byte(bucketName))
cursor := bucket.Cursor()
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
err := tx.conn.UnmarshalObject(v, obj)
if err != nil {
return err
}
obj, err = append(obj)
if err != nil {
return err
}
}
return nil
}
func (tx *DbTransaction) GetAllWithJsoniter(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error {
bucket := tx.tx.Bucket([]byte(bucketName))
cursor := bucket.Cursor()
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
err := tx.conn.UnmarshalObjectWithJsoniter(v, obj)
if err != nil {
return err
}
obj, err = append(obj)
if err != nil {
return err
}
}
return nil
}
func (tx *DbTransaction) GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj interface{}, append func(o interface{}) (interface{}, error)) error {
cursor := tx.tx.Bucket([]byte(bucketName)).Cursor()
for k, v := cursor.Seek(keyPrefix); k != nil && bytes.HasPrefix(k, keyPrefix); k, v = cursor.Next() {
err := tx.conn.UnmarshalObjectWithJsoniter(v, obj)
if err != nil {
return err
}
obj, err = append(obj)
if err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,126 @@
package boltdb
import (
"errors"
"testing"
portainer "github.com/portainer/portainer/api"
dserrors "github.com/portainer/portainer/api/dataservices/errors"
)
const testBucketName = "test-bucket"
const testId = 1234
type testStruct struct {
Key string
Value string
}
func TestTxs(t *testing.T) {
conn := DbConnection{
Path: t.TempDir(),
}
err := conn.Open()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
// Error propagation
err = conn.UpdateTx(func(tx portainer.Transaction) error {
return errors.New("this is an error")
})
if err == nil {
t.Fatal("an error was expected, got nil instead")
}
// Create an object
newObj := testStruct{
Key: "key",
Value: "value",
}
err = conn.UpdateTx(func(tx portainer.Transaction) error {
err = tx.SetServiceName(testBucketName)
if err != nil {
return err
}
return tx.CreateObjectWithId(testBucketName, testId, newObj)
})
if err != nil {
t.Fatal(err)
}
obj := testStruct{}
err = conn.ViewTx(func(tx portainer.Transaction) error {
return tx.GetObject(testBucketName, conn.ConvertToKey(testId), &obj)
})
if err != nil {
t.Fatal(err)
}
if obj.Key != newObj.Key || obj.Value != newObj.Value {
t.Fatalf("expected %s:%s, got %s:%s instead", newObj.Key, newObj.Value, obj.Key, obj.Value)
}
// Update an object
updatedObj := testStruct{
Key: "updated-key",
Value: "updated-value",
}
err = conn.UpdateTx(func(tx portainer.Transaction) error {
return tx.UpdateObject(testBucketName, conn.ConvertToKey(testId), &updatedObj)
})
err = conn.ViewTx(func(tx portainer.Transaction) error {
return tx.GetObject(testBucketName, conn.ConvertToKey(testId), &obj)
})
if err != nil {
t.Fatal(err)
}
if obj.Key != updatedObj.Key || obj.Value != updatedObj.Value {
t.Fatalf("expected %s:%s, got %s:%s instead", updatedObj.Key, updatedObj.Value, obj.Key, obj.Value)
}
// Delete an object
err = conn.UpdateTx(func(tx portainer.Transaction) error {
return tx.DeleteObject(testBucketName, conn.ConvertToKey(testId))
})
if err != nil {
t.Fatal(err)
}
err = conn.ViewTx(func(tx portainer.Transaction) error {
return tx.GetObject(testBucketName, conn.ConvertToKey(testId), &obj)
})
if err != dserrors.ErrObjectNotFound {
t.Fatal(err)
}
// Get next identifier
err = conn.UpdateTx(func(tx portainer.Transaction) error {
id1 := tx.GetNextIdentifier(testBucketName)
id2 := tx.GetNextIdentifier(testBucketName)
if id1+1 != id2 {
return errors.New("unexpected identifier sequence")
}
return nil
})
if err != nil {
t.Fatal(err)
}
// Try to write in a read transaction
err = conn.ViewTx(func(tx portainer.Transaction) error {
return tx.CreateObjectWithId(testBucketName, testId, newObj)
})
if err == nil {
t.Fatal("an error was expected, got nil instead")
}
}

View File

@ -1,11 +1,7 @@
package edgegroup package edgegroup
import ( import (
"fmt"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/rs/zerolog/log"
) )
// BucketName represents the name of the bucket where this service stores data. // BucketName represents the name of the bucket where this service stores data.
@ -32,47 +28,46 @@ func NewService(connection portainer.Connection) (*Service, error) {
}, nil }, nil
} }
// EdgeGroups return an array containing all the Edge groups. func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
return ServiceTx{
service: service,
tx: tx,
}
}
// EdgeGroups return a slice containing all the Edge groups.
func (service *Service) EdgeGroups() ([]portainer.EdgeGroup, error) { func (service *Service) EdgeGroups() ([]portainer.EdgeGroup, error) {
var groups = make([]portainer.EdgeGroup, 0) var groups []portainer.EdgeGroup
var err error
err := service.connection.GetAllWithJsoniter( err = service.connection.ViewTx(func(tx portainer.Transaction) error {
BucketName, groups, err = service.Tx(tx).EdgeGroups()
&portainer.EdgeGroup{}, return err
func(obj interface{}) (interface{}, error) { })
group, ok := obj.(*portainer.EdgeGroup)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to EdgeGroup object")
return nil, fmt.Errorf("Failed to convert to EdgeGroup object: %s", obj)
}
groups = append(groups, *group)
return &portainer.EdgeGroup{}, nil
})
return groups, err return groups, err
} }
// EdgeGroup returns an Edge group by ID. // EdgeGroup returns an Edge group by ID.
func (service *Service) EdgeGroup(ID portainer.EdgeGroupID) (*portainer.EdgeGroup, error) { func (service *Service) EdgeGroup(ID portainer.EdgeGroupID) (*portainer.EdgeGroup, error) {
var group portainer.EdgeGroup var group *portainer.EdgeGroup
identifier := service.connection.ConvertToKey(int(ID)) var err error
err := service.connection.GetObject(BucketName, identifier, &group) err = service.connection.ViewTx(func(tx portainer.Transaction) error {
if err != nil { group, err = service.Tx(tx).EdgeGroup(ID)
return nil, err return err
} })
return &group, nil return group, err
} }
// Deprecated: Use UpdateEdgeGroupFunc instead. // UpdateEdgeGroup updates an edge group.
func (service *Service) UpdateEdgeGroup(ID portainer.EdgeGroupID, group *portainer.EdgeGroup) error { func (service *Service) UpdateEdgeGroup(ID portainer.EdgeGroupID, group *portainer.EdgeGroup) error {
identifier := service.connection.ConvertToKey(int(ID)) identifier := service.connection.ConvertToKey(int(ID))
return service.connection.UpdateObject(BucketName, identifier, group) return service.connection.UpdateObject(BucketName, identifier, group)
} }
// UpdateEdgeGroupFunc updates an edge group inside a transaction avoiding data races. // Deprecated: UpdateEdgeGroupFunc updates an edge group inside a transaction avoiding data races.
func (service *Service) UpdateEdgeGroupFunc(ID portainer.EdgeGroupID, updateFunc func(edgeGroup *portainer.EdgeGroup)) error { func (service *Service) UpdateEdgeGroupFunc(ID portainer.EdgeGroupID, updateFunc func(edgeGroup *portainer.EdgeGroup)) error {
id := service.connection.ConvertToKey(int(ID)) id := service.connection.ConvertToKey(int(ID))
edgeGroup := &portainer.EdgeGroup{} edgeGroup := &portainer.EdgeGroup{}
@ -84,17 +79,14 @@ func (service *Service) UpdateEdgeGroupFunc(ID portainer.EdgeGroupID, updateFunc
// DeleteEdgeGroup deletes an Edge group. // DeleteEdgeGroup deletes an Edge group.
func (service *Service) DeleteEdgeGroup(ID portainer.EdgeGroupID) error { func (service *Service) DeleteEdgeGroup(ID portainer.EdgeGroupID) error {
identifier := service.connection.ConvertToKey(int(ID)) return service.connection.UpdateTx(func(tx portainer.Transaction) error {
return service.connection.DeleteObject(BucketName, identifier) return service.Tx(tx).DeleteEdgeGroup(ID)
})
} }
// CreateEdgeGroup assign an ID to a new Edge group and saves it. // CreateEdgeGroup assign an ID to a new Edge group and saves it.
func (service *Service) Create(group *portainer.EdgeGroup) error { func (service *Service) Create(group *portainer.EdgeGroup) error {
return service.connection.CreateObject( return service.connection.UpdateTx(func(tx portainer.Transaction) error {
BucketName, return service.Tx(tx).Create(group)
func(id uint64) (int, interface{}) { })
group.ID = portainer.EdgeGroupID(id)
return int(group.ID), group
},
)
} }

View File

@ -0,0 +1,80 @@
package edgegroup
import (
"errors"
"fmt"
portainer "github.com/portainer/portainer/api"
"github.com/rs/zerolog/log"
)
type ServiceTx struct {
service *Service
tx portainer.Transaction
}
func (service ServiceTx) BucketName() string {
return BucketName
}
// EdgeGroups return a slice containing all the Edge groups.
func (service ServiceTx) EdgeGroups() ([]portainer.EdgeGroup, error) {
var groups = make([]portainer.EdgeGroup, 0)
err := service.tx.GetAllWithJsoniter(
BucketName,
&portainer.EdgeGroup{},
func(obj interface{}) (interface{}, error) {
group, ok := obj.(*portainer.EdgeGroup)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to EdgeGroup object")
return nil, fmt.Errorf("Failed to convert to EdgeGroup object: %s", obj)
}
groups = append(groups, *group)
return &portainer.EdgeGroup{}, nil
})
return groups, err
}
// EdgeGroup returns an Edge group by ID.
func (service ServiceTx) EdgeGroup(ID portainer.EdgeGroupID) (*portainer.EdgeGroup, error) {
var group portainer.EdgeGroup
identifier := service.service.connection.ConvertToKey(int(ID))
err := service.tx.GetObject(BucketName, identifier, &group)
if err != nil {
return nil, err
}
return &group, nil
}
// UpdateEdgeGroup updates an edge group.
func (service ServiceTx) UpdateEdgeGroup(ID portainer.EdgeGroupID, group *portainer.EdgeGroup) error {
identifier := service.service.connection.ConvertToKey(int(ID))
return service.tx.UpdateObject(BucketName, identifier, group)
}
// UpdateEdgeGroupFunc is a no-op inside a transaction.
func (service ServiceTx) UpdateEdgeGroupFunc(ID portainer.EdgeGroupID, updateFunc func(edgeGroup *portainer.EdgeGroup)) error {
return errors.New("cannot be called inside a transaction")
}
// DeleteEdgeGroup deletes an Edge group.
func (service ServiceTx) DeleteEdgeGroup(ID portainer.EdgeGroupID) error {
identifier := service.service.connection.ConvertToKey(int(ID))
return service.tx.DeleteObject(BucketName, identifier)
}
func (service ServiceTx) Create(group *portainer.EdgeGroup) error {
return service.tx.CreateObject(
BucketName,
func(id uint64) (int, interface{}) {
group.ID = portainer.EdgeGroupID(id)
return int(group.ID), group
},
)
}

View File

@ -1,20 +1,14 @@
package endpoint package endpoint
import ( import (
"fmt"
"sync" "sync"
"time" "time"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/internal/edge/cache"
"github.com/rs/zerolog/log"
) )
const ( // BucketName represents the name of the bucket where this service stores data.
// BucketName represents the name of the bucket where this service stores data. const BucketName = "endpoints"
BucketName = "endpoints"
)
// Service represents a service for managing environment(endpoint) data. // Service represents a service for managing environment(endpoint) data.
type Service struct { type Service struct {
@ -56,84 +50,54 @@ func NewService(connection portainer.Connection) (*Service, error) {
return s, nil return s, nil
} }
func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
return ServiceTx{
service: service,
tx: tx,
}
}
// Endpoint returns an environment(endpoint) by ID. // Endpoint returns an environment(endpoint) by ID.
func (service *Service) Endpoint(ID portainer.EndpointID) (*portainer.Endpoint, error) { func (service *Service) Endpoint(ID portainer.EndpointID) (*portainer.Endpoint, error) {
var endpoint portainer.Endpoint var endpoint *portainer.Endpoint
identifier := service.connection.ConvertToKey(int(ID)) var err error
err := service.connection.GetObject(BucketName, identifier, &endpoint) err = service.connection.ViewTx(func(tx portainer.Transaction) error {
endpoint, err = service.Tx(tx).Endpoint(ID)
return err
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
endpoint.LastCheckInDate, _ = service.Heartbeat(ID) endpoint.LastCheckInDate, _ = service.Heartbeat(ID)
return &endpoint, nil return endpoint, nil
} }
// UpdateEndpoint updates an environment(endpoint). // UpdateEndpoint updates an environment(endpoint).
func (service *Service) UpdateEndpoint(ID portainer.EndpointID, endpoint *portainer.Endpoint) error { func (service *Service) UpdateEndpoint(ID portainer.EndpointID, endpoint *portainer.Endpoint) error {
identifier := service.connection.ConvertToKey(int(ID)) return service.connection.UpdateTx(func(tx portainer.Transaction) error {
return service.Tx(tx).UpdateEndpoint(ID, endpoint)
err := service.connection.UpdateObject(BucketName, identifier, endpoint) })
if err != nil {
return err
}
service.mu.Lock()
if len(endpoint.EdgeID) > 0 {
service.idxEdgeID[endpoint.EdgeID] = ID
}
service.heartbeats.Store(ID, endpoint.LastCheckInDate)
service.mu.Unlock()
cache.Del(endpoint.ID)
return nil
} }
// DeleteEndpoint deletes an environment(endpoint). // DeleteEndpoint deletes an environment(endpoint).
func (service *Service) DeleteEndpoint(ID portainer.EndpointID) error { func (service *Service) DeleteEndpoint(ID portainer.EndpointID) error {
identifier := service.connection.ConvertToKey(int(ID)) return service.connection.UpdateTx(func(tx portainer.Transaction) error {
return service.Tx(tx).DeleteEndpoint(ID)
err := service.connection.DeleteObject(BucketName, identifier) })
if err != nil {
return err
}
service.mu.Lock()
for edgeID, endpointID := range service.idxEdgeID {
if endpointID == ID {
delete(service.idxEdgeID, edgeID)
break
}
}
service.heartbeats.Delete(ID)
service.mu.Unlock()
cache.Del(ID)
return nil
} }
// Endpoints return an array containing all the environments(endpoints). // Endpoints return an array containing all the environments(endpoints).
func (service *Service) Endpoints() ([]portainer.Endpoint, error) { func (service *Service) Endpoints() ([]portainer.Endpoint, error) {
var endpoints = make([]portainer.Endpoint, 0) var endpoints []portainer.Endpoint
var err error
err := service.connection.GetAllWithJsoniter( err = service.connection.ViewTx(func(tx portainer.Transaction) error {
BucketName, endpoints, err = service.Tx(tx).Endpoints()
&portainer.Endpoint{}, return err
func(obj interface{}) (interface{}, error) { })
endpoint, ok := obj.(*portainer.Endpoint)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to Endpoint object")
return nil, fmt.Errorf("failed to convert to Endpoint object: %s", obj)
}
endpoints = append(endpoints, *endpoint)
return &portainer.Endpoint{}, nil
})
if err != nil { if err != nil {
return endpoints, err return endpoints, err
@ -170,22 +134,20 @@ func (service *Service) UpdateHeartbeat(endpointID portainer.EndpointID) {
// CreateEndpoint assign an ID to a new environment(endpoint) and saves it. // CreateEndpoint assign an ID to a new environment(endpoint) and saves it.
func (service *Service) Create(endpoint *portainer.Endpoint) error { func (service *Service) Create(endpoint *portainer.Endpoint) error {
err := service.connection.CreateObjectWithId(BucketName, int(endpoint.ID), endpoint) return service.connection.UpdateTx(func(tx portainer.Transaction) error {
if err != nil { return service.Tx(tx).Create(endpoint)
return err })
}
service.mu.Lock()
if len(endpoint.EdgeID) > 0 {
service.idxEdgeID[endpoint.EdgeID] = endpoint.ID
}
service.heartbeats.Store(endpoint.ID, endpoint.LastCheckInDate)
service.mu.Unlock()
return nil
} }
// GetNextIdentifier returns the next identifier for an environment(endpoint). // GetNextIdentifier returns the next identifier for an environment(endpoint).
func (service *Service) GetNextIdentifier() int { func (service *Service) GetNextIdentifier() int {
return service.connection.GetNextIdentifier(BucketName) var identifier int
service.connection.UpdateTx(func(tx portainer.Transaction) error {
identifier = service.Tx(tx).GetNextIdentifier()
return nil
})
return identifier
} }

View File

@ -0,0 +1,137 @@
package endpoint
import (
"fmt"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/internal/edge/cache"
"github.com/rs/zerolog/log"
)
type ServiceTx struct {
service *Service
tx portainer.Transaction
}
func (service ServiceTx) BucketName() string {
return BucketName
}
// Endpoint returns an environment(endpoint) by ID.
func (service ServiceTx) Endpoint(ID portainer.EndpointID) (*portainer.Endpoint, error) {
var endpoint portainer.Endpoint
identifier := service.service.connection.ConvertToKey(int(ID))
err := service.tx.GetObject(BucketName, identifier, &endpoint)
if err != nil {
return nil, err
}
return &endpoint, nil
}
// UpdateEndpoint updates an environment(endpoint).
func (service ServiceTx) UpdateEndpoint(ID portainer.EndpointID, endpoint *portainer.Endpoint) error {
identifier := service.service.connection.ConvertToKey(int(ID))
err := service.tx.UpdateObject(BucketName, identifier, endpoint)
if err != nil {
return err
}
service.service.mu.Lock()
if len(endpoint.EdgeID) > 0 {
service.service.idxEdgeID[endpoint.EdgeID] = ID
}
service.service.heartbeats.Store(ID, endpoint.LastCheckInDate)
service.service.mu.Unlock()
cache.Del(endpoint.ID)
return nil
}
// DeleteEndpoint deletes an environment(endpoint).
func (service ServiceTx) DeleteEndpoint(ID portainer.EndpointID) error {
identifier := service.service.connection.ConvertToKey(int(ID))
err := service.tx.DeleteObject(BucketName, identifier)
if err != nil {
return err
}
service.service.mu.Lock()
for edgeID, endpointID := range service.service.idxEdgeID {
if endpointID == ID {
delete(service.service.idxEdgeID, edgeID)
break
}
}
service.service.heartbeats.Delete(ID)
service.service.mu.Unlock()
cache.Del(ID)
return nil
}
// Endpoints return an array containing all the environments(endpoints).
func (service ServiceTx) Endpoints() ([]portainer.Endpoint, error) {
var endpoints = make([]portainer.Endpoint, 0)
err := service.tx.GetAllWithJsoniter(
BucketName,
&portainer.Endpoint{},
func(obj interface{}) (interface{}, error) {
endpoint, ok := obj.(*portainer.Endpoint)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to Endpoint object")
return nil, fmt.Errorf("failed to convert to Endpoint object: %s", obj)
}
endpoints = append(endpoints, *endpoint)
return &portainer.Endpoint{}, nil
})
return endpoints, err
}
func (service ServiceTx) EndpointIDByEdgeID(edgeID string) (portainer.EndpointID, bool) {
log.Error().Str("func", "EndpointIDByEdgeID").Msg("cannot be called inside a transaction")
return 0, false
}
func (service ServiceTx) Heartbeat(endpointID portainer.EndpointID) (int64, bool) {
log.Error().Str("func", "Heartbeat").Msg("cannot be called inside a transaction")
return 0, false
}
func (service ServiceTx) UpdateHeartbeat(endpointID portainer.EndpointID) {
log.Error().Str("func", "UpdateHeartbeat").Msg("cannot be called inside a transaction")
}
// CreateEndpoint assign an ID to a new environment(endpoint) and saves it.
func (service ServiceTx) Create(endpoint *portainer.Endpoint) error {
err := service.tx.CreateObjectWithId(BucketName, int(endpoint.ID), endpoint)
if err != nil {
return err
}
service.service.mu.Lock()
if len(endpoint.EdgeID) > 0 {
service.service.idxEdgeID[endpoint.EdgeID] = endpoint.ID
}
service.service.heartbeats.Store(endpoint.ID, endpoint.LastCheckInDate)
service.service.mu.Unlock()
return nil
}
// GetNextIdentifier returns the next identifier for an environment(endpoint).
func (service ServiceTx) GetNextIdentifier() int {
return service.tx.GetNextIdentifier(BucketName)
}

View File

@ -13,16 +13,7 @@ import (
) )
type ( type (
// DataStore defines the interface to manage the data DataStoreTx interface {
DataStore interface {
Open() (newStore bool, err error)
Init() error
Close() error
MigrateData() error
Rollback(force bool) error
CheckCurrentEdition() error
BackupTo(w io.Writer) error
Export(filename string) (err error)
IsErrObjectNotFound(err error) bool IsErrObjectNotFound(err error) bool
CustomTemplate() CustomTemplateService CustomTemplate() CustomTemplateService
EdgeGroup() EdgeGroupService EdgeGroup() EdgeGroupService
@ -50,6 +41,22 @@ type (
Webhook() WebhookService Webhook() WebhookService
} }
// DataStore defines the interface to manage the data
DataStore interface {
Open() (newStore bool, err error)
Init() error
Close() error
UpdateTx(func(DataStoreTx) error) error
ViewTx(func(DataStoreTx) error) error
MigrateData() error
Rollback(force bool) error
CheckCurrentEdition() error
BackupTo(w io.Writer) error
Export(filename string) (err error)
DataStoreTx
}
// CustomTemplateService represents a service to manage custom templates // CustomTemplateService represents a service to manage custom templates
CustomTemplateService interface { CustomTemplateService interface {
GetNextIdentifier() int GetNextIdentifier() int

View File

@ -8,6 +8,7 @@ import (
"time" "time"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
portainerErrors "github.com/portainer/portainer/api/dataservices/errors" portainerErrors "github.com/portainer/portainer/api/dataservices/errors"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -61,6 +62,24 @@ func (store *Store) Close() error {
return store.connection.Close() return store.connection.Close()
} }
func (store *Store) UpdateTx(fn func(dataservices.DataStoreTx) error) error {
return store.connection.UpdateTx(func(tx portainer.Transaction) error {
return fn(&StoreTx{
store: store,
tx: tx,
})
})
}
func (store *Store) ViewTx(fn func(dataservices.DataStoreTx) error) error {
return store.connection.ViewTx(func(tx portainer.Transaction) error {
return fn(&StoreTx{
store: store,
tx: tx,
})
})
}
// BackupTo backs up db to a provided writer. // BackupTo backs up db to a provided writer.
// It does hot backup and doesn't block other database reads and writes // It does hot backup and doesn't block other database reads and writes
func (store *Store) BackupTo(w io.Writer) error { func (store *Store) BackupTo(w io.Writer) error {

View File

@ -0,0 +1,48 @@
package datastore
import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
)
type StoreTx struct {
store *Store
tx portainer.Transaction
}
func (tx *StoreTx) IsErrObjectNotFound(err error) bool {
return tx.store.IsErrObjectNotFound(err)
}
func (tx *StoreTx) CustomTemplate() dataservices.CustomTemplateService { return nil }
func (tx *StoreTx) EdgeGroup() dataservices.EdgeGroupService {
return tx.store.EdgeGroupService.Tx(tx.tx)
}
func (tx *StoreTx) EdgeJob() dataservices.EdgeJobService { return nil }
func (tx *StoreTx) EdgeStack() dataservices.EdgeStackService { return nil }
func (tx *StoreTx) Endpoint() dataservices.EndpointService {
return tx.store.EndpointService.Tx(tx.tx)
}
func (tx *StoreTx) EndpointGroup() dataservices.EndpointGroupService { return nil }
func (tx *StoreTx) EndpointRelation() dataservices.EndpointRelationService { return nil }
func (tx *StoreTx) FDOProfile() dataservices.FDOProfileService { return nil }
func (tx *StoreTx) HelmUserRepository() dataservices.HelmUserRepositoryService { return nil }
func (tx *StoreTx) Registry() dataservices.RegistryService { return nil }
func (tx *StoreTx) ResourceControl() dataservices.ResourceControlService { return nil }
func (tx *StoreTx) Role() dataservices.RoleService { return nil }
func (tx *StoreTx) APIKeyRepository() dataservices.APIKeyRepository { return nil }
func (tx *StoreTx) Settings() dataservices.SettingsService { return nil }
func (tx *StoreTx) Snapshot() dataservices.SnapshotService { return nil }
func (tx *StoreTx) SSLSettings() dataservices.SSLSettingsService { return nil }
func (tx *StoreTx) Stack() dataservices.StackService { return nil }
func (tx *StoreTx) Tag() dataservices.TagService { return nil }
func (tx *StoreTx) TeamMembership() dataservices.TeamMembershipService { return nil }
func (tx *StoreTx) Team() dataservices.TeamService { return nil }
func (tx *StoreTx) TunnelServer() dataservices.TunnelServerService { return nil }
func (tx *StoreTx) User() dataservices.UserService { return nil }
func (tx *StoreTx) Version() dataservices.VersionService { return nil }
func (tx *StoreTx) Webhook() dataservices.WebhookService { return nil }

View File

@ -38,7 +38,7 @@ require (
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/portainer/docker-compose-wrapper v0.0.0-20221215210951-2c30d1b17a27 github.com/portainer/docker-compose-wrapper v0.0.0-20221215210951-2c30d1b17a27
github.com/portainer/libcrypto v0.0.0-20220506221303-1f4fb3b30f9a github.com/portainer/libcrypto v0.0.0-20220506221303-1f4fb3b30f9a
github.com/portainer/libhttp v0.0.0-20221121135534-76f46e09c9a9 github.com/portainer/libhttp v0.0.0-20230206214615-dabd58de9f44
github.com/portainer/portainer/pkg/featureflags v0.0.0-20230209201943-d73622ed9cd4 github.com/portainer/portainer/pkg/featureflags v0.0.0-20230209201943-d73622ed9cd4
github.com/portainer/portainer/pkg/libhelm v0.0.0-20221201012749-4fee35924724 github.com/portainer/portainer/pkg/libhelm v0.0.0-20221201012749-4fee35924724
github.com/portainer/portainer/third_party/digest v0.0.0-20221201002639-8fd0efa34f73 github.com/portainer/portainer/third_party/digest v0.0.0-20221201002639-8fd0efa34f73

View File

@ -348,8 +348,8 @@ github.com/portainer/docker-compose-wrapper v0.0.0-20221215210951-2c30d1b17a27 h
github.com/portainer/docker-compose-wrapper v0.0.0-20221215210951-2c30d1b17a27/go.mod h1:03UmPLyjiPUexGJuW20mQXvmsoSpeErvMlItJGtq/Ww= github.com/portainer/docker-compose-wrapper v0.0.0-20221215210951-2c30d1b17a27/go.mod h1:03UmPLyjiPUexGJuW20mQXvmsoSpeErvMlItJGtq/Ww=
github.com/portainer/libcrypto v0.0.0-20220506221303-1f4fb3b30f9a h1:B0z3skIMT+OwVNJPQhKp52X+9OWW6A9n5UWig3lHBJk= github.com/portainer/libcrypto v0.0.0-20220506221303-1f4fb3b30f9a h1:B0z3skIMT+OwVNJPQhKp52X+9OWW6A9n5UWig3lHBJk=
github.com/portainer/libcrypto v0.0.0-20220506221303-1f4fb3b30f9a/go.mod h1:n54EEIq+MM0NNtqLeCby8ljL+l275VpolXO0ibHegLE= github.com/portainer/libcrypto v0.0.0-20220506221303-1f4fb3b30f9a/go.mod h1:n54EEIq+MM0NNtqLeCby8ljL+l275VpolXO0ibHegLE=
github.com/portainer/libhttp v0.0.0-20221121135534-76f46e09c9a9 h1:L7o0L+1qq+LzKjzgRB6bDIh5ZrZ5A1oSS+WgWzDgJIo= github.com/portainer/libhttp v0.0.0-20230206214615-dabd58de9f44 h1:4LYprPd3TsYjHk7CaTmCov1ceG6VKJsL40fJIWiRxpw=
github.com/portainer/libhttp v0.0.0-20221121135534-76f46e09c9a9/go.mod h1:H49JLiywwLt2rrJVroafEWy8fIs0i7mThAThK40sbb8= github.com/portainer/libhttp v0.0.0-20230206214615-dabd58de9f44/go.mod h1:H49JLiywwLt2rrJVroafEWy8fIs0i7mThAThK40sbb8=
github.com/portainer/portainer/pkg/featureflags v0.0.0-20230209201943-d73622ed9cd4 h1:gnXwaF0GnFUIlynRq994WFOtqOULTKZks4aSWuonlhA= github.com/portainer/portainer/pkg/featureflags v0.0.0-20230209201943-d73622ed9cd4 h1:gnXwaF0GnFUIlynRq994WFOtqOULTKZks4aSWuonlhA=
github.com/portainer/portainer/pkg/featureflags v0.0.0-20230209201943-d73622ed9cd4/go.mod h1:T37rFZMg+PhRhT9n/z9cLSj9khJSdwHj3/Ac5PZQgKI= github.com/portainer/portainer/pkg/featureflags v0.0.0-20230209201943-d73622ed9cd4/go.mod h1:T37rFZMg+PhRhT9n/z9cLSj9khJSdwHj3/Ac5PZQgKI=
github.com/portainer/portainer/pkg/libhelm v0.0.0-20221201012749-4fee35924724 h1:FZrRVMpxXdUV+p5VSCAy9Uz7RzAeEJr2ytlctvMrsHY= github.com/portainer/portainer/pkg/libhelm v0.0.0-20221201012749-4fee35924724 h1:FZrRVMpxXdUV+p5VSCAy9Uz7RzAeEJr2ytlctvMrsHY=

View File

@ -8,6 +8,7 @@ import (
"github.com/portainer/libhttp/request" "github.com/portainer/libhttp/request"
"github.com/portainer/libhttp/response" "github.com/portainer/libhttp/response"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/endpointutils" "github.com/portainer/portainer/api/internal/endpointutils"
"github.com/asaskevich/govalidator" "github.com/asaskevich/govalidator"
@ -54,45 +55,58 @@ func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request)
return httperror.BadRequest("Invalid request payload", err) return httperror.BadRequest("Invalid request payload", err)
} }
edgeGroups, err := handler.DataStore.EdgeGroup().EdgeGroups() var edgeGroup *portainer.EdgeGroup
if err != nil {
return httperror.InternalServerError("Unable to retrieve Edge groups from the database", err)
}
for _, edgeGroup := range edgeGroups { err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
if edgeGroup.Name == payload.Name { edgeGroups, err := tx.EdgeGroup().EdgeGroups()
return httperror.BadRequest("Edge group name must be unique", errors.New("edge group name must be unique")) if err != nil {
return httperror.InternalServerError("Unable to retrieve Edge groups from the database", err)
} }
}
edgeGroup := &portainer.EdgeGroup{ for _, edgeGroup := range edgeGroups {
Name: payload.Name, if edgeGroup.Name == payload.Name {
Dynamic: payload.Dynamic, return httperror.BadRequest("Edge group name must be unique", errors.New("edge group name must be unique"))
TagIDs: []portainer.TagID{},
Endpoints: []portainer.EndpointID{},
PartialMatch: payload.PartialMatch,
}
if edgeGroup.Dynamic {
edgeGroup.TagIDs = payload.TagIDs
} else {
endpointIDs := []portainer.EndpointID{}
for _, endpointID := range payload.Endpoints {
endpoint, err := handler.DataStore.Endpoint().Endpoint(endpointID)
if err != nil {
return httperror.InternalServerError("Unable to retrieve environment from the database", err)
}
if endpointutils.IsEdgeEndpoint(endpoint) {
endpointIDs = append(endpointIDs, endpoint.ID)
} }
} }
edgeGroup.Endpoints = endpointIDs
}
err = handler.DataStore.EdgeGroup().Create(edgeGroup) edgeGroup = &portainer.EdgeGroup{
Name: payload.Name,
Dynamic: payload.Dynamic,
TagIDs: []portainer.TagID{},
Endpoints: []portainer.EndpointID{},
PartialMatch: payload.PartialMatch,
}
if edgeGroup.Dynamic {
edgeGroup.TagIDs = payload.TagIDs
} else {
endpointIDs := []portainer.EndpointID{}
for _, endpointID := range payload.Endpoints {
endpoint, err := tx.Endpoint().Endpoint(endpointID)
if err != nil {
return httperror.InternalServerError("Unable to retrieve environment from the database", err)
}
if endpointutils.IsEdgeEndpoint(endpoint) {
endpointIDs = append(endpointIDs, endpoint.ID)
}
}
edgeGroup.Endpoints = endpointIDs
}
err = tx.EdgeGroup().Create(edgeGroup)
if err != nil {
return httperror.InternalServerError("Unable to persist the Edge group inside the database", err)
}
return nil
})
if err != nil { if err != nil {
return httperror.InternalServerError("Unable to persist the Edge group inside the database", err) if httpErr, ok := err.(*httperror.HandlerError); ok {
return httpErr
}
return httperror.InternalServerError("Unexpected error", err)
} }
return response.JSON(w, edgeGroup) return response.JSON(w, edgeGroup)

View File

@ -36,10 +36,13 @@ type testDatastore struct {
webhook dataservices.WebhookService webhook dataservices.WebhookService
} }
func (d *testDatastore) BackupTo(io.Writer) error { return nil } func (d *testDatastore) BackupTo(io.Writer) error { return nil }
func (d *testDatastore) Open() (bool, error) { return false, nil } func (d *testDatastore) Open() (bool, error) { return false, nil }
func (d *testDatastore) Init() error { return nil } func (d *testDatastore) Init() error { return nil }
func (d *testDatastore) Close() error { return nil } func (d *testDatastore) Close() error { return nil }
func (d *testDatastore) UpdateTx(func(dataservices.DataStoreTx) error) error { return nil }
func (d *testDatastore) ViewTx(func(dataservices.DataStoreTx) error) error { return nil }
func (d *testDatastore) CheckCurrentEdition() error { return nil } func (d *testDatastore) CheckCurrentEdition() error { return nil }
func (d *testDatastore) MigrateData() error { return nil } func (d *testDatastore) MigrateData() error { return nil }
func (d *testDatastore) Rollback(force bool) error { return nil } func (d *testDatastore) Rollback(force bool) error { return nil }