portainer/api/database/boltdb/db.go

384 lines
12 KiB
Go

package boltdb
import (
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"os"
"path"
"time"
portainer "github.com/portainer/portainer/api"
dserrors "github.com/portainer/portainer/api/dataservices/errors"
"github.com/rs/zerolog/log"
bolt "go.etcd.io/bbolt"
)
const (
DatabaseFileName = "portainer.db"
EncryptedDatabaseFileName = "portainer.edb"
)
var (
ErrHaveEncryptedAndUnencrypted = errors.New("Portainer has detected both an encrypted and un-encrypted database and cannot start. Only one database should exist")
ErrHaveEncryptedWithNoKey = errors.New("The portainer database is encrypted, but no secret was loaded")
)
type DbConnection struct {
Path string
MaxBatchSize int
MaxBatchDelay time.Duration
InitialMmapSize int
EncryptionKey []byte
isEncrypted bool
*bolt.DB
}
// GetDatabaseFileName get the database filename
func (connection *DbConnection) GetDatabaseFileName() string {
if connection.IsEncryptedStore() {
return EncryptedDatabaseFileName
}
return DatabaseFileName
}
// GetDataseFilePath get the path + filename for the database file
func (connection *DbConnection) GetDatabaseFilePath() string {
if connection.IsEncryptedStore() {
return path.Join(connection.Path, EncryptedDatabaseFileName)
}
return path.Join(connection.Path, DatabaseFileName)
}
// GetStorePath get the filename and path for the database file
func (connection *DbConnection) GetStorePath() string {
return connection.Path
}
func (connection *DbConnection) SetEncrypted(flag bool) {
connection.isEncrypted = flag
}
// Return true if the database is encrypted
func (connection *DbConnection) IsEncryptedStore() bool {
return connection.getEncryptionKey() != nil
}
// NeedsEncryptionMigration returns true if database encryption is enabled and
// we have an un-encrypted DB that requires migration to an encrypted DB
func (connection *DbConnection) NeedsEncryptionMigration() (bool, error) {
// Cases: Note, we need to check both portainer.db and portainer.edb
// to determine if it's a new store. We only need to differentiate between cases 2,3 and 5
// 1) portainer.edb + key => False
// 2) portainer.edb + no key => ERROR Fatal!
// 3) portainer.db + key => True (needs migration)
// 4) portainer.db + no key => False
// 5) NoDB (new) + key => False
// 6) NoDB (new) + no key => False
// 7) portainer.db & portainer.edb => ERROR Fatal!
// If we have a loaded encryption key, always set encrypted
if connection.EncryptionKey != nil {
connection.SetEncrypted(true)
}
// Check for portainer.db
dbFile := path.Join(connection.Path, DatabaseFileName)
_, err := os.Stat(dbFile)
haveDbFile := err == nil
// Check for portainer.edb
edbFile := path.Join(connection.Path, EncryptedDatabaseFileName)
_, err = os.Stat(edbFile)
haveEdbFile := err == nil
if haveDbFile && haveEdbFile {
// 7 - encrypted and unencrypted db?
return false, ErrHaveEncryptedAndUnencrypted
}
if haveDbFile && connection.EncryptionKey != nil {
// 3 - needs migration
return true, nil
}
if haveEdbFile && connection.EncryptionKey == nil {
// 2 - encrypted db, but no key?
return false, ErrHaveEncryptedWithNoKey
}
// 1, 4, 5, 6
return false, nil
}
// Open opens and initializes the BoltDB database.
func (connection *DbConnection) Open() error {
log.Info().Str("filename", connection.GetDatabaseFileName()).Msg("loading PortainerDB")
// Now we open the db
databasePath := connection.GetDatabaseFilePath()
db, err := bolt.Open(databasePath, 0600, &bolt.Options{
Timeout: 1 * time.Second,
InitialMmapSize: connection.InitialMmapSize,
})
if err != nil {
return err
}
db.MaxBatchSize = connection.MaxBatchSize
db.MaxBatchDelay = connection.MaxBatchDelay
connection.DB = db
return nil
}
// Close closes the BoltDB database.
// Safe to being called multiple times.
func (connection *DbConnection) Close() error {
log.Info().Msg("closing PortainerDB")
if connection.DB != nil {
return connection.DB.Close()
}
return nil
}
func (connection *DbConnection) txFn(fn func(portainer.Transaction) error) func(*bolt.Tx) error {
return func(tx *bolt.Tx) error {
return fn(&DbTransaction{conn: connection, tx: tx})
}
}
// UpdateTx executes the given function inside a read-write transaction
func (connection *DbConnection) UpdateTx(fn func(portainer.Transaction) error) error {
if connection.MaxBatchDelay > 0 && connection.MaxBatchSize > 1 {
return connection.Batch(connection.txFn(fn))
}
return connection.Update(connection.txFn(fn))
}
// ViewTx executes the given function inside a read-only transaction
func (connection *DbConnection) ViewTx(fn func(portainer.Transaction) error) error {
return connection.View(connection.txFn(fn))
}
// BackupTo backs up db to a provided writer.
// It does hot backup and doesn't block other database reads and writes
func (connection *DbConnection) BackupTo(w io.Writer) error {
return connection.View(func(tx *bolt.Tx) error {
_, err := tx.WriteTo(w)
return err
})
}
func (connection *DbConnection) ExportRaw(filename string) error {
databasePath := connection.GetDatabaseFilePath()
if _, err := os.Stat(databasePath); err != nil {
return fmt.Errorf("stat on %s failed, error: %w", databasePath, err)
}
b, err := connection.ExportJSON(databasePath, true)
if err != nil {
return err
}
return os.WriteFile(filename, b, 0600)
}
// ConvertToKey returns an 8-byte big endian representation of v.
// This function is typically used for encoding integer IDs to byte slices
// so that they can be used as BoltDB keys.
func (connection *DbConnection) ConvertToKey(v int) []byte {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, uint64(v))
return b
}
// keyToString Converts a key to a string value suitable for logging
func keyToString(b []byte) string {
if len(b) != 8 {
return string(b)
}
v := binary.BigEndian.Uint64(b)
if v <= math.MaxInt32 {
return fmt.Sprintf("%d", v)
}
return string(b)
}
// CreateBucket is a generic function used to create a bucket inside a database.
func (connection *DbConnection) SetServiceName(bucketName string) error {
return connection.UpdateTx(func(tx portainer.Transaction) error {
return tx.SetServiceName(bucketName)
})
}
// GetObject is a generic function used to retrieve an unmarshalled object from a database.
func (connection *DbConnection) GetObject(bucketName string, key []byte, object interface{}) error {
return connection.ViewTx(func(tx portainer.Transaction) error {
return tx.GetObject(bucketName, key, object)
})
}
func (connection *DbConnection) getEncryptionKey() []byte {
if !connection.isEncrypted {
return nil
}
return connection.EncryptionKey
}
// UpdateObject is a generic function used to update an object inside a database.
func (connection *DbConnection) UpdateObject(bucketName string, key []byte, object interface{}) error {
return connection.UpdateTx(func(tx portainer.Transaction) error {
return tx.UpdateObject(bucketName, key, object)
})
}
// UpdateObjectFunc is a generic function used to update an object safely without race conditions.
func (connection *DbConnection) UpdateObjectFunc(bucketName string, key []byte, object any, updateFn func()) error {
return connection.Batch(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(bucketName))
data := bucket.Get(key)
if data == nil {
return fmt.Errorf("%w (bucket=%s, key=%s)", dserrors.ErrObjectNotFound, bucketName, keyToString(key))
}
err := connection.UnmarshalObject(data, object)
if err != nil {
return err
}
updateFn()
data, err = connection.MarshalObject(object)
if err != nil {
return err
}
return bucket.Put(key, data)
})
}
// DeleteObject is a generic function used to delete an object inside a database.
func (connection *DbConnection) DeleteObject(bucketName string, key []byte) error {
return connection.UpdateTx(func(tx portainer.Transaction) error {
return tx.DeleteObject(bucketName, key)
})
}
// DeleteAllObjects delete all objects where matching() returns (id, ok).
// TODO: think about how to return the error inside (maybe change ok to type err, and use "notfound"?
func (connection *DbConnection) DeleteAllObjects(bucketName string, obj interface{}, matching func(o interface{}) (id int, ok bool)) error {
return connection.UpdateTx(func(tx portainer.Transaction) error {
return tx.DeleteAllObjects(bucketName, obj, matching)
})
}
// GetNextIdentifier is a generic function that returns the specified bucket identifier incremented by 1.
func (connection *DbConnection) GetNextIdentifier(bucketName string) int {
var identifier int
_ = connection.UpdateTx(func(tx portainer.Transaction) error {
identifier = tx.GetNextIdentifier(bucketName)
return nil
})
return identifier
}
// CreateObject creates a new object in the bucket, using the next bucket sequence id
func (connection *DbConnection) CreateObject(bucketName string, fn func(uint64) (int, interface{})) error {
return connection.UpdateTx(func(tx portainer.Transaction) error {
return tx.CreateObject(bucketName, fn)
})
}
// CreateObjectWithId creates a new object in the bucket, using the specified id
func (connection *DbConnection) CreateObjectWithId(bucketName string, id int, obj interface{}) error {
return connection.UpdateTx(func(tx portainer.Transaction) error {
return tx.CreateObjectWithId(bucketName, id, obj)
})
}
// CreateObjectWithStringId creates a new object in the bucket, using the specified id
func (connection *DbConnection) CreateObjectWithStringId(bucketName string, id []byte, obj interface{}) error {
return connection.UpdateTx(func(tx portainer.Transaction) error {
return tx.CreateObjectWithStringId(bucketName, id, obj)
})
}
func (connection *DbConnection) GetAll(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error {
return connection.ViewTx(func(tx portainer.Transaction) error {
return tx.GetAll(bucketName, obj, append)
})
}
// TODO: decide which Unmarshal to use, and use one...
func (connection *DbConnection) GetAllWithJsoniter(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error {
return connection.ViewTx(func(tx portainer.Transaction) error {
return tx.GetAllWithJsoniter(bucketName, obj, append)
})
}
func (connection *DbConnection) GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj interface{}, append func(o interface{}) (interface{}, error)) error {
return connection.ViewTx(func(tx portainer.Transaction) error {
return tx.GetAllWithKeyPrefix(bucketName, keyPrefix, obj, append)
})
}
// BackupMetadata will return a copy of the boltdb sequence numbers for all buckets.
func (connection *DbConnection) BackupMetadata() (map[string]interface{}, error) {
buckets := map[string]interface{}{}
err := connection.View(func(tx *bolt.Tx) error {
err := tx.ForEach(func(name []byte, bucket *bolt.Bucket) error {
bucketName := string(name)
seqId := bucket.Sequence()
buckets[bucketName] = int(seqId)
return nil
})
return err
})
return buckets, err
}
// RestoreMetadata will restore the boltdb sequence numbers for all buckets.
func (connection *DbConnection) RestoreMetadata(s map[string]interface{}) error {
var err error
for bucketName, v := range s {
id, ok := v.(float64) // JSON ints are unmarshalled to interface as float64. See: https://pkg.go.dev/encoding/json#Decoder.Decode
if !ok {
log.Error().Str("bucket", bucketName).Msg("failed to restore metadata to bucket, skipped")
continue
}
err = connection.Batch(func(tx *bolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists([]byte(bucketName))
if err != nil {
return err
}
return bucket.SetSequence(uint64(id))
})
}
return err
}