diff --git a/agent/http.go b/agent/http.go index 930f003472..1b869102e9 100644 --- a/agent/http.go +++ b/agent/http.go @@ -15,6 +15,7 @@ import ( "github.com/armon/go-metrics" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-cleanhttp" "github.com/mitchellh/mapstructure" ) @@ -62,6 +63,17 @@ func registerEndpoint(pattern string, fn unboundEndpoint) { endpoints[pattern] = fn } +// wrappedMux hangs on to the underlying mux for unit tests. +type wrappedMux struct { + mux *http.ServeMux + handler http.Handler +} + +// ServeHTTP implements the http.Handler interface. +func (w *wrappedMux) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + w.handler.ServeHTTP(resp, req) +} + // handler is used to attach our handlers to the mux func (s *HTTPServer) handler(enableDebug bool) http.Handler { mux := http.NewServeMux() @@ -118,7 +130,13 @@ func (s *HTTPServer) handler(enableDebug bool) http.Handler { } else if s.agent.config.EnableUI { mux.Handle("/ui/", http.StripPrefix("/ui/", http.FileServer(assetFS()))) } - return mux + + // Wrap the whole mux with a handler that bans URLs with non-printable + // characters. + return &wrappedMux{ + mux: mux, + handler: cleanhttp.PrintablePathCheckHandler(mux, nil), + } } // aclEndpointRE is used to find old ACL endpoints that take tokens in the URL diff --git a/agent/http_test.go b/agent/http_test.go index 8a347d046e..12ce2a8138 100644 --- a/agent/http_test.go +++ b/agent/http_test.go @@ -165,11 +165,11 @@ func TestHTTPServer_H2(t *testing.T) { resp.WriteHeader(http.StatusOK) fmt.Fprint(resp, req.Proto) } - mux, ok := a.srv.Handler.(*http.ServeMux) + w, ok := a.srv.Handler.(*wrappedMux) if !ok { t.Fatalf("handler is not expected type") } - mux.HandleFunc("/echo", handler) + w.mux.HandleFunc("/echo", handler) // Call it and make sure we see HTTP/2. url := fmt.Sprintf("https://%s/echo", a.srv.ln.Addr().String()) @@ -315,6 +315,18 @@ func TestHTTPAPI_BlockEndpoints(t *testing.T) { } } +func TestHTTPAPI_Ban_Nonprintable_Characters(t *testing.T) { + a := NewTestAgent(t.Name(), "") + defer a.Shutdown() + + req, _ := http.NewRequest("GET", "/v1/kv/bad\x00ness", nil) + resp := httptest.NewRecorder() + a.srv.Handler.ServeHTTP(resp, req) + if got, want := resp.Code, http.StatusBadRequest; got != want { + t.Fatalf("bad response code got %d want %d", got, want) + } +} + func TestHTTPAPI_TranslateAddrHeader(t *testing.T) { t.Parallel() // Header should not be present if address translation is off.