agent/http: un-embed the HTTPServer

The embedded HTTPServer struct is not used by the large HTTPServer
struct. It is used by tests and the agent. This change is a small first
step in the process of removing that field.

The eventual goal is to reduce the scope of HTTPServer making it easier
to test, and split into separate packages.
pull/8231/head
Daniel Nephin 2020-07-02 16:47:54 -04:00
parent db387eccd6
commit a5e45defb1
7 changed files with 42 additions and 40 deletions

View File

@ -1151,24 +1151,25 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
l = tls.NewListener(l, tlscfg) l = tls.NewListener(l, tlscfg)
} }
httpServer := &http.Server{
Addr: l.Addr().String(),
TLSConfig: tlscfg,
}
srv := &HTTPServer{ srv := &HTTPServer{
Server: &http.Server{ Server: httpServer,
Addr: l.Addr().String(),
TLSConfig: tlscfg,
},
ln: l, ln: l,
agent: a, agent: a,
denylist: NewDenylist(a.config.HTTPBlockEndpoints), denylist: NewDenylist(a.config.HTTPBlockEndpoints),
proto: proto, proto: proto,
} }
srv.Server.Handler = srv.handler(a.config.EnableDebug) httpServer.Handler = srv.handler(a.config.EnableDebug)
// Load the connlimit helper into the server // Load the connlimit helper into the server
connLimitFn := a.httpConnLimiter.HTTPConnStateFunc() connLimitFn := a.httpConnLimiter.HTTPConnStateFunc()
if proto == "https" { if proto == "https" {
// Enforce TLS handshake timeout // Enforce TLS handshake timeout
srv.Server.ConnState = func(conn net.Conn, state http.ConnState) { httpServer.ConnState = func(conn net.Conn, state http.ConnState) {
switch state { switch state {
case http.StateNew: case http.StateNew:
// Set deadline to prevent slow send before TLS handshake or first // Set deadline to prevent slow send before TLS handshake or first
@ -1188,12 +1189,12 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
// This will enable upgrading connections to HTTP/2 as // This will enable upgrading connections to HTTP/2 as
// part of TLS negotiation. // part of TLS negotiation.
err = http2.ConfigureServer(srv.Server, nil) err = http2.ConfigureServer(httpServer, nil)
if err != nil { if err != nil {
return err return err
} }
} else { } else {
srv.Server.ConnState = connLimitFn httpServer.ConnState = connLimitFn
} }
ln = append(ln, l) ln = append(ln, l)
@ -1263,7 +1264,7 @@ func (a *Agent) serveHTTP(srv *HTTPServer) error {
go func() { go func() {
defer a.wgServers.Done() defer a.wgServers.Done()
notif <- srv.ln.Addr() notif <- srv.ln.Addr()
err := srv.Serve(srv.ln) err := srv.Server.Serve(srv.ln)
if err != nil && err != http.ErrServerClosed { if err != nil && err != http.ErrServerClosed {
a.logger.Error("error closing server", "error", err) a.logger.Error("error closing server", "error", err)
} }
@ -2111,7 +2112,7 @@ func (a *Agent) ShutdownEndpoints() {
) )
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
srv.Shutdown(ctx) srv.Server.Shutdown(ctx)
if ctx.Err() == context.DeadlineExceeded { if ctx.Err() == context.DeadlineExceeded {
a.logger.Warn("Timeout stopping server", a.logger.Warn("Timeout stopping server",
"protocol", strings.ToUpper(srv.proto), "protocol", strings.ToUpper(srv.proto),

View File

@ -4465,7 +4465,8 @@ func TestAgent_Monitor(t *testing.T) {
req = req.WithContext(cancelCtx) req = req.WithContext(cancelCtx)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
go a.srv.Handler.ServeHTTP(resp, req) handler := a.srv.handler(true)
go handler.ServeHTTP(resp, req)
args := &structs.ServiceDefinition{ args := &structs.ServiceDefinition{
Name: "monitor", Name: "monitor",

View File

@ -81,7 +81,8 @@ func (e ForbiddenError) Error() string {
// HTTPServer provides an HTTP api for an agent. // HTTPServer provides an HTTP api for an agent.
type HTTPServer struct { type HTTPServer struct {
*http.Server // TODO(dnephin): remove Server field, it is not used by any of the HTTPServer methods
Server *http.Server
ln net.Listener ln net.Listener
agent *Agent agent *Agent
denylist *Denylist denylist *Denylist

View File

@ -133,7 +133,7 @@ func TestHTTPAPI_OptionMethod_OSS(t *testing.T) {
uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path) uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path)
req, _ := http.NewRequest("OPTIONS", uri, nil) req, _ := http.NewRequest("OPTIONS", uri, nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
allMethods := append([]string{"OPTIONS"}, methods...) allMethods := append([]string{"OPTIONS"}, methods...)
if resp.Code != http.StatusOK { if resp.Code != http.StatusOK {
@ -175,7 +175,7 @@ func TestHTTPAPI_AllowedNets_OSS(t *testing.T) {
req, _ := http.NewRequest(method, uri, nil) req, _ := http.NewRequest(method, uri, nil)
req.RemoteAddr = "192.168.1.2:5555" req.RemoteAddr = "192.168.1.2:5555"
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
require.Equal(t, http.StatusForbidden, resp.Code, "%s %s", method, path) require.Equal(t, http.StatusForbidden, resp.Code, "%s %s", method, path)
}) })

View File

@ -129,7 +129,7 @@ func TestHTTPServer_UnixSocket_FileExists(t *testing.T) {
} }
} }
func TestHTTPServer_H2(t *testing.T) { func TestHTTPServer_HTTP2(t *testing.T) {
t.Parallel() t.Parallel()
// Fire up an agent with TLS enabled. // Fire up an agent with TLS enabled.
@ -161,16 +161,15 @@ func TestHTTPServer_H2(t *testing.T) {
if err := http2.ConfigureTransport(transport); err != nil { if err := http2.ConfigureTransport(transport); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
hc := &http.Client{ httpClient := &http.Client{Transport: transport}
Transport: transport,
}
// Hook a handler that echoes back the protocol. // Hook a handler that echoes back the protocol.
handler := func(resp http.ResponseWriter, req *http.Request) { handler := func(resp http.ResponseWriter, req *http.Request) {
resp.WriteHeader(http.StatusOK) resp.WriteHeader(http.StatusOK)
fmt.Fprint(resp, req.Proto) fmt.Fprint(resp, req.Proto)
} }
w, ok := a.srv.Handler.(*wrappedMux)
w, ok := a.srv.Server.Handler.(*wrappedMux)
if !ok { if !ok {
t.Fatalf("handler is not expected type") t.Fatalf("handler is not expected type")
} }
@ -178,7 +177,7 @@ func TestHTTPServer_H2(t *testing.T) {
// Call it and make sure we see HTTP/2. // Call it and make sure we see HTTP/2.
url := fmt.Sprintf("https://%s/echo", a.srv.ln.Addr().String()) url := fmt.Sprintf("https://%s/echo", a.srv.ln.Addr().String())
resp, err := hc.Get(url) resp, err := httpClient.Get(url)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -197,7 +196,7 @@ func TestHTTPServer_H2(t *testing.T) {
cfg := &api.Config{ cfg := &api.Config{
Address: a.srv.ln.Addr().String(), Address: a.srv.ln.Addr().String(),
Scheme: "https", Scheme: "https",
HttpClient: hc, HttpClient: httpClient,
} }
client, err := api.NewClient(cfg) client, err := api.NewClient(cfg)
if err != nil { if err != nil {
@ -333,7 +332,7 @@ func TestHTTPAPI_Ban_Nonprintable_Characters(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
if got, want := resp.Code, http.StatusBadRequest; got != want { if got, want := resp.Code, http.StatusBadRequest; got != want {
t.Fatalf("bad response code got %d want %d", got, want) t.Fatalf("bad response code got %d want %d", got, want)
} }
@ -352,7 +351,7 @@ func TestHTTPAPI_Allow_Nonprintable_Characters_With_Flag(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
// Key doesn't actually exist so we should get 404 // Key doesn't actually exist so we should get 404
if got, want := resp.Code, http.StatusNotFound; got != want { if got, want := resp.Code, http.StatusNotFound; got != want {
t.Fatalf("bad response code got %d want %d", got, want) t.Fatalf("bad response code got %d want %d", got, want)
@ -490,14 +489,14 @@ func TestAcceptEncodingGzip(t *testing.T) {
// negotiation, but since this call doesn't go through a real // negotiation, but since this call doesn't go through a real
// transport, the header has to be set manually // transport, the header has to be set manually
req.Header["Accept-Encoding"] = []string{"gzip"} req.Header["Accept-Encoding"] = []string{"gzip"}
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
require.Equal(t, 200, resp.Code) require.Equal(t, 200, resp.Code)
require.Equal(t, "", resp.Header().Get("Content-Encoding")) require.Equal(t, "", resp.Header().Get("Content-Encoding"))
resp = httptest.NewRecorder() resp = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/v1/kv/long", nil) req, _ = http.NewRequest("GET", "/v1/kv/long", nil)
req.Header["Accept-Encoding"] = []string{"gzip"} req.Header["Accept-Encoding"] = []string{"gzip"}
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
require.Equal(t, 200, resp.Code) require.Equal(t, 200, resp.Code)
require.Equal(t, "gzip", resp.Header().Get("Content-Encoding")) require.Equal(t, "gzip", resp.Header().Get("Content-Encoding"))
} }
@ -811,35 +810,35 @@ func TestParseWait(t *testing.T) {
} }
} }
func TestPProfHandlers_EnableDebug(t *testing.T) { func TestHTTPServer_PProfHandlers_EnableDebug(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t) a := NewTestAgent(t, ``)
a := NewTestAgent(t, "enable_debug = true")
defer a.Shutdown() defer a.Shutdown()
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil) req, _ := http.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil)
a.srv.Handler.ServeHTTP(resp, req) httpServer := &HTTPServer{agent: a.Agent}
httpServer.handler(true).ServeHTTP(resp, req)
require.Equal(http.StatusOK, resp.Code) require.Equal(t, http.StatusOK, resp.Code)
} }
func TestPProfHandlers_DisableDebugNoACLs(t *testing.T) { func TestHTTPServer_PProfHandlers_DisableDebugNoACLs(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t) a := NewTestAgent(t, ``)
a := NewTestAgent(t, "enable_debug = false")
defer a.Shutdown() defer a.Shutdown()
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/debug/pprof/profile", nil) req, _ := http.NewRequest("GET", "/debug/pprof/profile", nil)
a.srv.Handler.ServeHTTP(resp, req) httpServer := &HTTPServer{agent: a.Agent}
httpServer.handler(false).ServeHTTP(resp, req)
require.Equal(http.StatusUnauthorized, resp.Code) require.Equal(t, http.StatusUnauthorized, resp.Code)
} }
func TestPProfHandlers_ACLs(t *testing.T) { func TestHTTPServer_PProfHandlers_ACLs(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t) assert := assert.New(t)
dc1 := "dc1" dc1 := "dc1"
@ -904,7 +903,7 @@ func TestPProfHandlers_ACLs(t *testing.T) {
t.Run(fmt.Sprintf("case %d (%#v)", i, c), func(t *testing.T) { t.Run(fmt.Sprintf("case %d (%#v)", i, c), func(t *testing.T) {
req, _ := http.NewRequest("GET", fmt.Sprintf("%s?token=%s", c.endpoint, c.token), nil) req, _ := http.NewRequest("GET", fmt.Sprintf("%s?token=%s", c.endpoint, c.token), nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
assert.Equal(c.code, resp.Code) assert.Equal(c.code, resp.Code)
}) })
} }
@ -1192,7 +1191,7 @@ func TestEnableWebUI(t *testing.T) {
req, _ := http.NewRequest("GET", "/ui/", nil) req, _ := http.NewRequest("GET", "/ui/", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
if resp.Code != 200 { if resp.Code != 200 {
t.Fatalf("should handle ui") t.Fatalf("should handle ui")
} }

View File

@ -384,7 +384,7 @@ func (a *TestAgent) HTTPAddr() string {
if a.srv == nil { if a.srv == nil {
return "" return ""
} }
return a.srv.Addr return a.srv.Server.Addr
} }
func (a *TestAgent) SegmentAddr(name string) string { func (a *TestAgent) SegmentAddr(name string) string {

View File

@ -43,7 +43,7 @@ func TestUiIndex(t *testing.T) {
// Register node // Register node
req, _ := http.NewRequest("GET", "/ui/my-file", nil) req, _ := http.NewRequest("GET", "/ui/my-file", nil)
req.URL.Scheme = "http" req.URL.Scheme = "http"
req.URL.Host = a.srv.Addr req.URL.Host = a.srv.Server.Addr
// Make the request // Make the request
client := cleanhttp.DefaultClient() client := cleanhttp.DefaultClient()