diff --git a/cmd/kube-apiserver/app/server.go b/cmd/kube-apiserver/app/server.go index 84ec8c2dab..59e7628e3e 100644 --- a/cmd/kube-apiserver/app/server.go +++ b/cmd/kube-apiserver/app/server.go @@ -415,13 +415,19 @@ func (s *APIServer) Run(_ []string) error { } longRunningRE := regexp.MustCompile(s.LongRunningRequestRE) + longRunningTimeout := func(req *http.Request) (<-chan time.Time, string) { + // TODO unify this with apiserver.MaxInFlightLimit + if longRunningRE.MatchString(req.URL.Path) || req.URL.Query().Get("watch") == "true" { + return nil, "" + } + return time.After(time.Minute), "" + } if secureLocation != "" { + handler := apiserver.TimeoutHandler(m.Handler, longRunningTimeout) secureServer := &http.Server{ Addr: secureLocation, - Handler: apiserver.MaxInFlightLimit(sem, longRunningRE, apiserver.RecoverPanics(m.Handler)), - ReadTimeout: ReadWriteTimeout, - WriteTimeout: ReadWriteTimeout, + Handler: apiserver.MaxInFlightLimit(sem, longRunningRE, apiserver.RecoverPanics(handler)), MaxHeaderBytes: 1 << 20, TLSConfig: &tls.Config{ // Change default from SSLv3 to TLSv1.0 (because of POODLE vulnerability) @@ -466,11 +472,10 @@ func (s *APIServer) Run(_ []string) error { } }() } + handler := apiserver.TimeoutHandler(m.InsecureHandler, longRunningTimeout) http := &http.Server{ Addr: insecureLocation, - Handler: apiserver.RecoverPanics(m.InsecureHandler), - ReadTimeout: ReadWriteTimeout, - WriteTimeout: ReadWriteTimeout, + Handler: apiserver.RecoverPanics(handler), MaxHeaderBytes: 1 << 20, } if secureLocation == "" { diff --git a/pkg/apiserver/handlers.go b/pkg/apiserver/handlers.go index 6670f155aa..0252117361 100644 --- a/pkg/apiserver/handlers.go +++ b/pkg/apiserver/handlers.go @@ -17,11 +17,16 @@ limitations under the License. package apiserver import ( + "bufio" + "encoding/json" "fmt" + "net" "net/http" "regexp" "runtime/debug" "strings" + "sync" + "time" "github.com/GoogleCloudPlatform/kubernetes/pkg/api" "github.com/GoogleCloudPlatform/kubernetes/pkg/api/errors" @@ -135,6 +140,163 @@ func RecoverPanics(handler http.Handler) http.Handler { }) } +// TimeoutHandler returns an http.Handler that runs h with a timeout +// determined by timeoutFunc. The new http.Handler calls h.ServeHTTP to handle +// each request, but if a call runs for longer than its time limit, the +// handler responds with a 503 Service Unavailable error and the message +// provided. (If msg is empty, a suitable default message with be sent.) After +// the handler times out, writes by h to its http.ResponseWriter will return +// http.ErrHandlerTimeout. If timeoutFunc returns a nil timeout channel, no +// timeout will be enforced. +func TimeoutHandler(h http.Handler, timeoutFunc func(*http.Request) (timeout <-chan time.Time, msg string)) http.Handler { + return &timeoutHandler{h, timeoutFunc} +} + +type timeoutHandler struct { + handler http.Handler + timeout func(*http.Request) (<-chan time.Time, string) +} + +func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + after, msg := t.timeout(r) + if after == nil { + t.handler.ServeHTTP(w, r) + return + } + + done := make(chan struct{}, 1) + tw := newTimeoutWriter(w) + go func() { + t.handler.ServeHTTP(tw, r) + done <- struct{}{} + }() + select { + case <-done: + return + case <-after: + tw.timeout(msg) + } +} + +type timeoutWriter interface { + http.ResponseWriter + timeout(string) +} + +func newTimeoutWriter(w http.ResponseWriter) timeoutWriter { + base := &baseTimeoutWriter{w: w} + + _, notifiable := w.(http.CloseNotifier) + _, hijackable := w.(http.Hijacker) + + switch { + case notifiable && hijackable: + return &closeHijackTimeoutWriter{base} + case notifiable: + return &closeTimeoutWriter{base} + case hijackable: + return &hijackTimeoutWriter{base} + default: + return base + } +} + +type baseTimeoutWriter struct { + w http.ResponseWriter + + mu sync.Mutex + timedOut bool + wroteHeader bool + hijacked bool +} + +func (tw *baseTimeoutWriter) Header() http.Header { + return tw.w.Header() +} + +func (tw *baseTimeoutWriter) Write(p []byte) (int, error) { + tw.mu.Lock() + defer tw.mu.Unlock() + tw.wroteHeader = true + if tw.hijacked { + return 0, http.ErrHijacked + } + if tw.timedOut { + return 0, http.ErrHandlerTimeout + } + return tw.w.Write(p) +} + +func (tw *baseTimeoutWriter) WriteHeader(code int) { + tw.mu.Lock() + defer tw.mu.Unlock() + if tw.timedOut || tw.wroteHeader || tw.hijacked { + return + } + tw.wroteHeader = true + tw.w.WriteHeader(code) +} + +func (tw *baseTimeoutWriter) timeout(msg string) { + tw.mu.Lock() + defer tw.mu.Unlock() + if !tw.wroteHeader && !tw.hijacked { + tw.w.WriteHeader(http.StatusGatewayTimeout) + if msg != "" { + tw.w.Write([]byte(msg)) + } else { + enc := json.NewEncoder(tw.w) + enc.Encode(errors.NewServerTimeout("", "", 0)) + } + } + tw.timedOut = true +} + +func (tw *baseTimeoutWriter) closeNotify() <-chan bool { + return tw.w.(http.CloseNotifier).CloseNotify() +} + +func (tw *baseTimeoutWriter) hijack() (net.Conn, *bufio.ReadWriter, error) { + tw.mu.Lock() + defer tw.mu.Unlock() + if tw.timedOut { + return nil, nil, http.ErrHandlerTimeout + } + conn, rw, err := tw.w.(http.Hijacker).Hijack() + if err == nil { + tw.hijacked = true + } + return conn, rw, err +} + +type closeTimeoutWriter struct { + *baseTimeoutWriter +} + +func (tw *closeTimeoutWriter) CloseNotify() <-chan bool { + return tw.closeNotify() +} + +type hijackTimeoutWriter struct { + *baseTimeoutWriter +} + +func (tw *hijackTimeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return tw.hijack() +} + +type closeHijackTimeoutWriter struct { + *baseTimeoutWriter +} + +func (tw *closeHijackTimeoutWriter) CloseNotify() <-chan bool { + return tw.closeNotify() +} + +func (tw *closeHijackTimeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return tw.hijack() +} + // TODO: use restful.CrossOriginResourceSharing // Simple CORS implementation that wraps an http Handler // For a more detailed implementation use https://github.com/martini-contrib/cors diff --git a/pkg/apiserver/handlers_test.go b/pkg/apiserver/handlers_test.go index 4151689827..5ebc2a4bbb 100644 --- a/pkg/apiserver/handlers_test.go +++ b/pkg/apiserver/handlers_test.go @@ -17,6 +17,7 @@ limitations under the License. package apiserver import ( + "io/ioutil" "net/http" "net/http/httptest" "reflect" @@ -143,6 +144,62 @@ func TestReadOnly(t *testing.T) { } } +func TestTimeout(t *testing.T) { + sendResponse := make(chan struct{}, 1) + writeErrors := make(chan error, 1) + timeout := make(chan time.Time, 1) + resp := "test response" + timeoutResp := "test timeout" + + ts := httptest.NewServer(TimeoutHandler(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + <-sendResponse + _, err := w.Write([]byte(resp)) + writeErrors <- err + }), + func(*http.Request) (<-chan time.Time, string) { + return timeout, timeoutResp + })) + defer ts.Close() + + // No timeouts + sendResponse <- struct{}{} + res, err := http.Get(ts.URL) + if err != nil { + t.Error(err) + } + if res.StatusCode != http.StatusOK { + t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusOK) + } + body, _ := ioutil.ReadAll(res.Body) + if string(body) != resp { + t.Errorf("got body %q; expected %q", string(body), resp) + } + if err := <-writeErrors; err != nil { + t.Errorf("got unexpected Write error on first request: %v", err) + } + + // Times out + timeout <- time.Time{} + res, err = http.Get(ts.URL) + if err != nil { + t.Error(err) + } + if res.StatusCode != http.StatusGatewayTimeout { + t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusServiceUnavailable) + } + body, _ = ioutil.ReadAll(res.Body) + if string(body) != timeoutResp { + t.Errorf("got body %q; expected %q", string(body), timeoutResp) + } + + // Now try to send a response + sendResponse <- struct{}{} + if err := <-writeErrors; err != http.ErrHandlerTimeout { + t.Errorf("got Write error of %v; expected %v", err, http.ErrHandlerTimeout) + } +} + func TestGetAPIRequestInfo(t *testing.T) { successCases := []struct { method string diff --git a/pkg/kubelet/server.go b/pkg/kubelet/server.go index 32d0c3a7dd..4199fd3799 100644 --- a/pkg/kubelet/server.go +++ b/pkg/kubelet/server.go @@ -65,8 +65,6 @@ func ListenAndServeKubeletServer(host HostInterface, address net.IP, port uint, s := &http.Server{ Addr: net.JoinHostPort(address.String(), strconv.FormatUint(uint64(port), 10)), Handler: &handler, - ReadTimeout: 5 * time.Minute, - WriteTimeout: 5 * time.Minute, MaxHeaderBytes: 1 << 20, } if tlsOptions != nil { @@ -86,8 +84,6 @@ func ListenAndServeKubeletReadOnlyServer(host HostInterface, address net.IP, por server := &http.Server{ Addr: net.JoinHostPort(address.String(), strconv.FormatUint(uint64(port), 10)), Handler: &s, - ReadTimeout: 5 * time.Minute, - WriteTimeout: 5 * time.Minute, MaxHeaderBytes: 1 << 20, } glog.Fatal(server.ListenAndServe())