mirror of https://github.com/k3s-io/k3s
Merge pull request #9159 from csrwng/remove_cors_headers
Remove CORS headers from pod proxy responsespull/6/head
commit
bc59e69ff0
|
@ -214,9 +214,37 @@ func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL) http.Roun
|
|||
suffix += "/"
|
||||
}
|
||||
pathPrepend := strings.TrimSuffix(url.Path, suffix)
|
||||
return &proxy.Transport{
|
||||
internalTransport := &proxy.Transport{
|
||||
Scheme: scheme,
|
||||
Host: host,
|
||||
PathPrepend: pathPrepend,
|
||||
}
|
||||
return &corsRemovingTransport{
|
||||
RoundTripper: internalTransport,
|
||||
}
|
||||
}
|
||||
|
||||
// corsRemovingTransport is a wrapper for an internal transport. It removes CORS headers
|
||||
// from the internal response.
|
||||
type corsRemovingTransport struct {
|
||||
http.RoundTripper
|
||||
}
|
||||
|
||||
func (p *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
resp, err := p.RoundTripper.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
removeCORSHeaders(resp)
|
||||
return resp, nil
|
||||
|
||||
}
|
||||
|
||||
// removeCORSHeaders strip CORS headers sent from the backend
|
||||
// This should be called on all responses before returning
|
||||
func removeCORSHeaders(resp *http.Response) {
|
||||
resp.Header.Del("Access-Control-Allow-Credentials")
|
||||
resp.Header.Del("Access-Control-Allow-Headers")
|
||||
resp.Header.Del("Access-Control-Allow-Methods")
|
||||
resp.Header.Del("Access-Control-Allow-Origin")
|
||||
}
|
||||
|
|
|
@ -51,8 +51,10 @@ func (s *SimpleBackendHandler) ServeHTTP(w http.ResponseWriter, req *http.Reques
|
|||
return
|
||||
}
|
||||
|
||||
for k, v := range s.responseHeader {
|
||||
w.Header().Add(k, v)
|
||||
if s.responseHeader != nil {
|
||||
for k, v := range s.responseHeader {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.Write([]byte(s.responseBody))
|
||||
}
|
||||
|
@ -71,7 +73,7 @@ func validateParameters(t *testing.T, name string, actual url.Values, expected m
|
|||
}
|
||||
}
|
||||
|
||||
func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string) {
|
||||
func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string, notExpected []string) {
|
||||
for k, v := range expected {
|
||||
actualValue, ok := actual[k]
|
||||
if !ok {
|
||||
|
@ -83,17 +85,28 @@ func validateHeaders(t *testing.T, name string, actual http.Header, expected map
|
|||
name, k, actualValue, v)
|
||||
}
|
||||
}
|
||||
if notExpected == nil {
|
||||
return
|
||||
}
|
||||
for _, h := range notExpected {
|
||||
if _, present := actual[h]; present {
|
||||
t.Errorf("%s: unexpected header: %s", name, h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
requestPath string
|
||||
expectedPath string
|
||||
requestBody string
|
||||
requestParams map[string]string
|
||||
requestHeader map[string]string
|
||||
name string
|
||||
method string
|
||||
requestPath string
|
||||
expectedPath string
|
||||
requestBody string
|
||||
requestParams map[string]string
|
||||
requestHeader map[string]string
|
||||
responseHeader map[string]string
|
||||
expectedRespHeader map[string]string
|
||||
notExpectedRespHeader []string
|
||||
}{
|
||||
{
|
||||
name: "root path, simple get",
|
||||
|
@ -128,14 +141,37 @@ func TestServeHTTP(t *testing.T) {
|
|||
requestPath: "",
|
||||
expectedPath: "/",
|
||||
},
|
||||
{
|
||||
name: "remove CORS headers",
|
||||
method: "GET",
|
||||
requestPath: "/some/path",
|
||||
expectedPath: "/some/path",
|
||||
responseHeader: map[string]string{
|
||||
"Header1": "value1",
|
||||
"Access-Control-Allow-Origin": "some.server",
|
||||
"Access-Control-Allow-Methods": "GET"},
|
||||
expectedRespHeader: map[string]string{
|
||||
"Header1": "value1",
|
||||
},
|
||||
notExpectedRespHeader: []string{
|
||||
"Access-Control-Allow-Origin",
|
||||
"Access-Control-Allow-Methods",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
func() {
|
||||
backendResponse := "<html><head></head><body><a href=\"/test/path\">Hello</a></body></html>"
|
||||
backendResponseHeader := test.responseHeader
|
||||
// Test a simple header if not specified in the test
|
||||
if backendResponseHeader == nil && test.expectedRespHeader == nil {
|
||||
backendResponseHeader = map[string]string{"Content-Type": "text/html"}
|
||||
test.expectedRespHeader = map[string]string{"Content-Type": "text/html"}
|
||||
}
|
||||
backendHandler := &SimpleBackendHandler{
|
||||
responseBody: backendResponse,
|
||||
responseHeader: map[string]string{"Content-Type": "text/html"},
|
||||
responseHeader: backendResponseHeader,
|
||||
}
|
||||
backendServer := httptest.NewServer(backendHandler)
|
||||
defer backendServer.Close()
|
||||
|
@ -197,9 +233,13 @@ func TestServeHTTP(t *testing.T) {
|
|||
|
||||
// Headers
|
||||
validateHeaders(t, test.name+" backend request", backendHandler.requestHeader,
|
||||
test.requestHeader)
|
||||
test.requestHeader, nil)
|
||||
|
||||
// Validate proxy response
|
||||
|
||||
// Response Headers
|
||||
validateHeaders(t, test.name+" backend headers", res.Header, test.expectedRespHeader, test.notExpectedRespHeader)
|
||||
|
||||
// Validate Body
|
||||
responseBody, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
|
@ -297,7 +337,7 @@ func TestDefaultProxyTransport(t *testing.T) {
|
|||
Location: locURL,
|
||||
}
|
||||
result := h.defaultProxyTransport(URL)
|
||||
transport := result.(*proxy.Transport)
|
||||
transport := result.(*corsRemovingTransport).RoundTripper.(*proxy.Transport)
|
||||
if transport.Scheme != test.expectedScheme {
|
||||
t.Errorf("%s: unexpected scheme. Actual: %s, Expected: %s", test.name, transport.Scheme, test.expectedScheme)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue