k3s/vendor/github.com/k3s-io/kine/pkg/drivers/generic/tx.go

114 lines
2.7 KiB
Go
Raw Normal View History

package generic
import (
"context"
"database/sql"
"github.com/k3s-io/kine/pkg/server"
"github.com/sirupsen/logrus"
)
// explicit interface check
var _ server.Transaction = (*Tx)(nil)
type Tx struct {
x *sql.Tx
d *Generic
}
func (d *Generic) BeginTx(ctx context.Context, opts *sql.TxOptions) (server.Transaction, error) {
logrus.Tracef("TX BEGIN")
x, err := d.DB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &Tx{
x: x,
d: d,
}, nil
}
func (t *Tx) Commit() error {
logrus.Tracef("TX COMMIT")
return t.x.Commit()
}
func (t *Tx) MustCommit() {
if err := t.Commit(); err != nil {
logrus.Fatalf("Transaction commit failed: %v", err)
}
}
func (t *Tx) Rollback() error {
logrus.Tracef("TX ROLLBACK")
return t.x.Rollback()
}
func (t *Tx) MustRollback() {
if err := t.Rollback(); err != nil {
if err != sql.ErrTxDone {
logrus.Fatalf("Transaction rollback failed: %v", err)
}
}
}
func (t *Tx) GetCompactRevision(ctx context.Context) (int64, error) {
var id int64
row := t.queryRow(ctx, compactRevSQL)
err := row.Scan(&id)
if err == sql.ErrNoRows {
return 0, nil
}
return id, err
}
func (t *Tx) SetCompactRevision(ctx context.Context, revision int64) error {
logrus.Tracef("TX SETCOMPACTREVISION %v", revision)
_, err := t.execute(ctx, t.d.UpdateCompactSQL, revision)
return err
}
func (t *Tx) Compact(ctx context.Context, revision int64) (int64, error) {
logrus.Tracef("TX COMPACT %v", revision)
res, err := t.execute(ctx, t.d.CompactSQL, revision, revision)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (t *Tx) GetRevision(ctx context.Context, revision int64) (*sql.Rows, error) {
return t.query(ctx, t.d.GetRevisionSQL, revision)
}
func (t *Tx) DeleteRevision(ctx context.Context, revision int64) error {
logrus.Tracef("TX DELETEREVISION %v", revision)
_, err := t.execute(ctx, t.d.DeleteSQL, revision)
return err
}
func (t *Tx) CurrentRevision(ctx context.Context) (int64, error) {
var id int64
row := t.queryRow(ctx, revSQL)
err := row.Scan(&id)
if err == sql.ErrNoRows {
return 0, nil
}
return id, err
}
func (t *Tx) query(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) {
logrus.Tracef("TX QUERY %v : %s", args, Stripped(sql))
return t.x.QueryContext(ctx, sql, args...)
}
func (t *Tx) queryRow(ctx context.Context, sql string, args ...interface{}) *sql.Row {
logrus.Tracef("TX QUERY ROW %v : %s", args, Stripped(sql))
return t.x.QueryRowContext(ctx, sql, args...)
}
func (t *Tx) execute(ctx context.Context, sql string, args ...interface{}) (result sql.Result, err error) {
logrus.Tracef("TX EXEC %v : %s", args, Stripped(sql))
return t.x.ExecContext(ctx, sql, args...)
}