k3s/vendor/github.com/k3s-io/kine/pkg/drivers/generic/generic.go

414 lines
11 KiB
Go

package generic
import (
"context"
"database/sql"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/Rican7/retry/backoff"
"github.com/Rican7/retry/strategy"
"github.com/k3s-io/kine/pkg/server"
"github.com/sirupsen/logrus"
)
const (
defaultMaxIdleConns = 2 // copied from database/sql
)
// explicit interface check
var _ server.Dialect = (*Generic)(nil)
var (
columns = "kv.id AS theid, kv.name, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value"
revSQL = `
SELECT MAX(rkv.id) AS id
FROM kine AS rkv`
compactRevSQL = `
SELECT MAX(crkv.prev_revision) AS prev_revision
FROM kine AS crkv
WHERE crkv.name = 'compact_rev_key'`
idOfKey = `
AND
mkv.id <= ? AND
mkv.id > (
SELECT MAX(ikv.id) AS id
FROM kine AS ikv
WHERE
ikv.name = ? AND
ikv.id <= ?)`
listSQL = fmt.Sprintf(`
SELECT (%s), (%s), %s
FROM kine AS kv
JOIN (
SELECT MAX(mkv.id) AS id
FROM kine AS mkv
WHERE
mkv.name LIKE ?
%%s
GROUP BY mkv.name) maxkv
ON maxkv.id = kv.id
WHERE
(kv.deleted = 0 OR ?)
ORDER BY kv.id ASC
`, revSQL, compactRevSQL, columns)
)
type Stripped string
func (s Stripped) String() string {
str := strings.ReplaceAll(string(s), "\n", "")
return regexp.MustCompile("[\t ]+").ReplaceAllString(str, " ")
}
type ErrRetry func(error) bool
type TranslateErr func(error) error
type ConnectionPoolConfig struct {
MaxIdle int // zero means defaultMaxIdleConns; negative means 0
MaxOpen int // <= 0 means unlimited
MaxLifetime time.Duration // maximum amount of time a connection may be reused
}
type Generic struct {
sync.Mutex
LockWrites bool
LastInsertID bool
DB *sql.DB
GetCurrentSQL string
GetRevisionSQL string
RevisionSQL string
ListRevisionStartSQL string
GetRevisionAfterSQL string
CountSQL string
AfterSQL string
DeleteSQL string
CompactSQL string
UpdateCompactSQL string
InsertSQL string
FillSQL string
InsertLastInsertIDSQL string
GetSizeSQL string
Retry ErrRetry
TranslateErr TranslateErr
}
func q(sql, param string, numbered bool) string {
if param == "?" && !numbered {
return sql
}
regex := regexp.MustCompile(`\?`)
n := 0
return regex.ReplaceAllStringFunc(sql, func(string) string {
if numbered {
n++
return param + strconv.Itoa(n)
}
return param
})
}
func (d *Generic) Migrate(ctx context.Context) {
var (
count = 0
countKV = d.queryRow(ctx, "SELECT COUNT(*) FROM key_value")
countKine = d.queryRow(ctx, "SELECT COUNT(*) FROM kine")
)
if err := countKV.Scan(&count); err != nil || count == 0 {
return
}
if err := countKine.Scan(&count); err != nil || count != 0 {
return
}
logrus.Infof("Migrating content from old table")
_, err := d.execute(ctx,
`INSERT INTO kine(deleted, create_revision, prev_revision, name, value, created, lease)
SELECT 0, 0, 0, kv.name, kv.value, 1, CASE WHEN kv.ttl > 0 THEN 15 ELSE 0 END
FROM key_value kv
WHERE kv.id IN (SELECT MAX(kvd.id) FROM key_value kvd GROUP BY kvd.name)`)
if err != nil {
logrus.Errorf("Migration failed: %v", err)
}
}
func configureConnectionPooling(connPoolConfig ConnectionPoolConfig, db *sql.DB, driverName string) {
// behavior copied from database/sql - zero means defaultMaxIdleConns; negative means 0
if connPoolConfig.MaxIdle < 0 {
connPoolConfig.MaxIdle = 0
} else if connPoolConfig.MaxIdle == 0 {
connPoolConfig.MaxIdle = defaultMaxIdleConns
}
logrus.Infof("Configuring %s database connection pooling: maxIdleConns=%d, maxOpenConns=%d, connMaxLifetime=%s", driverName, connPoolConfig.MaxIdle, connPoolConfig.MaxOpen, connPoolConfig.MaxLifetime)
db.SetMaxIdleConns(connPoolConfig.MaxIdle)
db.SetMaxOpenConns(connPoolConfig.MaxOpen)
db.SetConnMaxLifetime(connPoolConfig.MaxLifetime)
}
func openAndTest(driverName, dataSourceName string) (*sql.DB, error) {
db, err := sql.Open(driverName, dataSourceName)
if err != nil {
return nil, err
}
for i := 0; i < 3; i++ {
if err := db.Ping(); err != nil {
db.Close()
return nil, err
}
}
return db, nil
}
func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig ConnectionPoolConfig, paramCharacter string, numbered bool) (*Generic, error) {
var (
db *sql.DB
err error
)
for i := 0; i < 300; i++ {
db, err = openAndTest(driverName, dataSourceName)
if err == nil {
break
}
logrus.Errorf("failed to ping connection: %v", err)
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(time.Second):
}
}
configureConnectionPooling(connPoolConfig, db, driverName)
return &Generic{
DB: db,
GetRevisionSQL: q(fmt.Sprintf(`
SELECT
0, 0, %s
FROM kine AS kv
WHERE kv.id = ?`, columns), paramCharacter, numbered),
GetCurrentSQL: q(fmt.Sprintf(listSQL, ""), paramCharacter, numbered),
ListRevisionStartSQL: q(fmt.Sprintf(listSQL, "AND mkv.id <= ?"), paramCharacter, numbered),
GetRevisionAfterSQL: q(fmt.Sprintf(listSQL, idOfKey), paramCharacter, numbered),
CountSQL: q(fmt.Sprintf(`
SELECT (%s), COUNT(c.theid)
FROM (
%s
) c`, revSQL, fmt.Sprintf(listSQL, "")), paramCharacter, numbered),
AfterSQL: q(fmt.Sprintf(`
SELECT (%s), (%s), %s
FROM kine AS kv
WHERE
kv.name LIKE ? AND
kv.id > ?
ORDER BY kv.id ASC`, revSQL, compactRevSQL, columns), paramCharacter, numbered),
DeleteSQL: q(`
DELETE FROM kine AS kv
WHERE kv.id = ?`, paramCharacter, numbered),
UpdateCompactSQL: q(`
UPDATE kine
SET prev_revision = ?
WHERE name = 'compact_rev_key'`, paramCharacter, numbered),
InsertLastInsertIDSQL: q(`INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
values(?, ?, ?, ?, ?, ?, ?, ?)`, paramCharacter, numbered),
InsertSQL: q(`INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
values(?, ?, ?, ?, ?, ?, ?, ?) RETURNING id`, paramCharacter, numbered),
FillSQL: q(`INSERT INTO kine(id, name, created, deleted, create_revision, prev_revision, lease, value, old_value)
values(?, ?, ?, ?, ?, ?, ?, ?, ?)`, paramCharacter, numbered),
}, err
}
func (d *Generic) query(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) {
logrus.Tracef("QUERY %v : %s", args, Stripped(sql))
return d.DB.QueryContext(ctx, sql, args...)
}
func (d *Generic) queryRow(ctx context.Context, sql string, args ...interface{}) *sql.Row {
logrus.Tracef("QUERY ROW %v : %s", args, Stripped(sql))
return d.DB.QueryRowContext(ctx, sql, args...)
}
func (d *Generic) execute(ctx context.Context, sql string, args ...interface{}) (result sql.Result, err error) {
if d.LockWrites {
d.Lock()
defer d.Unlock()
}
wait := strategy.Backoff(backoff.Linear(100 + time.Millisecond))
for i := uint(0); i < 20; i++ {
logrus.Tracef("EXEC (try: %d) %v : %s", i, args, Stripped(sql))
result, err = d.DB.ExecContext(ctx, sql, args...)
if err != nil && d.Retry != nil && d.Retry(err) {
wait(i)
continue
}
return result, err
}
return
}
func (d *Generic) GetCompactRevision(ctx context.Context) (int64, error) {
var id int64
row := d.queryRow(ctx, compactRevSQL)
err := row.Scan(&id)
if err == sql.ErrNoRows {
return 0, nil
}
return id, err
}
func (d *Generic) SetCompactRevision(ctx context.Context, revision int64) error {
logrus.Tracef("SETCOMPACTREVISION %v", revision)
_, err := d.execute(ctx, d.UpdateCompactSQL, revision)
return err
}
func (d *Generic) Compact(ctx context.Context, revision int64) (int64, error) {
logrus.Tracef("COMPACT %v", revision)
res, err := d.execute(ctx, d.CompactSQL, revision, revision)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (d *Generic) GetRevision(ctx context.Context, revision int64) (*sql.Rows, error) {
return d.query(ctx, d.GetRevisionSQL, revision)
}
func (d *Generic) DeleteRevision(ctx context.Context, revision int64) error {
logrus.Tracef("DELETEREVISION %v", revision)
_, err := d.execute(ctx, d.DeleteSQL, revision)
return err
}
func (d *Generic) ListCurrent(ctx context.Context, prefix string, limit int64, includeDeleted bool) (*sql.Rows, error) {
sql := d.GetCurrentSQL
if limit > 0 {
sql = fmt.Sprintf("%s LIMIT %d", sql, limit)
}
return d.query(ctx, sql, prefix, includeDeleted)
}
func (d *Generic) List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (*sql.Rows, error) {
if startKey == "" {
sql := d.ListRevisionStartSQL
if limit > 0 {
sql = fmt.Sprintf("%s LIMIT %d", sql, limit)
}
return d.query(ctx, sql, prefix, revision, includeDeleted)
}
sql := d.GetRevisionAfterSQL
if limit > 0 {
sql = fmt.Sprintf("%s LIMIT %d", sql, limit)
}
return d.query(ctx, sql, prefix, revision, startKey, revision, includeDeleted)
}
func (d *Generic) Count(ctx context.Context, prefix string) (int64, int64, error) {
var (
rev sql.NullInt64
id int64
)
row := d.queryRow(ctx, d.CountSQL, prefix, false)
err := row.Scan(&rev, &id)
return rev.Int64, id, err
}
func (d *Generic) CurrentRevision(ctx context.Context) (int64, error) {
var id int64
row := d.queryRow(ctx, revSQL)
err := row.Scan(&id)
if err == sql.ErrNoRows {
return 0, nil
}
return id, err
}
func (d *Generic) After(ctx context.Context, prefix string, rev, limit int64) (*sql.Rows, error) {
sql := d.AfterSQL
if limit > 0 {
sql = fmt.Sprintf("%s LIMIT %d", sql, limit)
}
return d.query(ctx, sql, prefix, rev)
}
func (d *Generic) Fill(ctx context.Context, revision int64) error {
_, err := d.execute(ctx, d.FillSQL, revision, fmt.Sprintf("gap-%d", revision), 0, 1, 0, 0, 0, nil, nil)
return err
}
func (d *Generic) IsFill(key string) bool {
return strings.HasPrefix(key, "gap-")
}
func (d *Generic) Insert(ctx context.Context, key string, create, delete bool, createRevision, previousRevision int64, ttl int64, value, prevValue []byte) (id int64, err error) {
if d.TranslateErr != nil {
defer func() {
if err != nil {
err = d.TranslateErr(err)
}
}()
}
cVal := 0
dVal := 0
if create {
cVal = 1
}
if delete {
dVal = 1
}
if d.LastInsertID {
row, err := d.execute(ctx, d.InsertLastInsertIDSQL, key, cVal, dVal, createRevision, previousRevision, ttl, value, prevValue)
if err != nil {
return 0, err
}
return row.LastInsertId()
}
row := d.queryRow(ctx, d.InsertSQL, key, cVal, dVal, createRevision, previousRevision, ttl, value, prevValue)
err = row.Scan(&id)
return id, err
}
func (d *Generic) GetSize(ctx context.Context) (int64, error) {
if d.GetSizeSQL == "" {
return 0, errors.New("driver does not support size reporting")
}
var size int64
row := d.queryRow(ctx, d.GetSizeSQL)
if err := row.Scan(&size); err != nil {
return 0, err
}
return size, nil
}