From 15c5012933f1243bae64367223fd3ceeba12be8f Mon Sep 17 00:00:00 2001 From: Daniel Smith Date: Wed, 15 Oct 2014 17:46:28 -0700 Subject: [PATCH] Add lock to fake handler to avoid races. --- pkg/util/fake_handler.go | 23 ++++++++++++++++++++++- pkg/util/fake_handler_test.go | 2 ++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/pkg/util/fake_handler.go b/pkg/util/fake_handler.go index 2a980e7032..c9041f5da4 100644 --- a/pkg/util/fake_handler.go +++ b/pkg/util/fake_handler.go @@ -21,12 +21,14 @@ import ( "net/http" "net/url" "reflect" + "sync" ) // TestInterface is a simple interface providing Errorf, to make injection for // testing easier (insert 'yo dawg' meme here). type TestInterface interface { Errorf(format string, args ...interface{}) + Logf(format string, args ...interface{}) } // LogInterface is a simple interface to allow injection of Logf to report serving errors. @@ -45,9 +47,21 @@ type FakeHandler struct { // For logging - you can use a *testing.T // This will keep log messages associated with the test. T LogInterface + + // Enforce "only one use" constraint. + lock sync.Mutex + requestCount int + hasBeenChecked bool } func (f *FakeHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { + f.lock.Lock() + defer f.lock.Unlock() + f.requestCount++ + if f.hasBeenChecked { + panic("got request after having been validated") + } + f.RequestReceived = request response.WriteHeader(f.StatusCode) response.Write([]byte(f.ResponseBody)) @@ -60,7 +74,14 @@ func (f *FakeHandler) ServeHTTP(response http.ResponseWriter, request *http.Requ } // ValidateRequest verifies that FakeHandler received a request with expected path, method, and body. -func (f FakeHandler) ValidateRequest(t TestInterface, expectedPath, expectedMethod string, body *string) { +func (f *FakeHandler) ValidateRequest(t TestInterface, expectedPath, expectedMethod string, body *string) { + f.lock.Lock() + defer f.lock.Unlock() + if f.requestCount != 1 { + t.Logf("Expected 1 call, but got %v. Only the last call is recorded and checked.", f.requestCount) + } + f.hasBeenChecked = true + expectURL, err := url.Parse(expectedPath) if err != nil { t.Errorf("Couldn't parse %v as a URL.", expectedPath) diff --git a/pkg/util/fake_handler_test.go b/pkg/util/fake_handler_test.go index 2a45474bf8..34c0ded117 100644 --- a/pkg/util/fake_handler_test.go +++ b/pkg/util/fake_handler_test.go @@ -72,6 +72,8 @@ func (f *fakeError) Errorf(format string, args ...interface{}) { f.errors = append(f.errors, format) } +func (f *fakeError) Logf(format string, args ...interface{}) {} + func TestFakeHandlerWrongPath(t *testing.T) { handler := FakeHandler{} server := httptest.NewServer(&handler)