From a5e45defb15e7f418de4a9d06ef63dd4b7e8fe21 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Thu, 2 Jul 2020 16:47:54 -0400 Subject: [PATCH 1/2] agent/http: un-embed the HTTPServer The embedded HTTPServer struct is not used by the large HTTPServer struct. It is used by tests and the agent. This change is a small first step in the process of removing that field. The eventual goal is to reduce the scope of HTTPServer making it easier to test, and split into separate packages. --- agent/agent.go | 21 ++++++++-------- agent/agent_endpoint_test.go | 3 ++- agent/http.go | 3 ++- agent/http_oss_test.go | 4 +-- agent/http_test.go | 47 ++++++++++++++++++------------------ agent/testagent.go | 2 +- agent/ui_endpoint_test.go | 2 +- 7 files changed, 42 insertions(+), 40 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 9e5666a0f7..14d3502df9 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1151,24 +1151,25 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) { l = tls.NewListener(l, tlscfg) } + httpServer := &http.Server{ + Addr: l.Addr().String(), + TLSConfig: tlscfg, + } srv := &HTTPServer{ - Server: &http.Server{ - Addr: l.Addr().String(), - TLSConfig: tlscfg, - }, + Server: httpServer, ln: l, agent: a, denylist: NewDenylist(a.config.HTTPBlockEndpoints), proto: proto, } - srv.Server.Handler = srv.handler(a.config.EnableDebug) + httpServer.Handler = srv.handler(a.config.EnableDebug) // Load the connlimit helper into the server connLimitFn := a.httpConnLimiter.HTTPConnStateFunc() if proto == "https" { // Enforce TLS handshake timeout - srv.Server.ConnState = func(conn net.Conn, state http.ConnState) { + httpServer.ConnState = func(conn net.Conn, state http.ConnState) { switch state { case http.StateNew: // Set deadline to prevent slow send before TLS handshake or first @@ -1188,12 +1189,12 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) { // This will enable upgrading connections to HTTP/2 as // part of TLS negotiation. - err = http2.ConfigureServer(srv.Server, nil) + err = http2.ConfigureServer(httpServer, nil) if err != nil { return err } } else { - srv.Server.ConnState = connLimitFn + httpServer.ConnState = connLimitFn } ln = append(ln, l) @@ -1263,7 +1264,7 @@ func (a *Agent) serveHTTP(srv *HTTPServer) error { go func() { defer a.wgServers.Done() notif <- srv.ln.Addr() - err := srv.Serve(srv.ln) + err := srv.Server.Serve(srv.ln) if err != nil && err != http.ErrServerClosed { a.logger.Error("error closing server", "error", err) } @@ -2111,7 +2112,7 @@ func (a *Agent) ShutdownEndpoints() { ) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - srv.Shutdown(ctx) + srv.Server.Shutdown(ctx) if ctx.Err() == context.DeadlineExceeded { a.logger.Warn("Timeout stopping server", "protocol", strings.ToUpper(srv.proto), diff --git a/agent/agent_endpoint_test.go b/agent/agent_endpoint_test.go index 31f2501f8d..c9aef60e3f 100644 --- a/agent/agent_endpoint_test.go +++ b/agent/agent_endpoint_test.go @@ -4465,7 +4465,8 @@ func TestAgent_Monitor(t *testing.T) { req = req.WithContext(cancelCtx) resp := httptest.NewRecorder() - go a.srv.Handler.ServeHTTP(resp, req) + handler := a.srv.handler(true) + go handler.ServeHTTP(resp, req) args := &structs.ServiceDefinition{ Name: "monitor", diff --git a/agent/http.go b/agent/http.go index 7402ef5f95..4cc754dfe1 100644 --- a/agent/http.go +++ b/agent/http.go @@ -81,7 +81,8 @@ func (e ForbiddenError) Error() string { // HTTPServer provides an HTTP api for an agent. type HTTPServer struct { - *http.Server + // TODO(dnephin): remove Server field, it is not used by any of the HTTPServer methods + Server *http.Server ln net.Listener agent *Agent denylist *Denylist diff --git a/agent/http_oss_test.go b/agent/http_oss_test.go index 62dafc61a4..8e936d9380 100644 --- a/agent/http_oss_test.go +++ b/agent/http_oss_test.go @@ -133,7 +133,7 @@ func TestHTTPAPI_OptionMethod_OSS(t *testing.T) { uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path) req, _ := http.NewRequest("OPTIONS", uri, nil) resp := httptest.NewRecorder() - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) allMethods := append([]string{"OPTIONS"}, methods...) if resp.Code != http.StatusOK { @@ -175,7 +175,7 @@ func TestHTTPAPI_AllowedNets_OSS(t *testing.T) { req, _ := http.NewRequest(method, uri, nil) req.RemoteAddr = "192.168.1.2:5555" resp := httptest.NewRecorder() - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) require.Equal(t, http.StatusForbidden, resp.Code, "%s %s", method, path) }) diff --git a/agent/http_test.go b/agent/http_test.go index e64edf0ef2..b90715e298 100644 --- a/agent/http_test.go +++ b/agent/http_test.go @@ -129,7 +129,7 @@ func TestHTTPServer_UnixSocket_FileExists(t *testing.T) { } } -func TestHTTPServer_H2(t *testing.T) { +func TestHTTPServer_HTTP2(t *testing.T) { t.Parallel() // Fire up an agent with TLS enabled. @@ -161,16 +161,15 @@ func TestHTTPServer_H2(t *testing.T) { if err := http2.ConfigureTransport(transport); err != nil { t.Fatalf("err: %v", err) } - hc := &http.Client{ - Transport: transport, - } + httpClient := &http.Client{Transport: transport} // Hook a handler that echoes back the protocol. handler := func(resp http.ResponseWriter, req *http.Request) { resp.WriteHeader(http.StatusOK) fmt.Fprint(resp, req.Proto) } - w, ok := a.srv.Handler.(*wrappedMux) + + w, ok := a.srv.Server.Handler.(*wrappedMux) if !ok { t.Fatalf("handler is not expected type") } @@ -178,7 +177,7 @@ func TestHTTPServer_H2(t *testing.T) { // Call it and make sure we see HTTP/2. url := fmt.Sprintf("https://%s/echo", a.srv.ln.Addr().String()) - resp, err := hc.Get(url) + resp, err := httpClient.Get(url) if err != nil { t.Fatalf("err: %v", err) } @@ -197,7 +196,7 @@ func TestHTTPServer_H2(t *testing.T) { cfg := &api.Config{ Address: a.srv.ln.Addr().String(), Scheme: "https", - HttpClient: hc, + HttpClient: httpClient, } client, err := api.NewClient(cfg) if err != nil { @@ -333,7 +332,7 @@ func TestHTTPAPI_Ban_Nonprintable_Characters(t *testing.T) { t.Fatal(err) } resp := httptest.NewRecorder() - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) if got, want := resp.Code, http.StatusBadRequest; got != want { t.Fatalf("bad response code got %d want %d", got, want) } @@ -352,7 +351,7 @@ func TestHTTPAPI_Allow_Nonprintable_Characters_With_Flag(t *testing.T) { t.Fatal(err) } resp := httptest.NewRecorder() - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) // Key doesn't actually exist so we should get 404 if got, want := resp.Code, http.StatusNotFound; got != want { t.Fatalf("bad response code got %d want %d", got, want) @@ -490,14 +489,14 @@ func TestAcceptEncodingGzip(t *testing.T) { // negotiation, but since this call doesn't go through a real // transport, the header has to be set manually req.Header["Accept-Encoding"] = []string{"gzip"} - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) require.Equal(t, 200, resp.Code) require.Equal(t, "", resp.Header().Get("Content-Encoding")) resp = httptest.NewRecorder() req, _ = http.NewRequest("GET", "/v1/kv/long", nil) req.Header["Accept-Encoding"] = []string{"gzip"} - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) require.Equal(t, 200, resp.Code) require.Equal(t, "gzip", resp.Header().Get("Content-Encoding")) } @@ -811,35 +810,35 @@ func TestParseWait(t *testing.T) { } } -func TestPProfHandlers_EnableDebug(t *testing.T) { +func TestHTTPServer_PProfHandlers_EnableDebug(t *testing.T) { t.Parallel() - require := require.New(t) - a := NewTestAgent(t, "enable_debug = true") + a := NewTestAgent(t, ``) defer a.Shutdown() resp := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil) - a.srv.Handler.ServeHTTP(resp, req) + httpServer := &HTTPServer{agent: a.Agent} + httpServer.handler(true).ServeHTTP(resp, req) - require.Equal(http.StatusOK, resp.Code) + require.Equal(t, http.StatusOK, resp.Code) } -func TestPProfHandlers_DisableDebugNoACLs(t *testing.T) { +func TestHTTPServer_PProfHandlers_DisableDebugNoACLs(t *testing.T) { t.Parallel() - require := require.New(t) - a := NewTestAgent(t, "enable_debug = false") + a := NewTestAgent(t, ``) defer a.Shutdown() resp := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/debug/pprof/profile", nil) - a.srv.Handler.ServeHTTP(resp, req) + httpServer := &HTTPServer{agent: a.Agent} + httpServer.handler(false).ServeHTTP(resp, req) - require.Equal(http.StatusUnauthorized, resp.Code) + require.Equal(t, http.StatusUnauthorized, resp.Code) } -func TestPProfHandlers_ACLs(t *testing.T) { +func TestHTTPServer_PProfHandlers_ACLs(t *testing.T) { t.Parallel() assert := assert.New(t) dc1 := "dc1" @@ -904,7 +903,7 @@ func TestPProfHandlers_ACLs(t *testing.T) { t.Run(fmt.Sprintf("case %d (%#v)", i, c), func(t *testing.T) { req, _ := http.NewRequest("GET", fmt.Sprintf("%s?token=%s", c.endpoint, c.token), nil) resp := httptest.NewRecorder() - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) assert.Equal(c.code, resp.Code) }) } @@ -1192,7 +1191,7 @@ func TestEnableWebUI(t *testing.T) { req, _ := http.NewRequest("GET", "/ui/", nil) resp := httptest.NewRecorder() - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) if resp.Code != 200 { t.Fatalf("should handle ui") } diff --git a/agent/testagent.go b/agent/testagent.go index 635ae6e581..3ddcc2de6f 100644 --- a/agent/testagent.go +++ b/agent/testagent.go @@ -384,7 +384,7 @@ func (a *TestAgent) HTTPAddr() string { if a.srv == nil { return "" } - return a.srv.Addr + return a.srv.Server.Addr } func (a *TestAgent) SegmentAddr(name string) string { diff --git a/agent/ui_endpoint_test.go b/agent/ui_endpoint_test.go index b139ba7240..4694e392b8 100644 --- a/agent/ui_endpoint_test.go +++ b/agent/ui_endpoint_test.go @@ -43,7 +43,7 @@ func TestUiIndex(t *testing.T) { // Register node req, _ := http.NewRequest("GET", "/ui/my-file", nil) req.URL.Scheme = "http" - req.URL.Host = a.srv.Addr + req.URL.Host = a.srv.Server.Addr // Make the request client := cleanhttp.DefaultClient() From df4088291c04df7171e6c2d1e135c73f2ea0e995 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Thu, 2 Jul 2020 17:51:25 -0400 Subject: [PATCH 2/2] agent/http: Update TestSetupHTTPServer_HTTP2 To remove the need to store the http.Server. This will allow us to remove the http.Server field from the HTTPServer struct. --- agent/agent.go | 51 +++++++++++++++++++++++++--------------------- agent/http_test.go | 31 ++++++++++++++++++++-------- 2 files changed, 51 insertions(+), 31 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 14d3502df9..521f4f86ef 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1168,29 +1168,7 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) { connLimitFn := a.httpConnLimiter.HTTPConnStateFunc() if proto == "https" { - // Enforce TLS handshake timeout - httpServer.ConnState = func(conn net.Conn, state http.ConnState) { - switch state { - case http.StateNew: - // Set deadline to prevent slow send before TLS handshake or first - // byte of request. - conn.SetReadDeadline(time.Now().Add(a.config.HTTPSHandshakeTimeout)) - case http.StateActive: - // Clear read deadline. We should maybe set read timeouts more - // generally but that's a bigger task as some HTTP endpoints may - // stream large requests and responses (e.g. snapshot) so we can't - // set sensible blanket timeouts here. - conn.SetReadDeadline(time.Time{}) - } - // Pass through to conn limit. This is OK because we didn't change - // state (i.e. Close conn). - connLimitFn(conn, state) - } - - // This will enable upgrading connections to HTTP/2 as - // part of TLS negotiation. - err = http2.ConfigureServer(httpServer, nil) - if err != nil { + if err := setupHTTPS(httpServer, connLimitFn, a.config.HTTPSHandshakeTimeout); err != nil { return err } } else { @@ -1218,6 +1196,33 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) { return servers, nil } +// setupHTTPS adds HTTP/2 support, ConnState, and a connection handshake timeout +// to the http.Server. +func setupHTTPS(server *http.Server, connState func(net.Conn, http.ConnState), timeout time.Duration) error { + // Enforce TLS handshake timeout + server.ConnState = func(conn net.Conn, state http.ConnState) { + switch state { + case http.StateNew: + // Set deadline to prevent slow send before TLS handshake or first + // byte of request. + conn.SetReadDeadline(time.Now().Add(timeout)) + case http.StateActive: + // Clear read deadline. We should maybe set read timeouts more + // generally but that's a bigger task as some HTTP endpoints may + // stream large requests and responses (e.g. snapshot) so we can't + // set sensible blanket timeouts here. + conn.SetReadDeadline(time.Time{}) + } + // Pass through to conn limit. This is OK because we didn't change + // state (i.e. Close conn). + connState(conn, state) + } + + // This will enable upgrading connections to HTTP/2 as + // part of TLS negotiation. + return http2.ConfigureServer(server, nil) +} + // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted // connections. It's used so dead TCP connections eventually go away. type tcpKeepAliveListener struct { diff --git a/agent/http_test.go b/agent/http_test.go index b90715e298..9136c1eddd 100644 --- a/agent/http_test.go +++ b/agent/http_test.go @@ -3,6 +3,7 @@ package agent import ( "bytes" "context" + "crypto/tls" "encoding/json" "fmt" "io" @@ -129,7 +130,7 @@ func TestHTTPServer_UnixSocket_FileExists(t *testing.T) { } } -func TestHTTPServer_HTTP2(t *testing.T) { +func TestSetupHTTPServer_HTTP2(t *testing.T) { t.Parallel() // Fire up an agent with TLS enabled. @@ -169,14 +170,28 @@ func TestHTTPServer_HTTP2(t *testing.T) { fmt.Fprint(resp, req.Proto) } - w, ok := a.srv.Server.Handler.(*wrappedMux) - if !ok { - t.Fatalf("handler is not expected type") - } - w.mux.HandleFunc("/echo", handler) + // Create an httpServer to be configured with setupHTTPS, and add our + // custom handler. + httpServer := &http.Server{} + noopConnState := func(net.Conn, http.ConnState) {} + err = setupHTTPS(httpServer, noopConnState, time.Second) + require.NoError(t, err) + + srvHandler := a.srv.handler(true) + mux, ok := srvHandler.(*wrappedMux) + require.True(t, ok, "expected a *wrappedMux, got %T", handler) + mux.mux.HandleFunc("/echo", handler) + httpServer.Handler = mux + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + tlsListener := tls.NewListener(listener, a.tlsConfigurator.IncomingHTTPSConfig()) + + go httpServer.Serve(tlsListener) + defer httpServer.Shutdown(context.Background()) // Call it and make sure we see HTTP/2. - url := fmt.Sprintf("https://%s/echo", a.srv.ln.Addr().String()) + url := fmt.Sprintf("https://%s/echo", listener.Addr().String()) resp, err := httpClient.Get(url) if err != nil { t.Fatalf("err: %v", err) @@ -194,7 +209,7 @@ func TestHTTPServer_HTTP2(t *testing.T) { // some other endpoint, but configure an API client and make a call // just as a sanity check. cfg := &api.Config{ - Address: a.srv.ln.Addr().String(), + Address: listener.Addr().String(), Scheme: "https", HttpClient: httpClient, }