package boltdb import ( "encoding/binary" "errors" "fmt" "io" "math" "os" "path" "time" portainer "github.com/portainer/portainer/api" dserrors "github.com/portainer/portainer/api/dataservices/errors" "github.com/rs/zerolog/log" bolt "go.etcd.io/bbolt" ) const ( DatabaseFileName = "portainer.db" EncryptedDatabaseFileName = "portainer.edb" ) var ( ErrHaveEncryptedAndUnencrypted = errors.New("Portainer has detected both an encrypted and un-encrypted database and cannot start. Only one database should exist") ErrHaveEncryptedWithNoKey = errors.New("The portainer database is encrypted, but no secret was loaded") ) type DbConnection struct { Path string MaxBatchSize int MaxBatchDelay time.Duration InitialMmapSize int EncryptionKey []byte isEncrypted bool *bolt.DB } // GetDatabaseFileName get the database filename func (connection *DbConnection) GetDatabaseFileName() string { if connection.IsEncryptedStore() { return EncryptedDatabaseFileName } return DatabaseFileName } // GetDataseFilePath get the path + filename for the database file func (connection *DbConnection) GetDatabaseFilePath() string { if connection.IsEncryptedStore() { return path.Join(connection.Path, EncryptedDatabaseFileName) } return path.Join(connection.Path, DatabaseFileName) } // GetStorePath get the filename and path for the database file func (connection *DbConnection) GetStorePath() string { return connection.Path } func (connection *DbConnection) SetEncrypted(flag bool) { connection.isEncrypted = flag } // Return true if the database is encrypted func (connection *DbConnection) IsEncryptedStore() bool { return connection.getEncryptionKey() != nil } // NeedsEncryptionMigration returns true if database encryption is enabled and // we have an un-encrypted DB that requires migration to an encrypted DB func (connection *DbConnection) NeedsEncryptionMigration() (bool, error) { // Cases: Note, we need to check both portainer.db and portainer.edb // to determine if it's a new store. We only need to differentiate between cases 2,3 and 5 // 1) portainer.edb + key => False // 2) portainer.edb + no key => ERROR Fatal! // 3) portainer.db + key => True (needs migration) // 4) portainer.db + no key => False // 5) NoDB (new) + key => False // 6) NoDB (new) + no key => False // 7) portainer.db & portainer.edb => ERROR Fatal! // If we have a loaded encryption key, always set encrypted if connection.EncryptionKey != nil { connection.SetEncrypted(true) } // Check for portainer.db dbFile := path.Join(connection.Path, DatabaseFileName) _, err := os.Stat(dbFile) haveDbFile := err == nil // Check for portainer.edb edbFile := path.Join(connection.Path, EncryptedDatabaseFileName) _, err = os.Stat(edbFile) haveEdbFile := err == nil if haveDbFile && haveEdbFile { // 7 - encrypted and unencrypted db? return false, ErrHaveEncryptedAndUnencrypted } if haveDbFile && connection.EncryptionKey != nil { // 3 - needs migration return true, nil } if haveEdbFile && connection.EncryptionKey == nil { // 2 - encrypted db, but no key? return false, ErrHaveEncryptedWithNoKey } // 1, 4, 5, 6 return false, nil } // Open opens and initializes the BoltDB database. func (connection *DbConnection) Open() error { log.Info().Str("filename", connection.GetDatabaseFileName()).Msg("loading PortainerDB") // Now we open the db databasePath := connection.GetDatabaseFilePath() db, err := bolt.Open(databasePath, 0600, &bolt.Options{ Timeout: 1 * time.Second, InitialMmapSize: connection.InitialMmapSize, }) if err != nil { return err } db.MaxBatchSize = connection.MaxBatchSize db.MaxBatchDelay = connection.MaxBatchDelay connection.DB = db return nil } // Close closes the BoltDB database. // Safe to being called multiple times. func (connection *DbConnection) Close() error { if connection.DB != nil { return connection.DB.Close() } return nil } func (connection *DbConnection) txFn(fn func(portainer.Transaction) error) func(*bolt.Tx) error { return func(tx *bolt.Tx) error { return fn(&DbTransaction{conn: connection, tx: tx}) } } // UpdateTx executes the given function inside a read-write transaction func (connection *DbConnection) UpdateTx(fn func(portainer.Transaction) error) error { if connection.MaxBatchDelay > 0 && connection.MaxBatchSize > 1 { return connection.Batch(connection.txFn(fn)) } return connection.Update(connection.txFn(fn)) } // ViewTx executes the given function inside a read-only transaction func (connection *DbConnection) ViewTx(fn func(portainer.Transaction) error) error { return connection.View(connection.txFn(fn)) } // BackupTo backs up db to a provided writer. // It does hot backup and doesn't block other database reads and writes func (connection *DbConnection) BackupTo(w io.Writer) error { return connection.View(func(tx *bolt.Tx) error { _, err := tx.WriteTo(w) return err }) } func (connection *DbConnection) ExportRaw(filename string) error { databasePath := connection.GetDatabaseFilePath() if _, err := os.Stat(databasePath); err != nil { return fmt.Errorf("stat on %s failed, error: %w", databasePath, err) } b, err := connection.ExportJSON(databasePath, true) if err != nil { return err } return os.WriteFile(filename, b, 0600) } // ConvertToKey returns an 8-byte big endian representation of v. // This function is typically used for encoding integer IDs to byte slices // so that they can be used as BoltDB keys. func (connection *DbConnection) ConvertToKey(v int) []byte { b := make([]byte, 8) binary.BigEndian.PutUint64(b, uint64(v)) return b } // keyToString Converts a key to a string value suitable for logging func keyToString(b []byte) string { if len(b) != 8 { return string(b) } v := binary.BigEndian.Uint64(b) if v <= math.MaxInt32 { return fmt.Sprintf("%d", v) } return string(b) } // CreateBucket is a generic function used to create a bucket inside a database. func (connection *DbConnection) SetServiceName(bucketName string) error { return connection.UpdateTx(func(tx portainer.Transaction) error { return tx.SetServiceName(bucketName) }) } // GetObject is a generic function used to retrieve an unmarshalled object from a database. func (connection *DbConnection) GetObject(bucketName string, key []byte, object interface{}) error { return connection.ViewTx(func(tx portainer.Transaction) error { return tx.GetObject(bucketName, key, object) }) } func (connection *DbConnection) getEncryptionKey() []byte { if !connection.isEncrypted { return nil } return connection.EncryptionKey } // UpdateObject is a generic function used to update an object inside a database. func (connection *DbConnection) UpdateObject(bucketName string, key []byte, object interface{}) error { return connection.UpdateTx(func(tx portainer.Transaction) error { return tx.UpdateObject(bucketName, key, object) }) } // UpdateObjectFunc is a generic function used to update an object safely without race conditions. func (connection *DbConnection) UpdateObjectFunc(bucketName string, key []byte, object any, updateFn func()) error { return connection.Batch(func(tx *bolt.Tx) error { bucket := tx.Bucket([]byte(bucketName)) data := bucket.Get(key) if data == nil { return fmt.Errorf("%w (bucket=%s, key=%s)", dserrors.ErrObjectNotFound, bucketName, keyToString(key)) } err := connection.UnmarshalObjectWithJsoniter(data, object) if err != nil { return err } updateFn() data, err = connection.MarshalObject(object) if err != nil { return err } return bucket.Put(key, data) }) } // DeleteObject is a generic function used to delete an object inside a database. func (connection *DbConnection) DeleteObject(bucketName string, key []byte) error { return connection.UpdateTx(func(tx portainer.Transaction) error { return tx.DeleteObject(bucketName, key) }) } // DeleteAllObjects delete all objects where matching() returns (id, ok). // TODO: think about how to return the error inside (maybe change ok to type err, and use "notfound"? func (connection *DbConnection) DeleteAllObjects(bucketName string, obj interface{}, matching func(o interface{}) (id int, ok bool)) error { return connection.UpdateTx(func(tx portainer.Transaction) error { return tx.DeleteAllObjects(bucketName, obj, matching) }) } // GetNextIdentifier is a generic function that returns the specified bucket identifier incremented by 1. func (connection *DbConnection) GetNextIdentifier(bucketName string) int { var identifier int _ = connection.UpdateTx(func(tx portainer.Transaction) error { identifier = tx.GetNextIdentifier(bucketName) return nil }) return identifier } // CreateObject creates a new object in the bucket, using the next bucket sequence id func (connection *DbConnection) CreateObject(bucketName string, fn func(uint64) (int, interface{})) error { return connection.UpdateTx(func(tx portainer.Transaction) error { return tx.CreateObject(bucketName, fn) }) } // CreateObjectWithId creates a new object in the bucket, using the specified id func (connection *DbConnection) CreateObjectWithId(bucketName string, id int, obj interface{}) error { return connection.UpdateTx(func(tx portainer.Transaction) error { return tx.CreateObjectWithId(bucketName, id, obj) }) } // CreateObjectWithStringId creates a new object in the bucket, using the specified id func (connection *DbConnection) CreateObjectWithStringId(bucketName string, id []byte, obj interface{}) error { return connection.UpdateTx(func(tx portainer.Transaction) error { return tx.CreateObjectWithStringId(bucketName, id, obj) }) } func (connection *DbConnection) GetAll(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error { return connection.ViewTx(func(tx portainer.Transaction) error { return tx.GetAll(bucketName, obj, append) }) } // TODO: decide which Unmarshal to use, and use one... func (connection *DbConnection) GetAllWithJsoniter(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error { return connection.ViewTx(func(tx portainer.Transaction) error { return tx.GetAllWithJsoniter(bucketName, obj, append) }) } func (connection *DbConnection) GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj interface{}, append func(o interface{}) (interface{}, error)) error { return connection.ViewTx(func(tx portainer.Transaction) error { return tx.GetAllWithKeyPrefix(bucketName, keyPrefix, obj, append) }) } // BackupMetadata will return a copy of the boltdb sequence numbers for all buckets. func (connection *DbConnection) BackupMetadata() (map[string]interface{}, error) { buckets := map[string]interface{}{} err := connection.View(func(tx *bolt.Tx) error { err := tx.ForEach(func(name []byte, bucket *bolt.Bucket) error { bucketName := string(name) seqId := bucket.Sequence() buckets[bucketName] = int(seqId) return nil }) return err }) return buckets, err } // RestoreMetadata will restore the boltdb sequence numbers for all buckets. func (connection *DbConnection) RestoreMetadata(s map[string]interface{}) error { var err error for bucketName, v := range s { id, ok := v.(float64) // JSON ints are unmarshalled to interface as float64. See: https://pkg.go.dev/encoding/json#Decoder.Decode if !ok { log.Error().Str("bucket", bucketName).Msg("failed to restore metadata to bucket, skipped") continue } err = connection.Batch(func(tx *bolt.Tx) error { bucket, err := tx.CreateBucketIfNotExists([]byte(bucketName)) if err != nil { return err } return bucket.SetSequence(uint64(id)) }) } return err }