From c995050ee34cd0c11652007b5411557688667413 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Mon, 25 Jul 2016 20:43:47 -0700 Subject: [PATCH] apiserver: fix timeout handler Protect access of the original writer. Panics if anything has wrote into the original writer or the writer is hijacked when times out. --- pkg/apiserver/handlers.go | 69 ++++++++++++++++++++---- pkg/genericapiserver/genericapiserver.go | 8 +-- 2 files changed, 63 insertions(+), 14 deletions(-) diff --git a/pkg/apiserver/handlers.go b/pkg/apiserver/handlers.go index 8d465fcf77..696a3522d7 100644 --- a/pkg/apiserver/handlers.go +++ b/pkg/apiserver/handlers.go @@ -164,6 +164,8 @@ func RecoverPanics(handler http.Handler) http.Handler { }) } +var errConnKilled = fmt.Errorf("kill connection/stream") + // TimeoutHandler returns an http.Handler that runs h with a timeout // determined by timeoutFunc. The new http.Handler calls h.ServeHTTP to handle // each request, but if a call runs for longer than its time limit, the @@ -188,11 +190,11 @@ func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - done := make(chan struct{}, 1) + done := make(chan struct{}) tw := newTimeoutWriter(w) go func() { t.handler.ServeHTTP(tw, r) - done <- struct{}{} + close(done) }() select { case <-done: @@ -228,26 +230,38 @@ func newTimeoutWriter(w http.ResponseWriter) timeoutWriter { type baseTimeoutWriter struct { w http.ResponseWriter - mu sync.Mutex - timedOut bool + mu sync.Mutex + // if the timeout handler has timedout + timedOut bool + // if this timeout writer has wrote header wroteHeader bool - hijacked bool + // if this timeout writer has been hijacked + hijacked bool } func (tw *baseTimeoutWriter) Header() http.Header { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut { + return http.Header{} + } + return tw.w.Header() } func (tw *baseTimeoutWriter) Write(p []byte) (int, error) { tw.mu.Lock() defer tw.mu.Unlock() - tw.wroteHeader = true - if tw.hijacked { - return 0, http.ErrHijacked - } + if tw.timedOut { return 0, http.ErrHandlerTimeout } + if tw.hijacked { + return 0, http.ErrHijacked + } + + tw.wroteHeader = true return tw.w.Write(p) } @@ -255,6 +269,10 @@ func (tw *baseTimeoutWriter) Flush() { tw.mu.Lock() defer tw.mu.Unlock() + if tw.timedOut { + return + } + if flusher, ok := tw.w.(http.Flusher); ok { flusher.Flush() } @@ -263,9 +281,11 @@ func (tw *baseTimeoutWriter) Flush() { func (tw *baseTimeoutWriter) WriteHeader(code int) { tw.mu.Lock() defer tw.mu.Unlock() + if tw.timedOut || tw.wroteHeader || tw.hijacked { return } + tw.wroteHeader = true tw.w.WriteHeader(code) } @@ -273,6 +293,12 @@ func (tw *baseTimeoutWriter) WriteHeader(code int) { func (tw *baseTimeoutWriter) timeout(msg string) { tw.mu.Lock() defer tw.mu.Unlock() + + tw.timedOut = true + + // The timeout writer has not been used by the inner handler. + // We can safely timeout the HTTP request by sending by a timeout + // handler if !tw.wroteHeader && !tw.hijacked { tw.w.WriteHeader(http.StatusGatewayTimeout) if msg != "" { @@ -281,17 +307,40 @@ func (tw *baseTimeoutWriter) timeout(msg string) { enc := json.NewEncoder(tw.w) enc.Encode(errors.NewServerTimeout(api.Resource(""), "", 0)) } + } else { + // The timeout writer has been used by the inner handler. There is + // no way to timeout the HTTP request at the point. We have to shutdown + // the connection for HTTP1 or reset stream for HTTP2. + // + // Note from: Brad Fitzpatrick + // if the ServeHTTP goroutine panics, that will do the best possible thing for both + // HTTP/1 and HTTP/2. In HTTP/1, assuming you're replying with at least HTTP/1.1 and + // you've already flushed the headers so it's using HTTP chunking, it'll kill the TCP + // connection immediately without a proper 0-byte EOF chunk, so the peer will recognize + // the response as bogus. In HTTP/2 the server will just RST_STREAM the stream, leaving + // the TCP connection open, but resetting the stream to the peer so it'll have an error, + // like the HTTP/1 case. + panic(errConnKilled) } - tw.timedOut = true } func (tw *baseTimeoutWriter) closeNotify() <-chan bool { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut { + done := make(chan bool) + close(done) + return done + } + return tw.w.(http.CloseNotifier).CloseNotify() } func (tw *baseTimeoutWriter) hijack() (net.Conn, *bufio.ReadWriter, error) { tw.mu.Lock() defer tw.mu.Unlock() + if tw.timedOut { return nil, nil, http.ErrHandlerTimeout } diff --git a/pkg/genericapiserver/genericapiserver.go b/pkg/genericapiserver/genericapiserver.go index 3693bbded9..a03b91e1f2 100644 --- a/pkg/genericapiserver/genericapiserver.go +++ b/pkg/genericapiserver/genericapiserver.go @@ -633,10 +633,10 @@ func (s *GenericAPIServer) Run(options *options.ServerRunOptions) { } if secureLocation != "" { - handler := apiserver.TimeoutHandler(s.Handler, longRunningTimeout) + handler := apiserver.TimeoutHandler(apiserver.RecoverPanics(s.Handler), longRunningTimeout) secureServer := &http.Server{ Addr: secureLocation, - Handler: apiserver.MaxInFlightLimit(sem, longRunningRequestCheck, apiserver.RecoverPanics(handler)), + Handler: apiserver.MaxInFlightLimit(sem, longRunningRequestCheck, handler), MaxHeaderBytes: 1 << 20, TLSConfig: &tls.Config{ // Can't use SSLv3 because of POODLE and BEAST @@ -696,10 +696,10 @@ func (s *GenericAPIServer) Run(options *options.ServerRunOptions) { } } - handler := apiserver.TimeoutHandler(s.InsecureHandler, longRunningTimeout) + handler := apiserver.TimeoutHandler(apiserver.RecoverPanics(s.InsecureHandler), longRunningTimeout) http := &http.Server{ Addr: insecureLocation, - Handler: apiserver.RecoverPanics(handler), + Handler: handler, MaxHeaderBytes: 1 << 20, }