From e8af67c180c7eeaa199d8f1c61bd436d07083c3f Mon Sep 17 00:00:00 2001 From: Cesar Wong Date: Tue, 2 Jun 2015 20:58:06 -0400 Subject: [PATCH] Remove CORS headers from pod proxy responses The API server sends its own CORS headers in its response, and if the proxied pod response also includes its own headers, it confuses clients. --- pkg/registry/generic/rest/proxy.go | 30 ++++++++++- pkg/registry/generic/rest/proxy_test.go | 66 ++++++++++++++++++++----- 2 files changed, 82 insertions(+), 14 deletions(-) diff --git a/pkg/registry/generic/rest/proxy.go b/pkg/registry/generic/rest/proxy.go index 2d5e31338e..bd4f3946a2 100644 --- a/pkg/registry/generic/rest/proxy.go +++ b/pkg/registry/generic/rest/proxy.go @@ -214,9 +214,37 @@ func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL) http.Roun suffix += "/" } pathPrepend := strings.TrimSuffix(url.Path, suffix) - return &proxy.Transport{ + internalTransport := &proxy.Transport{ Scheme: scheme, Host: host, PathPrepend: pathPrepend, } + return &corsRemovingTransport{ + RoundTripper: internalTransport, + } +} + +// corsRemovingTransport is a wrapper for an internal transport. It removes CORS headers +// from the internal response. +type corsRemovingTransport struct { + http.RoundTripper +} + +func (p *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := p.RoundTripper.RoundTrip(req) + if err != nil { + return nil, err + } + removeCORSHeaders(resp) + return resp, nil + +} + +// removeCORSHeaders strip CORS headers sent from the backend +// This should be called on all responses before returning +func removeCORSHeaders(resp *http.Response) { + resp.Header.Del("Access-Control-Allow-Credentials") + resp.Header.Del("Access-Control-Allow-Headers") + resp.Header.Del("Access-Control-Allow-Methods") + resp.Header.Del("Access-Control-Allow-Origin") } diff --git a/pkg/registry/generic/rest/proxy_test.go b/pkg/registry/generic/rest/proxy_test.go index 7b4eca07cd..0aadeb3e79 100644 --- a/pkg/registry/generic/rest/proxy_test.go +++ b/pkg/registry/generic/rest/proxy_test.go @@ -51,8 +51,10 @@ func (s *SimpleBackendHandler) ServeHTTP(w http.ResponseWriter, req *http.Reques return } - for k, v := range s.responseHeader { - w.Header().Add(k, v) + if s.responseHeader != nil { + for k, v := range s.responseHeader { + w.Header().Add(k, v) + } } w.Write([]byte(s.responseBody)) } @@ -71,7 +73,7 @@ func validateParameters(t *testing.T, name string, actual url.Values, expected m } } -func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string) { +func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string, notExpected []string) { for k, v := range expected { actualValue, ok := actual[k] if !ok { @@ -83,17 +85,28 @@ func validateHeaders(t *testing.T, name string, actual http.Header, expected map name, k, actualValue, v) } } + if notExpected == nil { + return + } + for _, h := range notExpected { + if _, present := actual[h]; present { + t.Errorf("%s: unexpected header: %s", name, h) + } + } } func TestServeHTTP(t *testing.T) { tests := []struct { - name string - method string - requestPath string - expectedPath string - requestBody string - requestParams map[string]string - requestHeader map[string]string + name string + method string + requestPath string + expectedPath string + requestBody string + requestParams map[string]string + requestHeader map[string]string + responseHeader map[string]string + expectedRespHeader map[string]string + notExpectedRespHeader []string }{ { name: "root path, simple get", @@ -128,14 +141,37 @@ func TestServeHTTP(t *testing.T) { requestPath: "", expectedPath: "/", }, + { + name: "remove CORS headers", + method: "GET", + requestPath: "/some/path", + expectedPath: "/some/path", + responseHeader: map[string]string{ + "Header1": "value1", + "Access-Control-Allow-Origin": "some.server", + "Access-Control-Allow-Methods": "GET"}, + expectedRespHeader: map[string]string{ + "Header1": "value1", + }, + notExpectedRespHeader: []string{ + "Access-Control-Allow-Origin", + "Access-Control-Allow-Methods", + }, + }, } for _, test := range tests { func() { backendResponse := "Hello" + backendResponseHeader := test.responseHeader + // Test a simple header if not specified in the test + if backendResponseHeader == nil && test.expectedRespHeader == nil { + backendResponseHeader = map[string]string{"Content-Type": "text/html"} + test.expectedRespHeader = map[string]string{"Content-Type": "text/html"} + } backendHandler := &SimpleBackendHandler{ responseBody: backendResponse, - responseHeader: map[string]string{"Content-Type": "text/html"}, + responseHeader: backendResponseHeader, } backendServer := httptest.NewServer(backendHandler) defer backendServer.Close() @@ -197,9 +233,13 @@ func TestServeHTTP(t *testing.T) { // Headers validateHeaders(t, test.name+" backend request", backendHandler.requestHeader, - test.requestHeader) + test.requestHeader, nil) // Validate proxy response + + // Response Headers + validateHeaders(t, test.name+" backend headers", res.Header, test.expectedRespHeader, test.notExpectedRespHeader) + // Validate Body responseBody, err := ioutil.ReadAll(res.Body) if err != nil { @@ -297,7 +337,7 @@ func TestDefaultProxyTransport(t *testing.T) { Location: locURL, } result := h.defaultProxyTransport(URL) - transport := result.(*proxy.Transport) + transport := result.(*corsRemovingTransport).RoundTripper.(*proxy.Transport) if transport.Scheme != test.expectedScheme { t.Errorf("%s: unexpected scheme. Actual: %s, Expected: %s", test.name, transport.Scheme, test.expectedScheme) }