package generic import ( "context" "database/sql" "fmt" "regexp" "strconv" "strings" "sync" "time" "github.com/Rican7/retry/backoff" "github.com/Rican7/retry/strategy" "github.com/sirupsen/logrus" ) var ( columns = "kv.id as theid, kv.name, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value" revSQL = ` SELECT rkv.id FROM kine rkv ORDER BY rkv.id DESC LIMIT 1` compactRevSQL = ` SELECT crkv.prev_revision FROM kine crkv WHERE crkv.name = 'compact_rev_key' ORDER BY crkv.id DESC LIMIT 1` idOfKey = ` AND mkv.id <= ? AND mkv.id > ( SELECT ikv.id FROM kine ikv WHERE ikv.name = ? AND ikv.id <= ? ORDER BY ikv.id DESC LIMIT 1)` listSQL = fmt.Sprintf(`SELECT (%s), (%s), %s FROM kine kv JOIN ( SELECT MAX(mkv.id) as id FROM kine mkv WHERE mkv.name LIKE ? %%s GROUP BY mkv.name) maxkv ON maxkv.id = kv.id WHERE (kv.deleted = 0 OR ?) ORDER BY kv.id ASC `, revSQL, compactRevSQL, columns) ) type Stripped string func (s Stripped) String() string { str := strings.ReplaceAll(string(s), "\n", "") return regexp.MustCompile("[\t ]+").ReplaceAllString(str, " ") } type ErrRetry func(error) bool type TranslateErr func(error) error type Generic struct { sync.Mutex LockWrites bool LastInsertID bool DB *sql.DB GetCurrentSQL string GetRevisionSQL string RevisionSQL string ListRevisionStartSQL string GetRevisionAfterSQL string CountSQL string AfterSQL string DeleteSQL string UpdateCompactSQL string InsertSQL string FillSQL string InsertLastInsertIDSQL string Retry ErrRetry TranslateErr TranslateErr } func q(sql, param string, numbered bool) string { if param == "?" && !numbered { return sql } regex := regexp.MustCompile(`\?`) n := 0 return regex.ReplaceAllStringFunc(sql, func(string) string { if numbered { n++ return param + strconv.Itoa(n) } return param }) } func (d *Generic) Migrate(ctx context.Context) { var ( count = 0 countKV = d.queryRow(ctx, "SELECT COUNT(*) FROM key_value") countKine = d.queryRow(ctx, "SELECT COUNT(*) FROM kine") ) if err := countKV.Scan(&count); err != nil || count == 0 { return } if err := countKine.Scan(&count); err != nil || count != 0 { return } logrus.Infof("Migrating content from old table") _, err := d.execute(ctx, `INSERT INTO kine(deleted, create_revision, prev_revision, name, value, created, lease) SELECT 0, 0, 0, kv.name, kv.value, 1, CASE WHEN kv.ttl > 0 THEN 15 ELSE 0 END FROM key_value kv WHERE kv.id IN (SELECT MAX(kvd.id) FROM key_value kvd GROUP BY kvd.name)`) if err != nil { logrus.Errorf("Migration failed: %v", err) } } func openAndTest(driverName, dataSourceName string) (*sql.DB, error) { db, err := sql.Open(driverName, dataSourceName) if err != nil { return nil, err } for i := 0; i < 3; i++ { if err := db.Ping(); err != nil { db.Close() return nil, err } } return db, nil } func Open(ctx context.Context, driverName, dataSourceName string, paramCharacter string, numbered bool) (*Generic, error) { var ( db *sql.DB err error ) for i := 0; i < 300; i++ { db, err = openAndTest(driverName, dataSourceName) if err == nil { break } logrus.Errorf("failed to ping connection: %v", err) select { case <-ctx.Done(): return nil, ctx.Err() case <-time.After(time.Second): } } return &Generic{ DB: db, GetRevisionSQL: q(fmt.Sprintf(` SELECT 0, 0, %s FROM kine kv WHERE kv.id = ?`, columns), paramCharacter, numbered), GetCurrentSQL: q(fmt.Sprintf(listSQL, ""), paramCharacter, numbered), ListRevisionStartSQL: q(fmt.Sprintf(listSQL, "AND mkv.id <= ?"), paramCharacter, numbered), GetRevisionAfterSQL: q(fmt.Sprintf(listSQL, idOfKey), paramCharacter, numbered), CountSQL: q(fmt.Sprintf(` SELECT (%s), COUNT(c.theid) FROM ( %s ) c`, revSQL, fmt.Sprintf(listSQL, "")), paramCharacter, numbered), AfterSQL: q(fmt.Sprintf(` SELECT (%s), (%s), %s FROM kine kv WHERE kv.name LIKE ? AND kv.id > ? ORDER BY kv.id ASC`, revSQL, compactRevSQL, columns), paramCharacter, numbered), DeleteSQL: q(` DELETE FROM kine WHERE id = ?`, paramCharacter, numbered), UpdateCompactSQL: q(` UPDATE kine SET prev_revision = ? WHERE name = 'compact_rev_key'`, paramCharacter, numbered), InsertLastInsertIDSQL: q(`INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value) values(?, ?, ?, ?, ?, ?, ?, ?)`, paramCharacter, numbered), InsertSQL: q(`INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value) values(?, ?, ?, ?, ?, ?, ?, ?) RETURNING id`, paramCharacter, numbered), FillSQL: q(`INSERT INTO kine(id, name, created, deleted, create_revision, prev_revision, lease, value, old_value) values(?, ?, ?, ?, ?, ?, ?, ?, ?)`, paramCharacter, numbered), }, err } func (d *Generic) query(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) { logrus.Tracef("QUERY %v : %s", args, Stripped(sql)) return d.DB.QueryContext(ctx, sql, args...) } func (d *Generic) queryRow(ctx context.Context, sql string, args ...interface{}) *sql.Row { logrus.Tracef("QUERY ROW %v : %s", args, Stripped(sql)) return d.DB.QueryRowContext(ctx, sql, args...) } func (d *Generic) execute(ctx context.Context, sql string, args ...interface{}) (result sql.Result, err error) { if d.LockWrites { d.Lock() defer d.Unlock() } wait := strategy.Backoff(backoff.Linear(100 + time.Millisecond)) for i := uint(0); i < 20; i++ { if i > 2 { logrus.Debugf("EXEC (try: %d) %v : %s", i, args, Stripped(sql)) } else { logrus.Tracef("EXEC (try: %d) %v : %s", i, args, Stripped(sql)) } result, err = d.DB.ExecContext(ctx, sql, args...) if err != nil && d.Retry != nil && d.Retry(err) { wait(i) continue } return result, err } return } func (d *Generic) GetCompactRevision(ctx context.Context) (int64, error) { var id int64 row := d.queryRow(ctx, compactRevSQL) err := row.Scan(&id) if err == sql.ErrNoRows { return 0, nil } return id, err } func (d *Generic) SetCompactRevision(ctx context.Context, revision int64) error { _, err := d.execute(ctx, d.UpdateCompactSQL, revision) return err } func (d *Generic) GetRevision(ctx context.Context, revision int64) (*sql.Rows, error) { return d.query(ctx, d.GetRevisionSQL, revision) } func (d *Generic) DeleteRevision(ctx context.Context, revision int64) error { _, err := d.execute(ctx, d.DeleteSQL, revision) return err } func (d *Generic) ListCurrent(ctx context.Context, prefix string, limit int64, includeDeleted bool) (*sql.Rows, error) { sql := d.GetCurrentSQL if limit > 0 { sql = fmt.Sprintf("%s LIMIT %d", sql, limit) } return d.query(ctx, sql, prefix, includeDeleted) } func (d *Generic) List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (*sql.Rows, error) { if startKey == "" { sql := d.ListRevisionStartSQL if limit > 0 { sql = fmt.Sprintf("%s LIMIT %d", sql, limit) } return d.query(ctx, sql, prefix, revision, includeDeleted) } sql := d.GetRevisionAfterSQL if limit > 0 { sql = fmt.Sprintf("%s LIMIT %d", sql, limit) } return d.query(ctx, sql, prefix, revision, startKey, revision, includeDeleted) } func (d *Generic) Count(ctx context.Context, prefix string) (int64, int64, error) { var ( rev sql.NullInt64 id int64 ) row := d.queryRow(ctx, d.CountSQL, prefix, false) err := row.Scan(&rev, &id) return rev.Int64, id, err } func (d *Generic) CurrentRevision(ctx context.Context) (int64, error) { var id int64 row := d.queryRow(ctx, revSQL) err := row.Scan(&id) if err == sql.ErrNoRows { return 0, nil } return id, err } func (d *Generic) After(ctx context.Context, prefix string, rev, limit int64) (*sql.Rows, error) { sql := d.AfterSQL if limit > 0 { sql = fmt.Sprintf("%s LIMIT %d", sql, limit) } return d.query(ctx, sql, prefix, rev) } func (d *Generic) Fill(ctx context.Context, revision int64) error { _, err := d.execute(ctx, d.FillSQL, revision, fmt.Sprintf("gap-%d", revision), 0, 1, 0, 0, 0, nil, nil) return err } func (d *Generic) IsFill(key string) bool { return strings.HasPrefix(key, "gap-") } func (d *Generic) Insert(ctx context.Context, key string, create, delete bool, createRevision, previousRevision int64, ttl int64, value, prevValue []byte) (id int64, err error) { if d.TranslateErr != nil { defer func() { if err != nil { err = d.TranslateErr(err) } }() } cVal := 0 dVal := 0 if create { cVal = 1 } if delete { dVal = 1 } if d.LastInsertID { row, err := d.execute(ctx, d.InsertLastInsertIDSQL, key, cVal, dVal, createRevision, previousRevision, ttl, value, prevValue) if err != nil { return 0, err } return row.LastInsertId() } row := d.queryRow(ctx, d.InsertSQL, key, cVal, dVal, createRevision, previousRevision, ttl, value, prevValue) err = row.Scan(&id) return id, err }