diff --git a/inventory/tx.go b/inventory/tx.go index e267047..f53c4fd 100644 --- a/inventory/tx.go +++ b/inventory/tx.go @@ -3,6 +3,7 @@ package inventory import ( "context" "fmt" + "github.com/cloudreve/Cloudreve/v4/ent" "github.com/cloudreve/Cloudreve/v4/pkg/logging" ) @@ -60,6 +61,22 @@ func WithTx[T TxOperator](ctx context.Context, c T) (T, *Tx, context.Context, er return c.SetClient(txClient).(T), txWrapper, ctx, nil } +// InheritTx wraps the given inventory client with a transaction. +// If the transaction is already in the context, it will be inherited. +// Otherwise, original client will be returned. +func InheritTx[T TxOperator](ctx context.Context, c T) (T, *Tx) { + var txClient *ent.Client + var txWrapper *Tx + + if txInherited, ok := ctx.Value(TxCtx{}).(*Tx); ok && !txInherited.finished { + txWrapper = &Tx{inherited: true, tx: txInherited.tx, parent: txInherited} + txClient = txWrapper.tx.Client() + return c.SetClient(txClient).(T), txWrapper + } + + return c, nil +} + func Rollback(tx *Tx) error { if !tx.inherited { tx.finished = true diff --git a/pkg/filemanager/fs/dbfs/dbfs.go b/pkg/filemanager/fs/dbfs/dbfs.go index e362e23..c060093 100644 --- a/pkg/filemanager/fs/dbfs/dbfs.go +++ b/pkg/filemanager/fs/dbfs/dbfs.go @@ -652,7 +652,8 @@ func (f *DBFS) getPreferredPolicy(ctx context.Context, file *File) (*ent.Storage return nil, fmt.Errorf("owner group not loaded") } - groupPolicy, err := f.storagePolicyClient.GetByGroup(ctx, ownerGroup) + sc, _ := inventory.InheritTx(ctx, f.storagePolicyClient) + groupPolicy, err := sc.GetByGroup(ctx, ownerGroup) if err != nil { return nil, serializer.NewError(serializer.CodeDBError, "Failed to get available storage policies", err) }