k3s/vendor/github.com/rancher/kine/pkg/drivers/generic/generic.go

363 lines
9.0 KiB
Go

package generic
import (
"context"
"database/sql"
"fmt"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/Rican7/retry/backoff"
"github.com/Rican7/retry/strategy"
"github.com/sirupsen/logrus"
)
var (
columns = "kv.id as theid, kv.name, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value"
revSQL = `
SELECT rkv.id
FROM kine rkv
ORDER BY rkv.id
DESC LIMIT 1`
compactRevSQL = `
SELECT crkv.prev_revision
FROM kine crkv
WHERE crkv.name = 'compact_rev_key'
ORDER BY crkv.id DESC LIMIT 1`
idOfKey = `
AND mkv.id <= ? AND mkv.id > (
SELECT ikv.id
FROM kine ikv
WHERE
ikv.name = ? AND
ikv.id <= ?
ORDER BY ikv.id DESC LIMIT 1)`
listSQL = fmt.Sprintf(`SELECT (%s), (%s), %s
FROM kine kv
JOIN (
SELECT MAX(mkv.id) as id
FROM kine mkv
WHERE
mkv.name LIKE ?
%%s
GROUP BY mkv.name) maxkv
ON maxkv.id = kv.id
WHERE
(kv.deleted = 0 OR ?)
ORDER BY kv.id ASC
`, revSQL, compactRevSQL, columns)
)
type Stripped string
func (s Stripped) String() string {
str := strings.ReplaceAll(string(s), "\n", "")
return regexp.MustCompile("[\t ]+").ReplaceAllString(str, " ")
}
type ErrRetry func(error) bool
type TranslateErr func(error) error
type Generic struct {
sync.Mutex
LockWrites bool
LastInsertID bool
DB *sql.DB
GetCurrentSQL string
GetRevisionSQL string
RevisionSQL string
ListRevisionStartSQL string
GetRevisionAfterSQL string
CountSQL string
AfterSQL string
DeleteSQL string
UpdateCompactSQL string
InsertSQL string
FillSQL string
InsertLastInsertIDSQL string
Retry ErrRetry
TranslateErr TranslateErr
}
func q(sql, param string, numbered bool) string {
if param == "?" && !numbered {
return sql
}
regex := regexp.MustCompile(`\?`)
n := 0
return regex.ReplaceAllStringFunc(sql, func(string) string {
if numbered {
n++
return param + strconv.Itoa(n)
}
return param
})
}
func (d *Generic) Migrate(ctx context.Context) {
var (
count = 0
countKV = d.queryRow(ctx, "SELECT COUNT(*) FROM key_value")
countKine = d.queryRow(ctx, "SELECT COUNT(*) FROM kine")
)
if err := countKV.Scan(&count); err != nil || count == 0 {
return
}
if err := countKine.Scan(&count); err != nil || count != 0 {
return
}
logrus.Infof("Migrating content from old table")
_, err := d.execute(ctx,
`INSERT INTO kine(deleted, create_revision, prev_revision, name, value, created, lease)
SELECT 0, 0, 0, kv.name, kv.value, 1, CASE WHEN kv.ttl > 0 THEN 15 ELSE 0 END
FROM key_value kv
WHERE kv.id IN (SELECT MAX(kvd.id) FROM key_value kvd GROUP BY kvd.name)`)
if err != nil {
logrus.Errorf("Migration failed: %v", err)
}
}
func openAndTest(driverName, dataSourceName string) (*sql.DB, error) {
db, err := sql.Open(driverName, dataSourceName)
if err != nil {
return nil, err
}
for i := 0; i < 3; i++ {
if err := db.Ping(); err != nil {
db.Close()
return nil, err
}
}
return db, nil
}
func Open(ctx context.Context, driverName, dataSourceName string, paramCharacter string, numbered bool) (*Generic, error) {
var (
db *sql.DB
err error
)
for i := 0; i < 300; i++ {
db, err = openAndTest(driverName, dataSourceName)
if err == nil {
break
}
logrus.Errorf("failed to ping connection: %v", err)
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(time.Second):
}
}
return &Generic{
DB: db,
GetRevisionSQL: q(fmt.Sprintf(`
SELECT
0, 0, %s
FROM kine kv
WHERE kv.id = ?`, columns), paramCharacter, numbered),
GetCurrentSQL: q(fmt.Sprintf(listSQL, ""), paramCharacter, numbered),
ListRevisionStartSQL: q(fmt.Sprintf(listSQL, "AND mkv.id <= ?"), paramCharacter, numbered),
GetRevisionAfterSQL: q(fmt.Sprintf(listSQL, idOfKey), paramCharacter, numbered),
CountSQL: q(fmt.Sprintf(`
SELECT (%s), COUNT(c.theid)
FROM (
%s
) c`, revSQL, fmt.Sprintf(listSQL, "")), paramCharacter, numbered),
AfterSQL: q(fmt.Sprintf(`
SELECT (%s), (%s), %s
FROM kine kv
WHERE
kv.name LIKE ? AND
kv.id > ?
ORDER BY kv.id ASC`, revSQL, compactRevSQL, columns), paramCharacter, numbered),
DeleteSQL: q(`
DELETE FROM kine
WHERE id = ?`, paramCharacter, numbered),
UpdateCompactSQL: q(`
UPDATE kine
SET prev_revision = ?
WHERE name = 'compact_rev_key'`, paramCharacter, numbered),
InsertLastInsertIDSQL: q(`INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
values(?, ?, ?, ?, ?, ?, ?, ?)`, paramCharacter, numbered),
InsertSQL: q(`INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
values(?, ?, ?, ?, ?, ?, ?, ?) RETURNING id`, paramCharacter, numbered),
FillSQL: q(`INSERT INTO kine(id, name, created, deleted, create_revision, prev_revision, lease, value, old_value)
values(?, ?, ?, ?, ?, ?, ?, ?, ?)`, paramCharacter, numbered),
}, err
}
func (d *Generic) query(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) {
logrus.Tracef("QUERY %v : %s", args, Stripped(sql))
return d.DB.QueryContext(ctx, sql, args...)
}
func (d *Generic) queryRow(ctx context.Context, sql string, args ...interface{}) *sql.Row {
logrus.Tracef("QUERY ROW %v : %s", args, Stripped(sql))
return d.DB.QueryRowContext(ctx, sql, args...)
}
func (d *Generic) execute(ctx context.Context, sql string, args ...interface{}) (result sql.Result, err error) {
if d.LockWrites {
d.Lock()
defer d.Unlock()
}
wait := strategy.Backoff(backoff.Linear(100 + time.Millisecond))
for i := uint(0); i < 20; i++ {
if i > 2 {
logrus.Debugf("EXEC (try: %d) %v : %s", i, args, Stripped(sql))
} else {
logrus.Tracef("EXEC (try: %d) %v : %s", i, args, Stripped(sql))
}
result, err = d.DB.ExecContext(ctx, sql, args...)
if err != nil && d.Retry != nil && d.Retry(err) {
wait(i)
continue
}
return result, err
}
return
}
func (d *Generic) GetCompactRevision(ctx context.Context) (int64, error) {
var id int64
row := d.queryRow(ctx, compactRevSQL)
err := row.Scan(&id)
if err == sql.ErrNoRows {
return 0, nil
}
return id, err
}
func (d *Generic) SetCompactRevision(ctx context.Context, revision int64) error {
_, err := d.execute(ctx, d.UpdateCompactSQL, revision)
return err
}
func (d *Generic) GetRevision(ctx context.Context, revision int64) (*sql.Rows, error) {
return d.query(ctx, d.GetRevisionSQL, revision)
}
func (d *Generic) DeleteRevision(ctx context.Context, revision int64) error {
_, err := d.execute(ctx, d.DeleteSQL, revision)
return err
}
func (d *Generic) ListCurrent(ctx context.Context, prefix string, limit int64, includeDeleted bool) (*sql.Rows, error) {
sql := d.GetCurrentSQL
if limit > 0 {
sql = fmt.Sprintf("%s LIMIT %d", sql, limit)
}
return d.query(ctx, sql, prefix, includeDeleted)
}
func (d *Generic) List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (*sql.Rows, error) {
if startKey == "" {
sql := d.ListRevisionStartSQL
if limit > 0 {
sql = fmt.Sprintf("%s LIMIT %d", sql, limit)
}
return d.query(ctx, sql, prefix, revision, includeDeleted)
}
sql := d.GetRevisionAfterSQL
if limit > 0 {
sql = fmt.Sprintf("%s LIMIT %d", sql, limit)
}
return d.query(ctx, sql, prefix, revision, startKey, revision, includeDeleted)
}
func (d *Generic) Count(ctx context.Context, prefix string) (int64, int64, error) {
var (
rev sql.NullInt64
id int64
)
row := d.queryRow(ctx, d.CountSQL, prefix, false)
err := row.Scan(&rev, &id)
return rev.Int64, id, err
}
func (d *Generic) CurrentRevision(ctx context.Context) (int64, error) {
var id int64
row := d.queryRow(ctx, revSQL)
err := row.Scan(&id)
if err == sql.ErrNoRows {
return 0, nil
}
return id, err
}
func (d *Generic) After(ctx context.Context, prefix string, rev, limit int64) (*sql.Rows, error) {
sql := d.AfterSQL
if limit > 0 {
sql = fmt.Sprintf("%s LIMIT %d", sql, limit)
}
return d.query(ctx, sql, prefix, rev)
}
func (d *Generic) Fill(ctx context.Context, revision int64) error {
_, err := d.execute(ctx, d.FillSQL, revision, fmt.Sprintf("gap-%d", revision), 0, 1, 0, 0, 0, nil, nil)
return err
}
func (d *Generic) IsFill(key string) bool {
return strings.HasPrefix(key, "gap-")
}
func (d *Generic) Insert(ctx context.Context, key string, create, delete bool, createRevision, previousRevision int64, ttl int64, value, prevValue []byte) (id int64, err error) {
if d.TranslateErr != nil {
defer func() {
if err != nil {
err = d.TranslateErr(err)
}
}()
}
cVal := 0
dVal := 0
if create {
cVal = 1
}
if delete {
dVal = 1
}
if d.LastInsertID {
row, err := d.execute(ctx, d.InsertLastInsertIDSQL, key, cVal, dVal, createRevision, previousRevision, ttl, value, prevValue)
if err != nil {
return 0, err
}
return row.LastInsertId()
}
row := d.queryRow(ctx, d.InsertSQL, key, cVal, dVal, createRevision, previousRevision, ttl, value, prevValue)
err = row.Scan(&id)
return id, err
}