// Copyright 2016 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package concurrency

import (
	"context"
	"math"

	v3 "go.etcd.io/etcd/clientv3"
)

// STM is an interface for software transactional memory.
type STM interface {
	// Get returns the value for a key and inserts the key in the txn's read set.
	// If Get fails, it aborts the transaction with an error, never returning.
	Get(key ...string) string
	// Put adds a value for a key to the write set.
	Put(key, val string, opts ...v3.OpOption)
	// Rev returns the revision of a key in the read set.
	Rev(key string) int64
	// Del deletes a key.
	Del(key string)

	// commit attempts to apply the txn's changes to the server.
	commit() *v3.TxnResponse
	reset()
}

// Isolation is an enumeration of transactional isolation levels which
// describes how transactions should interfere and conflict.
type Isolation int

const (
	// SerializableSnapshot provides serializable isolation and also checks
	// for write conflicts.
	SerializableSnapshot Isolation = iota
	// Serializable reads within the same transaction attempt return data
	// from the at the revision of the first read.
	Serializable
	// RepeatableReads reads within the same transaction attempt always
	// return the same data.
	RepeatableReads
	// ReadCommitted reads keys from any committed revision.
	ReadCommitted
)

// stmError safely passes STM errors through panic to the STM error channel.
type stmError struct{ err error }

type stmOptions struct {
	iso      Isolation
	ctx      context.Context
	prefetch []string
}

type stmOption func(*stmOptions)

// WithIsolation specifies the transaction isolation level.
func WithIsolation(lvl Isolation) stmOption {
	return func(so *stmOptions) { so.iso = lvl }
}

// WithAbortContext specifies the context for permanently aborting the transaction.
func WithAbortContext(ctx context.Context) stmOption {
	return func(so *stmOptions) { so.ctx = ctx }
}

// WithPrefetch is a hint to prefetch a list of keys before trying to apply.
// If an STM transaction will unconditionally fetch a set of keys, prefetching
// those keys will save the round-trip cost from requesting each key one by one
// with Get().
func WithPrefetch(keys ...string) stmOption {
	return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) }
}

// NewSTM initiates a new STM instance, using serializable snapshot isolation by default.
func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) {
	opts := &stmOptions{ctx: c.Ctx()}
	for _, f := range so {
		f(opts)
	}
	if len(opts.prefetch) != 0 {
		f := apply
		apply = func(s STM) error {
			s.Get(opts.prefetch...)
			return f(s)
		}
	}
	return runSTM(mkSTM(c, opts), apply)
}

func mkSTM(c *v3.Client, opts *stmOptions) STM {
	switch opts.iso {
	case SerializableSnapshot:
		s := &stmSerializable{
			stm:      stm{client: c, ctx: opts.ctx},
			prefetch: make(map[string]*v3.GetResponse),
		}
		s.conflicts = func() []v3.Cmp {
			return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...)
		}
		return s
	case Serializable:
		s := &stmSerializable{
			stm:      stm{client: c, ctx: opts.ctx},
			prefetch: make(map[string]*v3.GetResponse),
		}
		s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
		return s
	case RepeatableReads:
		s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
		s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
		return s
	case ReadCommitted:
		s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
		s.conflicts = func() []v3.Cmp { return nil }
		return s
	default:
		panic("unsupported stm")
	}
}

type stmResponse struct {
	resp *v3.TxnResponse
	err  error
}

func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) {
	outc := make(chan stmResponse, 1)
	go func() {
		defer func() {
			if r := recover(); r != nil {
				e, ok := r.(stmError)
				if !ok {
					// client apply panicked
					panic(r)
				}
				outc <- stmResponse{nil, e.err}
			}
		}()
		var out stmResponse
		for {
			s.reset()
			if out.err = apply(s); out.err != nil {
				break
			}
			if out.resp = s.commit(); out.resp != nil {
				break
			}
		}
		outc <- out
	}()
	r := <-outc
	return r.resp, r.err
}

// stm implements repeatable-read software transactional memory over etcd
type stm struct {
	client *v3.Client
	ctx    context.Context
	// rset holds read key values and revisions
	rset readSet
	// wset holds overwritten keys and their values
	wset writeSet
	// getOpts are the opts used for gets
	getOpts []v3.OpOption
	// conflicts computes the current conflicts on the txn
	conflicts func() []v3.Cmp
}

type stmPut struct {
	val string
	op  v3.Op
}

type readSet map[string]*v3.GetResponse

func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) {
	for i, resp := range txnresp.Responses {
		rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
	}
}

// first returns the store revision from the first fetch
func (rs readSet) first() int64 {
	ret := int64(math.MaxInt64 - 1)
	for _, resp := range rs {
		if rev := resp.Header.Revision; rev < ret {
			ret = rev
		}
	}
	return ret
}

// cmps guards the txn from updates to read set
func (rs readSet) cmps() []v3.Cmp {
	cmps := make([]v3.Cmp, 0, len(rs))
	for k, rk := range rs {
		cmps = append(cmps, isKeyCurrent(k, rk))
	}
	return cmps
}

type writeSet map[string]stmPut

func (ws writeSet) get(keys ...string) *stmPut {
	for _, key := range keys {
		if wv, ok := ws[key]; ok {
			return &wv
		}
	}
	return nil
}

// cmps returns a cmp list testing no writes have happened past rev
func (ws writeSet) cmps(rev int64) []v3.Cmp {
	cmps := make([]v3.Cmp, 0, len(ws))
	for key := range ws {
		cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
	}
	return cmps
}

// puts is the list of ops for all pending writes
func (ws writeSet) puts() []v3.Op {
	puts := make([]v3.Op, 0, len(ws))
	for _, v := range ws {
		puts = append(puts, v.op)
	}
	return puts
}

func (s *stm) Get(keys ...string) string {
	if wv := s.wset.get(keys...); wv != nil {
		return wv.val
	}
	return respToValue(s.fetch(keys...))
}

func (s *stm) Put(key, val string, opts ...v3.OpOption) {
	s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)}
}

func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} }

func (s *stm) Rev(key string) int64 {
	if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 {
		return resp.Kvs[0].ModRevision
	}
	return 0
}

func (s *stm) commit() *v3.TxnResponse {
	txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit()
	if err != nil {
		panic(stmError{err})
	}
	if txnresp.Succeeded {
		return txnresp
	}
	return nil
}

func (s *stm) fetch(keys ...string) *v3.GetResponse {
	if len(keys) == 0 {
		return nil
	}
	ops := make([]v3.Op, len(keys))
	for i, key := range keys {
		if resp, ok := s.rset[key]; ok {
			return resp
		}
		ops[i] = v3.OpGet(key, s.getOpts...)
	}
	txnresp, err := s.client.Txn(s.ctx).Then(ops...).Commit()
	if err != nil {
		panic(stmError{err})
	}
	s.rset.add(keys, txnresp)
	return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange())
}

func (s *stm) reset() {
	s.rset = make(map[string]*v3.GetResponse)
	s.wset = make(map[string]stmPut)
}

type stmSerializable struct {
	stm
	prefetch map[string]*v3.GetResponse
}

func (s *stmSerializable) Get(keys ...string) string {
	if wv := s.wset.get(keys...); wv != nil {
		return wv.val
	}
	firstRead := len(s.rset) == 0
	for _, key := range keys {
		if resp, ok := s.prefetch[key]; ok {
			delete(s.prefetch, key)
			s.rset[key] = resp
		}
	}
	resp := s.stm.fetch(keys...)
	if firstRead {
		// txn's base revision is defined by the first read
		s.getOpts = []v3.OpOption{
			v3.WithRev(resp.Header.Revision),
			v3.WithSerializable(),
		}
	}
	return respToValue(resp)
}

func (s *stmSerializable) Rev(key string) int64 {
	s.Get(key)
	return s.stm.Rev(key)
}

func (s *stmSerializable) gets() ([]string, []v3.Op) {
	keys := make([]string, 0, len(s.rset))
	ops := make([]v3.Op, 0, len(s.rset))
	for k := range s.rset {
		keys = append(keys, k)
		ops = append(ops, v3.OpGet(k))
	}
	return keys, ops
}

func (s *stmSerializable) commit() *v3.TxnResponse {
	keys, getops := s.gets()
	txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...)
	// use Else to prefetch keys in case of conflict to save a round trip
	txnresp, err := txn.Else(getops...).Commit()
	if err != nil {
		panic(stmError{err})
	}
	if txnresp.Succeeded {
		return txnresp
	}
	// load prefetch with Else data
	s.rset.add(keys, txnresp)
	s.prefetch = s.rset
	s.getOpts = nil
	return nil
}

func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
	if len(r.Kvs) != 0 {
		return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)
	}
	return v3.Compare(v3.ModRevision(k), "=", 0)
}

func respToValue(resp *v3.GetResponse) string {
	if resp == nil || len(resp.Kvs) == 0 {
		return ""
	}
	return string(resp.Kvs[0].Value)
}

// NewSTMRepeatable is deprecated.
func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
	return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(RepeatableReads))
}

// NewSTMSerializable is deprecated.
func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
	return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(Serializable))
}

// NewSTMReadCommitted is deprecated.
func NewSTMReadCommitted(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
	return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(ReadCommitted))
}