From f2732f18a775aec8e0d66816e9f72ee46324bbf0 Mon Sep 17 00:00:00 2001 From: Jessica Forrester Date: Wed, 17 Sep 2014 11:13:17 -0400 Subject: [PATCH] Match any Connection header that contains the Upgrade token for websockets --- pkg/apiserver/watch.go | 10 +++++++- pkg/apiserver/watch_test.go | 48 +++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/pkg/apiserver/watch.go b/pkg/apiserver/watch.go index 29c2941365..8421df3c69 100644 --- a/pkg/apiserver/watch.go +++ b/pkg/apiserver/watch.go @@ -20,7 +20,9 @@ import ( "encoding/json" "net/http" "net/url" + "regexp" "strconv" + "strings" "code.google.com/p/go.net/websocket" "github.com/GoogleCloudPlatform/kubernetes/pkg/api" @@ -52,6 +54,12 @@ func getWatchParams(query url.Values) (label, field labels.Selector, resourceVer return label, field, resourceVersion } +var connectionUpgradeRegex = regexp.MustCompile("(^|.*,\\s*)upgrade($|\\s*,)") + +func isWebsocketRequest(req *http.Request) bool { + return connectionUpgradeRegex.MatchString(strings.ToLower(req.Header.Get("Connection"))) && strings.ToLower(req.Header.Get("Upgrade")) == "websocket" +} + // ServeHTTP processes watch requests. func (h *WatchHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { parts := splitPath(req.URL.Path) @@ -75,7 +83,7 @@ func (h *WatchHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { // TODO: This is one watch per connection. We want to multiplex, so that // multiple watches of the same thing don't create two watches downstream. watchServer := &WatchServer{watching, h.codec} - if req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket" { + if isWebsocketRequest(req) { websocket.Handler(watchServer.HandleWS).ServeHTTP(httplog.Unlogged(w), req) } else { watchServer.ServeHTTP(w, req) diff --git a/pkg/apiserver/watch_test.go b/pkg/apiserver/watch_test.go index 51d5228554..40039418c4 100644 --- a/pkg/apiserver/watch_test.go +++ b/pkg/apiserver/watch_test.go @@ -200,3 +200,51 @@ func TestWatchParamParsing(t *testing.T) { } } } + +func TestWatchProtocolSelection(t *testing.T) { + simpleStorage := &SimpleRESTStorage{} + handler := Handle(map[string]RESTStorage{ + "foo": simpleStorage, + }, codec, "/prefix/version") + server := httptest.NewServer(handler) + client := http.Client{} + + dest, _ := url.Parse(server.URL) + dest.Path = "/prefix/version/watch/foo" + dest.RawQuery = "" + + table := []struct { + isWebsocket bool + connHeader string + }{ + {true, "Upgrade"}, + {true, "keep-alive, Upgrade"}, + {true, "upgrade"}, + {false, "keep-alive"}, + } + + for _, item := range table { + request, err := http.NewRequest("GET", dest.String(), nil) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + request.Header.Set("Connection", item.connHeader) + request.Header.Set("Upgrade", "websocket") + + response, err := client.Do(request) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + // The requests recognized as websocket requests based on connection + // and upgrade headers will not also have the necessary Sec-Websocket-* + // headers so it is expected to throw a 400 + if item.isWebsocket && response.StatusCode != http.StatusBadRequest { + t.Errorf("Unexpected response %#v", response) + } + + if !item.isWebsocket && response.StatusCode != http.StatusOK { + t.Errorf("Unexpected response %#v", response) + } + } +}