lib/retry: Refactor to reduce the interface surface

Reduce Jitter to one function

Rename NewRetryWaiter

Fix a bug in calculateWait where maxWait was applied before jitter, which would make it
possible to wait longer than maxWait.
pull/8802/head
Daniel Nephin 2020-10-01 01:14:21 -04:00
parent 7b4aca2088
commit e54567223b
8 changed files with 243 additions and 331 deletions

View File

@ -85,7 +85,11 @@ func New(config Config) (*AutoConfig, error) {
} }
if config.Waiter == nil { if config.Waiter == nil {
config.Waiter = retry.NewRetryWaiter(1, 0, 10*time.Minute, retry.NewJitterRandomStagger(25)) config.Waiter = &retry.Waiter{
MinFailures: 1,
MaxWait: 10 * time.Minute,
Jitter: retry.NewJitter(25),
}
} }
return &AutoConfig{ return &AutoConfig{
@ -306,23 +310,21 @@ func (ac *AutoConfig) getInitialConfiguration(ctx context.Context) (*pbautoconf.
return nil, err return nil, err
} }
// this resets the failures so that we will perform immediate request ac.acConfig.Waiter.Reset()
wait := ac.acConfig.Waiter.Success()
for { for {
select { resp, err := ac.getInitialConfigurationOnce(ctx, csr, key)
case <-wait: switch {
if resp, err := ac.getInitialConfigurationOnce(ctx, csr, key); err == nil && resp != nil { case err == nil && resp != nil:
return resp, nil return resp, nil
} else if err != nil { case err != nil:
ac.logger.Error(err.Error()) ac.logger.Error(err.Error())
} else { default:
ac.logger.Error("No error returned when fetching configuration from the servers but no response was either") ac.logger.Error("No error returned when fetching configuration from the servers but no response was either")
} }
wait = ac.acConfig.Waiter.Failed() if err := ac.acConfig.Waiter.Wait(ctx); err != nil {
case <-ctx.Done(): ac.logger.Info("interrupted during initial auto configuration", "err", err)
ac.logger.Info("interrupted during initial auto configuration", "err", ctx.Err()) return nil, err
return nil, ctx.Err()
} }
} }
} }

View File

@ -413,7 +413,7 @@ func TestInitialConfiguration_retries(t *testing.T) {
mcfg.Config.Loader = loader.Load mcfg.Config.Loader = loader.Load
// reduce the retry wait times to make this test run faster // reduce the retry wait times to make this test run faster
mcfg.Config.Waiter = retry.NewWaiter(2, 0, 1*time.Millisecond, nil) mcfg.Config.Waiter = &retry.Waiter{MinFailures: 2, MaxWait: time.Millisecond}
indexedRoots, cert, extraCerts := mcfg.setupInitialTLS(t, "autoconf", "dc1", "secret") indexedRoots, cert, extraCerts := mcfg.setupInitialTLS(t, "autoconf", "dc1", "secret")

View File

@ -16,23 +16,21 @@ func (ac *AutoConfig) autoEncryptInitialCerts(ctx context.Context) (*structs.Sig
return nil, err return nil, err
} }
// this resets the failures so that we will perform immediate request ac.acConfig.Waiter.Reset()
wait := ac.acConfig.Waiter.Success()
for { for {
select { resp, err := ac.autoEncryptInitialCertsOnce(ctx, csr, key)
case <-wait: switch {
if resp, err := ac.autoEncryptInitialCertsOnce(ctx, csr, key); err == nil && resp != nil { case err == nil && resp != nil:
return resp, nil return resp, nil
} else if err != nil { case err != nil:
ac.logger.Error(err.Error()) ac.logger.Error(err.Error())
} else { default:
ac.logger.Error("No error returned when fetching certificates from the servers but no response was either") ac.logger.Error("No error returned when fetching certificates from the servers but no response was either")
} }
wait = ac.acConfig.Waiter.Failed() if err := ac.acConfig.Waiter.Wait(ctx); err != nil {
case <-ctx.Done(): ac.logger.Info("interrupted during retrieval of auto-encrypt certificates", "err", err)
ac.logger.Info("interrupted during retrieval of auto-encrypt certificates", "err", ctx.Err()) return nil, err
return nil, ctx.Err()
} }
} }
} }

View File

@ -248,7 +248,7 @@ func TestAutoEncrypt_InitialCerts(t *testing.T) {
resp.VerifyServerHostname = true resp.VerifyServerHostname = true
}) })
mcfg.Config.Waiter = retry.NewRetryWaiter(2, 0, 1*time.Millisecond, nil) mcfg.Config.Waiter = &retry.Waiter{MinFailures: 2, MaxWait: time.Millisecond}
ac := AutoConfig{ ac := AutoConfig{
config: &config.RuntimeConfig{ config: &config.RuntimeConfig{

View File

@ -68,12 +68,12 @@ type Config struct {
// known servers during fallback operations. // known servers during fallback operations.
ServerProvider ServerProvider ServerProvider ServerProvider
// Waiter is a RetryWaiter to be used during retrieval of the // Waiter is used during retrieval of the initial configuration.
// initial configuration. When a round of requests fails we will // When around of requests fails we will
// wait and eventually make another round of requests (1 round // wait and eventually make another round of requests (1 round
// is trying the RPC once against each configured server addr). The // is trying the RPC once against each configured server addr). The
// waiting implements some backoff to prevent from retrying these RPCs // waiting implements some backoff to prevent from retrying these RPCs
// to often. This field is not required and if left unset a waiter will // too often. This field is not required and if left unset a waiter will
// be used that has a max wait duration of 10 minutes and a randomized // be used that has a max wait duration of 10 minutes and a randomized
// jitter of 25% of the wait time. Setting this is mainly useful for // jitter of 25% of the wait time. Setting this is mainly useful for
// testing purposes to allow testing out the retrying functionality without // testing purposes to allow testing out the retrying functionality without

View File

@ -18,8 +18,6 @@ const (
// replicationMaxRetryWait is the maximum number of seconds to wait between // replicationMaxRetryWait is the maximum number of seconds to wait between
// failed blocking queries when backing off. // failed blocking queries when backing off.
replicationDefaultMaxRetryWait = 120 * time.Second replicationDefaultMaxRetryWait = 120 * time.Second
replicationDefaultRate = 1
) )
type ReplicatorDelegate interface { type ReplicatorDelegate interface {
@ -36,7 +34,7 @@ type ReplicatorConfig struct {
// The number of replication rounds that can be done in a burst // The number of replication rounds that can be done in a burst
Burst int Burst int
// Minimum number of RPC failures to ignore before backing off // Minimum number of RPC failures to ignore before backing off
MinFailures int MinFailures uint
// Maximum wait time between failing RPCs // Maximum wait time between failing RPCs
MaxRetryWait time.Duration MaxRetryWait time.Duration
// Where to send our logs // Where to send our logs
@ -71,12 +69,11 @@ func NewReplicator(config *ReplicatorConfig) (*Replicator, error) {
if maxWait == 0 { if maxWait == 0 {
maxWait = replicationDefaultMaxRetryWait maxWait = replicationDefaultMaxRetryWait
} }
waiter := &retry.Waiter{
minFailures := config.MinFailures MinFailures: config.MinFailures,
if minFailures < 0 { MaxWait: maxWait,
minFailures = 0 Jitter: retry.NewJitter(10),
} }
waiter := retry.NewRetryWaiter(minFailures, 0*time.Second, maxWait, retry.NewJitterRandomStagger(10))
return &Replicator{ return &Replicator{
limiter: limiter, limiter: limiter,
waiter: waiter, waiter: waiter,
@ -100,10 +97,8 @@ func (r *Replicator) Run(ctx context.Context) error {
// Perform a single round of replication // Perform a single round of replication
index, exit, err := r.delegate.Replicate(ctx, atomic.LoadUint64(&r.lastRemoteIndex), r.logger) index, exit, err := r.delegate.Replicate(ctx, atomic.LoadUint64(&r.lastRemoteIndex), r.logger)
if exit { if exit {
// the replication function told us to exit
return nil return nil
} }
if err != nil { if err != nil {
// reset the lastRemoteIndex when there is an RPC failure. This should cause a full sync to be done during // reset the lastRemoteIndex when there is an RPC failure. This should cause a full sync to be done during
// the next round of replication // the next round of replication
@ -112,18 +107,16 @@ func (r *Replicator) Run(ctx context.Context) error {
if r.suppressErrorLog != nil && !r.suppressErrorLog(err) { if r.suppressErrorLog != nil && !r.suppressErrorLog(err) {
r.logger.Warn("replication error (will retry if still leader)", "error", err) r.logger.Warn("replication error (will retry if still leader)", "error", err)
} }
} else {
atomic.StoreUint64(&r.lastRemoteIndex, index) if err := r.waiter.Wait(ctx); err != nil {
r.logger.Debug("replication completed through remote index", "index", index) return nil
}
continue
} }
select { atomic.StoreUint64(&r.lastRemoteIndex, index)
case <-ctx.Done(): r.logger.Debug("replication completed through remote index", "index", index)
return nil r.waiter.Reset()
// wait some amount of time to prevent churning through many replication rounds while replication is failing
case <-r.waiter.WaitIfErr(err):
// do nothing
}
} }
} }

View File

@ -1,9 +1,9 @@
package retry package retry
import ( import (
"context"
"math/rand"
"time" "time"
"github.com/hashicorp/consul/lib"
) )
const ( const (
@ -11,153 +11,96 @@ const (
defaultMaxWait = 2 * time.Minute defaultMaxWait = 2 * time.Minute
) )
// Interface used for offloading jitter calculations from the RetryWaiter // Jitter should return a new wait duration optionally with some time added or
type Jitter interface { // removed to create some randomness in wait time.
AddJitter(baseTime time.Duration) time.Duration type Jitter func(baseTime time.Duration) time.Duration
}
// Calculates a random jitter between 0 and up to a specific percentage of the baseTime // NewJitter returns a new random Jitter that is up to percent longer than the
type JitterRandomStagger struct { // original wait time.
// int64 because we are going to be doing math against an int64 to represent nanoseconds func NewJitter(percent int64) Jitter {
percent int64
}
// Creates a new JitterRandomStagger
func NewJitterRandomStagger(percent int) *JitterRandomStagger {
if percent < 0 { if percent < 0 {
percent = 0 percent = 0
} }
return &JitterRandomStagger{ return func(baseTime time.Duration) time.Duration {
percent: int64(percent), if percent == 0 {
return baseTime
}
max := (int64(baseTime) * percent) / 100
if max < 0 { // overflow
return baseTime
}
return baseTime + time.Duration(rand.Int63n(max))
} }
} }
// Implments the Jitter interface // Waiter records the number of failures and performs exponential backoff when
func (j *JitterRandomStagger) AddJitter(baseTime time.Duration) time.Duration { // when there are consecutive failures.
if j.percent == 0 {
return baseTime
}
// time.Duration is actually a type alias for int64 which is why casting
// to the duration type and then dividing works
return baseTime + lib.RandomStagger((baseTime*time.Duration(j.percent))/100)
}
// RetryWaiter will record failed and successful operations and provide
// a channel to wait on before a failed operation can be retried.
type Waiter struct { type Waiter struct {
// MinFailures before exponential backoff starts. Any failures before
// MinFailures is reached will wait MinWait time.
MinFailures uint MinFailures uint
MinWait time.Duration // MinWait time. Returned after the first failure.
MaxWait time.Duration MinWait time.Duration
Jitter Jitter // MaxWait time.
failures uint MaxWait time.Duration
// Jitter to add to each wait time.
Jitter Jitter
// Factor is the multiplier to use when calculating the delay. Defaults to
// 1 second.
Factor time.Duration
failures uint
} }
// Creates a new RetryWaiter // delay calculates the time to wait based on the number of failures
func NewRetryWaiter(minFailures int, minWait, maxWait time.Duration, jitter Jitter) *Waiter { func (w *Waiter) delay() time.Duration {
if minFailures < 0 { if w.failures <= w.MinFailures {
minFailures = defaultMinFailures return w.MinWait
}
factor := w.Factor
if factor == 0 {
factor = time.Second
} }
if maxWait <= 0 { shift := w.failures - w.MinFailures - 1
maxWait = defaultMaxWait waitTime := w.MaxWait
if shift < 31 {
waitTime = (1 << shift) * factor
} }
if w.Jitter != nil {
if minWait <= 0 { waitTime = w.Jitter(waitTime)
minWait = 0 * time.Nanosecond
} }
if w.MaxWait != 0 && waitTime > w.MaxWait {
return &Waiter{ return w.MaxWait
MinFailures: uint(minFailures),
MinWait: minWait,
MaxWait: maxWait,
failures: 0,
Jitter: jitter,
} }
} if waitTime < w.MinWait {
return w.MinWait
// calculates the necessary wait time before the
// next operation should be allowed.
func (rw *Waiter) calculateWait() time.Duration {
waitTime := rw.MinWait
if rw.failures > rw.MinFailures {
shift := rw.failures - rw.MinFailures - 1
waitTime = rw.MaxWait
if shift < 31 {
waitTime = (1 << shift) * time.Second
}
if waitTime > rw.MaxWait {
waitTime = rw.MaxWait
}
if rw.Jitter != nil {
waitTime = rw.Jitter.AddJitter(waitTime)
}
} }
if waitTime < rw.MinWait {
waitTime = rw.MinWait
}
return waitTime return waitTime
} }
// calculates the waitTime and returns a chan // Reset the failure count to 0.
// that will become selectable once that amount func (w *Waiter) Reset() {
// of time has elapsed. w.failures = 0
func (rw *Waiter) wait() <-chan struct{} { }
waitTime := rw.calculateWait()
ch := make(chan struct{}) // Failures returns the count of consecutive failures.
if waitTime > 0 { func (w *Waiter) Failures() int {
time.AfterFunc(waitTime, func() { close(ch) }) return int(w.failures)
} else { }
// if there should be 0 wait time then we ensure
// that the chan will be immediately selectable // Wait increase the number of failures by one, and then blocks until the context
close(ch) // is cancelled, or until the wait time is reached.
// The wait time increases exponentially as the number of failures increases.
// Wait will return ctx.Err() if the context is cancelled.
func (w *Waiter) Wait(ctx context.Context) error {
w.failures++
timer := time.NewTimer(w.delay())
select {
case <-ctx.Done():
timer.Stop()
return ctx.Err()
case <-timer.C:
return nil
} }
return ch
}
// Marks that an operation is successful which resets the failure count.
// The chan that is returned will be immediately selectable
func (rw *Waiter) Success() <-chan struct{} {
rw.Reset()
return rw.wait()
}
// Marks that an operation failed. The chan returned will be selectable
// once the calculated retry wait amount of time has elapsed
func (rw *Waiter) Failed() <-chan struct{} {
rw.failures += 1
ch := rw.wait()
return ch
}
// Resets the internal failure counter.
func (rw *Waiter) Reset() {
rw.failures = 0
}
// Failures returns the current number of consecutive failures recorded.
func (rw *Waiter) Failures() int {
return int(rw.failures)
}
// WaitIf is a convenice method to record whether the last
// operation was a success or failure and return a chan that
// will be selectablw when the next operation can be done.
func (rw *Waiter) WaitIf(failure bool) <-chan struct{} {
if failure {
return rw.Failed()
}
return rw.Success()
}
// WaitIfErr is a convenience method to record whether the last
// operation was a success or failure based on whether the err
// is nil and then return a chan that will be selectable when
// the next operation can be done.
func (rw *Waiter) WaitIfErr(err error) <-chan struct{} {
return rw.WaitIf(err != nil)
} }

View File

@ -1,184 +1,160 @@
package retry package retry
import ( import (
"fmt" "context"
"math"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestJitterRandomStagger(t *testing.T) { func TestJitter(t *testing.T) {
t.Parallel() repeat(t, "0 percent", func(t *testing.T) {
jitter := NewJitter(0)
t.Run("0 percent", func(t *testing.T) {
t.Parallel()
jitter := NewJitterRandomStagger(0)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
baseTime := time.Duration(i) * time.Second baseTime := time.Duration(i) * time.Second
require.Equal(t, baseTime, jitter.AddJitter(baseTime)) require.Equal(t, baseTime, jitter(baseTime))
} }
}) })
t.Run("10 percent", func(t *testing.T) { repeat(t, "10 percent", func(t *testing.T) {
t.Parallel() jitter := NewJitter(10)
jitter := NewJitterRandomStagger(10) baseTime := 5000 * time.Millisecond
for i := 0; i < 10; i++ { maxTime := 5500 * time.Millisecond
baseTime := 5000 * time.Millisecond newTime := jitter(baseTime)
maxTime := 5500 * time.Millisecond require.True(t, newTime > baseTime)
newTime := jitter.AddJitter(baseTime) require.True(t, newTime <= maxTime)
require.True(t, newTime > baseTime)
require.True(t, newTime <= maxTime)
}
}) })
t.Run("100 percent", func(t *testing.T) { repeat(t, "100 percent", func(t *testing.T) {
t.Parallel() jitter := NewJitter(100)
jitter := NewJitterRandomStagger(100) baseTime := 1234 * time.Millisecond
for i := 0; i < 10; i++ { maxTime := 2468 * time.Millisecond
baseTime := 1234 * time.Millisecond newTime := jitter(baseTime)
maxTime := 2468 * time.Millisecond require.True(t, newTime > baseTime)
newTime := jitter.AddJitter(baseTime) require.True(t, newTime <= maxTime)
require.True(t, newTime > baseTime) })
require.True(t, newTime <= maxTime)
repeat(t, "overflow", func(t *testing.T) {
jitter := NewJitter(100)
baseTime := time.Duration(math.MaxInt64) - 2*time.Hour
newTime := jitter(baseTime)
require.Equal(t, baseTime, newTime)
})
}
func repeat(t *testing.T, name string, fn func(t *testing.T)) {
t.Run(name, func(t *testing.T) {
for i := 0; i < 1000; i++ {
fn(t)
} }
}) })
} }
func TestRetryWaiter_calculateWait(t *testing.T) { func TestWaiter_Delay(t *testing.T) {
t.Parallel() t.Run("zero value", func(t *testing.T) {
w := &Waiter{}
t.Run("Defaults", func(t *testing.T) { for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 32, 64, 128} {
t.Parallel() w.failures = uint(i)
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
rw := NewRetryWaiter(0, 0, 0, nil)
require.Equal(t, 0*time.Nanosecond, rw.calculateWait())
rw.failures += 1
require.Equal(t, 1*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 2*time.Second, rw.calculateWait())
rw.failures = 31
require.Equal(t, defaultMaxWait, rw.calculateWait())
})
t.Run("Minimum Wait", func(t *testing.T) {
t.Parallel()
rw := NewRetryWaiter(0, 5*time.Second, 0, nil)
require.Equal(t, 5*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 5*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 5*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 5*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 8*time.Second, rw.calculateWait())
})
t.Run("Minimum Failures", func(t *testing.T) {
t.Parallel()
rw := NewRetryWaiter(5, 0, 0, nil)
require.Equal(t, 0*time.Nanosecond, rw.calculateWait())
rw.failures += 5
require.Equal(t, 0*time.Nanosecond, rw.calculateWait())
rw.failures += 1
require.Equal(t, 1*time.Second, rw.calculateWait())
})
t.Run("Maximum Wait", func(t *testing.T) {
t.Parallel()
rw := NewRetryWaiter(0, 0, 5*time.Second, nil)
require.Equal(t, 0*time.Nanosecond, rw.calculateWait())
rw.failures += 1
require.Equal(t, 1*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 2*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 4*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 5*time.Second, rw.calculateWait())
rw.failures = 31
require.Equal(t, 5*time.Second, rw.calculateWait())
})
}
func TestRetryWaiter_WaitChans(t *testing.T) {
t.Parallel()
t.Run("Minimum Wait - Success", func(t *testing.T) {
t.Parallel()
rw := NewRetryWaiter(0, 250*time.Millisecond, 0, nil)
select {
case <-time.After(200 * time.Millisecond):
case <-rw.Success():
require.Fail(t, "minimum wait not respected")
} }
}) })
t.Run("Minimum Wait - WaitIf", func(t *testing.T) { t.Run("with minimum wait", func(t *testing.T) {
t.Parallel() w := &Waiter{MinWait: 5 * time.Second}
for i, expected := range []time.Duration{5, 5, 5, 5, 8, 16, 32, 64, 128} {
rw := NewRetryWaiter(0, 250*time.Millisecond, 0, nil) w.failures = uint(i)
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
select {
case <-time.After(200 * time.Millisecond):
case <-rw.WaitIf(false):
require.Fail(t, "minimum wait not respected")
} }
}) })
t.Run("Minimum Wait - WaitIfErr", func(t *testing.T) { t.Run("with maximum wait", func(t *testing.T) {
t.Parallel() w := &Waiter{MaxWait: 20 * time.Second}
for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 20, 20, 20} {
rw := NewRetryWaiter(0, 250*time.Millisecond, 0, nil) w.failures = uint(i)
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
select {
case <-time.After(200 * time.Millisecond):
case <-rw.WaitIfErr(nil):
require.Fail(t, "minimum wait not respected")
} }
}) })
t.Run("Maximum Wait - Failed", func(t *testing.T) { t.Run("with minimum failures", func(t *testing.T) {
t.Parallel() w := &Waiter{MinFailures: 4}
for i, expected := range []time.Duration{0, 0, 0, 0, 0, 1, 2, 4, 8, 16} {
rw := NewRetryWaiter(0, 0, 250*time.Millisecond, nil) w.failures = uint(i)
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
select {
case <-time.After(500 * time.Millisecond):
require.Fail(t, "maximum wait not respected")
case <-rw.Failed():
} }
}) })
t.Run("Maximum Wait - WaitIf", func(t *testing.T) { t.Run("with factor", func(t *testing.T) {
t.Parallel() w := &Waiter{Factor: time.Millisecond}
for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 32, 64, 128} {
rw := NewRetryWaiter(0, 0, 250*time.Millisecond, nil) w.failures = uint(i)
require.Equal(t, expected*time.Millisecond, w.delay(), "failure count: %d", i)
select {
case <-time.After(500 * time.Millisecond):
require.Fail(t, "maximum wait not respected")
case <-rw.WaitIf(true):
} }
}) })
t.Run("Maximum Wait - WaitIfErr", func(t *testing.T) { t.Run("with all settings", func(t *testing.T) {
t.Parallel() w := &Waiter{
MinFailures: 2,
rw := NewRetryWaiter(0, 0, 250*time.Millisecond, nil) MinWait: 4 * time.Millisecond,
MaxWait: 20 * time.Millisecond,
select { Factor: time.Millisecond,
case <-time.After(500 * time.Millisecond): }
require.Fail(t, "maximum wait not respected") for i, expected := range []time.Duration{4, 4, 4, 4, 4, 4, 8, 16, 20, 20, 20} {
case <-rw.WaitIfErr(fmt.Errorf("Fake Error")): w.failures = uint(i)
require.Equal(t, expected*time.Millisecond, w.delay(), "failure count: %d", i)
} }
}) })
} }
func TestWaiter_Wait(t *testing.T) {
ctx := context.Background()
t.Run("first failure", func(t *testing.T) {
w := &Waiter{MinWait: time.Millisecond, Factor: 1}
elapsed, err := runWait(ctx, w)
require.NoError(t, err)
assertApproximateDuration(t, elapsed, time.Millisecond)
require.Equal(t, w.failures, uint(1))
})
t.Run("max failures", func(t *testing.T) {
w := &Waiter{
MaxWait: 100 * time.Millisecond,
failures: 200,
}
elapsed, err := runWait(ctx, w)
require.NoError(t, err)
assertApproximateDuration(t, elapsed, 100*time.Millisecond)
require.Equal(t, w.failures, uint(201))
})
t.Run("context deadline", func(t *testing.T) {
w := &Waiter{failures: 200, MinWait: time.Second}
ctx, cancel := context.WithTimeout(ctx, 5*time.Millisecond)
t.Cleanup(cancel)
elapsed, err := runWait(ctx, w)
require.Equal(t, err, context.DeadlineExceeded)
assertApproximateDuration(t, elapsed, 5*time.Millisecond)
require.Equal(t, w.failures, uint(201))
})
}
func runWait(ctx context.Context, w *Waiter) (time.Duration, error) {
before := time.Now()
err := w.Wait(ctx)
return time.Since(before), err
}
func assertApproximateDuration(t *testing.T, actual time.Duration, expected time.Duration) {
t.Helper()
delta := 20 * time.Millisecond
min, max := expected-delta, expected+delta
if min < 0 {
min = 0
}
if actual < min || actual > max {
t.Fatalf("expected %v to be between %v and %v", actual, min, max)
}
}