diff --git a/agent/consul/state/catalog.go b/agent/consul/state/catalog.go index eb9673ca60..8a0f8469aa 100644 --- a/agent/consul/state/catalog.go +++ b/agent/consul/state/catalog.go @@ -1,6 +1,7 @@ package state import ( + "errors" "fmt" "reflect" "strings" @@ -736,34 +737,31 @@ func (s *Store) EnsureService(idx uint64, node string, svc *structs.NodeService) return nil } +var errCASCompareFailed = errors.New("compare-and-set: comparison failed") + // ensureServiceCASTxn updates a service only if the existing index matches the given index. -// Returns a bool indicating if a write happened and any error. -func (s *Store) ensureServiceCASTxn(tx *memdb.Txn, idx uint64, node string, svc *structs.NodeService) (bool, error) { +// Returns an error if the write didn't happen and nil if write was successful. +func (s *Store) ensureServiceCASTxn(tx *memdb.Txn, idx uint64, node string, svc *structs.NodeService) error { // Retrieve the existing service. _, existing, err := firstWatchCompoundWithTxn(tx, "services", "id", &svc.EnterpriseMeta, node, svc.ID) if err != nil { - return false, fmt.Errorf("failed service lookup: %s", err) + return fmt.Errorf("failed service lookup: %s", err) } // Check if the we should do the set. A ModifyIndex of 0 means that // we are doing a set-if-not-exists. if svc.ModifyIndex == 0 && existing != nil { - return false, nil + return errCASCompareFailed } if svc.ModifyIndex != 0 && existing == nil { - return false, nil + return errCASCompareFailed } e, ok := existing.(*structs.ServiceNode) if ok && svc.ModifyIndex != 0 && svc.ModifyIndex != e.ModifyIndex { - return false, nil + return errCASCompareFailed } - // Perform the update. - if err := s.ensureServiceTxn(tx, idx, node, svc); err != nil { - return false, err - } - - return true, nil + return s.ensureServiceTxn(tx, idx, node, svc) } // ensureServiceTxn is used to upsert a service registration within an diff --git a/agent/consul/state/catalog_test.go b/agent/consul/state/catalog_test.go index 2627eab614..b92ada564d 100644 --- a/agent/consul/state/catalog_test.go +++ b/agent/consul/state/catalog_test.go @@ -2,6 +2,11 @@ package state import ( "fmt" + "reflect" + "sort" + "strings" + "testing" + "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/lib" @@ -11,10 +16,6 @@ import ( "github.com/pascaldekloe/goe/verify" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "reflect" - "sort" - "strings" - "testing" ) func makeRandomNodeID(t *testing.T) types.NodeID { @@ -4395,9 +4396,8 @@ func TestStateStore_ensureServiceCASTxn(t *testing.T) { // attempt to update with a 0 index tx := s.db.Txn(true) - update, err := s.ensureServiceCASTxn(tx, 3, "node1", &ns) - require.False(t, update) - require.NoError(t, err) + err := s.ensureServiceCASTxn(tx, 3, "node1", &ns) + require.Equal(t, err, errCASCompareFailed) tx.Commit() // ensure no update happened @@ -4411,9 +4411,8 @@ func TestStateStore_ensureServiceCASTxn(t *testing.T) { ns.ModifyIndex = 99 // attempt to update with a non-matching index tx = s.db.Txn(true) - update, err = s.ensureServiceCASTxn(tx, 4, "node1", &ns) - require.False(t, update) - require.NoError(t, err) + err = s.ensureServiceCASTxn(tx, 4, "node1", &ns) + require.Equal(t, err, errCASCompareFailed) tx.Commit() // ensure no update happened @@ -4427,8 +4426,7 @@ func TestStateStore_ensureServiceCASTxn(t *testing.T) { ns.ModifyIndex = 2 // update with the matching modify index tx = s.db.Txn(true) - update, err = s.ensureServiceCASTxn(tx, 7, "node1", &ns) - require.True(t, update) + err = s.ensureServiceCASTxn(tx, 7, "node1", &ns) require.NoError(t, err) tx.Commit() diff --git a/agent/consul/state/txn.go b/agent/consul/state/txn.go index 06f20681e9..f8c02e25a9 100644 --- a/agent/consul/state/txn.go +++ b/agent/consul/state/txn.go @@ -230,11 +230,13 @@ func (s *Store) txnService(tx *memdb.Txn, idx uint64, op *structs.TxnServiceOp) return newTxnResultFromNodeServiceEntry(entry), err case api.ServiceCAS: - ok, err := s.ensureServiceCASTxn(tx, idx, op.Node, &op.Service) - // TODO: err != nil case is ignored - if !ok && err == nil { + err := s.ensureServiceCASTxn(tx, idx, op.Node, &op.Service) + switch { + case err == errCASCompareFailed: err := fmt.Errorf("failed to set service %q on node %q, index is stale", op.Service.ID, op.Node) return nil, err + case err != nil: + return nil, err } entry, err := s.getNodeServiceTxn(tx, op.Node, op.Service.ID, &op.Service.EnterpriseMeta)