diff --git a/cmd/apiserver/apiserver.go b/cmd/apiserver/apiserver.go index fd66438f8a..1becccec4e 100644 --- a/cmd/apiserver/apiserver.go +++ b/cmd/apiserver/apiserver.go @@ -24,6 +24,7 @@ import ( "net/http" "os" "strconv" + "strings" "time" "github.com/GoogleCloudPlatform/kubernetes/pkg/apiserver" @@ -36,21 +37,25 @@ import ( ) var ( - port = flag.Uint("port", 8080, "The port to listen on. Default 8080.") - address = flag.String("address", "127.0.0.1", "The address on the local server to listen to. Default 127.0.0.1") - apiPrefix = flag.String("api_prefix", "/api/v1beta1", "The prefix for API requests on the server. Default '/api/v1beta1'") - cloudProvider = flag.String("cloud_provider", "", "The provider for cloud services. Empty string for no provider.") - cloudConfigFile = flag.String("cloud_config", "", "The path to the cloud provider configuration file. Empty string for no configuration file.") - minionRegexp = flag.String("minion_regexp", "", "If non empty, and -cloud_provider is specified, a regular expression for matching minion VMs") - minionPort = flag.Uint("minion_port", 10250, "The port at which kubelet will be listening on the minions.") - healthCheckMinions = flag.Bool("health_check_minions", true, "If true, health check minions and filter unhealthy ones. [default true]") - minionCacheTTL = flag.Duration("minion_cache_ttl", 30*time.Second, "Duration of time to cache minion information. [default 30 seconds]") - etcdServerList, machineList util.StringList + port = flag.Uint("port", 8080, "The port to listen on. Default 8080.") + address = flag.String("address", "127.0.0.1", "The address on the local server to listen to. Default 127.0.0.1") + apiPrefix = flag.String("api_prefix", "/api/v1beta1", "The prefix for API requests on the server. Default '/api/v1beta1'") + enableCORS = flag.Bool("enable_cors", false, "If true, the basic CORS implementation will be enabled. [default false]") + cloudProvider = flag.String("cloud_provider", "", "The provider for cloud services. Empty string for no provider.") + cloudConfigFile = flag.String("cloud_config", "", "The path to the cloud provider configuration file. Empty string for no configuration file.") + minionRegexp = flag.String("minion_regexp", "", "If non empty, and -cloud_provider is specified, a regular expression for matching minion VMs") + minionPort = flag.Uint("minion_port", 10250, "The port at which kubelet will be listening on the minions.") + healthCheckMinions = flag.Bool("health_check_minions", true, "If true, health check minions and filter unhealthy ones. [default true]") + minionCacheTTL = flag.Duration("minion_cache_ttl", 30*time.Second, "Duration of time to cache minion information. [default 30 seconds]") + etcdServerList util.StringList + machineList util.StringList + corsAllowedOriginList util.StringList ) func init() { flag.Var(&etcdServerList, "etcd_servers", "List of etcd servers to watch (http://ip:port), comma separated") flag.Var(&machineList, "machines", "List of machines to schedule onto, comma separated.") + flag.Var(&corsAllowedOriginList, "cors_allowed_origins", "List of allowed origins for CORS, comma separated. An allowed origin can be a regular expression to support subdomain matching. If this list is empty CORS will not be enabled.") } func verifyMinionFlags() { @@ -132,9 +137,19 @@ func main() { }) storage, codec := m.API_v1beta1() + + handler := apiserver.Handle(storage, codec, *apiPrefix) + if len(corsAllowedOriginList) > 0 { + allowedOriginRegexps, err := util.CompileRegexps(corsAllowedOriginList) + if err != nil { + glog.Fatalf("Invalid CORS allowed origin, --cors_allowed_origins flag was set to %v - %v", strings.Join(corsAllowedOriginList, ","), err) + } + handler = apiserver.CORS(handler, allowedOriginRegexps, nil, nil, "true") + } + s := &http.Server{ Addr: net.JoinHostPort(*address, strconv.Itoa(int(*port))), - Handler: apiserver.Handle(storage, codec, *apiPrefix), + Handler: apiserver.RecoverPanics(handler), ReadTimeout: 5 * time.Minute, WriteTimeout: 5 * time.Minute, MaxHeaderBytes: 1 << 20, diff --git a/hack/local-up-cluster.sh b/hack/local-up-cluster.sh index 65a794b1b5..875da775b6 100755 --- a/hack/local-up-cluster.sh +++ b/hack/local-up-cluster.sh @@ -39,6 +39,8 @@ set +e API_PORT=${API_PORT:-8080} API_HOST=${API_HOST:-127.0.0.1} +# By default only allow CORS for requests on localhost +API_CORS_ALLOWED_ORIGINS=${API_CORS_ALLOWED_ORIGINS:-127.0.0.1:.*,localhost:.*} KUBELET_PORT=${KUBELET_PORT:-10250} GO_OUT=$(dirname $0)/../_output/go/bin @@ -48,7 +50,8 @@ APISERVER_LOG=/tmp/apiserver.log --address="${API_HOST}" \ --port="${API_PORT}" \ --etcd_servers="http://127.0.0.1:4001" \ - --machines="127.0.0.1" >"${APISERVER_LOG}" 2>&1 & + --machines="127.0.0.1" \ + --cors_allowed_origins="${API_CORS_ALLOWED_ORIGINS}" >"${APISERVER_LOG}" 2>&1 & APISERVER_PID=$! CTLRMGR_LOG=/tmp/controller-manager.log diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index ebc722b861..34845c2abc 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -18,15 +18,12 @@ package apiserver import ( "encoding/json" - "fmt" "io/ioutil" "net/http" - "runtime/debug" "strings" "time" "github.com/GoogleCloudPlatform/kubernetes/pkg/healthz" - "github.com/GoogleCloudPlatform/kubernetes/pkg/httplog" "github.com/GoogleCloudPlatform/kubernetes/pkg/runtime" "github.com/GoogleCloudPlatform/kubernetes/pkg/version" "github.com/golang/glog" @@ -57,8 +54,7 @@ func Handle(storage map[string]RESTStorage, codec runtime.Codec, prefix string) mux := http.NewServeMux() group.InstallREST(mux, prefix) InstallSupport(mux) - - return &defaultAPIServer{RecoverPanics(mux), group} + return &defaultAPIServer{mux, group} } // APIGroup is a http.Handler that exposes multiple RESTStorage objects @@ -116,33 +112,6 @@ func InstallSupport(mux mux) { mux.HandleFunc("/", handleIndex) } -// RecoverPanics wraps an http Handler to recover and log panics. -func RecoverPanics(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - defer func() { - if x := recover(); x != nil { - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprint(w, "apis panic. Look in log for details.") - glog.Infof("APIServer panic'd on %v %v: %#v\n%s\n", req.Method, req.RequestURI, x, debug.Stack()) - } - }() - defer httplog.NewLogged(req, &w).StacktraceWhen( - httplog.StatusIsNot( - http.StatusOK, - http.StatusAccepted, - http.StatusMovedPermanently, - http.StatusTemporaryRedirect, - http.StatusConflict, - http.StatusNotFound, - StatusUnprocessableEntity, - ), - ).Log() - - // Dispatch to the internal handler - handler.ServeHTTP(w, req) - }) -} - // handleVersion writes the server's version information. func handleVersion(w http.ResponseWriter, req *http.Request) { writeRawJSON(http.StatusOK, version.Get(), w) diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index f08dc999f9..59dceff11e 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -34,6 +34,7 @@ import ( apierrs "github.com/GoogleCloudPlatform/kubernetes/pkg/api/errors" "github.com/GoogleCloudPlatform/kubernetes/pkg/labels" "github.com/GoogleCloudPlatform/kubernetes/pkg/runtime" + "github.com/GoogleCloudPlatform/kubernetes/pkg/util" "github.com/GoogleCloudPlatform/kubernetes/pkg/version" "github.com/GoogleCloudPlatform/kubernetes/pkg/watch" ) @@ -731,3 +732,73 @@ func TestSyncCreateTimeout(t *testing.T) { t.Errorf("Unexpected status %#v", itemOut) } } + +func TestCORSAllowedOrigins(t *testing.T) { + table := []struct { + allowedOrigins util.StringList + origin string + allowed bool + }{ + {[]string{}, "example.com", false}, + {[]string{"example.com"}, "example.com", true}, + {[]string{"example.com"}, "not-allowed.com", false}, + {[]string{"not-matching.com", "example.com"}, "example.com", true}, + {[]string{".*"}, "example.com", true}, + } + + for _, item := range table { + allowedOriginRegexps, err := util.CompileRegexps(item.allowedOrigins) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), allowedOriginRegexps, nil, nil, "true") + server := httptest.NewServer(handler) + client := http.Client{} + + request, err := http.NewRequest("GET", server.URL+"/version", nil) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + request.Header.Set("Origin", item.origin) + + response, err := client.Do(request) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if item.allowed { + if !reflect.DeepEqual(item.origin, response.Header.Get("Access-Control-Allow-Origin")) { + t.Errorf("Expected %#v, Got %#v", item.origin, response.Header.Get("Access-Control-Allow-Origin")) + } + + if response.Header.Get("Access-Control-Allow-Credentials") == "" { + t.Errorf("Expected Access-Control-Allow-Credentials header to be set") + } + + if response.Header.Get("Access-Control-Allow-Headers") == "" { + t.Errorf("Expected Access-Control-Allow-Headers header to be set") + } + + if response.Header.Get("Access-Control-Allow-Methods") == "" { + t.Errorf("Expected Access-Control-Allow-Methods header to be set") + } + } else { + if response.Header.Get("Access-Control-Allow-Origin") != "" { + t.Errorf("Expected Access-Control-Allow-Origin header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Credentials") != "" { + t.Errorf("Expected Access-Control-Allow-Credentials header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Headers") != "" { + t.Errorf("Expected Access-Control-Allow-Headers header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Methods") != "" { + t.Errorf("Expected Access-Control-Allow-Methods header to not be set") + } + } + } +} diff --git a/pkg/apiserver/handlers.go b/pkg/apiserver/handlers.go new file mode 100644 index 0000000000..a2f70a34f2 --- /dev/null +++ b/pkg/apiserver/handlers.go @@ -0,0 +1,94 @@ +/* +Copyright 2014 Google Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package apiserver + +import ( + "fmt" + "net/http" + "regexp" + "runtime/debug" + "strings" + + "github.com/GoogleCloudPlatform/kubernetes/pkg/httplog" + "github.com/golang/glog" +) + +// RecoverPanics wraps an http Handler to recover and log panics. +func RecoverPanics(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + defer func() { + if x := recover(); x != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, "apis panic. Look in log for details.") + glog.Infof("APIServer panic'd on %v %v: %#v\n%s\n", req.Method, req.RequestURI, x, debug.Stack()) + } + }() + defer httplog.NewLogged(req, &w).StacktraceWhen( + httplog.StatusIsNot( + http.StatusOK, + http.StatusAccepted, + http.StatusMovedPermanently, + http.StatusTemporaryRedirect, + http.StatusConflict, + http.StatusNotFound, + StatusUnprocessableEntity, + ), + ).Log() + + // Dispatch to the internal handler + handler.ServeHTTP(w, req) + }) +} + +// Simple CORS implementation that wraps an http Handler +// For a more detailed implementation use https://github.com/martini-contrib/cors +// or implement CORS at your proxy layer +// Pass nil for allowedMethods and allowedHeaders to use the defaults +func CORS(handler http.Handler, allowedOriginPatterns []*regexp.Regexp, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + origin := req.Header.Get("Origin") + if origin != "" { + allowed := false + for _, pattern := range allowedOriginPatterns { + if allowed = pattern.MatchString(origin); allowed { + break + } + } + if allowed { + w.Header().Set("Access-Control-Allow-Origin", origin) + // Set defaults for methods and headers if nothing was passed + if allowedMethods == nil { + allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE"} + } + if allowedHeaders == nil { + allowedHeaders = []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "X-Requested-With", "If-Modified-Since"} + } + w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", ")) + w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", ")) + w.Header().Set("Access-Control-Allow-Credentials", allowCredentials) + + // Stop here if its a preflight OPTIONS request + if req.Method == "OPTIONS" { + w.WriteHeader(http.StatusNoContent) + return + } + } + } + // Dispatch to the next handler + handler.ServeHTTP(w, req) + }) +} diff --git a/pkg/apiserver/proxy.go b/pkg/apiserver/proxy.go index 98330530ec..caeeb0eaec 100644 --- a/pkg/apiserver/proxy.go +++ b/pkg/apiserver/proxy.go @@ -89,14 +89,14 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } storage, ok := r.storage[resourceName] if !ok { - httplog.LogOf(w).Addf("'%v' has no storage object", resourceName) + httplog.LogOf(req, w).Addf("'%v' has no storage object", resourceName) notFound(w, req) return } redirector, ok := storage.(Redirector) if !ok { - httplog.LogOf(w).Addf("'%v' is not a redirector", resourceName) + httplog.LogOf(req, w).Addf("'%v' is not a redirector", resourceName) notFound(w, req) return } diff --git a/pkg/apiserver/redirect.go b/pkg/apiserver/redirect.go index 57aff02d3d..d517586540 100644 --- a/pkg/apiserver/redirect.go +++ b/pkg/apiserver/redirect.go @@ -38,14 +38,14 @@ func (r *RedirectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { id := parts[1] storage, ok := r.storage[resourceName] if !ok { - httplog.LogOf(w).Addf("'%v' has no storage object", resourceName) + httplog.LogOf(req, w).Addf("'%v' has no storage object", resourceName) notFound(w, req) return } redirector, ok := storage.(Redirector) if !ok { - httplog.LogOf(w).Addf("'%v' is not a redirector", resourceName) + httplog.LogOf(req, w).Addf("'%v' is not a redirector", resourceName) notFound(w, req) return } diff --git a/pkg/apiserver/resthandler.go b/pkg/apiserver/resthandler.go index 09251b5bc4..4b6977b9e8 100644 --- a/pkg/apiserver/resthandler.go +++ b/pkg/apiserver/resthandler.go @@ -42,7 +42,7 @@ func (h *RESTHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } storage := h.storage[parts[0]] if storage == nil { - httplog.LogOf(w).Addf("'%v' has no storage object", parts[0]) + httplog.LogOf(req, w).Addf("'%v' has no storage object", parts[0]) notFound(w, req) return } @@ -114,7 +114,7 @@ func (h *RESTHandler) handleRESTStorage(parts []string, req *http.Request, w htt return } op := h.createOperation(out, sync, timeout) - h.finishReq(op, w) + h.finishReq(op, req, w) case "DELETE": if len(parts) != 2 { @@ -127,7 +127,7 @@ func (h *RESTHandler) handleRESTStorage(parts []string, req *http.Request, w htt return } op := h.createOperation(out, sync, timeout) - h.finishReq(op, w) + h.finishReq(op, req, w) case "PUT": if len(parts) != 2 { @@ -151,7 +151,7 @@ func (h *RESTHandler) handleRESTStorage(parts []string, req *http.Request, w htt return } op := h.createOperation(out, sync, timeout) - h.finishReq(op, w) + h.finishReq(op, req, w) default: notFound(w, req) @@ -171,7 +171,7 @@ func (h *RESTHandler) createOperation(out <-chan runtime.Object, sync bool, time // finishReq finishes up a request, waiting until the operation finishes or, after a timeout, creating an // Operation to receive the result and returning its ID down the writer. -func (h *RESTHandler) finishReq(op *Operation, w http.ResponseWriter) { +func (h *RESTHandler) finishReq(op *Operation, req *http.Request, w http.ResponseWriter) { obj, complete := op.StatusOrResult() if complete { status := http.StatusOK diff --git a/pkg/apiserver/watch.go b/pkg/apiserver/watch.go index 7cd047093b..38b3ace500 100644 --- a/pkg/apiserver/watch.go +++ b/pkg/apiserver/watch.go @@ -127,24 +127,24 @@ func (w *WatchServer) HandleWS(ws *websocket.Conn) { // ServeHTTP serves a series of JSON encoded events via straight HTTP with // Transfer-Encoding: chunked. func (self *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { - loggedW := httplog.LogOf(w) + loggedW := httplog.LogOf(req, w) w = httplog.Unlogged(w) cn, ok := w.(http.CloseNotifier) if !ok { loggedW.Addf("unable to get CloseNotifier") - http.NotFound(loggedW, req) + http.NotFound(w, req) return } flusher, ok := w.(http.Flusher) if !ok { loggedW.Addf("unable to get Flusher") - http.NotFound(loggedW, req) + http.NotFound(w, req) return } - loggedW.Header().Set("Transfer-Encoding", "chunked") - loggedW.WriteHeader(http.StatusOK) + w.Header().Set("Transfer-Encoding", "chunked") + w.WriteHeader(http.StatusOK) flusher.Flush() encoder := json.NewEncoder(w) diff --git a/pkg/httplog/log.go b/pkg/httplog/log.go index 1f95b504b3..c964e3864c 100644 --- a/pkg/httplog/log.go +++ b/pkg/httplog/log.go @@ -40,6 +40,10 @@ func Handler(delegate http.Handler, pred StacktracePred) http.Handler { // StacktracePred returns true if a stacktrace should be logged for this status. type StacktracePred func(httpStatus int) (logStacktrace bool) +type logger interface { + Addf(format string, data ...interface{}) +} + // Add a layer on top of ResponseWriter, so we can track latency and error // message sources. type respLogger struct { @@ -54,6 +58,14 @@ type respLogger struct { logStacktracePred StacktracePred } +// Simple logger that logs immediately when Addf is called +type passthroughLogger struct{} + +// Addf logs info immediately. +func (passthroughLogger) Addf(format string, data ...interface{}) { + glog.Infof(format, data...) +} + // DefaultStacktracePred is the default implementation of StacktracePred. func DefaultStacktracePred(status int) bool { return status < http.StatusOK || status >= http.StatusBadRequest @@ -86,13 +98,18 @@ func NewLogged(req *http.Request, w *http.ResponseWriter) *respLogger { return rl } -// LogOf returns the logger hiding in w. Panics if there isn't such a logger, -// because NewLogged() must have been previously called for the log to work. -func LogOf(w http.ResponseWriter) *respLogger { +// LogOf returns the logger hiding in w. If there is not an existing logger +// then a passthroughLogger will be created which will log to stdout immediately +// when Addf is called. +func LogOf(req *http.Request, w http.ResponseWriter) logger { + if _, exists := w.(*respLogger); !exists { + pl := &passthroughLogger{} + return pl + } if rl, ok := w.(*respLogger); ok { return rl } - panic("Logger not installed yet!") + panic("Unable to find or create the logger!") } // Unlogged returns the original ResponseWriter, or w if it is not our inserted logger. diff --git a/pkg/httplog/log_test.go b/pkg/httplog/log_test.go index e287c8b11d..aea79ab545 100644 --- a/pkg/httplog/log_test.go +++ b/pkg/httplog/log_test.go @@ -91,17 +91,14 @@ func TestLogOf(t *testing.T) { t.Errorf("Unexpected error: %v", err) } handler := func(w http.ResponseWriter, r *http.Request) { - var want *respLogger + var want string if makeLogger { - want = NewLogged(req, &w) + NewLogged(req, &w) + want = "*httplog.respLogger" } else { - defer func() { - if r := recover(); r == nil { - t.Errorf("Expected LogOf to panic") - } - }() + want = "*httplog.passthroughLogger" } - got := LogOf(w) + got := reflect.TypeOf(LogOf(r, w)).String() if want != got { t.Errorf("Expected %v, got %v", want, got) } diff --git a/pkg/util/util.go b/pkg/util/util.go index 5e9cf463b5..69aaccfc3a 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -19,6 +19,7 @@ package util import ( "encoding/json" "fmt" + "regexp" "runtime" "time" @@ -162,3 +163,16 @@ func StringDiff(a, b string) string { out = append(out, []byte("\n\n")...) return string(out) } + +// Takes a list of strings and compiles them into a list of regular expressions +func CompileRegexps(regexpStrings StringList) ([]*regexp.Regexp, error) { + regexps := []*regexp.Regexp{} + for _, regexpStr := range regexpStrings { + r, err := regexp.Compile(regexpStr) + if err != nil { + return []*regexp.Regexp{}, err + } + regexps = append(regexps, r) + } + return regexps, nil +}