package generic import ( "context" "database/sql" "errors" "fmt" "regexp" "strconv" "strings" "sync" "time" "github.com/Rican7/retry/backoff" "github.com/Rican7/retry/strategy" "github.com/k3s-io/kine/pkg/server" "github.com/sirupsen/logrus" ) const ( defaultMaxIdleConns = 2 // copied from database/sql ) // explicit interface check var _ server.Dialect = (*Generic)(nil) var ( columns = "kv.id AS theid, kv.name, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value" revSQL = ` SELECT MAX(rkv.id) AS id FROM kine AS rkv` compactRevSQL = ` SELECT MAX(crkv.prev_revision) AS prev_revision FROM kine AS crkv WHERE crkv.name = 'compact_rev_key'` idOfKey = ` AND mkv.id <= ? AND mkv.id > ( SELECT MAX(ikv.id) AS id FROM kine AS ikv WHERE ikv.name = ? AND ikv.id <= ?)` listSQL = fmt.Sprintf(` SELECT (%s), (%s), %s FROM kine AS kv JOIN ( SELECT MAX(mkv.id) AS id FROM kine AS mkv WHERE mkv.name LIKE ? %%s GROUP BY mkv.name) maxkv ON maxkv.id = kv.id WHERE (kv.deleted = 0 OR ?) ORDER BY kv.id ASC `, revSQL, compactRevSQL, columns) ) type Stripped string func (s Stripped) String() string { str := strings.ReplaceAll(string(s), "\n", "") return regexp.MustCompile("[\t ]+").ReplaceAllString(str, " ") } type ErrRetry func(error) bool type TranslateErr func(error) error type ConnectionPoolConfig struct { MaxIdle int // zero means defaultMaxIdleConns; negative means 0 MaxOpen int // <= 0 means unlimited MaxLifetime time.Duration // maximum amount of time a connection may be reused } type Generic struct { sync.Mutex LockWrites bool LastInsertID bool DB *sql.DB GetCurrentSQL string GetRevisionSQL string RevisionSQL string ListRevisionStartSQL string GetRevisionAfterSQL string CountSQL string AfterSQL string DeleteSQL string CompactSQL string UpdateCompactSQL string InsertSQL string FillSQL string InsertLastInsertIDSQL string GetSizeSQL string Retry ErrRetry TranslateErr TranslateErr } func q(sql, param string, numbered bool) string { if param == "?" && !numbered { return sql } regex := regexp.MustCompile(`\?`) n := 0 return regex.ReplaceAllStringFunc(sql, func(string) string { if numbered { n++ return param + strconv.Itoa(n) } return param }) } func (d *Generic) Migrate(ctx context.Context) { var ( count = 0 countKV = d.queryRow(ctx, "SELECT COUNT(*) FROM key_value") countKine = d.queryRow(ctx, "SELECT COUNT(*) FROM kine") ) if err := countKV.Scan(&count); err != nil || count == 0 { return } if err := countKine.Scan(&count); err != nil || count != 0 { return } logrus.Infof("Migrating content from old table") _, err := d.execute(ctx, `INSERT INTO kine(deleted, create_revision, prev_revision, name, value, created, lease) SELECT 0, 0, 0, kv.name, kv.value, 1, CASE WHEN kv.ttl > 0 THEN 15 ELSE 0 END FROM key_value kv WHERE kv.id IN (SELECT MAX(kvd.id) FROM key_value kvd GROUP BY kvd.name)`) if err != nil { logrus.Errorf("Migration failed: %v", err) } } func configureConnectionPooling(connPoolConfig ConnectionPoolConfig, db *sql.DB, driverName string) { // behavior copied from database/sql - zero means defaultMaxIdleConns; negative means 0 if connPoolConfig.MaxIdle < 0 { connPoolConfig.MaxIdle = 0 } else if connPoolConfig.MaxIdle == 0 { connPoolConfig.MaxIdle = defaultMaxIdleConns } logrus.Infof("Configuring %s database connection pooling: maxIdleConns=%d, maxOpenConns=%d, connMaxLifetime=%s", driverName, connPoolConfig.MaxIdle, connPoolConfig.MaxOpen, connPoolConfig.MaxLifetime) db.SetMaxIdleConns(connPoolConfig.MaxIdle) db.SetMaxOpenConns(connPoolConfig.MaxOpen) db.SetConnMaxLifetime(connPoolConfig.MaxLifetime) } func openAndTest(driverName, dataSourceName string) (*sql.DB, error) { db, err := sql.Open(driverName, dataSourceName) if err != nil { return nil, err } for i := 0; i < 3; i++ { if err := db.Ping(); err != nil { db.Close() return nil, err } } return db, nil } func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig ConnectionPoolConfig, paramCharacter string, numbered bool) (*Generic, error) { var ( db *sql.DB err error ) for i := 0; i < 300; i++ { db, err = openAndTest(driverName, dataSourceName) if err == nil { break } logrus.Errorf("failed to ping connection: %v", err) select { case <-ctx.Done(): return nil, ctx.Err() case <-time.After(time.Second): } } configureConnectionPooling(connPoolConfig, db, driverName) return &Generic{ DB: db, GetRevisionSQL: q(fmt.Sprintf(` SELECT 0, 0, %s FROM kine AS kv WHERE kv.id = ?`, columns), paramCharacter, numbered), GetCurrentSQL: q(fmt.Sprintf(listSQL, ""), paramCharacter, numbered), ListRevisionStartSQL: q(fmt.Sprintf(listSQL, "AND mkv.id <= ?"), paramCharacter, numbered), GetRevisionAfterSQL: q(fmt.Sprintf(listSQL, idOfKey), paramCharacter, numbered), CountSQL: q(fmt.Sprintf(` SELECT (%s), COUNT(c.theid) FROM ( %s ) c`, revSQL, fmt.Sprintf(listSQL, "")), paramCharacter, numbered), AfterSQL: q(fmt.Sprintf(` SELECT (%s), (%s), %s FROM kine AS kv WHERE kv.name LIKE ? AND kv.id > ? ORDER BY kv.id ASC`, revSQL, compactRevSQL, columns), paramCharacter, numbered), DeleteSQL: q(` DELETE FROM kine AS kv WHERE kv.id = ?`, paramCharacter, numbered), UpdateCompactSQL: q(` UPDATE kine SET prev_revision = ? WHERE name = 'compact_rev_key'`, paramCharacter, numbered), InsertLastInsertIDSQL: q(`INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value) values(?, ?, ?, ?, ?, ?, ?, ?)`, paramCharacter, numbered), InsertSQL: q(`INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value) values(?, ?, ?, ?, ?, ?, ?, ?) RETURNING id`, paramCharacter, numbered), FillSQL: q(`INSERT INTO kine(id, name, created, deleted, create_revision, prev_revision, lease, value, old_value) values(?, ?, ?, ?, ?, ?, ?, ?, ?)`, paramCharacter, numbered), }, err } func (d *Generic) query(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) { logrus.Tracef("QUERY %v : %s", args, Stripped(sql)) return d.DB.QueryContext(ctx, sql, args...) } func (d *Generic) queryRow(ctx context.Context, sql string, args ...interface{}) *sql.Row { logrus.Tracef("QUERY ROW %v : %s", args, Stripped(sql)) return d.DB.QueryRowContext(ctx, sql, args...) } func (d *Generic) execute(ctx context.Context, sql string, args ...interface{}) (result sql.Result, err error) { if d.LockWrites { d.Lock() defer d.Unlock() } wait := strategy.Backoff(backoff.Linear(100 + time.Millisecond)) for i := uint(0); i < 20; i++ { logrus.Tracef("EXEC (try: %d) %v : %s", i, args, Stripped(sql)) result, err = d.DB.ExecContext(ctx, sql, args...) if err != nil && d.Retry != nil && d.Retry(err) { wait(i) continue } return result, err } return } func (d *Generic) GetCompactRevision(ctx context.Context) (int64, error) { var id int64 row := d.queryRow(ctx, compactRevSQL) err := row.Scan(&id) if err == sql.ErrNoRows { return 0, nil } return id, err } func (d *Generic) SetCompactRevision(ctx context.Context, revision int64) error { logrus.Tracef("SETCOMPACTREVISION %v", revision) _, err := d.execute(ctx, d.UpdateCompactSQL, revision) return err } func (d *Generic) Compact(ctx context.Context, revision int64) (int64, error) { logrus.Tracef("COMPACT %v", revision) res, err := d.execute(ctx, d.CompactSQL, revision, revision) if err != nil { return 0, err } return res.RowsAffected() } func (d *Generic) GetRevision(ctx context.Context, revision int64) (*sql.Rows, error) { return d.query(ctx, d.GetRevisionSQL, revision) } func (d *Generic) DeleteRevision(ctx context.Context, revision int64) error { logrus.Tracef("DELETEREVISION %v", revision) _, err := d.execute(ctx, d.DeleteSQL, revision) return err } func (d *Generic) ListCurrent(ctx context.Context, prefix string, limit int64, includeDeleted bool) (*sql.Rows, error) { sql := d.GetCurrentSQL if limit > 0 { sql = fmt.Sprintf("%s LIMIT %d", sql, limit) } return d.query(ctx, sql, prefix, includeDeleted) } func (d *Generic) List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (*sql.Rows, error) { if startKey == "" { sql := d.ListRevisionStartSQL if limit > 0 { sql = fmt.Sprintf("%s LIMIT %d", sql, limit) } return d.query(ctx, sql, prefix, revision, includeDeleted) } sql := d.GetRevisionAfterSQL if limit > 0 { sql = fmt.Sprintf("%s LIMIT %d", sql, limit) } return d.query(ctx, sql, prefix, revision, startKey, revision, includeDeleted) } func (d *Generic) Count(ctx context.Context, prefix string) (int64, int64, error) { var ( rev sql.NullInt64 id int64 ) row := d.queryRow(ctx, d.CountSQL, prefix, false) err := row.Scan(&rev, &id) return rev.Int64, id, err } func (d *Generic) CurrentRevision(ctx context.Context) (int64, error) { var id int64 row := d.queryRow(ctx, revSQL) err := row.Scan(&id) if err == sql.ErrNoRows { return 0, nil } return id, err } func (d *Generic) After(ctx context.Context, prefix string, rev, limit int64) (*sql.Rows, error) { sql := d.AfterSQL if limit > 0 { sql = fmt.Sprintf("%s LIMIT %d", sql, limit) } return d.query(ctx, sql, prefix, rev) } func (d *Generic) Fill(ctx context.Context, revision int64) error { _, err := d.execute(ctx, d.FillSQL, revision, fmt.Sprintf("gap-%d", revision), 0, 1, 0, 0, 0, nil, nil) return err } func (d *Generic) IsFill(key string) bool { return strings.HasPrefix(key, "gap-") } func (d *Generic) Insert(ctx context.Context, key string, create, delete bool, createRevision, previousRevision int64, ttl int64, value, prevValue []byte) (id int64, err error) { if d.TranslateErr != nil { defer func() { if err != nil { err = d.TranslateErr(err) } }() } cVal := 0 dVal := 0 if create { cVal = 1 } if delete { dVal = 1 } if d.LastInsertID { row, err := d.execute(ctx, d.InsertLastInsertIDSQL, key, cVal, dVal, createRevision, previousRevision, ttl, value, prevValue) if err != nil { return 0, err } return row.LastInsertId() } row := d.queryRow(ctx, d.InsertSQL, key, cVal, dVal, createRevision, previousRevision, ttl, value, prevValue) err = row.Scan(&id) return id, err } func (d *Generic) GetSize(ctx context.Context) (int64, error) { if d.GetSizeSQL == "" { return 0, errors.New("driver does not support size reporting") } var size int64 row := d.queryRow(ctx, d.GetSizeSQL) if err := row.Scan(&size); err != nil { return 0, err } return size, nil }