k3s/vendor/github.com/ibuildthecloud/kvsql/clientv3/driver/generic.go

294 lines
6.2 KiB
Go

package driver
import (
"context"
"database/sql"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/ibuildthecloud/kvsql/pkg/broadcast"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
utiltrace "k8s.io/apiserver/pkg/util/trace"
)
type Generic struct {
db *sql.DB
CleanupSQL string
GetSQL string
ListSQL string
ListRevisionSQL string
ListResumeSQL string
ReplaySQL string
InsertSQL string
GetRevisionSQL string
ToDeleteSQL string
DeleteOldSQL string
revision int64
changes chan *KeyValue
broadcaster broadcast.Broadcaster
cancel func()
}
func (g *Generic) Start(ctx context.Context, db *sql.DB) error {
g.db = db
g.changes = make(chan *KeyValue, 1024)
row := db.QueryRowContext(ctx, g.GetRevisionSQL)
rev := sql.NullInt64{}
if err := row.Scan(&rev); err != nil {
return errors.Wrap(err, "Failed to initialize revision")
}
if rev.Int64 == 0 {
g.revision = 1
} else {
g.revision = rev.Int64
}
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(time.Minute):
_, err := g.ExecContext(ctx, g.CleanupSQL, time.Now().Unix())
if err != nil {
logrus.Errorf("Failed to purge expired TTL entries")
}
err = g.cleanup(ctx)
if err != nil {
logrus.Errorf("Failed to cleanup duplicate entries")
}
}
}
}()
return nil
}
func (g *Generic) cleanup(ctx context.Context) error {
rows, err := g.QueryContext(ctx, g.ToDeleteSQL)
if err != nil {
return err
}
defer rows.Close()
toDelete := map[string]int64{}
for rows.Next() {
var (
count, revision int64
name string
)
err := rows.Scan(&count, &name, &revision)
if err != nil {
return err
}
toDelete[name] = revision
}
rows.Close()
for name, rev := range toDelete {
_, err = g.ExecContext(ctx, g.DeleteOldSQL, name, rev, rev)
if err != nil {
return err
}
}
return nil
}
func (g *Generic) Get(ctx context.Context, key string) (*KeyValue, error) {
kvs, _, err := g.List(ctx, 0, 1, key, "")
if err != nil {
return nil, err
}
if len(kvs) > 0 {
return kvs[0], nil
}
return nil, nil
}
func (g *Generic) replayEvents(ctx context.Context, key string, revision int64) ([]*KeyValue, error) {
rows, err := g.QueryContext(ctx, g.ReplaySQL, key, revision)
if err != nil {
return nil, err
}
defer rows.Close()
var resp []*KeyValue
for rows.Next() {
value := KeyValue{}
if err := scan(rows.Scan, &value); err != nil {
return nil, err
}
resp = append(resp, &value)
}
return resp, nil
}
func (g *Generic) List(ctx context.Context, revision, limit int64, rangeKey, startKey string) ([]*KeyValue, int64, error) {
var (
rows *sql.Rows
err error
)
if limit == 0 {
limit = 1000000
} else {
limit = limit + 1
}
listRevision := atomic.LoadInt64(&g.revision)
if !strings.HasSuffix(rangeKey, "%") && revision <= 0 {
rows, err = g.QueryContext(ctx, g.GetSQL, rangeKey, limit)
} else if revision <= 0 {
rows, err = g.QueryContext(ctx, g.ListSQL, rangeKey, limit)
} else if len(startKey) > 0 {
listRevision = revision
rows, err = g.QueryContext(ctx, g.ListResumeSQL, revision, rangeKey, startKey, limit)
} else {
rows, err = g.QueryContext(ctx, g.ListRevisionSQL, revision, rangeKey, limit)
}
if err != nil {
return nil, 0, err
}
defer rows.Close()
var resp []*KeyValue
for rows.Next() {
value := KeyValue{}
if err := scan(rows.Scan, &value); err != nil {
return nil, 0, err
}
if value.Revision > listRevision {
listRevision = value.Revision
}
if value.Del == 0 {
resp = append(resp, &value)
}
}
return resp, listRevision, nil
}
func (g *Generic) Delete(ctx context.Context, key string, revision int64) ([]*KeyValue, error) {
if strings.HasSuffix(key, "%") {
panic("can not delete list revision")
}
_, err := g.mod(ctx, true, key, []byte{}, revision, 0)
return nil, err
}
func (g *Generic) Update(ctx context.Context, key string, value []byte, revision, ttl int64) (*KeyValue, *KeyValue, error) {
kv, err := g.mod(ctx, false, key, value, revision, ttl)
if err != nil {
return nil, nil, err
}
if kv.Version == 1 {
return nil, kv, nil
}
oldKv := *kv
oldKv.Revision = oldKv.OldRevision
oldKv.Value = oldKv.OldValue
return &oldKv, kv, nil
}
func (g *Generic) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
trace := utiltrace.New(fmt.Sprintf("SQL DB ExecContext query: %s keys: %v", query, args))
defer trace.LogIfLong(500 * time.Millisecond)
return g.db.ExecContext(ctx, query, args...)
}
func (g *Generic) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
trace := utiltrace.New(fmt.Sprintf("SQL DB QueryContext query: %s keys: %v", query, args))
defer trace.LogIfLong(500 * time.Millisecond)
return g.db.QueryContext(ctx, query, args...)
}
func (g *Generic) mod(ctx context.Context, delete bool, key string, value []byte, revision int64, ttl int64) (*KeyValue, error) {
oldKv, err := g.Get(ctx, key)
if err != nil {
return nil, err
}
if revision > 0 && oldKv == nil {
return nil, ErrNotExists
}
if revision > 0 && oldKv.Revision != revision {
return nil, ErrRevisionMatch
}
if ttl > 0 {
ttl = int64(time.Now().Unix()) + ttl
}
newRevision := atomic.AddInt64(&g.revision, 1)
result := &KeyValue{
Key: key,
Value: value,
Revision: newRevision,
TTL: int64(ttl),
CreateRevision: newRevision,
Version: 1,
}
if oldKv != nil {
result.OldRevision = oldKv.Revision
result.OldValue = oldKv.Value
result.TTL = oldKv.TTL
result.CreateRevision = oldKv.CreateRevision
result.Version = oldKv.Version + 1
}
if delete {
result.Del = 1
}
_, err = g.ExecContext(ctx, g.InsertSQL,
result.Key,
result.Value,
result.OldValue,
result.OldRevision,
result.CreateRevision,
result.Revision,
result.TTL,
result.Version,
result.Del,
)
if err != nil {
return nil, err
}
g.changes <- result
return result, nil
}
type scanner func(dest ...interface{}) error
func scan(s scanner, out *KeyValue) error {
return s(
&out.ID,
&out.Key,
&out.Value,
&out.OldValue,
&out.OldRevision,
&out.CreateRevision,
&out.Revision,
&out.TTL,
&out.Version,
&out.Del)
}