state: use an error to indicate compare failed

Errors are values. We can use the error value to identify the 'comparison failed' case which makes the function easier to use and should make it harder to miss handle the error case
pull/7933/head
Daniel Nephin 2020-05-20 12:43:33 -04:00
parent 3dd8b66aa2
commit 3f607d9ef0
3 changed files with 24 additions and 26 deletions

View File

@ -1,6 +1,7 @@
package state
import (
"errors"
"fmt"
"reflect"
"strings"
@ -736,34 +737,31 @@ func (s *Store) EnsureService(idx uint64, node string, svc *structs.NodeService)
return nil
}
var errCASCompareFailed = errors.New("compare-and-set: comparison failed")
// ensureServiceCASTxn updates a service only if the existing index matches the given index.
// Returns a bool indicating if a write happened and any error.
func (s *Store) ensureServiceCASTxn(tx *memdb.Txn, idx uint64, node string, svc *structs.NodeService) (bool, error) {
func (s *Store) ensureServiceCASTxn(tx *memdb.Txn, idx uint64, node string, svc *structs.NodeService) error {
// Retrieve the existing service.
_, existing, err := firstWatchCompoundWithTxn(tx, "services", "id", &svc.EnterpriseMeta, node, svc.ID)
if err != nil {
return false, fmt.Errorf("failed service lookup: %s", err)
return fmt.Errorf("failed service lookup: %s", err)
}
// Check if the we should do the set. A ModifyIndex of 0 means that
// we are doing a set-if-not-exists.
if svc.ModifyIndex == 0 && existing != nil {
return false, nil
return errCASCompareFailed
}
if svc.ModifyIndex != 0 && existing == nil {
return false, nil
return errCASCompareFailed
}
e, ok := existing.(*structs.ServiceNode)
if ok && svc.ModifyIndex != 0 && svc.ModifyIndex != e.ModifyIndex {
return false, nil
return errCASCompareFailed
}
// Perform the update.
if err := s.ensureServiceTxn(tx, idx, node, svc); err != nil {
return false, err
}
return true, nil
return s.ensureServiceTxn(tx, idx, node, svc)
}
// ensureServiceTxn is used to upsert a service registration within an

View File

@ -2,6 +2,11 @@ package state
import (
"fmt"
"reflect"
"sort"
"strings"
"testing"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/lib"
@ -11,10 +16,6 @@ import (
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"reflect"
"sort"
"strings"
"testing"
)
func makeRandomNodeID(t *testing.T) types.NodeID {
@ -4395,9 +4396,8 @@ func TestStateStore_ensureServiceCASTxn(t *testing.T) {
// attempt to update with a 0 index
tx := s.db.Txn(true)
update, err := s.ensureServiceCASTxn(tx, 3, "node1", &ns)
require.False(t, update)
require.NoError(t, err)
err := s.ensureServiceCASTxn(tx, 3, "node1", &ns)
require.Equal(t, err, errCASCompareFailed)
tx.Commit()
// ensure no update happened
@ -4411,9 +4411,8 @@ func TestStateStore_ensureServiceCASTxn(t *testing.T) {
ns.ModifyIndex = 99
// attempt to update with a non-matching index
tx = s.db.Txn(true)
update, err = s.ensureServiceCASTxn(tx, 4, "node1", &ns)
require.False(t, update)
require.NoError(t, err)
err = s.ensureServiceCASTxn(tx, 4, "node1", &ns)
require.Equal(t, err, errCASCompareFailed)
tx.Commit()
// ensure no update happened
@ -4427,8 +4426,7 @@ func TestStateStore_ensureServiceCASTxn(t *testing.T) {
ns.ModifyIndex = 2
// update with the matching modify index
tx = s.db.Txn(true)
update, err = s.ensureServiceCASTxn(tx, 7, "node1", &ns)
require.True(t, update)
err = s.ensureServiceCASTxn(tx, 7, "node1", &ns)
require.NoError(t, err)
tx.Commit()

View File

@ -230,11 +230,13 @@ func (s *Store) txnService(tx *memdb.Txn, idx uint64, op *structs.TxnServiceOp)
return newTxnResultFromNodeServiceEntry(entry), err
case api.ServiceCAS:
ok, err := s.ensureServiceCASTxn(tx, idx, op.Node, &op.Service)
// TODO: err != nil case is ignored
if !ok && err == nil {
err := s.ensureServiceCASTxn(tx, idx, op.Node, &op.Service)
switch {
case err == errCASCompareFailed:
err := fmt.Errorf("failed to set service %q on node %q, index is stale", op.Service.ID, op.Node)
return nil, err
case err != nil:
return nil, err
}
entry, err := s.getNodeServiceTxn(tx, op.Node, op.Service.ID, &op.Service.EnterpriseMeta)