mirror of https://github.com/hashicorp/consul
213 lines
5.7 KiB
Go
213 lines
5.7 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package retry
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestJitter(t *testing.T) {
|
|
repeat(t, "0 percent", func(t *testing.T) {
|
|
jitter := NewJitter(0)
|
|
for i := 0; i < 10; i++ {
|
|
baseTime := time.Duration(i) * time.Second
|
|
require.Equal(t, baseTime, jitter(baseTime))
|
|
}
|
|
})
|
|
|
|
repeat(t, "10 percent", func(t *testing.T) {
|
|
jitter := NewJitter(10)
|
|
baseTime := 5000 * time.Millisecond
|
|
maxTime := 5500 * time.Millisecond
|
|
newTime := jitter(baseTime)
|
|
require.True(t, newTime > baseTime)
|
|
require.True(t, newTime <= maxTime)
|
|
})
|
|
|
|
repeat(t, "100 percent", func(t *testing.T) {
|
|
jitter := NewJitter(100)
|
|
baseTime := 1234 * time.Millisecond
|
|
maxTime := 2468 * time.Millisecond
|
|
newTime := jitter(baseTime)
|
|
require.True(t, newTime > baseTime)
|
|
require.True(t, newTime <= maxTime)
|
|
})
|
|
|
|
repeat(t, "overflow", func(t *testing.T) {
|
|
jitter := NewJitter(100)
|
|
baseTime := time.Duration(math.MaxInt64) - 2*time.Hour
|
|
newTime := jitter(baseTime)
|
|
require.Equal(t, baseTime, newTime)
|
|
})
|
|
}
|
|
|
|
func repeat(t *testing.T, name string, fn func(t *testing.T)) {
|
|
t.Run(name, func(t *testing.T) {
|
|
for i := 0; i < 1000; i++ {
|
|
fn(t)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestWaiter_Delay(t *testing.T) {
|
|
t.Run("zero value", func(t *testing.T) {
|
|
w := &Waiter{}
|
|
for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 32, 64, 128} {
|
|
w.failures = uint(i)
|
|
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
|
|
}
|
|
})
|
|
|
|
t.Run("with minimum wait", func(t *testing.T) {
|
|
w := &Waiter{MinWait: 5 * time.Second}
|
|
for i, expected := range []time.Duration{5, 5, 5, 5, 8, 16, 32, 64, 128} {
|
|
w.failures = uint(i)
|
|
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
|
|
}
|
|
})
|
|
|
|
t.Run("with maximum wait", func(t *testing.T) {
|
|
w := &Waiter{MaxWait: 20 * time.Second}
|
|
for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 20, 20, 20} {
|
|
w.failures = uint(i)
|
|
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
|
|
}
|
|
})
|
|
|
|
t.Run("with minimum failures", func(t *testing.T) {
|
|
w := &Waiter{MinFailures: 4}
|
|
for i, expected := range []time.Duration{0, 0, 0, 0, 0, 1, 2, 4, 8, 16} {
|
|
w.failures = uint(i)
|
|
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
|
|
}
|
|
})
|
|
|
|
t.Run("with factor", func(t *testing.T) {
|
|
w := &Waiter{Factor: time.Millisecond}
|
|
for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 32, 64, 128} {
|
|
w.failures = uint(i)
|
|
require.Equal(t, expected*time.Millisecond, w.delay(), "failure count: %d", i)
|
|
}
|
|
})
|
|
|
|
t.Run("with all settings", func(t *testing.T) {
|
|
w := &Waiter{
|
|
MinFailures: 2,
|
|
MinWait: 4 * time.Millisecond,
|
|
MaxWait: 20 * time.Millisecond,
|
|
Factor: time.Millisecond,
|
|
}
|
|
for i, expected := range []time.Duration{4, 4, 4, 4, 4, 4, 8, 16, 20, 20, 20} {
|
|
w.failures = uint(i)
|
|
require.Equal(t, expected*time.Millisecond, w.delay(), "failure count: %d", i)
|
|
}
|
|
})
|
|
|
|
t.Run("jitter can exceed MaxWait", func(t *testing.T) {
|
|
w := &Waiter{
|
|
MaxWait: 20 * time.Second,
|
|
Jitter: NewJitter(300),
|
|
|
|
failures: 30,
|
|
}
|
|
|
|
delay := w.delay()
|
|
require.True(t, delay > 20*time.Second, "expected delay %v to be greater than MaxWait %v", delay, w.MaxWait)
|
|
})
|
|
}
|
|
|
|
func TestWaiter_Wait(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("too slow for testing.Short")
|
|
}
|
|
|
|
ctx := context.Background()
|
|
|
|
t.Run("first failure", func(t *testing.T) {
|
|
w := &Waiter{MinWait: time.Millisecond, Factor: 1}
|
|
elapsed, err := runWait(ctx, w)
|
|
require.NoError(t, err)
|
|
assertApproximateDuration(t, elapsed, time.Millisecond)
|
|
require.Equal(t, w.failures, uint(1))
|
|
})
|
|
|
|
t.Run("max failures", func(t *testing.T) {
|
|
w := &Waiter{
|
|
MaxWait: 100 * time.Millisecond,
|
|
failures: 200,
|
|
}
|
|
elapsed, err := runWait(ctx, w)
|
|
require.NoError(t, err)
|
|
assertApproximateDuration(t, elapsed, 100*time.Millisecond)
|
|
require.Equal(t, w.failures, uint(201))
|
|
})
|
|
|
|
t.Run("context deadline", func(t *testing.T) {
|
|
w := &Waiter{failures: 200, MinWait: time.Second}
|
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Millisecond)
|
|
t.Cleanup(cancel)
|
|
|
|
elapsed, err := runWait(ctx, w)
|
|
require.Equal(t, err, context.DeadlineExceeded)
|
|
assertApproximateDuration(t, elapsed, 5*time.Millisecond)
|
|
require.Equal(t, w.failures, uint(201))
|
|
})
|
|
}
|
|
|
|
func TestWaiter_RetryLoop(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("too slow for testing.Short")
|
|
}
|
|
// Change the default factor so that we retry faster.
|
|
w := &Waiter{Factor: 1 * time.Millisecond}
|
|
|
|
t.Run("exits if operation is successful after a few reties", func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
t.Cleanup(cancel)
|
|
numRetries := 0
|
|
err := w.RetryLoop(ctx, func() error {
|
|
if numRetries < 2 {
|
|
numRetries++
|
|
return fmt.Errorf("operation not successful")
|
|
}
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
t.Run("errors if operation is never successful", func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
|
t.Cleanup(cancel)
|
|
err := w.RetryLoop(ctx, func() error {
|
|
return fmt.Errorf("operation not successful")
|
|
})
|
|
require.NotNil(t, err)
|
|
require.EqualError(t, err, "could not retry operation: operation not successful")
|
|
})
|
|
}
|
|
|
|
func runWait(ctx context.Context, w *Waiter) (time.Duration, error) {
|
|
before := time.Now()
|
|
err := w.Wait(ctx)
|
|
return time.Since(before), err
|
|
}
|
|
|
|
func assertApproximateDuration(t *testing.T, actual time.Duration, expected time.Duration) {
|
|
t.Helper()
|
|
delta := 20 * time.Millisecond
|
|
min, max := expected-delta, expected+delta
|
|
if min < 0 {
|
|
min = 0
|
|
}
|
|
if actual < min || actual > max {
|
|
t.Fatalf("expected %v to be between %v and %v", actual, min, max)
|
|
}
|
|
}
|