mirror of https://github.com/hashicorp/consul
agent/http: un-embed the HTTPServer
The embedded HTTPServer struct is not used by the large HTTPServer struct. It is used by tests and the agent. This change is a small first step in the process of removing that field. The eventual goal is to reduce the scope of HTTPServer making it easier to test, and split into separate packages.pull/8231/head
parent
db387eccd6
commit
a5e45defb1
|
@ -1151,24 +1151,25 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
|
||||||
l = tls.NewListener(l, tlscfg)
|
l = tls.NewListener(l, tlscfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
httpServer := &http.Server{
|
||||||
|
Addr: l.Addr().String(),
|
||||||
|
TLSConfig: tlscfg,
|
||||||
|
}
|
||||||
srv := &HTTPServer{
|
srv := &HTTPServer{
|
||||||
Server: &http.Server{
|
Server: httpServer,
|
||||||
Addr: l.Addr().String(),
|
|
||||||
TLSConfig: tlscfg,
|
|
||||||
},
|
|
||||||
ln: l,
|
ln: l,
|
||||||
agent: a,
|
agent: a,
|
||||||
denylist: NewDenylist(a.config.HTTPBlockEndpoints),
|
denylist: NewDenylist(a.config.HTTPBlockEndpoints),
|
||||||
proto: proto,
|
proto: proto,
|
||||||
}
|
}
|
||||||
srv.Server.Handler = srv.handler(a.config.EnableDebug)
|
httpServer.Handler = srv.handler(a.config.EnableDebug)
|
||||||
|
|
||||||
// Load the connlimit helper into the server
|
// Load the connlimit helper into the server
|
||||||
connLimitFn := a.httpConnLimiter.HTTPConnStateFunc()
|
connLimitFn := a.httpConnLimiter.HTTPConnStateFunc()
|
||||||
|
|
||||||
if proto == "https" {
|
if proto == "https" {
|
||||||
// Enforce TLS handshake timeout
|
// Enforce TLS handshake timeout
|
||||||
srv.Server.ConnState = func(conn net.Conn, state http.ConnState) {
|
httpServer.ConnState = func(conn net.Conn, state http.ConnState) {
|
||||||
switch state {
|
switch state {
|
||||||
case http.StateNew:
|
case http.StateNew:
|
||||||
// Set deadline to prevent slow send before TLS handshake or first
|
// Set deadline to prevent slow send before TLS handshake or first
|
||||||
|
@ -1188,12 +1189,12 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
|
||||||
|
|
||||||
// This will enable upgrading connections to HTTP/2 as
|
// This will enable upgrading connections to HTTP/2 as
|
||||||
// part of TLS negotiation.
|
// part of TLS negotiation.
|
||||||
err = http2.ConfigureServer(srv.Server, nil)
|
err = http2.ConfigureServer(httpServer, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
srv.Server.ConnState = connLimitFn
|
httpServer.ConnState = connLimitFn
|
||||||
}
|
}
|
||||||
|
|
||||||
ln = append(ln, l)
|
ln = append(ln, l)
|
||||||
|
@ -1263,7 +1264,7 @@ func (a *Agent) serveHTTP(srv *HTTPServer) error {
|
||||||
go func() {
|
go func() {
|
||||||
defer a.wgServers.Done()
|
defer a.wgServers.Done()
|
||||||
notif <- srv.ln.Addr()
|
notif <- srv.ln.Addr()
|
||||||
err := srv.Serve(srv.ln)
|
err := srv.Server.Serve(srv.ln)
|
||||||
if err != nil && err != http.ErrServerClosed {
|
if err != nil && err != http.ErrServerClosed {
|
||||||
a.logger.Error("error closing server", "error", err)
|
a.logger.Error("error closing server", "error", err)
|
||||||
}
|
}
|
||||||
|
@ -2111,7 +2112,7 @@ func (a *Agent) ShutdownEndpoints() {
|
||||||
)
|
)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
srv.Shutdown(ctx)
|
srv.Server.Shutdown(ctx)
|
||||||
if ctx.Err() == context.DeadlineExceeded {
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
a.logger.Warn("Timeout stopping server",
|
a.logger.Warn("Timeout stopping server",
|
||||||
"protocol", strings.ToUpper(srv.proto),
|
"protocol", strings.ToUpper(srv.proto),
|
||||||
|
|
|
@ -4465,7 +4465,8 @@ func TestAgent_Monitor(t *testing.T) {
|
||||||
req = req.WithContext(cancelCtx)
|
req = req.WithContext(cancelCtx)
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
go a.srv.Handler.ServeHTTP(resp, req)
|
handler := a.srv.handler(true)
|
||||||
|
go handler.ServeHTTP(resp, req)
|
||||||
|
|
||||||
args := &structs.ServiceDefinition{
|
args := &structs.ServiceDefinition{
|
||||||
Name: "monitor",
|
Name: "monitor",
|
||||||
|
|
|
@ -81,7 +81,8 @@ func (e ForbiddenError) Error() string {
|
||||||
|
|
||||||
// HTTPServer provides an HTTP api for an agent.
|
// HTTPServer provides an HTTP api for an agent.
|
||||||
type HTTPServer struct {
|
type HTTPServer struct {
|
||||||
*http.Server
|
// TODO(dnephin): remove Server field, it is not used by any of the HTTPServer methods
|
||||||
|
Server *http.Server
|
||||||
ln net.Listener
|
ln net.Listener
|
||||||
agent *Agent
|
agent *Agent
|
||||||
denylist *Denylist
|
denylist *Denylist
|
||||||
|
|
|
@ -133,7 +133,7 @@ func TestHTTPAPI_OptionMethod_OSS(t *testing.T) {
|
||||||
uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path)
|
uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path)
|
||||||
req, _ := http.NewRequest("OPTIONS", uri, nil)
|
req, _ := http.NewRequest("OPTIONS", uri, nil)
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
a.srv.Handler.ServeHTTP(resp, req)
|
a.srv.handler(true).ServeHTTP(resp, req)
|
||||||
allMethods := append([]string{"OPTIONS"}, methods...)
|
allMethods := append([]string{"OPTIONS"}, methods...)
|
||||||
|
|
||||||
if resp.Code != http.StatusOK {
|
if resp.Code != http.StatusOK {
|
||||||
|
@ -175,7 +175,7 @@ func TestHTTPAPI_AllowedNets_OSS(t *testing.T) {
|
||||||
req, _ := http.NewRequest(method, uri, nil)
|
req, _ := http.NewRequest(method, uri, nil)
|
||||||
req.RemoteAddr = "192.168.1.2:5555"
|
req.RemoteAddr = "192.168.1.2:5555"
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
a.srv.Handler.ServeHTTP(resp, req)
|
a.srv.handler(true).ServeHTTP(resp, req)
|
||||||
|
|
||||||
require.Equal(t, http.StatusForbidden, resp.Code, "%s %s", method, path)
|
require.Equal(t, http.StatusForbidden, resp.Code, "%s %s", method, path)
|
||||||
})
|
})
|
||||||
|
|
|
@ -129,7 +129,7 @@ func TestHTTPServer_UnixSocket_FileExists(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPServer_H2(t *testing.T) {
|
func TestHTTPServer_HTTP2(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
// Fire up an agent with TLS enabled.
|
// Fire up an agent with TLS enabled.
|
||||||
|
@ -161,16 +161,15 @@ func TestHTTPServer_H2(t *testing.T) {
|
||||||
if err := http2.ConfigureTransport(transport); err != nil {
|
if err := http2.ConfigureTransport(transport); err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
hc := &http.Client{
|
httpClient := &http.Client{Transport: transport}
|
||||||
Transport: transport,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hook a handler that echoes back the protocol.
|
// Hook a handler that echoes back the protocol.
|
||||||
handler := func(resp http.ResponseWriter, req *http.Request) {
|
handler := func(resp http.ResponseWriter, req *http.Request) {
|
||||||
resp.WriteHeader(http.StatusOK)
|
resp.WriteHeader(http.StatusOK)
|
||||||
fmt.Fprint(resp, req.Proto)
|
fmt.Fprint(resp, req.Proto)
|
||||||
}
|
}
|
||||||
w, ok := a.srv.Handler.(*wrappedMux)
|
|
||||||
|
w, ok := a.srv.Server.Handler.(*wrappedMux)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("handler is not expected type")
|
t.Fatalf("handler is not expected type")
|
||||||
}
|
}
|
||||||
|
@ -178,7 +177,7 @@ func TestHTTPServer_H2(t *testing.T) {
|
||||||
|
|
||||||
// Call it and make sure we see HTTP/2.
|
// Call it and make sure we see HTTP/2.
|
||||||
url := fmt.Sprintf("https://%s/echo", a.srv.ln.Addr().String())
|
url := fmt.Sprintf("https://%s/echo", a.srv.ln.Addr().String())
|
||||||
resp, err := hc.Get(url)
|
resp, err := httpClient.Get(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -197,7 +196,7 @@ func TestHTTPServer_H2(t *testing.T) {
|
||||||
cfg := &api.Config{
|
cfg := &api.Config{
|
||||||
Address: a.srv.ln.Addr().String(),
|
Address: a.srv.ln.Addr().String(),
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
HttpClient: hc,
|
HttpClient: httpClient,
|
||||||
}
|
}
|
||||||
client, err := api.NewClient(cfg)
|
client, err := api.NewClient(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -333,7 +332,7 @@ func TestHTTPAPI_Ban_Nonprintable_Characters(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
a.srv.Handler.ServeHTTP(resp, req)
|
a.srv.handler(true).ServeHTTP(resp, req)
|
||||||
if got, want := resp.Code, http.StatusBadRequest; got != want {
|
if got, want := resp.Code, http.StatusBadRequest; got != want {
|
||||||
t.Fatalf("bad response code got %d want %d", got, want)
|
t.Fatalf("bad response code got %d want %d", got, want)
|
||||||
}
|
}
|
||||||
|
@ -352,7 +351,7 @@ func TestHTTPAPI_Allow_Nonprintable_Characters_With_Flag(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
a.srv.Handler.ServeHTTP(resp, req)
|
a.srv.handler(true).ServeHTTP(resp, req)
|
||||||
// Key doesn't actually exist so we should get 404
|
// Key doesn't actually exist so we should get 404
|
||||||
if got, want := resp.Code, http.StatusNotFound; got != want {
|
if got, want := resp.Code, http.StatusNotFound; got != want {
|
||||||
t.Fatalf("bad response code got %d want %d", got, want)
|
t.Fatalf("bad response code got %d want %d", got, want)
|
||||||
|
@ -490,14 +489,14 @@ func TestAcceptEncodingGzip(t *testing.T) {
|
||||||
// negotiation, but since this call doesn't go through a real
|
// negotiation, but since this call doesn't go through a real
|
||||||
// transport, the header has to be set manually
|
// transport, the header has to be set manually
|
||||||
req.Header["Accept-Encoding"] = []string{"gzip"}
|
req.Header["Accept-Encoding"] = []string{"gzip"}
|
||||||
a.srv.Handler.ServeHTTP(resp, req)
|
a.srv.handler(true).ServeHTTP(resp, req)
|
||||||
require.Equal(t, 200, resp.Code)
|
require.Equal(t, 200, resp.Code)
|
||||||
require.Equal(t, "", resp.Header().Get("Content-Encoding"))
|
require.Equal(t, "", resp.Header().Get("Content-Encoding"))
|
||||||
|
|
||||||
resp = httptest.NewRecorder()
|
resp = httptest.NewRecorder()
|
||||||
req, _ = http.NewRequest("GET", "/v1/kv/long", nil)
|
req, _ = http.NewRequest("GET", "/v1/kv/long", nil)
|
||||||
req.Header["Accept-Encoding"] = []string{"gzip"}
|
req.Header["Accept-Encoding"] = []string{"gzip"}
|
||||||
a.srv.Handler.ServeHTTP(resp, req)
|
a.srv.handler(true).ServeHTTP(resp, req)
|
||||||
require.Equal(t, 200, resp.Code)
|
require.Equal(t, 200, resp.Code)
|
||||||
require.Equal(t, "gzip", resp.Header().Get("Content-Encoding"))
|
require.Equal(t, "gzip", resp.Header().Get("Content-Encoding"))
|
||||||
}
|
}
|
||||||
|
@ -811,35 +810,35 @@ func TestParseWait(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPProfHandlers_EnableDebug(t *testing.T) {
|
func TestHTTPServer_PProfHandlers_EnableDebug(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
require := require.New(t)
|
a := NewTestAgent(t, ``)
|
||||||
a := NewTestAgent(t, "enable_debug = true")
|
|
||||||
defer a.Shutdown()
|
defer a.Shutdown()
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
req, _ := http.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil)
|
req, _ := http.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil)
|
||||||
|
|
||||||
a.srv.Handler.ServeHTTP(resp, req)
|
httpServer := &HTTPServer{agent: a.Agent}
|
||||||
|
httpServer.handler(true).ServeHTTP(resp, req)
|
||||||
|
|
||||||
require.Equal(http.StatusOK, resp.Code)
|
require.Equal(t, http.StatusOK, resp.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPProfHandlers_DisableDebugNoACLs(t *testing.T) {
|
func TestHTTPServer_PProfHandlers_DisableDebugNoACLs(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
require := require.New(t)
|
a := NewTestAgent(t, ``)
|
||||||
a := NewTestAgent(t, "enable_debug = false")
|
|
||||||
defer a.Shutdown()
|
defer a.Shutdown()
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
req, _ := http.NewRequest("GET", "/debug/pprof/profile", nil)
|
req, _ := http.NewRequest("GET", "/debug/pprof/profile", nil)
|
||||||
|
|
||||||
a.srv.Handler.ServeHTTP(resp, req)
|
httpServer := &HTTPServer{agent: a.Agent}
|
||||||
|
httpServer.handler(false).ServeHTTP(resp, req)
|
||||||
|
|
||||||
require.Equal(http.StatusUnauthorized, resp.Code)
|
require.Equal(t, http.StatusUnauthorized, resp.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPProfHandlers_ACLs(t *testing.T) {
|
func TestHTTPServer_PProfHandlers_ACLs(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
assert := assert.New(t)
|
assert := assert.New(t)
|
||||||
dc1 := "dc1"
|
dc1 := "dc1"
|
||||||
|
@ -904,7 +903,7 @@ func TestPProfHandlers_ACLs(t *testing.T) {
|
||||||
t.Run(fmt.Sprintf("case %d (%#v)", i, c), func(t *testing.T) {
|
t.Run(fmt.Sprintf("case %d (%#v)", i, c), func(t *testing.T) {
|
||||||
req, _ := http.NewRequest("GET", fmt.Sprintf("%s?token=%s", c.endpoint, c.token), nil)
|
req, _ := http.NewRequest("GET", fmt.Sprintf("%s?token=%s", c.endpoint, c.token), nil)
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
a.srv.Handler.ServeHTTP(resp, req)
|
a.srv.handler(true).ServeHTTP(resp, req)
|
||||||
assert.Equal(c.code, resp.Code)
|
assert.Equal(c.code, resp.Code)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1192,7 +1191,7 @@ func TestEnableWebUI(t *testing.T) {
|
||||||
|
|
||||||
req, _ := http.NewRequest("GET", "/ui/", nil)
|
req, _ := http.NewRequest("GET", "/ui/", nil)
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
a.srv.Handler.ServeHTTP(resp, req)
|
a.srv.handler(true).ServeHTTP(resp, req)
|
||||||
if resp.Code != 200 {
|
if resp.Code != 200 {
|
||||||
t.Fatalf("should handle ui")
|
t.Fatalf("should handle ui")
|
||||||
}
|
}
|
||||||
|
|
|
@ -384,7 +384,7 @@ func (a *TestAgent) HTTPAddr() string {
|
||||||
if a.srv == nil {
|
if a.srv == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return a.srv.Addr
|
return a.srv.Server.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *TestAgent) SegmentAddr(name string) string {
|
func (a *TestAgent) SegmentAddr(name string) string {
|
||||||
|
|
|
@ -43,7 +43,7 @@ func TestUiIndex(t *testing.T) {
|
||||||
// Register node
|
// Register node
|
||||||
req, _ := http.NewRequest("GET", "/ui/my-file", nil)
|
req, _ := http.NewRequest("GET", "/ui/my-file", nil)
|
||||||
req.URL.Scheme = "http"
|
req.URL.Scheme = "http"
|
||||||
req.URL.Host = a.srv.Addr
|
req.URL.Host = a.srv.Server.Addr
|
||||||
|
|
||||||
// Make the request
|
// Make the request
|
||||||
client := cleanhttp.DefaultClient()
|
client := cleanhttp.DefaultClient()
|
||||||
|
|
Loading…
Reference in New Issue