From 8723eece49efd00af1089a6fd0a6f7ec0d1c65cc Mon Sep 17 00:00:00 2001 From: Jessica Forrester Date: Tue, 2 Sep 2014 12:28:24 -0400 Subject: [PATCH 1/6] Add option to enable a simple CORS implementation for the api server --- cmd/apiserver/apiserver.go | 3 +- cmd/integration/integration.go | 2 +- hack/local-up-cluster.sh | 4 +- pkg/apiserver/apiserver.go | 36 ++---------- pkg/apiserver/apiserver_test.go | 67 ++++++++++++++++------ pkg/apiserver/handlers.go | 92 +++++++++++++++++++++++++++++++ pkg/apiserver/minionproxy_test.go | 2 +- pkg/apiserver/operation_test.go | 4 +- pkg/apiserver/redirect_test.go | 2 +- pkg/apiserver/watch_test.go | 6 +- test/integration/client_test.go | 2 +- 11 files changed, 161 insertions(+), 59 deletions(-) create mode 100644 pkg/apiserver/handlers.go diff --git a/cmd/apiserver/apiserver.go b/cmd/apiserver/apiserver.go index fd66438f8a..bebc3f9169 100644 --- a/cmd/apiserver/apiserver.go +++ b/cmd/apiserver/apiserver.go @@ -39,6 +39,7 @@ var ( port = flag.Uint("port", 8080, "The port to listen on. Default 8080.") address = flag.String("address", "127.0.0.1", "The address on the local server to listen to. Default 127.0.0.1") apiPrefix = flag.String("api_prefix", "/api/v1beta1", "The prefix for API requests on the server. Default '/api/v1beta1'") + enableCORS = flag.Bool("enable_cors", false, "If true, the basic CORS implementation will be enabled. [default false]") cloudProvider = flag.String("cloud_provider", "", "The provider for cloud services. Empty string for no provider.") cloudConfigFile = flag.String("cloud_config", "", "The path to the cloud provider configuration file. Empty string for no configuration file.") minionRegexp = flag.String("minion_regexp", "", "If non empty, and -cloud_provider is specified, a regular expression for matching minion VMs") @@ -134,7 +135,7 @@ func main() { storage, codec := m.API_v1beta1() s := &http.Server{ Addr: net.JoinHostPort(*address, strconv.Itoa(int(*port))), - Handler: apiserver.Handle(storage, codec, *apiPrefix), + Handler: apiserver.Handle(storage, codec, *apiPrefix, *enableCORS), ReadTimeout: 5 * time.Minute, WriteTimeout: 5 * time.Minute, MaxHeaderBytes: 1 << 20, diff --git a/cmd/integration/integration.go b/cmd/integration/integration.go index 14fe09fdd3..43c17dd7da 100644 --- a/cmd/integration/integration.go +++ b/cmd/integration/integration.go @@ -116,7 +116,7 @@ func startComponents(manifestURL string) (apiServerURL string) { PodInfoGetter: fakePodInfoGetter{}, }) storage, codec := m.API_v1beta1() - handler.delegate = apiserver.Handle(storage, codec, "/api/v1beta1") + handler.delegate = apiserver.Handle(storage, codec, "/api/v1beta1", false) // Scheduler scheduler.New((&factory.ConfigFactory{cl}).Create()).Run() diff --git a/hack/local-up-cluster.sh b/hack/local-up-cluster.sh index 65a794b1b5..ac564515c9 100755 --- a/hack/local-up-cluster.sh +++ b/hack/local-up-cluster.sh @@ -39,6 +39,7 @@ set +e API_PORT=${API_PORT:-8080} API_HOST=${API_HOST:-127.0.0.1} +API_ENABLE_CORS=${API_ENABLE_CORS:-false} KUBELET_PORT=${KUBELET_PORT:-10250} GO_OUT=$(dirname $0)/../_output/go/bin @@ -48,7 +49,8 @@ APISERVER_LOG=/tmp/apiserver.log --address="${API_HOST}" \ --port="${API_PORT}" \ --etcd_servers="http://127.0.0.1:4001" \ - --machines="127.0.0.1" >"${APISERVER_LOG}" 2>&1 & + --machines="127.0.0.1" \ + --enable_cors="${API_ENABLE_CORS}" >"${APISERVER_LOG}" 2>&1 & APISERVER_PID=$! CTLRMGR_LOG=/tmp/controller-manager.log diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index ebc722b861..4d6b57d9d1 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -18,10 +18,8 @@ package apiserver import ( "encoding/json" - "fmt" "io/ioutil" "net/http" - "runtime/debug" "strings" "time" @@ -57,8 +55,11 @@ func Handle(storage map[string]RESTStorage, codec runtime.Codec, prefix string) mux := http.NewServeMux() group.InstallREST(mux, prefix) InstallSupport(mux) - - return &defaultAPIServer{RecoverPanics(mux), group} + handler := RecoverPanics(mux) + if enableCORS { + handler = CORS(handler, []string{".*"}, nil, nil, "true") + } + return &defaultAPIServer{handler, group} } // APIGroup is a http.Handler that exposes multiple RESTStorage objects @@ -116,33 +117,6 @@ func InstallSupport(mux mux) { mux.HandleFunc("/", handleIndex) } -// RecoverPanics wraps an http Handler to recover and log panics. -func RecoverPanics(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - defer func() { - if x := recover(); x != nil { - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprint(w, "apis panic. Look in log for details.") - glog.Infof("APIServer panic'd on %v %v: %#v\n%s\n", req.Method, req.RequestURI, x, debug.Stack()) - } - }() - defer httplog.NewLogged(req, &w).StacktraceWhen( - httplog.StatusIsNot( - http.StatusOK, - http.StatusAccepted, - http.StatusMovedPermanently, - http.StatusTemporaryRedirect, - http.StatusConflict, - http.StatusNotFound, - StatusUnprocessableEntity, - ), - ).Log() - - // Dispatch to the internal handler - handler.ServeHTTP(w, req) - }) -} - // handleVersion writes the server's version information. func handleVersion(w http.ResponseWriter, req *http.Request) { writeRawJSON(http.StatusOK, version.Get(), w) diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index f08dc999f9..e0f8806959 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -191,7 +191,7 @@ func TestNotFound(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": &SimpleRESTStorage{}, - }, codec, "/prefix/version") + }, codec, "/prefix/version", false) server := httptest.NewServer(handler) client := http.Client{} for k, v := range cases { @@ -212,7 +212,7 @@ func TestNotFound(t *testing.T) { } func TestVersion(t *testing.T) { - handler := Handle(map[string]RESTStorage{}, codec, "/prefix/version") + handler := Handle(map[string]RESTStorage{}, codec, "/prefix/version", false) server := httptest.NewServer(handler) client := http.Client{} @@ -241,7 +241,7 @@ func TestSimpleList(t *testing.T) { storage := map[string]RESTStorage{} simpleStorage := SimpleRESTStorage{} storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version") + handler := Handle(storage, codec, "/prefix/version", false) server := httptest.NewServer(handler) resp, err := http.Get(server.URL + "/prefix/version/simple") @@ -260,7 +260,7 @@ func TestErrorList(t *testing.T) { errors: map[string]error{"list": fmt.Errorf("test Error")}, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version") + handler := Handle(storage, codec, "/prefix/version", false) server := httptest.NewServer(handler) resp, err := http.Get(server.URL + "/prefix/version/simple") @@ -284,7 +284,7 @@ func TestNonEmptyList(t *testing.T) { }, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version") + handler := Handle(storage, codec, "/prefix/version", false) server := httptest.NewServer(handler) resp, err := http.Get(server.URL + "/prefix/version/simple") @@ -319,7 +319,7 @@ func TestGet(t *testing.T) { }, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version") + handler := Handle(storage, codec, "/prefix/version", false) server := httptest.NewServer(handler) resp, err := http.Get(server.URL + "/prefix/version/simple/id") @@ -340,7 +340,7 @@ func TestGetMissing(t *testing.T) { errors: map[string]error{"get": apierrs.NewNotFound("simple", "id")}, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version") + handler := Handle(storage, codec, "/prefix/version", false) server := httptest.NewServer(handler) resp, err := http.Get(server.URL + "/prefix/version/simple/id") @@ -358,7 +358,7 @@ func TestDelete(t *testing.T) { simpleStorage := SimpleRESTStorage{} ID := "id" storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version") + handler := Handle(storage, codec, "/prefix/version", false) server := httptest.NewServer(handler) client := http.Client{} @@ -380,7 +380,7 @@ func TestDeleteMissing(t *testing.T) { errors: map[string]error{"delete": apierrs.NewNotFound("simple", ID)}, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version") + handler := Handle(storage, codec, "/prefix/version", false) server := httptest.NewServer(handler) client := http.Client{} @@ -400,7 +400,7 @@ func TestUpdate(t *testing.T) { simpleStorage := SimpleRESTStorage{} ID := "id" storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version") + handler := Handle(storage, codec, "/prefix/version", false) server := httptest.NewServer(handler) item := &Simple{ @@ -430,7 +430,7 @@ func TestUpdateMissing(t *testing.T) { errors: map[string]error{"update": apierrs.NewNotFound("simple", ID)}, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version") + handler := Handle(storage, codec, "/prefix/version", false) server := httptest.NewServer(handler) item := &Simple{ @@ -457,7 +457,7 @@ func TestCreate(t *testing.T) { simpleStorage := &SimpleRESTStorage{} handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version") + }, codec, "/prefix/version", false) handler.(*defaultAPIServer).group.handler.asyncOpWait = 0 server := httptest.NewServer(handler) client := http.Client{} @@ -498,7 +498,7 @@ func TestCreateNotFound(t *testing.T) { // See https://github.com/GoogleCloudPlatform/kubernetes/pull/486#discussion_r15037092. errors: map[string]error{"create": apierrs.NewNotFound("simple", "id")}, }, - }, codec, "/prefix/version") + }, codec, "/prefix/version", false) server := httptest.NewServer(handler) client := http.Client{} @@ -540,7 +540,7 @@ func TestSyncCreate(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": &storage, - }, codec, "/prefix/version") + }, codec, "/prefix/version", false) server := httptest.NewServer(handler) client := http.Client{} @@ -609,7 +609,7 @@ func TestAsyncDelayReturnsError(t *testing.T) { return nil, apierrs.NewAlreadyExists("foo", "bar") }, } - handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version") + handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version", false) handler.(*defaultAPIServer).group.handler.asyncOpWait = time.Millisecond / 2 server := httptest.NewServer(handler) @@ -627,7 +627,7 @@ func TestAsyncCreateError(t *testing.T) { return nil, apierrs.NewAlreadyExists("foo", "bar") }, } - handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version") + handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version", false) handler.(*defaultAPIServer).group.handler.asyncOpWait = 0 server := httptest.NewServer(handler) @@ -721,7 +721,7 @@ func TestSyncCreateTimeout(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": &storage, - }, codec, "/prefix/version") + }, codec, "/prefix/version", false) server := httptest.NewServer(handler) simple := &Simple{Name: "foo"} @@ -731,3 +731,36 @@ func TestSyncCreateTimeout(t *testing.T) { t.Errorf("Unexpected status %#v", itemOut) } } + +func TestEnableCORS(t *testing.T) { + handler := Handle(map[string]RESTStorage{}, codec, "/prefix/version", true) + server := httptest.NewServer(handler) + client := http.Client{} + + request, err := http.NewRequest("GET", server.URL+"/version", nil) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + request.Header.Set("Origin", "example.com") + + response, err := client.Do(request) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if !reflect.DeepEqual(response.Header.Get("Access-Control-Allow-Origin"), "example.com") { + t.Errorf("Expected %#v, Got %#v", response.Header.Get("Access-Control-Allow-Origin"), "example.com") + } + + if response.Header.Get("Access-Control-Allow-Credentials") == "" { + t.Errorf("Expected Access-Control-Allow-Credentials header to be set") + } + + if response.Header.Get("Access-Control-Allow-Headers") == "" { + t.Errorf("Expected Access-Control-Allow-Headers header to be set") + } + + if response.Header.Get("Access-Control-Allow-Methods") == "" { + t.Errorf("Expected Access-Control-Allow-Methods header to be set") + } +} diff --git a/pkg/apiserver/handlers.go b/pkg/apiserver/handlers.go new file mode 100644 index 0000000000..ac79f4d84d --- /dev/null +++ b/pkg/apiserver/handlers.go @@ -0,0 +1,92 @@ +/* +Copyright 2014 Google Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package apiserver + +import ( + "fmt" + "net/http" + "regexp" + "runtime/debug" + "strings" + + "github.com/GoogleCloudPlatform/kubernetes/pkg/httplog" + "github.com/golang/glog" +) + +// RecoverPanics wraps an http Handler to recover and log panics. +func RecoverPanics(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + defer func() { + if x := recover(); x != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, "apis panic. Look in log for details.") + glog.Infof("APIServer panic'd on %v %v: %#v\n%s\n", req.Method, req.RequestURI, x, debug.Stack()) + } + }() + defer httplog.NewLogged(req, &w).StacktraceWhen( + httplog.StatusIsNot( + http.StatusOK, + http.StatusAccepted, + http.StatusMovedPermanently, + http.StatusTemporaryRedirect, + http.StatusConflict, + http.StatusNotFound, + StatusUnprocessableEntity, + ), + ).Log() + + // Dispatch to the internal handler + handler.ServeHTTP(w, req) + }) +} + +// Simple CORS implementation that wraps an http Handler +// For a more detailed implementation use https://github.com/martini-contrib/cors +// or implement CORS at your proxy layer +// Pass nil for allowedMethods and allowedHeaders to use the defaults +func CORS(handler http.Handler, allowedOriginPatterns []string, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + origin := req.Header.Get("Origin") + if origin != "" { + allowed := false + for _, pattern := range allowedOriginPatterns { + allowed, _ = regexp.MatchString(pattern, origin) + } + if allowed { + w.Header().Set("Access-Control-Allow-Origin", origin) + // Set defaults for methods and headers if nothing was passed + if allowedMethods == nil { + allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE"} + } + if allowedHeaders == nil { + allowedHeaders = []string{"Accept", "Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization"} + } + w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", ")) + w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", ")) + w.Header().Set("Access-Control-Allow-Credentials", allowCredentials) + + // Stop here if its a preflight OPTIONS request + if req.Method == "OPTIONS" { + w.WriteHeader(http.StatusNoContent) + return + } + } + } + // Dispatch to the next handler + handler.ServeHTTP(w, req) + }) +} diff --git a/pkg/apiserver/minionproxy_test.go b/pkg/apiserver/minionproxy_test.go index 1ceeedf3a6..26c6820dff 100644 --- a/pkg/apiserver/minionproxy_test.go +++ b/pkg/apiserver/minionproxy_test.go @@ -127,7 +127,7 @@ func TestApiServerMinionProxy(t *testing.T) { proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte(req.URL.Path)) })) - server := httptest.NewServer(Handle(nil, nil, "/prefix")) + server := httptest.NewServer(Handle(nil, nil, "/prefix", false)) proxy, _ := url.Parse(proxyServer.URL) resp, err := http.Get(fmt.Sprintf("%s/proxy/minion/%s%s", server.URL, proxy.Host, "/test")) if err != nil { diff --git a/pkg/apiserver/operation_test.go b/pkg/apiserver/operation_test.go index 856fe76438..203a415232 100644 --- a/pkg/apiserver/operation_test.go +++ b/pkg/apiserver/operation_test.go @@ -107,7 +107,7 @@ func TestOperationsList(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version") + }, codec, "/prefix/version", false) handler.(*defaultAPIServer).group.handler.asyncOpWait = 0 server := httptest.NewServer(handler) client := http.Client{} @@ -163,7 +163,7 @@ func TestOpGet(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version") + }, codec, "/prefix/version", false) handler.(*defaultAPIServer).group.handler.asyncOpWait = 0 server := httptest.NewServer(handler) client := http.Client{} diff --git a/pkg/apiserver/redirect_test.go b/pkg/apiserver/redirect_test.go index 6425f920d8..5730b75b95 100644 --- a/pkg/apiserver/redirect_test.go +++ b/pkg/apiserver/redirect_test.go @@ -30,7 +30,7 @@ func TestRedirect(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version") + }, codec, "/prefix/version", false) server := httptest.NewServer(handler) dontFollow := errors.New("don't follow") diff --git a/pkg/apiserver/watch_test.go b/pkg/apiserver/watch_test.go index 70027c1d44..ff5bbc864c 100644 --- a/pkg/apiserver/watch_test.go +++ b/pkg/apiserver/watch_test.go @@ -44,7 +44,7 @@ func TestWatchWebsocket(t *testing.T) { _ = ResourceWatcher(simpleStorage) // Give compile error if this doesn't work. handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version") + }, codec, "/prefix/version", false) server := httptest.NewServer(handler) dest, _ := url.Parse(server.URL) @@ -90,7 +90,7 @@ func TestWatchHTTP(t *testing.T) { simpleStorage := &SimpleRESTStorage{} handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version") + }, codec, "/prefix/version", false) server := httptest.NewServer(handler) client := http.Client{} @@ -147,7 +147,7 @@ func TestWatchParamParsing(t *testing.T) { simpleStorage := &SimpleRESTStorage{} handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version") + }, codec, "/prefix/version", false) server := httptest.NewServer(handler) dest, _ := url.Parse(server.URL) diff --git a/test/integration/client_test.go b/test/integration/client_test.go index 5d539d11e2..f8f3e35645 100644 --- a/test/integration/client_test.go +++ b/test/integration/client_test.go @@ -41,7 +41,7 @@ func TestClient(t *testing.T) { }) storage, codec := m.API_v1beta1() - s := httptest.NewServer(apiserver.Handle(storage, codec, "/api/v1beta1/")) + s := httptest.NewServer(apiserver.Handle(storage, codec, "/api/v1beta1/", false)) client := client.NewOrDie(s.URL, nil) From 8b4ca9c2a7913d2dc6c951cc8ba3a1a00c811e2a Mon Sep 17 00:00:00 2001 From: Jessica Forrester Date: Wed, 3 Sep 2014 14:33:52 -0400 Subject: [PATCH 2/6] Move CORS handler wrapping into cmd/apiserver and switch config flag to a list of allowed origins --- cmd/apiserver/apiserver.go | 33 ++++++++------ cmd/integration/integration.go | 2 +- hack/local-up-cluster.sh | 5 ++- pkg/apiserver/apiserver.go | 6 +-- pkg/apiserver/apiserver_test.go | 71 ++++++++++++++++++++++--------- pkg/apiserver/handlers.go | 19 +++++++-- pkg/apiserver/minionproxy_test.go | 2 +- pkg/apiserver/operation_test.go | 4 +- pkg/apiserver/redirect_test.go | 2 +- pkg/apiserver/watch_test.go | 6 +-- test/integration/client_test.go | 2 +- 11 files changed, 102 insertions(+), 50 deletions(-) diff --git a/cmd/apiserver/apiserver.go b/cmd/apiserver/apiserver.go index bebc3f9169..96fa3e44d2 100644 --- a/cmd/apiserver/apiserver.go +++ b/cmd/apiserver/apiserver.go @@ -36,22 +36,25 @@ import ( ) var ( - port = flag.Uint("port", 8080, "The port to listen on. Default 8080.") - address = flag.String("address", "127.0.0.1", "The address on the local server to listen to. Default 127.0.0.1") - apiPrefix = flag.String("api_prefix", "/api/v1beta1", "The prefix for API requests on the server. Default '/api/v1beta1'") - enableCORS = flag.Bool("enable_cors", false, "If true, the basic CORS implementation will be enabled. [default false]") - cloudProvider = flag.String("cloud_provider", "", "The provider for cloud services. Empty string for no provider.") - cloudConfigFile = flag.String("cloud_config", "", "The path to the cloud provider configuration file. Empty string for no configuration file.") - minionRegexp = flag.String("minion_regexp", "", "If non empty, and -cloud_provider is specified, a regular expression for matching minion VMs") - minionPort = flag.Uint("minion_port", 10250, "The port at which kubelet will be listening on the minions.") - healthCheckMinions = flag.Bool("health_check_minions", true, "If true, health check minions and filter unhealthy ones. [default true]") - minionCacheTTL = flag.Duration("minion_cache_ttl", 30*time.Second, "Duration of time to cache minion information. [default 30 seconds]") - etcdServerList, machineList util.StringList + port = flag.Uint("port", 8080, "The port to listen on. Default 8080.") + address = flag.String("address", "127.0.0.1", "The address on the local server to listen to. Default 127.0.0.1") + apiPrefix = flag.String("api_prefix", "/api/v1beta1", "The prefix for API requests on the server. Default '/api/v1beta1'") + enableCORS = flag.Bool("enable_cors", false, "If true, the basic CORS implementation will be enabled. [default false]") + cloudProvider = flag.String("cloud_provider", "", "The provider for cloud services. Empty string for no provider.") + cloudConfigFile = flag.String("cloud_config", "", "The path to the cloud provider configuration file. Empty string for no configuration file.") + minionRegexp = flag.String("minion_regexp", "", "If non empty, and -cloud_provider is specified, a regular expression for matching minion VMs") + minionPort = flag.Uint("minion_port", 10250, "The port at which kubelet will be listening on the minions.") + healthCheckMinions = flag.Bool("health_check_minions", true, "If true, health check minions and filter unhealthy ones. [default true]") + minionCacheTTL = flag.Duration("minion_cache_ttl", 30*time.Second, "Duration of time to cache minion information. [default 30 seconds]") + etcdServerList util.StringList + machineList util.StringList + corsAllowedOriginList util.StringList ) func init() { flag.Var(&etcdServerList, "etcd_servers", "List of etcd servers to watch (http://ip:port), comma separated") flag.Var(&machineList, "machines", "List of machines to schedule onto, comma separated.") + flag.Var(&corsAllowedOriginList, "cors_allowed_origins", "List of allowed origins for CORS, comma separated. An allowed origin can be a regular expression to support subdomain matching. If this list is empty CORS will not be enabled.") } func verifyMinionFlags() { @@ -133,9 +136,15 @@ func main() { }) storage, codec := m.API_v1beta1() + + handler := apiserver.Handle(storage, codec, *apiPrefix) + if len(corsAllowedOriginList) > 0 { + handler = apiserver.CORS(handler, corsAllowedOriginList, nil, nil, "true") + } + s := &http.Server{ Addr: net.JoinHostPort(*address, strconv.Itoa(int(*port))), - Handler: apiserver.Handle(storage, codec, *apiPrefix, *enableCORS), + Handler: handler, ReadTimeout: 5 * time.Minute, WriteTimeout: 5 * time.Minute, MaxHeaderBytes: 1 << 20, diff --git a/cmd/integration/integration.go b/cmd/integration/integration.go index 43c17dd7da..14fe09fdd3 100644 --- a/cmd/integration/integration.go +++ b/cmd/integration/integration.go @@ -116,7 +116,7 @@ func startComponents(manifestURL string) (apiServerURL string) { PodInfoGetter: fakePodInfoGetter{}, }) storage, codec := m.API_v1beta1() - handler.delegate = apiserver.Handle(storage, codec, "/api/v1beta1", false) + handler.delegate = apiserver.Handle(storage, codec, "/api/v1beta1") // Scheduler scheduler.New((&factory.ConfigFactory{cl}).Create()).Run() diff --git a/hack/local-up-cluster.sh b/hack/local-up-cluster.sh index ac564515c9..875da775b6 100755 --- a/hack/local-up-cluster.sh +++ b/hack/local-up-cluster.sh @@ -39,7 +39,8 @@ set +e API_PORT=${API_PORT:-8080} API_HOST=${API_HOST:-127.0.0.1} -API_ENABLE_CORS=${API_ENABLE_CORS:-false} +# By default only allow CORS for requests on localhost +API_CORS_ALLOWED_ORIGINS=${API_CORS_ALLOWED_ORIGINS:-127.0.0.1:.*,localhost:.*} KUBELET_PORT=${KUBELET_PORT:-10250} GO_OUT=$(dirname $0)/../_output/go/bin @@ -50,7 +51,7 @@ APISERVER_LOG=/tmp/apiserver.log --port="${API_PORT}" \ --etcd_servers="http://127.0.0.1:4001" \ --machines="127.0.0.1" \ - --enable_cors="${API_ENABLE_CORS}" >"${APISERVER_LOG}" 2>&1 & + --cors_allowed_origins="${API_CORS_ALLOWED_ORIGINS}" >"${APISERVER_LOG}" 2>&1 & APISERVER_PID=$! CTLRMGR_LOG=/tmp/controller-manager.log diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 4d6b57d9d1..908774cd3c 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -55,11 +55,7 @@ func Handle(storage map[string]RESTStorage, codec runtime.Codec, prefix string) mux := http.NewServeMux() group.InstallREST(mux, prefix) InstallSupport(mux) - handler := RecoverPanics(mux) - if enableCORS { - handler = CORS(handler, []string{".*"}, nil, nil, "true") - } - return &defaultAPIServer{handler, group} + return &defaultAPIServer{RecoverPanics(mux), group} } // APIGroup is a http.Handler that exposes multiple RESTStorage objects diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index e0f8806959..6f970c2ecc 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -191,7 +191,7 @@ func TestNotFound(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": &SimpleRESTStorage{}, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} for k, v := range cases { @@ -212,7 +212,7 @@ func TestNotFound(t *testing.T) { } func TestVersion(t *testing.T) { - handler := Handle(map[string]RESTStorage{}, codec, "/prefix/version", false) + handler := Handle(map[string]RESTStorage{}, codec, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} @@ -241,7 +241,7 @@ func TestSimpleList(t *testing.T) { storage := map[string]RESTStorage{} simpleStorage := SimpleRESTStorage{} storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) resp, err := http.Get(server.URL + "/prefix/version/simple") @@ -260,7 +260,7 @@ func TestErrorList(t *testing.T) { errors: map[string]error{"list": fmt.Errorf("test Error")}, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) resp, err := http.Get(server.URL + "/prefix/version/simple") @@ -284,7 +284,7 @@ func TestNonEmptyList(t *testing.T) { }, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) resp, err := http.Get(server.URL + "/prefix/version/simple") @@ -319,7 +319,7 @@ func TestGet(t *testing.T) { }, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) resp, err := http.Get(server.URL + "/prefix/version/simple/id") @@ -340,7 +340,7 @@ func TestGetMissing(t *testing.T) { errors: map[string]error{"get": apierrs.NewNotFound("simple", "id")}, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) resp, err := http.Get(server.URL + "/prefix/version/simple/id") @@ -358,7 +358,7 @@ func TestDelete(t *testing.T) { simpleStorage := SimpleRESTStorage{} ID := "id" storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} @@ -380,7 +380,7 @@ func TestDeleteMissing(t *testing.T) { errors: map[string]error{"delete": apierrs.NewNotFound("simple", ID)}, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} @@ -400,7 +400,7 @@ func TestUpdate(t *testing.T) { simpleStorage := SimpleRESTStorage{} ID := "id" storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) item := &Simple{ @@ -430,7 +430,7 @@ func TestUpdateMissing(t *testing.T) { errors: map[string]error{"update": apierrs.NewNotFound("simple", ID)}, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) item := &Simple{ @@ -457,7 +457,7 @@ func TestCreate(t *testing.T) { simpleStorage := &SimpleRESTStorage{} handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") handler.(*defaultAPIServer).group.handler.asyncOpWait = 0 server := httptest.NewServer(handler) client := http.Client{} @@ -498,7 +498,7 @@ func TestCreateNotFound(t *testing.T) { // See https://github.com/GoogleCloudPlatform/kubernetes/pull/486#discussion_r15037092. errors: map[string]error{"create": apierrs.NewNotFound("simple", "id")}, }, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} @@ -540,7 +540,7 @@ func TestSyncCreate(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": &storage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} @@ -609,7 +609,7 @@ func TestAsyncDelayReturnsError(t *testing.T) { return nil, apierrs.NewAlreadyExists("foo", "bar") }, } - handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version", false) + handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version") handler.(*defaultAPIServer).group.handler.asyncOpWait = time.Millisecond / 2 server := httptest.NewServer(handler) @@ -627,7 +627,7 @@ func TestAsyncCreateError(t *testing.T) { return nil, apierrs.NewAlreadyExists("foo", "bar") }, } - handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version", false) + handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version") handler.(*defaultAPIServer).group.handler.asyncOpWait = 0 server := httptest.NewServer(handler) @@ -721,7 +721,7 @@ func TestSyncCreateTimeout(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": &storage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) simple := &Simple{Name: "foo"} @@ -732,8 +732,8 @@ func TestSyncCreateTimeout(t *testing.T) { } } -func TestEnableCORS(t *testing.T) { - handler := Handle(map[string]RESTStorage{}, codec, "/prefix/version", true) +func TestCORSAllowedOrigin(t *testing.T) { + handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []string{"example.com"}, nil, nil, "true") server := httptest.NewServer(handler) client := http.Client{} @@ -764,3 +764,36 @@ func TestEnableCORS(t *testing.T) { t.Errorf("Expected Access-Control-Allow-Methods header to be set") } } + +func TestCORSUnallowedOrigin(t *testing.T) { + handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []string{"example.com"}, nil, nil, "true") + server := httptest.NewServer(handler) + client := http.Client{} + + request, err := http.NewRequest("GET", server.URL+"/version", nil) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + request.Header.Set("Origin", "not-allowed.com") + + response, err := client.Do(request) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if response.Header.Get("Access-Control-Allow-Origin") != "" { + t.Errorf("Expected Access-Control-Allow-Origin header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Credentials") != "" { + t.Errorf("Expected Access-Control-Allow-Credentials header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Headers") != "" { + t.Errorf("Expected Access-Control-Allow-Headers header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Methods") != "" { + t.Errorf("Expected Access-Control-Allow-Methods header to not be set") + } +} diff --git a/pkg/apiserver/handlers.go b/pkg/apiserver/handlers.go index ac79f4d84d..4554e15f25 100644 --- a/pkg/apiserver/handlers.go +++ b/pkg/apiserver/handlers.go @@ -24,6 +24,7 @@ import ( "strings" "github.com/GoogleCloudPlatform/kubernetes/pkg/httplog" + "github.com/GoogleCloudPlatform/kubernetes/pkg/util" "github.com/golang/glog" ) @@ -58,13 +59,25 @@ func RecoverPanics(handler http.Handler) http.Handler { // For a more detailed implementation use https://github.com/martini-contrib/cors // or implement CORS at your proxy layer // Pass nil for allowedMethods and allowedHeaders to use the defaults -func CORS(handler http.Handler, allowedOriginPatterns []string, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler { +func CORS(handler http.Handler, allowedOriginPatterns util.StringList, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler { + // Compile the regular expressions once upfront + allowedOriginRegexps := []*regexp.Regexp{} + for _, allowedOrigin := range allowedOriginPatterns { + allowedOriginRegexp, err := regexp.Compile(allowedOrigin) + if err != nil { + glog.Fatalf("Invalid CORS allowed origin regexp: %v", err) + } + allowedOriginRegexps = append(allowedOriginRegexps, allowedOriginRegexp) + } + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { origin := req.Header.Get("Origin") if origin != "" { allowed := false - for _, pattern := range allowedOriginPatterns { - allowed, _ = regexp.MatchString(pattern, origin) + for _, pattern := range allowedOriginRegexps { + if allowed = pattern.MatchString(origin); allowed { + break + } } if allowed { w.Header().Set("Access-Control-Allow-Origin", origin) diff --git a/pkg/apiserver/minionproxy_test.go b/pkg/apiserver/minionproxy_test.go index 26c6820dff..1ceeedf3a6 100644 --- a/pkg/apiserver/minionproxy_test.go +++ b/pkg/apiserver/minionproxy_test.go @@ -127,7 +127,7 @@ func TestApiServerMinionProxy(t *testing.T) { proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte(req.URL.Path)) })) - server := httptest.NewServer(Handle(nil, nil, "/prefix", false)) + server := httptest.NewServer(Handle(nil, nil, "/prefix")) proxy, _ := url.Parse(proxyServer.URL) resp, err := http.Get(fmt.Sprintf("%s/proxy/minion/%s%s", server.URL, proxy.Host, "/test")) if err != nil { diff --git a/pkg/apiserver/operation_test.go b/pkg/apiserver/operation_test.go index 203a415232..856fe76438 100644 --- a/pkg/apiserver/operation_test.go +++ b/pkg/apiserver/operation_test.go @@ -107,7 +107,7 @@ func TestOperationsList(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") handler.(*defaultAPIServer).group.handler.asyncOpWait = 0 server := httptest.NewServer(handler) client := http.Client{} @@ -163,7 +163,7 @@ func TestOpGet(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") handler.(*defaultAPIServer).group.handler.asyncOpWait = 0 server := httptest.NewServer(handler) client := http.Client{} diff --git a/pkg/apiserver/redirect_test.go b/pkg/apiserver/redirect_test.go index 5730b75b95..6425f920d8 100644 --- a/pkg/apiserver/redirect_test.go +++ b/pkg/apiserver/redirect_test.go @@ -30,7 +30,7 @@ func TestRedirect(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) dontFollow := errors.New("don't follow") diff --git a/pkg/apiserver/watch_test.go b/pkg/apiserver/watch_test.go index ff5bbc864c..70027c1d44 100644 --- a/pkg/apiserver/watch_test.go +++ b/pkg/apiserver/watch_test.go @@ -44,7 +44,7 @@ func TestWatchWebsocket(t *testing.T) { _ = ResourceWatcher(simpleStorage) // Give compile error if this doesn't work. handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) dest, _ := url.Parse(server.URL) @@ -90,7 +90,7 @@ func TestWatchHTTP(t *testing.T) { simpleStorage := &SimpleRESTStorage{} handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} @@ -147,7 +147,7 @@ func TestWatchParamParsing(t *testing.T) { simpleStorage := &SimpleRESTStorage{} handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) dest, _ := url.Parse(server.URL) diff --git a/test/integration/client_test.go b/test/integration/client_test.go index f8f3e35645..5d539d11e2 100644 --- a/test/integration/client_test.go +++ b/test/integration/client_test.go @@ -41,7 +41,7 @@ func TestClient(t *testing.T) { }) storage, codec := m.API_v1beta1() - s := httptest.NewServer(apiserver.Handle(storage, codec, "/api/v1beta1/", false)) + s := httptest.NewServer(apiserver.Handle(storage, codec, "/api/v1beta1/")) client := client.NewOrDie(s.URL, nil) From becf6ca4e7145078331b503644c3320fe0f111ca Mon Sep 17 00:00:00 2001 From: Jessica Forrester Date: Thu, 4 Sep 2014 13:55:30 -0400 Subject: [PATCH 3/6] Move RecoverPanics to be the top level wrapped handler. Add new method to be sure a logger has been generated instead of assuming one has. Move regexp list compilation into a utility and pass regexp list into CORS. --- cmd/apiserver/apiserver.go | 8 ++++++-- pkg/apiserver/apiserver.go | 2 +- pkg/apiserver/apiserver_test.go | 5 +++-- pkg/apiserver/handlers.go | 17 +++-------------- pkg/apiserver/resthandler.go | 2 +- pkg/apiserver/watch.go | 2 +- pkg/httplog/log.go | 9 +++++++++ pkg/util/util.go | 14 ++++++++++++++ 8 files changed, 38 insertions(+), 21 deletions(-) diff --git a/cmd/apiserver/apiserver.go b/cmd/apiserver/apiserver.go index 96fa3e44d2..43553e8880 100644 --- a/cmd/apiserver/apiserver.go +++ b/cmd/apiserver/apiserver.go @@ -139,12 +139,16 @@ func main() { handler := apiserver.Handle(storage, codec, *apiPrefix) if len(corsAllowedOriginList) > 0 { - handler = apiserver.CORS(handler, corsAllowedOriginList, nil, nil, "true") + allowedOriginRegexps, err := util.CompileRegexps(corsAllowedOriginList) + if err != nil { + glog.Fatalf("Invalid CORS allowed origin: %v", err) + } + handler = apiserver.CORS(handler, allowedOriginRegexps, nil, nil, "true") } s := &http.Server{ Addr: net.JoinHostPort(*address, strconv.Itoa(int(*port))), - Handler: handler, + Handler: apiserver.RecoverPanics(handler), ReadTimeout: 5 * time.Minute, WriteTimeout: 5 * time.Minute, MaxHeaderBytes: 1 << 20, diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 908774cd3c..2593688efb 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -55,7 +55,7 @@ func Handle(storage map[string]RESTStorage, codec runtime.Codec, prefix string) mux := http.NewServeMux() group.InstallREST(mux, prefix) InstallSupport(mux) - return &defaultAPIServer{RecoverPanics(mux), group} + return &defaultAPIServer{mux, group} } // APIGroup is a http.Handler that exposes multiple RESTStorage objects diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 6f970c2ecc..19e412bb3a 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -25,6 +25,7 @@ import ( "net/http" "net/http/httptest" "reflect" + "regexp" "strings" "sync" "testing" @@ -733,7 +734,7 @@ func TestSyncCreateTimeout(t *testing.T) { } func TestCORSAllowedOrigin(t *testing.T) { - handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []string{"example.com"}, nil, nil, "true") + handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []*regexp.Regexp{regexp.MustCompile("example.com")}, nil, nil, "true") server := httptest.NewServer(handler) client := http.Client{} @@ -766,7 +767,7 @@ func TestCORSAllowedOrigin(t *testing.T) { } func TestCORSUnallowedOrigin(t *testing.T) { - handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []string{"example.com"}, nil, nil, "true") + handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []*regexp.Regexp{regexp.MustCompile("example.com")}, nil, nil, "true") server := httptest.NewServer(handler) client := http.Client{} diff --git a/pkg/apiserver/handlers.go b/pkg/apiserver/handlers.go index 4554e15f25..a2f70a34f2 100644 --- a/pkg/apiserver/handlers.go +++ b/pkg/apiserver/handlers.go @@ -24,7 +24,6 @@ import ( "strings" "github.com/GoogleCloudPlatform/kubernetes/pkg/httplog" - "github.com/GoogleCloudPlatform/kubernetes/pkg/util" "github.com/golang/glog" ) @@ -59,22 +58,12 @@ func RecoverPanics(handler http.Handler) http.Handler { // For a more detailed implementation use https://github.com/martini-contrib/cors // or implement CORS at your proxy layer // Pass nil for allowedMethods and allowedHeaders to use the defaults -func CORS(handler http.Handler, allowedOriginPatterns util.StringList, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler { - // Compile the regular expressions once upfront - allowedOriginRegexps := []*regexp.Regexp{} - for _, allowedOrigin := range allowedOriginPatterns { - allowedOriginRegexp, err := regexp.Compile(allowedOrigin) - if err != nil { - glog.Fatalf("Invalid CORS allowed origin regexp: %v", err) - } - allowedOriginRegexps = append(allowedOriginRegexps, allowedOriginRegexp) - } - +func CORS(handler http.Handler, allowedOriginPatterns []*regexp.Regexp, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { origin := req.Header.Get("Origin") if origin != "" { allowed := false - for _, pattern := range allowedOriginRegexps { + for _, pattern := range allowedOriginPatterns { if allowed = pattern.MatchString(origin); allowed { break } @@ -86,7 +75,7 @@ func CORS(handler http.Handler, allowedOriginPatterns util.StringList, allowedMe allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE"} } if allowedHeaders == nil { - allowedHeaders = []string{"Accept", "Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization"} + allowedHeaders = []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "X-Requested-With", "If-Modified-Since"} } w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", ")) w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", ")) diff --git a/pkg/apiserver/resthandler.go b/pkg/apiserver/resthandler.go index 09251b5bc4..8f5657a5d8 100644 --- a/pkg/apiserver/resthandler.go +++ b/pkg/apiserver/resthandler.go @@ -42,7 +42,7 @@ func (h *RESTHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } storage := h.storage[parts[0]] if storage == nil { - httplog.LogOf(w).Addf("'%v' has no storage object", parts[0]) + httplog.FindOrCreateLogOf(req, &w).Addf("'%v' has no storage object", parts[0]) notFound(w, req) return } diff --git a/pkg/apiserver/watch.go b/pkg/apiserver/watch.go index 7cd047093b..60ff30c489 100644 --- a/pkg/apiserver/watch.go +++ b/pkg/apiserver/watch.go @@ -127,7 +127,7 @@ func (w *WatchServer) HandleWS(ws *websocket.Conn) { // ServeHTTP serves a series of JSON encoded events via straight HTTP with // Transfer-Encoding: chunked. func (self *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { - loggedW := httplog.LogOf(w) + loggedW := httplog.FindOrCreateLogOf(req, &w) w = httplog.Unlogged(w) cn, ok := w.(http.CloseNotifier) diff --git a/pkg/httplog/log.go b/pkg/httplog/log.go index 1f95b504b3..d9d887b9de 100644 --- a/pkg/httplog/log.go +++ b/pkg/httplog/log.go @@ -95,6 +95,15 @@ func LogOf(w http.ResponseWriter) *respLogger { panic("Logger not installed yet!") } +// Returns the existing logger hiding in w. If there is not an existing logger +// then one will be created. +func FindOrCreateLogOf(req *http.Request, w *http.ResponseWriter) *respLogger { + if _, exists := (*w).(*respLogger); !exists { + NewLogged(req, w) + } + return LogOf(*w) +} + // Unlogged returns the original ResponseWriter, or w if it is not our inserted logger. func Unlogged(w http.ResponseWriter) http.ResponseWriter { if rl, ok := w.(*respLogger); ok { diff --git a/pkg/util/util.go b/pkg/util/util.go index 5e9cf463b5..69aaccfc3a 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -19,6 +19,7 @@ package util import ( "encoding/json" "fmt" + "regexp" "runtime" "time" @@ -162,3 +163,16 @@ func StringDiff(a, b string) string { out = append(out, []byte("\n\n")...) return string(out) } + +// Takes a list of strings and compiles them into a list of regular expressions +func CompileRegexps(regexpStrings StringList) ([]*regexp.Regexp, error) { + regexps := []*regexp.Regexp{} + for _, regexpStr := range regexpStrings { + r, err := regexp.Compile(regexpStr) + if err != nil { + return []*regexp.Regexp{}, err + } + regexps = append(regexps, r) + } + return regexps, nil +} From 0cac1c5f79417fad73018431c7eee1739d09dec2 Mon Sep 17 00:00:00 2001 From: Jessica Forrester Date: Tue, 9 Sep 2014 17:05:18 -0400 Subject: [PATCH 4/6] Switch LogOf from panicking when logger is missing to creating logger with the defaults. Update CORS tests to a table-based test and cover more cases. --- cmd/apiserver/apiserver.go | 3 +- pkg/apiserver/apiserver_test.go | 122 +++++++++++++++++--------------- pkg/apiserver/redirect.go | 4 +- pkg/apiserver/resthandler.go | 10 +-- pkg/apiserver/watch.go | 2 +- pkg/httplog/log.go | 21 +++--- pkg/httplog/log_test.go | 15 ++-- 7 files changed, 85 insertions(+), 92 deletions(-) diff --git a/cmd/apiserver/apiserver.go b/cmd/apiserver/apiserver.go index 43553e8880..1becccec4e 100644 --- a/cmd/apiserver/apiserver.go +++ b/cmd/apiserver/apiserver.go @@ -24,6 +24,7 @@ import ( "net/http" "os" "strconv" + "strings" "time" "github.com/GoogleCloudPlatform/kubernetes/pkg/apiserver" @@ -141,7 +142,7 @@ func main() { if len(corsAllowedOriginList) > 0 { allowedOriginRegexps, err := util.CompileRegexps(corsAllowedOriginList) if err != nil { - glog.Fatalf("Invalid CORS allowed origin: %v", err) + glog.Fatalf("Invalid CORS allowed origin, --cors_allowed_origins flag was set to %v - %v", strings.Join(corsAllowedOriginList, ","), err) } handler = apiserver.CORS(handler, allowedOriginRegexps, nil, nil, "true") } diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 19e412bb3a..59dceff11e 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -25,7 +25,6 @@ import ( "net/http" "net/http/httptest" "reflect" - "regexp" "strings" "sync" "testing" @@ -35,6 +34,7 @@ import ( apierrs "github.com/GoogleCloudPlatform/kubernetes/pkg/api/errors" "github.com/GoogleCloudPlatform/kubernetes/pkg/labels" "github.com/GoogleCloudPlatform/kubernetes/pkg/runtime" + "github.com/GoogleCloudPlatform/kubernetes/pkg/util" "github.com/GoogleCloudPlatform/kubernetes/pkg/version" "github.com/GoogleCloudPlatform/kubernetes/pkg/watch" ) @@ -733,68 +733,72 @@ func TestSyncCreateTimeout(t *testing.T) { } } -func TestCORSAllowedOrigin(t *testing.T) { - handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []*regexp.Regexp{regexp.MustCompile("example.com")}, nil, nil, "true") - server := httptest.NewServer(handler) - client := http.Client{} - - request, err := http.NewRequest("GET", server.URL+"/version", nil) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - request.Header.Set("Origin", "example.com") - - response, err := client.Do(request) - if err != nil { - t.Errorf("unexpected error: %v", err) +func TestCORSAllowedOrigins(t *testing.T) { + table := []struct { + allowedOrigins util.StringList + origin string + allowed bool + }{ + {[]string{}, "example.com", false}, + {[]string{"example.com"}, "example.com", true}, + {[]string{"example.com"}, "not-allowed.com", false}, + {[]string{"not-matching.com", "example.com"}, "example.com", true}, + {[]string{".*"}, "example.com", true}, } - if !reflect.DeepEqual(response.Header.Get("Access-Control-Allow-Origin"), "example.com") { - t.Errorf("Expected %#v, Got %#v", response.Header.Get("Access-Control-Allow-Origin"), "example.com") - } + for _, item := range table { + allowedOriginRegexps, err := util.CompileRegexps(item.allowedOrigins) + if err != nil { + t.Errorf("unexpected error: %v", err) + } - if response.Header.Get("Access-Control-Allow-Credentials") == "" { - t.Errorf("Expected Access-Control-Allow-Credentials header to be set") - } + handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), allowedOriginRegexps, nil, nil, "true") + server := httptest.NewServer(handler) + client := http.Client{} - if response.Header.Get("Access-Control-Allow-Headers") == "" { - t.Errorf("Expected Access-Control-Allow-Headers header to be set") - } + request, err := http.NewRequest("GET", server.URL+"/version", nil) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + request.Header.Set("Origin", item.origin) - if response.Header.Get("Access-Control-Allow-Methods") == "" { - t.Errorf("Expected Access-Control-Allow-Methods header to be set") - } -} - -func TestCORSUnallowedOrigin(t *testing.T) { - handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []*regexp.Regexp{regexp.MustCompile("example.com")}, nil, nil, "true") - server := httptest.NewServer(handler) - client := http.Client{} - - request, err := http.NewRequest("GET", server.URL+"/version", nil) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - request.Header.Set("Origin", "not-allowed.com") - - response, err := client.Do(request) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - if response.Header.Get("Access-Control-Allow-Origin") != "" { - t.Errorf("Expected Access-Control-Allow-Origin header to not be set") - } - - if response.Header.Get("Access-Control-Allow-Credentials") != "" { - t.Errorf("Expected Access-Control-Allow-Credentials header to not be set") - } - - if response.Header.Get("Access-Control-Allow-Headers") != "" { - t.Errorf("Expected Access-Control-Allow-Headers header to not be set") - } - - if response.Header.Get("Access-Control-Allow-Methods") != "" { - t.Errorf("Expected Access-Control-Allow-Methods header to not be set") + response, err := client.Do(request) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if item.allowed { + if !reflect.DeepEqual(item.origin, response.Header.Get("Access-Control-Allow-Origin")) { + t.Errorf("Expected %#v, Got %#v", item.origin, response.Header.Get("Access-Control-Allow-Origin")) + } + + if response.Header.Get("Access-Control-Allow-Credentials") == "" { + t.Errorf("Expected Access-Control-Allow-Credentials header to be set") + } + + if response.Header.Get("Access-Control-Allow-Headers") == "" { + t.Errorf("Expected Access-Control-Allow-Headers header to be set") + } + + if response.Header.Get("Access-Control-Allow-Methods") == "" { + t.Errorf("Expected Access-Control-Allow-Methods header to be set") + } + } else { + if response.Header.Get("Access-Control-Allow-Origin") != "" { + t.Errorf("Expected Access-Control-Allow-Origin header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Credentials") != "" { + t.Errorf("Expected Access-Control-Allow-Credentials header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Headers") != "" { + t.Errorf("Expected Access-Control-Allow-Headers header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Methods") != "" { + t.Errorf("Expected Access-Control-Allow-Methods header to not be set") + } + } } } diff --git a/pkg/apiserver/redirect.go b/pkg/apiserver/redirect.go index 57aff02d3d..d517586540 100644 --- a/pkg/apiserver/redirect.go +++ b/pkg/apiserver/redirect.go @@ -38,14 +38,14 @@ func (r *RedirectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { id := parts[1] storage, ok := r.storage[resourceName] if !ok { - httplog.LogOf(w).Addf("'%v' has no storage object", resourceName) + httplog.LogOf(req, w).Addf("'%v' has no storage object", resourceName) notFound(w, req) return } redirector, ok := storage.(Redirector) if !ok { - httplog.LogOf(w).Addf("'%v' is not a redirector", resourceName) + httplog.LogOf(req, w).Addf("'%v' is not a redirector", resourceName) notFound(w, req) return } diff --git a/pkg/apiserver/resthandler.go b/pkg/apiserver/resthandler.go index 8f5657a5d8..4b6977b9e8 100644 --- a/pkg/apiserver/resthandler.go +++ b/pkg/apiserver/resthandler.go @@ -42,7 +42,7 @@ func (h *RESTHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } storage := h.storage[parts[0]] if storage == nil { - httplog.FindOrCreateLogOf(req, &w).Addf("'%v' has no storage object", parts[0]) + httplog.LogOf(req, w).Addf("'%v' has no storage object", parts[0]) notFound(w, req) return } @@ -114,7 +114,7 @@ func (h *RESTHandler) handleRESTStorage(parts []string, req *http.Request, w htt return } op := h.createOperation(out, sync, timeout) - h.finishReq(op, w) + h.finishReq(op, req, w) case "DELETE": if len(parts) != 2 { @@ -127,7 +127,7 @@ func (h *RESTHandler) handleRESTStorage(parts []string, req *http.Request, w htt return } op := h.createOperation(out, sync, timeout) - h.finishReq(op, w) + h.finishReq(op, req, w) case "PUT": if len(parts) != 2 { @@ -151,7 +151,7 @@ func (h *RESTHandler) handleRESTStorage(parts []string, req *http.Request, w htt return } op := h.createOperation(out, sync, timeout) - h.finishReq(op, w) + h.finishReq(op, req, w) default: notFound(w, req) @@ -171,7 +171,7 @@ func (h *RESTHandler) createOperation(out <-chan runtime.Object, sync bool, time // finishReq finishes up a request, waiting until the operation finishes or, after a timeout, creating an // Operation to receive the result and returning its ID down the writer. -func (h *RESTHandler) finishReq(op *Operation, w http.ResponseWriter) { +func (h *RESTHandler) finishReq(op *Operation, req *http.Request, w http.ResponseWriter) { obj, complete := op.StatusOrResult() if complete { status := http.StatusOK diff --git a/pkg/apiserver/watch.go b/pkg/apiserver/watch.go index 60ff30c489..771682087d 100644 --- a/pkg/apiserver/watch.go +++ b/pkg/apiserver/watch.go @@ -127,7 +127,7 @@ func (w *WatchServer) HandleWS(ws *websocket.Conn) { // ServeHTTP serves a series of JSON encoded events via straight HTTP with // Transfer-Encoding: chunked. func (self *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { - loggedW := httplog.FindOrCreateLogOf(req, &w) + loggedW := httplog.LogOf(req, w) w = httplog.Unlogged(w) cn, ok := w.(http.CloseNotifier) diff --git a/pkg/httplog/log.go b/pkg/httplog/log.go index d9d887b9de..912625b7eb 100644 --- a/pkg/httplog/log.go +++ b/pkg/httplog/log.go @@ -86,22 +86,17 @@ func NewLogged(req *http.Request, w *http.ResponseWriter) *respLogger { return rl } -// LogOf returns the logger hiding in w. Panics if there isn't such a logger, -// because NewLogged() must have been previously called for the log to work. -func LogOf(w http.ResponseWriter) *respLogger { +// LogOf returns the logger hiding in w. If there is not an existing logger +// then one will be created because NewLogged() must have been previously +// called for the log to work. +func LogOf(req *http.Request, w http.ResponseWriter) *respLogger { + if _, exists := w.(*respLogger); !exists { + NewLogged(req, &w) + } if rl, ok := w.(*respLogger); ok { return rl } - panic("Logger not installed yet!") -} - -// Returns the existing logger hiding in w. If there is not an existing logger -// then one will be created. -func FindOrCreateLogOf(req *http.Request, w *http.ResponseWriter) *respLogger { - if _, exists := (*w).(*respLogger); !exists { - NewLogged(req, w) - } - return LogOf(*w) + panic("Unable to find or create the logger!") } // Unlogged returns the original ResponseWriter, or w if it is not our inserted logger. diff --git a/pkg/httplog/log_test.go b/pkg/httplog/log_test.go index e287c8b11d..0ae8f53cde 100644 --- a/pkg/httplog/log_test.go +++ b/pkg/httplog/log_test.go @@ -91,19 +91,12 @@ func TestLogOf(t *testing.T) { t.Errorf("Unexpected error: %v", err) } handler := func(w http.ResponseWriter, r *http.Request) { - var want *respLogger if makeLogger { - want = NewLogged(req, &w) - } else { - defer func() { - if r := recover(); r == nil { - t.Errorf("Expected LogOf to panic") - } - }() + NewLogged(req, &w) } - got := LogOf(w) - if want != got { - t.Errorf("Expected %v, got %v", want, got) + got := reflect.TypeOf(*LogOf(r, w)).String() + if got != "httplog.respLogger" { + t.Errorf("Expected %v, got %v", "httplog.respLogger", got) } } w := httptest.NewRecorder() From d82cf7dd487730c5d5ff88c49a36c580ede31458 Mon Sep 17 00:00:00 2001 From: Jessica Forrester Date: Tue, 9 Sep 2014 17:16:07 -0400 Subject: [PATCH 5/6] Rebased and fixed test cases --- pkg/apiserver/apiserver.go | 1 - pkg/apiserver/proxy.go | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 2593688efb..34845c2abc 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -24,7 +24,6 @@ import ( "time" "github.com/GoogleCloudPlatform/kubernetes/pkg/healthz" - "github.com/GoogleCloudPlatform/kubernetes/pkg/httplog" "github.com/GoogleCloudPlatform/kubernetes/pkg/runtime" "github.com/GoogleCloudPlatform/kubernetes/pkg/version" "github.com/golang/glog" diff --git a/pkg/apiserver/proxy.go b/pkg/apiserver/proxy.go index 98330530ec..caeeb0eaec 100644 --- a/pkg/apiserver/proxy.go +++ b/pkg/apiserver/proxy.go @@ -89,14 +89,14 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } storage, ok := r.storage[resourceName] if !ok { - httplog.LogOf(w).Addf("'%v' has no storage object", resourceName) + httplog.LogOf(req, w).Addf("'%v' has no storage object", resourceName) notFound(w, req) return } redirector, ok := storage.(Redirector) if !ok { - httplog.LogOf(w).Addf("'%v' is not a redirector", resourceName) + httplog.LogOf(req, w).Addf("'%v' is not a redirector", resourceName) notFound(w, req) return } From 0f2b8f4f9f68fe621bc8dd7b5df361f934aa9b79 Mon Sep 17 00:00:00 2001 From: Jessica Forrester Date: Thu, 11 Sep 2014 16:48:06 -0400 Subject: [PATCH 6/6] Create passthroughLogger when LogOf is called and request has no logger --- pkg/apiserver/watch.go | 8 ++++---- pkg/httplog/log.go | 21 +++++++++++++++++---- pkg/httplog/log_test.go | 10 +++++++--- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/pkg/apiserver/watch.go b/pkg/apiserver/watch.go index 771682087d..38b3ace500 100644 --- a/pkg/apiserver/watch.go +++ b/pkg/apiserver/watch.go @@ -133,18 +133,18 @@ func (self *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { cn, ok := w.(http.CloseNotifier) if !ok { loggedW.Addf("unable to get CloseNotifier") - http.NotFound(loggedW, req) + http.NotFound(w, req) return } flusher, ok := w.(http.Flusher) if !ok { loggedW.Addf("unable to get Flusher") - http.NotFound(loggedW, req) + http.NotFound(w, req) return } - loggedW.Header().Set("Transfer-Encoding", "chunked") - loggedW.WriteHeader(http.StatusOK) + w.Header().Set("Transfer-Encoding", "chunked") + w.WriteHeader(http.StatusOK) flusher.Flush() encoder := json.NewEncoder(w) diff --git a/pkg/httplog/log.go b/pkg/httplog/log.go index 912625b7eb..c964e3864c 100644 --- a/pkg/httplog/log.go +++ b/pkg/httplog/log.go @@ -40,6 +40,10 @@ func Handler(delegate http.Handler, pred StacktracePred) http.Handler { // StacktracePred returns true if a stacktrace should be logged for this status. type StacktracePred func(httpStatus int) (logStacktrace bool) +type logger interface { + Addf(format string, data ...interface{}) +} + // Add a layer on top of ResponseWriter, so we can track latency and error // message sources. type respLogger struct { @@ -54,6 +58,14 @@ type respLogger struct { logStacktracePred StacktracePred } +// Simple logger that logs immediately when Addf is called +type passthroughLogger struct{} + +// Addf logs info immediately. +func (passthroughLogger) Addf(format string, data ...interface{}) { + glog.Infof(format, data...) +} + // DefaultStacktracePred is the default implementation of StacktracePred. func DefaultStacktracePred(status int) bool { return status < http.StatusOK || status >= http.StatusBadRequest @@ -87,11 +99,12 @@ func NewLogged(req *http.Request, w *http.ResponseWriter) *respLogger { } // LogOf returns the logger hiding in w. If there is not an existing logger -// then one will be created because NewLogged() must have been previously -// called for the log to work. -func LogOf(req *http.Request, w http.ResponseWriter) *respLogger { +// then a passthroughLogger will be created which will log to stdout immediately +// when Addf is called. +func LogOf(req *http.Request, w http.ResponseWriter) logger { if _, exists := w.(*respLogger); !exists { - NewLogged(req, &w) + pl := &passthroughLogger{} + return pl } if rl, ok := w.(*respLogger); ok { return rl diff --git a/pkg/httplog/log_test.go b/pkg/httplog/log_test.go index 0ae8f53cde..aea79ab545 100644 --- a/pkg/httplog/log_test.go +++ b/pkg/httplog/log_test.go @@ -91,12 +91,16 @@ func TestLogOf(t *testing.T) { t.Errorf("Unexpected error: %v", err) } handler := func(w http.ResponseWriter, r *http.Request) { + var want string if makeLogger { NewLogged(req, &w) + want = "*httplog.respLogger" + } else { + want = "*httplog.passthroughLogger" } - got := reflect.TypeOf(*LogOf(r, w)).String() - if got != "httplog.respLogger" { - t.Errorf("Expected %v, got %v", "httplog.respLogger", got) + got := reflect.TypeOf(LogOf(r, w)).String() + if want != got { + t.Errorf("Expected %v, got %v", want, got) } } w := httptest.NewRecorder()