mirror of https://github.com/cloudreve/Cloudreve
				
				
				
			
		
			
				
	
	
		
			102 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			102 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Go
		
	
	
package inventory
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"fmt"
 | 
						|
	"github.com/cloudreve/Cloudreve/v4/ent"
 | 
						|
	"github.com/cloudreve/Cloudreve/v4/pkg/logging"
 | 
						|
)
 | 
						|
 | 
						|
type TxOperator interface {
 | 
						|
	SetClient(newClient *ent.Client) TxOperator
 | 
						|
	GetClient() *ent.Client
 | 
						|
}
 | 
						|
 | 
						|
type (
 | 
						|
	Tx struct {
 | 
						|
		tx          *ent.Tx
 | 
						|
		parent      *Tx
 | 
						|
		inherited   bool
 | 
						|
		finished    bool
 | 
						|
		storageDiff StorageDiff
 | 
						|
	}
 | 
						|
 | 
						|
	// TxCtx is the context key for inherited transaction
 | 
						|
	TxCtx struct{}
 | 
						|
)
 | 
						|
 | 
						|
// AppendStorageDiff appends the given storage diff to the transaction.
 | 
						|
func (t *Tx) AppendStorageDiff(diff StorageDiff) {
 | 
						|
	root := t
 | 
						|
	for root.inherited {
 | 
						|
		root = root.parent
 | 
						|
	}
 | 
						|
 | 
						|
	if root.storageDiff == nil {
 | 
						|
		root.storageDiff = diff
 | 
						|
	} else {
 | 
						|
		root.storageDiff.Merge(diff)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// WithTx wraps the given inventory client with a transaction.
 | 
						|
func WithTx[T TxOperator](ctx context.Context, c T) (T, *Tx, context.Context, error) {
 | 
						|
	var txClient *ent.Client
 | 
						|
	var txWrapper *Tx
 | 
						|
 | 
						|
	if txInherited, ok := ctx.Value(TxCtx{}).(*Tx); ok && !txInherited.finished {
 | 
						|
		txWrapper = &Tx{inherited: true, tx: txInherited.tx, parent: txInherited}
 | 
						|
	} else {
 | 
						|
		tx, err := c.GetClient().Tx(ctx)
 | 
						|
		if err != nil {
 | 
						|
			return c, nil, ctx, fmt.Errorf("failed to create transaction: %w", err)
 | 
						|
		}
 | 
						|
 | 
						|
		txWrapper = &Tx{inherited: false, tx: tx}
 | 
						|
		ctx = context.WithValue(ctx, TxCtx{}, txWrapper)
 | 
						|
	}
 | 
						|
 | 
						|
	txClient = txWrapper.tx.Client()
 | 
						|
	return c.SetClient(txClient).(T), txWrapper, ctx, nil
 | 
						|
}
 | 
						|
 | 
						|
func Rollback(tx *Tx) error {
 | 
						|
	if !tx.inherited {
 | 
						|
		tx.finished = true
 | 
						|
		return tx.tx.Rollback()
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func commit(tx *Tx) (bool, error) {
 | 
						|
	if !tx.inherited {
 | 
						|
		tx.finished = true
 | 
						|
		return true, tx.tx.Commit()
 | 
						|
	}
 | 
						|
	return false, nil
 | 
						|
}
 | 
						|
 | 
						|
func Commit(tx *Tx) error {
 | 
						|
	_, err := commit(tx)
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
// CommitWithStorageDiff commits the transaction and applies the storage diff, only if the transaction is not inherited.
 | 
						|
func CommitWithStorageDiff(ctx context.Context, tx *Tx, l logging.Logger, uc UserClient) error {
 | 
						|
	commited, err := commit(tx)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	if !commited {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
 | 
						|
	if err := uc.ApplyStorageDiff(ctx, tx.storageDiff); err != nil {
 | 
						|
		l.Error("Failed to apply storage diff", "error", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 |