Fix wait_test flakes

pull/6/head
Wojciech Tyczynski 2016-01-03 09:56:57 +01:00
parent 05609cbf44
commit 1ad524ac24
1 changed files with 27 additions and 22 deletions

View File

@ -18,6 +18,7 @@ package wait
import ( import (
"errors" "errors"
"sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@ -48,10 +49,17 @@ DRAIN:
} }
} }
func fakeTicker(max int, used *int32) WaitFunc { type fakePoller struct {
max int
used int32 // accessed with atomics
wg sync.WaitGroup
}
func fakeTicker(max int, used *int32, doneFunc func()) WaitFunc {
return func(done <-chan struct{}) <-chan struct{} { return func(done <-chan struct{}) <-chan struct{} {
ch := make(chan struct{}) ch := make(chan struct{})
go func() { go func() {
defer doneFunc()
defer close(ch) defer close(ch)
for i := 0; i < max; i++ { for i := 0; i < max; i++ {
select { select {
@ -68,13 +76,9 @@ func fakeTicker(max int, used *int32) WaitFunc {
} }
} }
type fakePoller struct { func (fp *fakePoller) GetWaitFunc() WaitFunc {
max int fp.wg.Add(1)
used int32 // accessed with atomics return fakeTicker(fp.max, &fp.used, fp.wg.Done)
}
func (fp *fakePoller) GetWaitFunc(interval, timeout time.Duration) WaitFunc {
return fakeTicker(fp.max, &fp.used)
} }
func TestPoll(t *testing.T) { func TestPoll(t *testing.T) {
@ -83,10 +87,11 @@ func TestPoll(t *testing.T) {
invocations++ invocations++
return true, nil return true, nil
}) })
fp := fakePoller{max: 1} fp := fakePoller{max: 1, wg: sync.WaitGroup{}}
if err := pollInternal(fp.GetWaitFunc(time.Microsecond, time.Second), f); err != nil { if err := pollInternal(fp.GetWaitFunc(), f); err != nil {
t.Fatalf("unexpected error %v", err) t.Fatalf("unexpected error %v", err)
} }
fp.wg.Wait()
if invocations != 1 { if invocations != 1 {
t.Errorf("Expected exactly one invocation, got %d", invocations) t.Errorf("Expected exactly one invocation, got %d", invocations)
} }
@ -101,10 +106,11 @@ func TestPollError(t *testing.T) {
f := ConditionFunc(func() (bool, error) { f := ConditionFunc(func() (bool, error) {
return false, expectedError return false, expectedError
}) })
fp := fakePoller{max: 1} fp := fakePoller{max: 1, wg: sync.WaitGroup{}}
if err := pollInternal(fp.GetWaitFunc(time.Microsecond, time.Second), f); err == nil || err != expectedError { if err := pollInternal(fp.GetWaitFunc(), f); err == nil || err != expectedError {
t.Fatalf("Expected error %v, got none %v", expectedError, err) t.Fatalf("Expected error %v, got none %v", expectedError, err)
} }
fp.wg.Wait()
used := atomic.LoadInt32(&fp.used) used := atomic.LoadInt32(&fp.used)
if used != 1 { if used != 1 {
t.Errorf("Expected exactly one tick, got %d", used) t.Errorf("Expected exactly one tick, got %d", used)
@ -117,8 +123,8 @@ func TestPollImmediate(t *testing.T) {
invocations++ invocations++
return true, nil return true, nil
}) })
fp := fakePoller{max: 0} fp := fakePoller{max: 0, wg: sync.WaitGroup{}}
if err := pollImmediateInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err != nil { if err := pollImmediateInternal(fp.GetWaitFunc(), f); err != nil {
t.Fatalf("unexpected error %v", err) t.Fatalf("unexpected error %v", err)
} }
if invocations != 1 { if invocations != 1 {
@ -128,19 +134,18 @@ func TestPollImmediate(t *testing.T) {
if used != 0 { if used != 0 {
t.Errorf("Expected exactly zero ticks, got %d", used) t.Errorf("Expected exactly zero ticks, got %d", used)
} }
}
func TestPollImmediateError(t *testing.T) {
expectedError := errors.New("Expected error") expectedError := errors.New("Expected error")
f = ConditionFunc(func() (bool, error) { f := ConditionFunc(func() (bool, error) {
return false, expectedError return false, expectedError
}) })
fp = fakePoller{max: 0} fp := fakePoller{max: 0, wg: sync.WaitGroup{}}
if err := pollImmediateInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err == nil || err != expectedError { if err := pollImmediateInternal(fp.GetWaitFunc(), f); err == nil || err != expectedError {
t.Fatalf("Expected error %v, got none %v", expectedError, err) t.Fatalf("Expected error %v, got none %v", expectedError, err)
} }
if invocations != 1 { used := atomic.LoadInt32(&fp.used)
t.Errorf("Expected exactly one invocation, got %d", invocations)
}
used = atomic.LoadInt32(&fp.used)
if used != 0 { if used != 0 {
t.Errorf("Expected exactly zero ticks, got %d", used) t.Errorf("Expected exactly zero ticks, got %d", used)
} }
@ -236,7 +241,7 @@ func TestWaitFor(t *testing.T) {
} }
for k, c := range testCases { for k, c := range testCases {
invocations = 0 invocations = 0
ticker := fakeTicker(c.Ticks, nil) ticker := fakeTicker(c.Ticks, nil, func() {})
err := func() error { err := func() error {
done := make(chan struct{}) done := make(chan struct{})
defer close(done) defer close(done)