Merge pull request #9159 from csrwng/remove_cors_headers

Remove CORS headers from pod proxy responses
pull/6/head
Quinton Hoole 2015-06-05 11:40:09 -07:00
commit bc59e69ff0
2 changed files with 82 additions and 14 deletions

View File

@ -214,9 +214,37 @@ func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL) http.Roun
suffix += "/"
}
pathPrepend := strings.TrimSuffix(url.Path, suffix)
return &proxy.Transport{
internalTransport := &proxy.Transport{
Scheme: scheme,
Host: host,
PathPrepend: pathPrepend,
}
return &corsRemovingTransport{
RoundTripper: internalTransport,
}
}
// corsRemovingTransport is a wrapper for an internal transport. It removes CORS headers
// from the internal response.
type corsRemovingTransport struct {
http.RoundTripper
}
func (p *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := p.RoundTripper.RoundTrip(req)
if err != nil {
return nil, err
}
removeCORSHeaders(resp)
return resp, nil
}
// removeCORSHeaders strip CORS headers sent from the backend
// This should be called on all responses before returning
func removeCORSHeaders(resp *http.Response) {
resp.Header.Del("Access-Control-Allow-Credentials")
resp.Header.Del("Access-Control-Allow-Headers")
resp.Header.Del("Access-Control-Allow-Methods")
resp.Header.Del("Access-Control-Allow-Origin")
}

View File

@ -51,8 +51,10 @@ func (s *SimpleBackendHandler) ServeHTTP(w http.ResponseWriter, req *http.Reques
return
}
for k, v := range s.responseHeader {
w.Header().Add(k, v)
if s.responseHeader != nil {
for k, v := range s.responseHeader {
w.Header().Add(k, v)
}
}
w.Write([]byte(s.responseBody))
}
@ -71,7 +73,7 @@ func validateParameters(t *testing.T, name string, actual url.Values, expected m
}
}
func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string) {
func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string, notExpected []string) {
for k, v := range expected {
actualValue, ok := actual[k]
if !ok {
@ -83,17 +85,28 @@ func validateHeaders(t *testing.T, name string, actual http.Header, expected map
name, k, actualValue, v)
}
}
if notExpected == nil {
return
}
for _, h := range notExpected {
if _, present := actual[h]; present {
t.Errorf("%s: unexpected header: %s", name, h)
}
}
}
func TestServeHTTP(t *testing.T) {
tests := []struct {
name string
method string
requestPath string
expectedPath string
requestBody string
requestParams map[string]string
requestHeader map[string]string
name string
method string
requestPath string
expectedPath string
requestBody string
requestParams map[string]string
requestHeader map[string]string
responseHeader map[string]string
expectedRespHeader map[string]string
notExpectedRespHeader []string
}{
{
name: "root path, simple get",
@ -128,14 +141,37 @@ func TestServeHTTP(t *testing.T) {
requestPath: "",
expectedPath: "/",
},
{
name: "remove CORS headers",
method: "GET",
requestPath: "/some/path",
expectedPath: "/some/path",
responseHeader: map[string]string{
"Header1": "value1",
"Access-Control-Allow-Origin": "some.server",
"Access-Control-Allow-Methods": "GET"},
expectedRespHeader: map[string]string{
"Header1": "value1",
},
notExpectedRespHeader: []string{
"Access-Control-Allow-Origin",
"Access-Control-Allow-Methods",
},
},
}
for _, test := range tests {
func() {
backendResponse := "<html><head></head><body><a href=\"/test/path\">Hello</a></body></html>"
backendResponseHeader := test.responseHeader
// Test a simple header if not specified in the test
if backendResponseHeader == nil && test.expectedRespHeader == nil {
backendResponseHeader = map[string]string{"Content-Type": "text/html"}
test.expectedRespHeader = map[string]string{"Content-Type": "text/html"}
}
backendHandler := &SimpleBackendHandler{
responseBody: backendResponse,
responseHeader: map[string]string{"Content-Type": "text/html"},
responseHeader: backendResponseHeader,
}
backendServer := httptest.NewServer(backendHandler)
defer backendServer.Close()
@ -197,9 +233,13 @@ func TestServeHTTP(t *testing.T) {
// Headers
validateHeaders(t, test.name+" backend request", backendHandler.requestHeader,
test.requestHeader)
test.requestHeader, nil)
// Validate proxy response
// Response Headers
validateHeaders(t, test.name+" backend headers", res.Header, test.expectedRespHeader, test.notExpectedRespHeader)
// Validate Body
responseBody, err := ioutil.ReadAll(res.Body)
if err != nil {
@ -297,7 +337,7 @@ func TestDefaultProxyTransport(t *testing.T) {
Location: locURL,
}
result := h.defaultProxyTransport(URL)
transport := result.(*proxy.Transport)
transport := result.(*corsRemovingTransport).RoundTripper.(*proxy.Transport)
if transport.Scheme != test.expectedScheme {
t.Errorf("%s: unexpected scheme. Actual: %s, Expected: %s", test.name, transport.Scheme, test.expectedScheme)
}