Fix wait_test flakes

pull/6/head
Wojciech Tyczynski 2016-01-03 09:56:57 +01:00
parent 05609cbf44
commit 1ad524ac24
1 changed files with 27 additions and 22 deletions

View File

@ -18,6 +18,7 @@ package wait
import (
"errors"
"sync"
"sync/atomic"
"testing"
"time"
@ -48,10 +49,17 @@ DRAIN:
}
}
func fakeTicker(max int, used *int32) WaitFunc {
type fakePoller struct {
max int
used int32 // accessed with atomics
wg sync.WaitGroup
}
func fakeTicker(max int, used *int32, doneFunc func()) WaitFunc {
return func(done <-chan struct{}) <-chan struct{} {
ch := make(chan struct{})
go func() {
defer doneFunc()
defer close(ch)
for i := 0; i < max; i++ {
select {
@ -68,13 +76,9 @@ func fakeTicker(max int, used *int32) WaitFunc {
}
}
type fakePoller struct {
max int
used int32 // accessed with atomics
}
func (fp *fakePoller) GetWaitFunc(interval, timeout time.Duration) WaitFunc {
return fakeTicker(fp.max, &fp.used)
func (fp *fakePoller) GetWaitFunc() WaitFunc {
fp.wg.Add(1)
return fakeTicker(fp.max, &fp.used, fp.wg.Done)
}
func TestPoll(t *testing.T) {
@ -83,10 +87,11 @@ func TestPoll(t *testing.T) {
invocations++
return true, nil
})
fp := fakePoller{max: 1}
if err := pollInternal(fp.GetWaitFunc(time.Microsecond, time.Second), f); err != nil {
fp := fakePoller{max: 1, wg: sync.WaitGroup{}}
if err := pollInternal(fp.GetWaitFunc(), f); err != nil {
t.Fatalf("unexpected error %v", err)
}
fp.wg.Wait()
if invocations != 1 {
t.Errorf("Expected exactly one invocation, got %d", invocations)
}
@ -101,10 +106,11 @@ func TestPollError(t *testing.T) {
f := ConditionFunc(func() (bool, error) {
return false, expectedError
})
fp := fakePoller{max: 1}
if err := pollInternal(fp.GetWaitFunc(time.Microsecond, time.Second), f); err == nil || err != expectedError {
fp := fakePoller{max: 1, wg: sync.WaitGroup{}}
if err := pollInternal(fp.GetWaitFunc(), f); err == nil || err != expectedError {
t.Fatalf("Expected error %v, got none %v", expectedError, err)
}
fp.wg.Wait()
used := atomic.LoadInt32(&fp.used)
if used != 1 {
t.Errorf("Expected exactly one tick, got %d", used)
@ -117,8 +123,8 @@ func TestPollImmediate(t *testing.T) {
invocations++
return true, nil
})
fp := fakePoller{max: 0}
if err := pollImmediateInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err != nil {
fp := fakePoller{max: 0, wg: sync.WaitGroup{}}
if err := pollImmediateInternal(fp.GetWaitFunc(), f); err != nil {
t.Fatalf("unexpected error %v", err)
}
if invocations != 1 {
@ -128,19 +134,18 @@ func TestPollImmediate(t *testing.T) {
if used != 0 {
t.Errorf("Expected exactly zero ticks, got %d", used)
}
}
func TestPollImmediateError(t *testing.T) {
expectedError := errors.New("Expected error")
f = ConditionFunc(func() (bool, error) {
f := ConditionFunc(func() (bool, error) {
return false, expectedError
})
fp = fakePoller{max: 0}
if err := pollImmediateInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err == nil || err != expectedError {
fp := fakePoller{max: 0, wg: sync.WaitGroup{}}
if err := pollImmediateInternal(fp.GetWaitFunc(), f); err == nil || err != expectedError {
t.Fatalf("Expected error %v, got none %v", expectedError, err)
}
if invocations != 1 {
t.Errorf("Expected exactly one invocation, got %d", invocations)
}
used = atomic.LoadInt32(&fp.used)
used := atomic.LoadInt32(&fp.used)
if used != 0 {
t.Errorf("Expected exactly zero ticks, got %d", used)
}
@ -236,7 +241,7 @@ func TestWaitFor(t *testing.T) {
}
for k, c := range testCases {
invocations = 0
ticker := fakeTicker(c.Ticks, nil)
ticker := fakeTicker(c.Ticks, nil, func() {})
err := func() error {
done := make(chan struct{})
defer close(done)