diff --git a/agent/agent.go b/agent/agent.go index 14d3502df9..521f4f86ef 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1168,29 +1168,7 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) { connLimitFn := a.httpConnLimiter.HTTPConnStateFunc() if proto == "https" { - // Enforce TLS handshake timeout - httpServer.ConnState = func(conn net.Conn, state http.ConnState) { - switch state { - case http.StateNew: - // Set deadline to prevent slow send before TLS handshake or first - // byte of request. - conn.SetReadDeadline(time.Now().Add(a.config.HTTPSHandshakeTimeout)) - case http.StateActive: - // Clear read deadline. We should maybe set read timeouts more - // generally but that's a bigger task as some HTTP endpoints may - // stream large requests and responses (e.g. snapshot) so we can't - // set sensible blanket timeouts here. - conn.SetReadDeadline(time.Time{}) - } - // Pass through to conn limit. This is OK because we didn't change - // state (i.e. Close conn). - connLimitFn(conn, state) - } - - // This will enable upgrading connections to HTTP/2 as - // part of TLS negotiation. - err = http2.ConfigureServer(httpServer, nil) - if err != nil { + if err := setupHTTPS(httpServer, connLimitFn, a.config.HTTPSHandshakeTimeout); err != nil { return err } } else { @@ -1218,6 +1196,33 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) { return servers, nil } +// setupHTTPS adds HTTP/2 support, ConnState, and a connection handshake timeout +// to the http.Server. +func setupHTTPS(server *http.Server, connState func(net.Conn, http.ConnState), timeout time.Duration) error { + // Enforce TLS handshake timeout + server.ConnState = func(conn net.Conn, state http.ConnState) { + switch state { + case http.StateNew: + // Set deadline to prevent slow send before TLS handshake or first + // byte of request. + conn.SetReadDeadline(time.Now().Add(timeout)) + case http.StateActive: + // Clear read deadline. We should maybe set read timeouts more + // generally but that's a bigger task as some HTTP endpoints may + // stream large requests and responses (e.g. snapshot) so we can't + // set sensible blanket timeouts here. + conn.SetReadDeadline(time.Time{}) + } + // Pass through to conn limit. This is OK because we didn't change + // state (i.e. Close conn). + connState(conn, state) + } + + // This will enable upgrading connections to HTTP/2 as + // part of TLS negotiation. + return http2.ConfigureServer(server, nil) +} + // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted // connections. It's used so dead TCP connections eventually go away. type tcpKeepAliveListener struct { diff --git a/agent/http_test.go b/agent/http_test.go index b90715e298..9136c1eddd 100644 --- a/agent/http_test.go +++ b/agent/http_test.go @@ -3,6 +3,7 @@ package agent import ( "bytes" "context" + "crypto/tls" "encoding/json" "fmt" "io" @@ -129,7 +130,7 @@ func TestHTTPServer_UnixSocket_FileExists(t *testing.T) { } } -func TestHTTPServer_HTTP2(t *testing.T) { +func TestSetupHTTPServer_HTTP2(t *testing.T) { t.Parallel() // Fire up an agent with TLS enabled. @@ -169,14 +170,28 @@ func TestHTTPServer_HTTP2(t *testing.T) { fmt.Fprint(resp, req.Proto) } - w, ok := a.srv.Server.Handler.(*wrappedMux) - if !ok { - t.Fatalf("handler is not expected type") - } - w.mux.HandleFunc("/echo", handler) + // Create an httpServer to be configured with setupHTTPS, and add our + // custom handler. + httpServer := &http.Server{} + noopConnState := func(net.Conn, http.ConnState) {} + err = setupHTTPS(httpServer, noopConnState, time.Second) + require.NoError(t, err) + + srvHandler := a.srv.handler(true) + mux, ok := srvHandler.(*wrappedMux) + require.True(t, ok, "expected a *wrappedMux, got %T", handler) + mux.mux.HandleFunc("/echo", handler) + httpServer.Handler = mux + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + tlsListener := tls.NewListener(listener, a.tlsConfigurator.IncomingHTTPSConfig()) + + go httpServer.Serve(tlsListener) + defer httpServer.Shutdown(context.Background()) // Call it and make sure we see HTTP/2. - url := fmt.Sprintf("https://%s/echo", a.srv.ln.Addr().String()) + url := fmt.Sprintf("https://%s/echo", listener.Addr().String()) resp, err := httpClient.Get(url) if err != nil { t.Fatalf("err: %v", err) @@ -194,7 +209,7 @@ func TestHTTPServer_HTTP2(t *testing.T) { // some other endpoint, but configure an API client and make a call // just as a sanity check. cfg := &api.Config{ - Address: a.srv.ln.Addr().String(), + Address: listener.Addr().String(), Scheme: "https", HttpClient: httpClient, }