diff --git a/pkg/registry/generic/rest/proxy.go b/pkg/registry/generic/rest/proxy.go index 2d5e31338e..bd4f3946a2 100644 --- a/pkg/registry/generic/rest/proxy.go +++ b/pkg/registry/generic/rest/proxy.go @@ -214,9 +214,37 @@ func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL) http.Roun suffix += "/" } pathPrepend := strings.TrimSuffix(url.Path, suffix) - return &proxy.Transport{ + internalTransport := &proxy.Transport{ Scheme: scheme, Host: host, PathPrepend: pathPrepend, } + return &corsRemovingTransport{ + RoundTripper: internalTransport, + } +} + +// corsRemovingTransport is a wrapper for an internal transport. It removes CORS headers +// from the internal response. +type corsRemovingTransport struct { + http.RoundTripper +} + +func (p *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := p.RoundTripper.RoundTrip(req) + if err != nil { + return nil, err + } + removeCORSHeaders(resp) + return resp, nil + +} + +// removeCORSHeaders strip CORS headers sent from the backend +// This should be called on all responses before returning +func removeCORSHeaders(resp *http.Response) { + resp.Header.Del("Access-Control-Allow-Credentials") + resp.Header.Del("Access-Control-Allow-Headers") + resp.Header.Del("Access-Control-Allow-Methods") + resp.Header.Del("Access-Control-Allow-Origin") } diff --git a/pkg/registry/generic/rest/proxy_test.go b/pkg/registry/generic/rest/proxy_test.go index 7b4eca07cd..0aadeb3e79 100644 --- a/pkg/registry/generic/rest/proxy_test.go +++ b/pkg/registry/generic/rest/proxy_test.go @@ -51,8 +51,10 @@ func (s *SimpleBackendHandler) ServeHTTP(w http.ResponseWriter, req *http.Reques return } - for k, v := range s.responseHeader { - w.Header().Add(k, v) + if s.responseHeader != nil { + for k, v := range s.responseHeader { + w.Header().Add(k, v) + } } w.Write([]byte(s.responseBody)) } @@ -71,7 +73,7 @@ func validateParameters(t *testing.T, name string, actual url.Values, expected m } } -func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string) { +func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string, notExpected []string) { for k, v := range expected { actualValue, ok := actual[k] if !ok { @@ -83,17 +85,28 @@ func validateHeaders(t *testing.T, name string, actual http.Header, expected map name, k, actualValue, v) } } + if notExpected == nil { + return + } + for _, h := range notExpected { + if _, present := actual[h]; present { + t.Errorf("%s: unexpected header: %s", name, h) + } + } } func TestServeHTTP(t *testing.T) { tests := []struct { - name string - method string - requestPath string - expectedPath string - requestBody string - requestParams map[string]string - requestHeader map[string]string + name string + method string + requestPath string + expectedPath string + requestBody string + requestParams map[string]string + requestHeader map[string]string + responseHeader map[string]string + expectedRespHeader map[string]string + notExpectedRespHeader []string }{ { name: "root path, simple get", @@ -128,14 +141,37 @@ func TestServeHTTP(t *testing.T) { requestPath: "", expectedPath: "/", }, + { + name: "remove CORS headers", + method: "GET", + requestPath: "/some/path", + expectedPath: "/some/path", + responseHeader: map[string]string{ + "Header1": "value1", + "Access-Control-Allow-Origin": "some.server", + "Access-Control-Allow-Methods": "GET"}, + expectedRespHeader: map[string]string{ + "Header1": "value1", + }, + notExpectedRespHeader: []string{ + "Access-Control-Allow-Origin", + "Access-Control-Allow-Methods", + }, + }, } for _, test := range tests { func() { backendResponse := "Hello" + backendResponseHeader := test.responseHeader + // Test a simple header if not specified in the test + if backendResponseHeader == nil && test.expectedRespHeader == nil { + backendResponseHeader = map[string]string{"Content-Type": "text/html"} + test.expectedRespHeader = map[string]string{"Content-Type": "text/html"} + } backendHandler := &SimpleBackendHandler{ responseBody: backendResponse, - responseHeader: map[string]string{"Content-Type": "text/html"}, + responseHeader: backendResponseHeader, } backendServer := httptest.NewServer(backendHandler) defer backendServer.Close() @@ -197,9 +233,13 @@ func TestServeHTTP(t *testing.T) { // Headers validateHeaders(t, test.name+" backend request", backendHandler.requestHeader, - test.requestHeader) + test.requestHeader, nil) // Validate proxy response + + // Response Headers + validateHeaders(t, test.name+" backend headers", res.Header, test.expectedRespHeader, test.notExpectedRespHeader) + // Validate Body responseBody, err := ioutil.ReadAll(res.Body) if err != nil { @@ -297,7 +337,7 @@ func TestDefaultProxyTransport(t *testing.T) { Location: locURL, } result := h.defaultProxyTransport(URL) - transport := result.(*proxy.Transport) + transport := result.(*corsRemovingTransport).RoundTripper.(*proxy.Transport) if transport.Scheme != test.expectedScheme { t.Errorf("%s: unexpected scheme. Actual: %s, Expected: %s", test.name, transport.Scheme, test.expectedScheme) }