Remove CORS headers from pod proxy responses

The API server sends its own CORS headers in its response, and if the
proxied pod response also includes its own headers, it confuses clients.
pull/6/head
Cesar Wong 2015-06-02 20:58:06 -04:00
parent 521446503a
commit e8af67c180
2 changed files with 82 additions and 14 deletions

View File

@ -214,9 +214,37 @@ func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL) http.Roun
suffix += "/"
}
pathPrepend := strings.TrimSuffix(url.Path, suffix)
return &proxy.Transport{
internalTransport := &proxy.Transport{
Scheme: scheme,
Host: host,
PathPrepend: pathPrepend,
}
return &corsRemovingTransport{
RoundTripper: internalTransport,
}
}
// corsRemovingTransport is a wrapper for an internal transport. It removes CORS headers
// from the internal response.
type corsRemovingTransport struct {
http.RoundTripper
}
func (p *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := p.RoundTripper.RoundTrip(req)
if err != nil {
return nil, err
}
removeCORSHeaders(resp)
return resp, nil
}
// removeCORSHeaders strip CORS headers sent from the backend
// This should be called on all responses before returning
func removeCORSHeaders(resp *http.Response) {
resp.Header.Del("Access-Control-Allow-Credentials")
resp.Header.Del("Access-Control-Allow-Headers")
resp.Header.Del("Access-Control-Allow-Methods")
resp.Header.Del("Access-Control-Allow-Origin")
}

View File

@ -51,9 +51,11 @@ func (s *SimpleBackendHandler) ServeHTTP(w http.ResponseWriter, req *http.Reques
return
}
if s.responseHeader != nil {
for k, v := range s.responseHeader {
w.Header().Add(k, v)
}
}
w.Write([]byte(s.responseBody))
}
@ -71,7 +73,7 @@ func validateParameters(t *testing.T, name string, actual url.Values, expected m
}
}
func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string) {
func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string, notExpected []string) {
for k, v := range expected {
actualValue, ok := actual[k]
if !ok {
@ -83,6 +85,14 @@ func validateHeaders(t *testing.T, name string, actual http.Header, expected map
name, k, actualValue, v)
}
}
if notExpected == nil {
return
}
for _, h := range notExpected {
if _, present := actual[h]; present {
t.Errorf("%s: unexpected header: %s", name, h)
}
}
}
func TestServeHTTP(t *testing.T) {
@ -94,6 +104,9 @@ func TestServeHTTP(t *testing.T) {
requestBody string
requestParams map[string]string
requestHeader map[string]string
responseHeader map[string]string
expectedRespHeader map[string]string
notExpectedRespHeader []string
}{
{
name: "root path, simple get",
@ -128,14 +141,37 @@ func TestServeHTTP(t *testing.T) {
requestPath: "",
expectedPath: "/",
},
{
name: "remove CORS headers",
method: "GET",
requestPath: "/some/path",
expectedPath: "/some/path",
responseHeader: map[string]string{
"Header1": "value1",
"Access-Control-Allow-Origin": "some.server",
"Access-Control-Allow-Methods": "GET"},
expectedRespHeader: map[string]string{
"Header1": "value1",
},
notExpectedRespHeader: []string{
"Access-Control-Allow-Origin",
"Access-Control-Allow-Methods",
},
},
}
for _, test := range tests {
func() {
backendResponse := "<html><head></head><body><a href=\"/test/path\">Hello</a></body></html>"
backendResponseHeader := test.responseHeader
// Test a simple header if not specified in the test
if backendResponseHeader == nil && test.expectedRespHeader == nil {
backendResponseHeader = map[string]string{"Content-Type": "text/html"}
test.expectedRespHeader = map[string]string{"Content-Type": "text/html"}
}
backendHandler := &SimpleBackendHandler{
responseBody: backendResponse,
responseHeader: map[string]string{"Content-Type": "text/html"},
responseHeader: backendResponseHeader,
}
backendServer := httptest.NewServer(backendHandler)
defer backendServer.Close()
@ -197,9 +233,13 @@ func TestServeHTTP(t *testing.T) {
// Headers
validateHeaders(t, test.name+" backend request", backendHandler.requestHeader,
test.requestHeader)
test.requestHeader, nil)
// Validate proxy response
// Response Headers
validateHeaders(t, test.name+" backend headers", res.Header, test.expectedRespHeader, test.notExpectedRespHeader)
// Validate Body
responseBody, err := ioutil.ReadAll(res.Body)
if err != nil {
@ -297,7 +337,7 @@ func TestDefaultProxyTransport(t *testing.T) {
Location: locURL,
}
result := h.defaultProxyTransport(URL)
transport := result.(*proxy.Transport)
transport := result.(*corsRemovingTransport).RoundTripper.(*proxy.Transport)
if transport.Scheme != test.expectedScheme {
t.Errorf("%s: unexpected scheme. Actual: %s, Expected: %s", test.name, transport.Scheme, test.expectedScheme)
}