mirror of https://github.com/hashicorp/consul
Merge pull request #8802 from hashicorp/dnephin/extract-lib-retry
lib/retry - extract a new package from lib/retry.gopull/8821/head
commit
5a5fd4f0b1
|
@ -7,13 +7,14 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/config"
|
||||
"github.com/hashicorp/consul/agent/token"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/lib/retry"
|
||||
"github.com/hashicorp/consul/logging"
|
||||
"github.com/hashicorp/consul/proto/pbautoconf"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
)
|
||||
|
||||
// AutoConfig is all the state necessary for being able to parse a configuration
|
||||
|
@ -24,7 +25,7 @@ type AutoConfig struct {
|
|||
acConfig Config
|
||||
logger hclog.Logger
|
||||
cache Cache
|
||||
waiter *lib.RetryWaiter
|
||||
waiter *retry.Waiter
|
||||
config *config.RuntimeConfig
|
||||
autoConfigResponse *pbautoconf.AutoConfigResponse
|
||||
autoConfigSource config.Source
|
||||
|
@ -84,7 +85,11 @@ func New(config Config) (*AutoConfig, error) {
|
|||
}
|
||||
|
||||
if config.Waiter == nil {
|
||||
config.Waiter = lib.NewRetryWaiter(1, 0, 10*time.Minute, lib.NewJitterRandomStagger(25))
|
||||
config.Waiter = &retry.Waiter{
|
||||
MinFailures: 1,
|
||||
MaxWait: 10 * time.Minute,
|
||||
Jitter: retry.NewJitter(25),
|
||||
}
|
||||
}
|
||||
|
||||
return &AutoConfig{
|
||||
|
@ -305,23 +310,21 @@ func (ac *AutoConfig) getInitialConfiguration(ctx context.Context) (*pbautoconf.
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// this resets the failures so that we will perform immediate request
|
||||
wait := ac.acConfig.Waiter.Success()
|
||||
ac.acConfig.Waiter.Reset()
|
||||
for {
|
||||
select {
|
||||
case <-wait:
|
||||
if resp, err := ac.getInitialConfigurationOnce(ctx, csr, key); err == nil && resp != nil {
|
||||
resp, err := ac.getInitialConfigurationOnce(ctx, csr, key)
|
||||
switch {
|
||||
case err == nil && resp != nil:
|
||||
return resp, nil
|
||||
} else if err != nil {
|
||||
case err != nil:
|
||||
ac.logger.Error(err.Error())
|
||||
} else {
|
||||
default:
|
||||
ac.logger.Error("No error returned when fetching configuration from the servers but no response was either")
|
||||
}
|
||||
|
||||
wait = ac.acConfig.Waiter.Failed()
|
||||
case <-ctx.Done():
|
||||
ac.logger.Info("interrupted during initial auto configuration", "err", ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
if err := ac.acConfig.Waiter.Wait(ctx); err != nil {
|
||||
ac.logger.Info("interrupted during initial auto configuration", "err", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,6 +11,9 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
cachetype "github.com/hashicorp/consul/agent/cache-types"
|
||||
"github.com/hashicorp/consul/agent/config"
|
||||
|
@ -18,13 +21,11 @@ import (
|
|||
"github.com/hashicorp/consul/agent/metadata"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/agent/token"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/lib/retry"
|
||||
"github.com/hashicorp/consul/proto/pbautoconf"
|
||||
"github.com/hashicorp/consul/proto/pbconfig"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/sdk/testutil/retry"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
testretry "github.com/hashicorp/consul/sdk/testutil/retry"
|
||||
)
|
||||
|
||||
type configLoader struct {
|
||||
|
@ -412,7 +413,7 @@ func TestInitialConfiguration_retries(t *testing.T) {
|
|||
mcfg.Config.Loader = loader.Load
|
||||
|
||||
// reduce the retry wait times to make this test run faster
|
||||
mcfg.Config.Waiter = lib.NewRetryWaiter(2, 0, 1*time.Millisecond, nil)
|
||||
mcfg.Config.Waiter = &retry.Waiter{MinFailures: 2, MaxWait: time.Millisecond}
|
||||
|
||||
indexedRoots, cert, extraCerts := mcfg.setupInitialTLS(t, "autoconf", "dc1", "secret")
|
||||
|
||||
|
@ -927,7 +928,7 @@ func TestRootsUpdate(t *testing.T) {
|
|||
// however there is no deterministic way to know once its been written outside of maybe a filesystem
|
||||
// event notifier. That seems a little heavy handed just for this and especially to do in any sort
|
||||
// of cross platform way.
|
||||
retry.Run(t, func(r *retry.R) {
|
||||
testretry.Run(t, func(r *testretry.R) {
|
||||
resp, err := testAC.ac.readPersistedAutoConfig()
|
||||
require.NoError(r, err)
|
||||
require.Equal(r, secondRoots.ActiveRootID, resp.CARoots.GetActiveRootID())
|
||||
|
@ -972,7 +973,7 @@ func TestCertUpdate(t *testing.T) {
|
|||
// persisting these to disk happens after all the things we would wait for in assertCertUpdated
|
||||
// will have fired. There is no deterministic way to know once its been written so we wrap
|
||||
// this in a retry.
|
||||
retry.Run(t, func(r *retry.R) {
|
||||
testretry.Run(t, func(r *testretry.R) {
|
||||
resp, err := testAC.ac.readPersistedAutoConfig()
|
||||
require.NoError(r, err)
|
||||
|
||||
|
@ -1099,7 +1100,7 @@ func TestFallback(t *testing.T) {
|
|||
|
||||
// persisting these to disk happens after the RPC we waited on above will have fired
|
||||
// There is no deterministic way to know once its been written so we wrap this in a retry.
|
||||
retry.Run(t, func(r *retry.R) {
|
||||
testretry.Run(t, func(r *testretry.R) {
|
||||
resp, err := testAC.ac.readPersistedAutoConfig()
|
||||
require.NoError(r, err)
|
||||
|
||||
|
|
|
@ -16,23 +16,21 @@ func (ac *AutoConfig) autoEncryptInitialCerts(ctx context.Context) (*structs.Sig
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// this resets the failures so that we will perform immediate request
|
||||
wait := ac.acConfig.Waiter.Success()
|
||||
ac.acConfig.Waiter.Reset()
|
||||
for {
|
||||
select {
|
||||
case <-wait:
|
||||
if resp, err := ac.autoEncryptInitialCertsOnce(ctx, csr, key); err == nil && resp != nil {
|
||||
resp, err := ac.autoEncryptInitialCertsOnce(ctx, csr, key)
|
||||
switch {
|
||||
case err == nil && resp != nil:
|
||||
return resp, nil
|
||||
} else if err != nil {
|
||||
case err != nil:
|
||||
ac.logger.Error(err.Error())
|
||||
} else {
|
||||
default:
|
||||
ac.logger.Error("No error returned when fetching certificates from the servers but no response was either")
|
||||
}
|
||||
|
||||
wait = ac.acConfig.Waiter.Failed()
|
||||
case <-ctx.Done():
|
||||
ac.logger.Info("interrupted during retrieval of auto-encrypt certificates", "err", ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
if err := ac.acConfig.Waiter.Wait(ctx); err != nil {
|
||||
ac.logger.Info("interrupted during retrieval of auto-encrypt certificates", "err", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,16 +11,17 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
cachetype "github.com/hashicorp/consul/agent/cache-types"
|
||||
"github.com/hashicorp/consul/agent/config"
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/metadata"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/lib/retry"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAutoEncrypt_generateCSR(t *testing.T) {
|
||||
|
@ -247,7 +248,7 @@ func TestAutoEncrypt_InitialCerts(t *testing.T) {
|
|||
resp.VerifyServerHostname = true
|
||||
})
|
||||
|
||||
mcfg.Config.Waiter = lib.NewRetryWaiter(2, 0, 1*time.Millisecond, nil)
|
||||
mcfg.Config.Waiter = &retry.Waiter{MinFailures: 2, MaxWait: time.Millisecond}
|
||||
|
||||
ac := AutoConfig{
|
||||
config: &config.RuntimeConfig{
|
||||
|
|
|
@ -5,12 +5,13 @@ import (
|
|||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/config"
|
||||
"github.com/hashicorp/consul/agent/metadata"
|
||||
"github.com/hashicorp/consul/agent/token"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/consul/lib/retry"
|
||||
)
|
||||
|
||||
// DirectRPC is the interface that needs to be satisifed for AutoConfig to be able to perform
|
||||
|
@ -67,17 +68,17 @@ type Config struct {
|
|||
// known servers during fallback operations.
|
||||
ServerProvider ServerProvider
|
||||
|
||||
// Waiter is a RetryWaiter to be used during retrieval of the
|
||||
// initial configuration. When a round of requests fails we will
|
||||
// Waiter is used during retrieval of the initial configuration.
|
||||
// When around of requests fails we will
|
||||
// wait and eventually make another round of requests (1 round
|
||||
// is trying the RPC once against each configured server addr). The
|
||||
// waiting implements some backoff to prevent from retrying these RPCs
|
||||
// to often. This field is not required and if left unset a waiter will
|
||||
// too often. This field is not required and if left unset a waiter will
|
||||
// be used that has a max wait duration of 10 minutes and a randomized
|
||||
// jitter of 25% of the wait time. Setting this is mainly useful for
|
||||
// testing purposes to allow testing out the retrying functionality without
|
||||
// having the test take minutes/hours to complete.
|
||||
Waiter *lib.RetryWaiter
|
||||
Waiter *retry.Waiter
|
||||
|
||||
// Loader merges source with the existing FileSources and returns the complete
|
||||
// RuntimeConfig.
|
||||
|
|
|
@ -7,18 +7,17 @@ import (
|
|||
"time"
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/logging"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/hashicorp/consul/lib/retry"
|
||||
"github.com/hashicorp/consul/logging"
|
||||
)
|
||||
|
||||
const (
|
||||
// replicationMaxRetryWait is the maximum number of seconds to wait between
|
||||
// failed blocking queries when backing off.
|
||||
replicationDefaultMaxRetryWait = 120 * time.Second
|
||||
|
||||
replicationDefaultRate = 1
|
||||
)
|
||||
|
||||
type ReplicatorDelegate interface {
|
||||
|
@ -35,7 +34,7 @@ type ReplicatorConfig struct {
|
|||
// The number of replication rounds that can be done in a burst
|
||||
Burst int
|
||||
// Minimum number of RPC failures to ignore before backing off
|
||||
MinFailures int
|
||||
MinFailures uint
|
||||
// Maximum wait time between failing RPCs
|
||||
MaxRetryWait time.Duration
|
||||
// Where to send our logs
|
||||
|
@ -46,7 +45,7 @@ type ReplicatorConfig struct {
|
|||
|
||||
type Replicator struct {
|
||||
limiter *rate.Limiter
|
||||
waiter *lib.RetryWaiter
|
||||
waiter *retry.Waiter
|
||||
delegate ReplicatorDelegate
|
||||
logger hclog.Logger
|
||||
lastRemoteIndex uint64
|
||||
|
@ -70,12 +69,11 @@ func NewReplicator(config *ReplicatorConfig) (*Replicator, error) {
|
|||
if maxWait == 0 {
|
||||
maxWait = replicationDefaultMaxRetryWait
|
||||
}
|
||||
|
||||
minFailures := config.MinFailures
|
||||
if minFailures < 0 {
|
||||
minFailures = 0
|
||||
waiter := &retry.Waiter{
|
||||
MinFailures: config.MinFailures,
|
||||
MaxWait: maxWait,
|
||||
Jitter: retry.NewJitter(10),
|
||||
}
|
||||
waiter := lib.NewRetryWaiter(minFailures, 0*time.Second, maxWait, lib.NewJitterRandomStagger(10))
|
||||
return &Replicator{
|
||||
limiter: limiter,
|
||||
waiter: waiter,
|
||||
|
@ -99,10 +97,8 @@ func (r *Replicator) Run(ctx context.Context) error {
|
|||
// Perform a single round of replication
|
||||
index, exit, err := r.delegate.Replicate(ctx, atomic.LoadUint64(&r.lastRemoteIndex), r.logger)
|
||||
if exit {
|
||||
// the replication function told us to exit
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// reset the lastRemoteIndex when there is an RPC failure. This should cause a full sync to be done during
|
||||
// the next round of replication
|
||||
|
@ -111,18 +107,16 @@ func (r *Replicator) Run(ctx context.Context) error {
|
|||
if r.suppressErrorLog != nil && !r.suppressErrorLog(err) {
|
||||
r.logger.Warn("replication error (will retry if still leader)", "error", err)
|
||||
}
|
||||
} else {
|
||||
atomic.StoreUint64(&r.lastRemoteIndex, index)
|
||||
r.logger.Debug("replication completed through remote index", "index", index)
|
||||
|
||||
if err := r.waiter.Wait(ctx); err != nil {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
// wait some amount of time to prevent churning through many replication rounds while replication is failing
|
||||
case <-r.waiter.WaitIfErr(err):
|
||||
// do nothing
|
||||
}
|
||||
atomic.StoreUint64(&r.lastRemoteIndex, index)
|
||||
r.logger.Debug("replication completed through remote index", "index", index)
|
||||
r.waiter.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
156
lib/retry.go
156
lib/retry.go
|
@ -1,156 +0,0 @@
|
|||
package lib
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMinFailures = 0
|
||||
defaultMaxWait = 2 * time.Minute
|
||||
)
|
||||
|
||||
// Interface used for offloading jitter calculations from the RetryWaiter
|
||||
type Jitter interface {
|
||||
AddJitter(baseTime time.Duration) time.Duration
|
||||
}
|
||||
|
||||
// Calculates a random jitter between 0 and up to a specific percentage of the baseTime
|
||||
type JitterRandomStagger struct {
|
||||
// int64 because we are going to be doing math against an int64 to represent nanoseconds
|
||||
percent int64
|
||||
}
|
||||
|
||||
// Creates a new JitterRandomStagger
|
||||
func NewJitterRandomStagger(percent int) *JitterRandomStagger {
|
||||
if percent < 0 {
|
||||
percent = 0
|
||||
}
|
||||
|
||||
return &JitterRandomStagger{
|
||||
percent: int64(percent),
|
||||
}
|
||||
}
|
||||
|
||||
// Implments the Jitter interface
|
||||
func (j *JitterRandomStagger) AddJitter(baseTime time.Duration) time.Duration {
|
||||
if j.percent == 0 {
|
||||
return baseTime
|
||||
}
|
||||
|
||||
// time.Duration is actually a type alias for int64 which is why casting
|
||||
// to the duration type and then dividing works
|
||||
return baseTime + RandomStagger((baseTime*time.Duration(j.percent))/100)
|
||||
}
|
||||
|
||||
// RetryWaiter will record failed and successful operations and provide
|
||||
// a channel to wait on before a failed operation can be retried.
|
||||
type RetryWaiter struct {
|
||||
minFailures uint
|
||||
minWait time.Duration
|
||||
maxWait time.Duration
|
||||
jitter Jitter
|
||||
failures uint
|
||||
}
|
||||
|
||||
// Creates a new RetryWaiter
|
||||
func NewRetryWaiter(minFailures int, minWait, maxWait time.Duration, jitter Jitter) *RetryWaiter {
|
||||
if minFailures < 0 {
|
||||
minFailures = defaultMinFailures
|
||||
}
|
||||
|
||||
if maxWait <= 0 {
|
||||
maxWait = defaultMaxWait
|
||||
}
|
||||
|
||||
if minWait <= 0 {
|
||||
minWait = 0 * time.Nanosecond
|
||||
}
|
||||
|
||||
return &RetryWaiter{
|
||||
minFailures: uint(minFailures),
|
||||
minWait: minWait,
|
||||
maxWait: maxWait,
|
||||
failures: 0,
|
||||
jitter: jitter,
|
||||
}
|
||||
}
|
||||
|
||||
// calculates the necessary wait time before the
|
||||
// next operation should be allowed.
|
||||
func (rw *RetryWaiter) calculateWait() time.Duration {
|
||||
waitTime := rw.minWait
|
||||
if rw.failures > rw.minFailures {
|
||||
shift := rw.failures - rw.minFailures - 1
|
||||
waitTime = rw.maxWait
|
||||
if shift < 31 {
|
||||
waitTime = (1 << shift) * time.Second
|
||||
}
|
||||
if waitTime > rw.maxWait {
|
||||
waitTime = rw.maxWait
|
||||
}
|
||||
|
||||
if rw.jitter != nil {
|
||||
waitTime = rw.jitter.AddJitter(waitTime)
|
||||
}
|
||||
}
|
||||
|
||||
if waitTime < rw.minWait {
|
||||
waitTime = rw.minWait
|
||||
}
|
||||
|
||||
return waitTime
|
||||
}
|
||||
|
||||
// calculates the waitTime and returns a chan
|
||||
// that will become selectable once that amount
|
||||
// of time has elapsed.
|
||||
func (rw *RetryWaiter) wait() <-chan struct{} {
|
||||
waitTime := rw.calculateWait()
|
||||
ch := make(chan struct{})
|
||||
if waitTime > 0 {
|
||||
time.AfterFunc(waitTime, func() { close(ch) })
|
||||
} else {
|
||||
// if there should be 0 wait time then we ensure
|
||||
// that the chan will be immediately selectable
|
||||
close(ch)
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
// Marks that an operation is successful which resets the failure count.
|
||||
// The chan that is returned will be immediately selectable
|
||||
func (rw *RetryWaiter) Success() <-chan struct{} {
|
||||
rw.Reset()
|
||||
return rw.wait()
|
||||
}
|
||||
|
||||
// Marks that an operation failed. The chan returned will be selectable
|
||||
// once the calculated retry wait amount of time has elapsed
|
||||
func (rw *RetryWaiter) Failed() <-chan struct{} {
|
||||
rw.failures += 1
|
||||
ch := rw.wait()
|
||||
return ch
|
||||
}
|
||||
|
||||
// Resets the internal failure counter
|
||||
func (rw *RetryWaiter) Reset() {
|
||||
rw.failures = 0
|
||||
}
|
||||
|
||||
// WaitIf is a convenice method to record whether the last
|
||||
// operation was a success or failure and return a chan that
|
||||
// will be selectablw when the next operation can be done.
|
||||
func (rw *RetryWaiter) WaitIf(failure bool) <-chan struct{} {
|
||||
if failure {
|
||||
return rw.Failed()
|
||||
}
|
||||
return rw.Success()
|
||||
}
|
||||
|
||||
// WaitIfErr is a convenience method to record whether the last
|
||||
// operation was a success or failure based on whether the err
|
||||
// is nil and then return a chan that will be selectable when
|
||||
// the next operation can be done.
|
||||
func (rw *RetryWaiter) WaitIfErr(err error) <-chan struct{} {
|
||||
return rw.WaitIf(err != nil)
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
package retry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMinFailures = 0
|
||||
defaultMaxWait = 2 * time.Minute
|
||||
)
|
||||
|
||||
// Jitter should return a new wait duration optionally with some time added or
|
||||
// removed to create some randomness in wait time.
|
||||
type Jitter func(baseTime time.Duration) time.Duration
|
||||
|
||||
// NewJitter returns a new random Jitter that is up to percent longer than the
|
||||
// original wait time.
|
||||
func NewJitter(percent int64) Jitter {
|
||||
if percent < 0 {
|
||||
percent = 0
|
||||
}
|
||||
|
||||
return func(baseTime time.Duration) time.Duration {
|
||||
if percent == 0 {
|
||||
return baseTime
|
||||
}
|
||||
max := (int64(baseTime) * percent) / 100
|
||||
if max < 0 { // overflow
|
||||
return baseTime
|
||||
}
|
||||
return baseTime + time.Duration(rand.Int63n(max))
|
||||
}
|
||||
}
|
||||
|
||||
// Waiter records the number of failures and performs exponential backoff when
|
||||
// when there are consecutive failures.
|
||||
type Waiter struct {
|
||||
// MinFailures before exponential backoff starts. Any failures before
|
||||
// MinFailures is reached will wait MinWait time.
|
||||
MinFailures uint
|
||||
// MinWait time. Returned after the first failure.
|
||||
MinWait time.Duration
|
||||
// MaxWait time.
|
||||
MaxWait time.Duration
|
||||
// Jitter to add to each wait time.
|
||||
Jitter Jitter
|
||||
// Factor is the multiplier to use when calculating the delay. Defaults to
|
||||
// 1 second.
|
||||
Factor time.Duration
|
||||
failures uint
|
||||
}
|
||||
|
||||
// delay calculates the time to wait based on the number of failures
|
||||
func (w *Waiter) delay() time.Duration {
|
||||
if w.failures <= w.MinFailures {
|
||||
return w.MinWait
|
||||
}
|
||||
factor := w.Factor
|
||||
if factor == 0 {
|
||||
factor = time.Second
|
||||
}
|
||||
|
||||
shift := w.failures - w.MinFailures - 1
|
||||
waitTime := w.MaxWait
|
||||
if shift < 31 {
|
||||
waitTime = (1 << shift) * factor
|
||||
}
|
||||
if w.Jitter != nil {
|
||||
waitTime = w.Jitter(waitTime)
|
||||
}
|
||||
if w.MaxWait != 0 && waitTime > w.MaxWait {
|
||||
return w.MaxWait
|
||||
}
|
||||
if waitTime < w.MinWait {
|
||||
return w.MinWait
|
||||
}
|
||||
return waitTime
|
||||
}
|
||||
|
||||
// Reset the failure count to 0.
|
||||
func (w *Waiter) Reset() {
|
||||
w.failures = 0
|
||||
}
|
||||
|
||||
// Failures returns the count of consecutive failures.
|
||||
func (w *Waiter) Failures() int {
|
||||
return int(w.failures)
|
||||
}
|
||||
|
||||
// Wait increase the number of failures by one, and then blocks until the context
|
||||
// is cancelled, or until the wait time is reached.
|
||||
// The wait time increases exponentially as the number of failures increases.
|
||||
// Wait will return ctx.Err() if the context is cancelled.
|
||||
func (w *Waiter) Wait(ctx context.Context) error {
|
||||
w.failures++
|
||||
timer := time.NewTimer(w.delay())
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
|
@ -0,0 +1,160 @@
|
|||
package retry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestJitter(t *testing.T) {
|
||||
repeat(t, "0 percent", func(t *testing.T) {
|
||||
jitter := NewJitter(0)
|
||||
for i := 0; i < 10; i++ {
|
||||
baseTime := time.Duration(i) * time.Second
|
||||
require.Equal(t, baseTime, jitter(baseTime))
|
||||
}
|
||||
})
|
||||
|
||||
repeat(t, "10 percent", func(t *testing.T) {
|
||||
jitter := NewJitter(10)
|
||||
baseTime := 5000 * time.Millisecond
|
||||
maxTime := 5500 * time.Millisecond
|
||||
newTime := jitter(baseTime)
|
||||
require.True(t, newTime > baseTime)
|
||||
require.True(t, newTime <= maxTime)
|
||||
})
|
||||
|
||||
repeat(t, "100 percent", func(t *testing.T) {
|
||||
jitter := NewJitter(100)
|
||||
baseTime := 1234 * time.Millisecond
|
||||
maxTime := 2468 * time.Millisecond
|
||||
newTime := jitter(baseTime)
|
||||
require.True(t, newTime > baseTime)
|
||||
require.True(t, newTime <= maxTime)
|
||||
})
|
||||
|
||||
repeat(t, "overflow", func(t *testing.T) {
|
||||
jitter := NewJitter(100)
|
||||
baseTime := time.Duration(math.MaxInt64) - 2*time.Hour
|
||||
newTime := jitter(baseTime)
|
||||
require.Equal(t, baseTime, newTime)
|
||||
})
|
||||
}
|
||||
|
||||
func repeat(t *testing.T, name string, fn func(t *testing.T)) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
for i := 0; i < 1000; i++ {
|
||||
fn(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaiter_Delay(t *testing.T) {
|
||||
t.Run("zero value", func(t *testing.T) {
|
||||
w := &Waiter{}
|
||||
for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 32, 64, 128} {
|
||||
w.failures = uint(i)
|
||||
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with minimum wait", func(t *testing.T) {
|
||||
w := &Waiter{MinWait: 5 * time.Second}
|
||||
for i, expected := range []time.Duration{5, 5, 5, 5, 8, 16, 32, 64, 128} {
|
||||
w.failures = uint(i)
|
||||
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with maximum wait", func(t *testing.T) {
|
||||
w := &Waiter{MaxWait: 20 * time.Second}
|
||||
for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 20, 20, 20} {
|
||||
w.failures = uint(i)
|
||||
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with minimum failures", func(t *testing.T) {
|
||||
w := &Waiter{MinFailures: 4}
|
||||
for i, expected := range []time.Duration{0, 0, 0, 0, 0, 1, 2, 4, 8, 16} {
|
||||
w.failures = uint(i)
|
||||
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with factor", func(t *testing.T) {
|
||||
w := &Waiter{Factor: time.Millisecond}
|
||||
for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 32, 64, 128} {
|
||||
w.failures = uint(i)
|
||||
require.Equal(t, expected*time.Millisecond, w.delay(), "failure count: %d", i)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with all settings", func(t *testing.T) {
|
||||
w := &Waiter{
|
||||
MinFailures: 2,
|
||||
MinWait: 4 * time.Millisecond,
|
||||
MaxWait: 20 * time.Millisecond,
|
||||
Factor: time.Millisecond,
|
||||
}
|
||||
for i, expected := range []time.Duration{4, 4, 4, 4, 4, 4, 8, 16, 20, 20, 20} {
|
||||
w.failures = uint(i)
|
||||
require.Equal(t, expected*time.Millisecond, w.delay(), "failure count: %d", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaiter_Wait(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("first failure", func(t *testing.T) {
|
||||
w := &Waiter{MinWait: time.Millisecond, Factor: 1}
|
||||
elapsed, err := runWait(ctx, w)
|
||||
require.NoError(t, err)
|
||||
assertApproximateDuration(t, elapsed, time.Millisecond)
|
||||
require.Equal(t, w.failures, uint(1))
|
||||
})
|
||||
|
||||
t.Run("max failures", func(t *testing.T) {
|
||||
w := &Waiter{
|
||||
MaxWait: 100 * time.Millisecond,
|
||||
failures: 200,
|
||||
}
|
||||
elapsed, err := runWait(ctx, w)
|
||||
require.NoError(t, err)
|
||||
assertApproximateDuration(t, elapsed, 100*time.Millisecond)
|
||||
require.Equal(t, w.failures, uint(201))
|
||||
})
|
||||
|
||||
t.Run("context deadline", func(t *testing.T) {
|
||||
w := &Waiter{failures: 200, MinWait: time.Second}
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Millisecond)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
elapsed, err := runWait(ctx, w)
|
||||
require.Equal(t, err, context.DeadlineExceeded)
|
||||
assertApproximateDuration(t, elapsed, 5*time.Millisecond)
|
||||
require.Equal(t, w.failures, uint(201))
|
||||
})
|
||||
}
|
||||
|
||||
func runWait(ctx context.Context, w *Waiter) (time.Duration, error) {
|
||||
before := time.Now()
|
||||
err := w.Wait(ctx)
|
||||
return time.Since(before), err
|
||||
}
|
||||
|
||||
func assertApproximateDuration(t *testing.T, actual time.Duration, expected time.Duration) {
|
||||
t.Helper()
|
||||
delta := 20 * time.Millisecond
|
||||
min, max := expected-delta, expected+delta
|
||||
if min < 0 {
|
||||
min = 0
|
||||
}
|
||||
if actual < min || actual > max {
|
||||
t.Fatalf("expected %v to be between %v and %v", actual, min, max)
|
||||
}
|
||||
}
|
|
@ -1,184 +0,0 @@
|
|||
package lib
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestJitterRandomStagger(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("0 percent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
jitter := NewJitterRandomStagger(0)
|
||||
for i := 0; i < 10; i++ {
|
||||
baseTime := time.Duration(i) * time.Second
|
||||
require.Equal(t, baseTime, jitter.AddJitter(baseTime))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("10 percent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
jitter := NewJitterRandomStagger(10)
|
||||
for i := 0; i < 10; i++ {
|
||||
baseTime := 5000 * time.Millisecond
|
||||
maxTime := 5500 * time.Millisecond
|
||||
newTime := jitter.AddJitter(baseTime)
|
||||
require.True(t, newTime > baseTime)
|
||||
require.True(t, newTime <= maxTime)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("100 percent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
jitter := NewJitterRandomStagger(100)
|
||||
for i := 0; i < 10; i++ {
|
||||
baseTime := 1234 * time.Millisecond
|
||||
maxTime := 2468 * time.Millisecond
|
||||
newTime := jitter.AddJitter(baseTime)
|
||||
require.True(t, newTime > baseTime)
|
||||
require.True(t, newTime <= maxTime)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetryWaiter_calculateWait(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rw := NewRetryWaiter(0, 0, 0, nil)
|
||||
|
||||
require.Equal(t, 0*time.Nanosecond, rw.calculateWait())
|
||||
rw.failures += 1
|
||||
require.Equal(t, 1*time.Second, rw.calculateWait())
|
||||
rw.failures += 1
|
||||
require.Equal(t, 2*time.Second, rw.calculateWait())
|
||||
rw.failures = 31
|
||||
require.Equal(t, defaultMaxWait, rw.calculateWait())
|
||||
})
|
||||
|
||||
t.Run("Minimum Wait", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rw := NewRetryWaiter(0, 5*time.Second, 0, nil)
|
||||
|
||||
require.Equal(t, 5*time.Second, rw.calculateWait())
|
||||
rw.failures += 1
|
||||
require.Equal(t, 5*time.Second, rw.calculateWait())
|
||||
rw.failures += 1
|
||||
require.Equal(t, 5*time.Second, rw.calculateWait())
|
||||
rw.failures += 1
|
||||
require.Equal(t, 5*time.Second, rw.calculateWait())
|
||||
rw.failures += 1
|
||||
require.Equal(t, 8*time.Second, rw.calculateWait())
|
||||
})
|
||||
|
||||
t.Run("Minimum Failures", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rw := NewRetryWaiter(5, 0, 0, nil)
|
||||
require.Equal(t, 0*time.Nanosecond, rw.calculateWait())
|
||||
rw.failures += 5
|
||||
require.Equal(t, 0*time.Nanosecond, rw.calculateWait())
|
||||
rw.failures += 1
|
||||
require.Equal(t, 1*time.Second, rw.calculateWait())
|
||||
})
|
||||
|
||||
t.Run("Maximum Wait", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rw := NewRetryWaiter(0, 0, 5*time.Second, nil)
|
||||
require.Equal(t, 0*time.Nanosecond, rw.calculateWait())
|
||||
rw.failures += 1
|
||||
require.Equal(t, 1*time.Second, rw.calculateWait())
|
||||
rw.failures += 1
|
||||
require.Equal(t, 2*time.Second, rw.calculateWait())
|
||||
rw.failures += 1
|
||||
require.Equal(t, 4*time.Second, rw.calculateWait())
|
||||
rw.failures += 1
|
||||
require.Equal(t, 5*time.Second, rw.calculateWait())
|
||||
rw.failures = 31
|
||||
require.Equal(t, 5*time.Second, rw.calculateWait())
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetryWaiter_WaitChans(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Minimum Wait - Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rw := NewRetryWaiter(0, 250*time.Millisecond, 0, nil)
|
||||
|
||||
select {
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
case <-rw.Success():
|
||||
require.Fail(t, "minimum wait not respected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Minimum Wait - WaitIf", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rw := NewRetryWaiter(0, 250*time.Millisecond, 0, nil)
|
||||
|
||||
select {
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
case <-rw.WaitIf(false):
|
||||
require.Fail(t, "minimum wait not respected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Minimum Wait - WaitIfErr", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rw := NewRetryWaiter(0, 250*time.Millisecond, 0, nil)
|
||||
|
||||
select {
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
case <-rw.WaitIfErr(nil):
|
||||
require.Fail(t, "minimum wait not respected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Maximum Wait - Failed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rw := NewRetryWaiter(0, 0, 250*time.Millisecond, nil)
|
||||
|
||||
select {
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
require.Fail(t, "maximum wait not respected")
|
||||
case <-rw.Failed():
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Maximum Wait - WaitIf", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rw := NewRetryWaiter(0, 0, 250*time.Millisecond, nil)
|
||||
|
||||
select {
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
require.Fail(t, "maximum wait not respected")
|
||||
case <-rw.WaitIf(true):
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Maximum Wait - WaitIfErr", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rw := NewRetryWaiter(0, 0, 250*time.Millisecond, nil)
|
||||
|
||||
select {
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
require.Fail(t, "maximum wait not respected")
|
||||
case <-rw.WaitIfErr(fmt.Errorf("Fake Error")):
|
||||
}
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue