Merge pull request #1133 from jwforres/enable_cors

Add option to enable a simple CORS implementation for the api server
pull/6/head
Daniel Smith 2014-09-11 13:59:14 -07:00
commit 6757b402d5
12 changed files with 250 additions and 70 deletions

View File

@ -24,6 +24,7 @@ import (
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/GoogleCloudPlatform/kubernetes/pkg/apiserver"
@ -36,21 +37,25 @@ import (
)
var (
port = flag.Uint("port", 8080, "The port to listen on. Default 8080.")
address = flag.String("address", "127.0.0.1", "The address on the local server to listen to. Default 127.0.0.1")
apiPrefix = flag.String("api_prefix", "/api/v1beta1", "The prefix for API requests on the server. Default '/api/v1beta1'")
cloudProvider = flag.String("cloud_provider", "", "The provider for cloud services. Empty string for no provider.")
cloudConfigFile = flag.String("cloud_config", "", "The path to the cloud provider configuration file. Empty string for no configuration file.")
minionRegexp = flag.String("minion_regexp", "", "If non empty, and -cloud_provider is specified, a regular expression for matching minion VMs")
minionPort = flag.Uint("minion_port", 10250, "The port at which kubelet will be listening on the minions.")
healthCheckMinions = flag.Bool("health_check_minions", true, "If true, health check minions and filter unhealthy ones. [default true]")
minionCacheTTL = flag.Duration("minion_cache_ttl", 30*time.Second, "Duration of time to cache minion information. [default 30 seconds]")
etcdServerList, machineList util.StringList
port = flag.Uint("port", 8080, "The port to listen on. Default 8080.")
address = flag.String("address", "127.0.0.1", "The address on the local server to listen to. Default 127.0.0.1")
apiPrefix = flag.String("api_prefix", "/api/v1beta1", "The prefix for API requests on the server. Default '/api/v1beta1'")
enableCORS = flag.Bool("enable_cors", false, "If true, the basic CORS implementation will be enabled. [default false]")
cloudProvider = flag.String("cloud_provider", "", "The provider for cloud services. Empty string for no provider.")
cloudConfigFile = flag.String("cloud_config", "", "The path to the cloud provider configuration file. Empty string for no configuration file.")
minionRegexp = flag.String("minion_regexp", "", "If non empty, and -cloud_provider is specified, a regular expression for matching minion VMs")
minionPort = flag.Uint("minion_port", 10250, "The port at which kubelet will be listening on the minions.")
healthCheckMinions = flag.Bool("health_check_minions", true, "If true, health check minions and filter unhealthy ones. [default true]")
minionCacheTTL = flag.Duration("minion_cache_ttl", 30*time.Second, "Duration of time to cache minion information. [default 30 seconds]")
etcdServerList util.StringList
machineList util.StringList
corsAllowedOriginList util.StringList
)
func init() {
flag.Var(&etcdServerList, "etcd_servers", "List of etcd servers to watch (http://ip:port), comma separated")
flag.Var(&machineList, "machines", "List of machines to schedule onto, comma separated.")
flag.Var(&corsAllowedOriginList, "cors_allowed_origins", "List of allowed origins for CORS, comma separated. An allowed origin can be a regular expression to support subdomain matching. If this list is empty CORS will not be enabled.")
}
func verifyMinionFlags() {
@ -132,9 +137,19 @@ func main() {
})
storage, codec := m.API_v1beta1()
handler := apiserver.Handle(storage, codec, *apiPrefix)
if len(corsAllowedOriginList) > 0 {
allowedOriginRegexps, err := util.CompileRegexps(corsAllowedOriginList)
if err != nil {
glog.Fatalf("Invalid CORS allowed origin, --cors_allowed_origins flag was set to %v - %v", strings.Join(corsAllowedOriginList, ","), err)
}
handler = apiserver.CORS(handler, allowedOriginRegexps, nil, nil, "true")
}
s := &http.Server{
Addr: net.JoinHostPort(*address, strconv.Itoa(int(*port))),
Handler: apiserver.Handle(storage, codec, *apiPrefix),
Handler: apiserver.RecoverPanics(handler),
ReadTimeout: 5 * time.Minute,
WriteTimeout: 5 * time.Minute,
MaxHeaderBytes: 1 << 20,

View File

@ -39,6 +39,8 @@ set +e
API_PORT=${API_PORT:-8080}
API_HOST=${API_HOST:-127.0.0.1}
# By default only allow CORS for requests on localhost
API_CORS_ALLOWED_ORIGINS=${API_CORS_ALLOWED_ORIGINS:-127.0.0.1:.*,localhost:.*}
KUBELET_PORT=${KUBELET_PORT:-10250}
GO_OUT=$(dirname $0)/../_output/go/bin
@ -48,7 +50,8 @@ APISERVER_LOG=/tmp/apiserver.log
--address="${API_HOST}" \
--port="${API_PORT}" \
--etcd_servers="http://127.0.0.1:4001" \
--machines="127.0.0.1" >"${APISERVER_LOG}" 2>&1 &
--machines="127.0.0.1" \
--cors_allowed_origins="${API_CORS_ALLOWED_ORIGINS}" >"${APISERVER_LOG}" 2>&1 &
APISERVER_PID=$!
CTLRMGR_LOG=/tmp/controller-manager.log

View File

@ -18,15 +18,12 @@ package apiserver
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"runtime/debug"
"strings"
"time"
"github.com/GoogleCloudPlatform/kubernetes/pkg/healthz"
"github.com/GoogleCloudPlatform/kubernetes/pkg/httplog"
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
"github.com/GoogleCloudPlatform/kubernetes/pkg/version"
"github.com/golang/glog"
@ -57,8 +54,7 @@ func Handle(storage map[string]RESTStorage, codec runtime.Codec, prefix string)
mux := http.NewServeMux()
group.InstallREST(mux, prefix)
InstallSupport(mux)
return &defaultAPIServer{RecoverPanics(mux), group}
return &defaultAPIServer{mux, group}
}
// APIGroup is a http.Handler that exposes multiple RESTStorage objects
@ -116,33 +112,6 @@ func InstallSupport(mux mux) {
mux.HandleFunc("/", handleIndex)
}
// RecoverPanics wraps an http Handler to recover and log panics.
func RecoverPanics(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
defer func() {
if x := recover(); x != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, "apis panic. Look in log for details.")
glog.Infof("APIServer panic'd on %v %v: %#v\n%s\n", req.Method, req.RequestURI, x, debug.Stack())
}
}()
defer httplog.NewLogged(req, &w).StacktraceWhen(
httplog.StatusIsNot(
http.StatusOK,
http.StatusAccepted,
http.StatusMovedPermanently,
http.StatusTemporaryRedirect,
http.StatusConflict,
http.StatusNotFound,
StatusUnprocessableEntity,
),
).Log()
// Dispatch to the internal handler
handler.ServeHTTP(w, req)
})
}
// handleVersion writes the server's version information.
func handleVersion(w http.ResponseWriter, req *http.Request) {
writeRawJSON(http.StatusOK, version.Get(), w)

View File

@ -34,6 +34,7 @@ import (
apierrs "github.com/GoogleCloudPlatform/kubernetes/pkg/api/errors"
"github.com/GoogleCloudPlatform/kubernetes/pkg/labels"
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
"github.com/GoogleCloudPlatform/kubernetes/pkg/version"
"github.com/GoogleCloudPlatform/kubernetes/pkg/watch"
)
@ -731,3 +732,73 @@ func TestSyncCreateTimeout(t *testing.T) {
t.Errorf("Unexpected status %#v", itemOut)
}
}
func TestCORSAllowedOrigins(t *testing.T) {
table := []struct {
allowedOrigins util.StringList
origin string
allowed bool
}{
{[]string{}, "example.com", false},
{[]string{"example.com"}, "example.com", true},
{[]string{"example.com"}, "not-allowed.com", false},
{[]string{"not-matching.com", "example.com"}, "example.com", true},
{[]string{".*"}, "example.com", true},
}
for _, item := range table {
allowedOriginRegexps, err := util.CompileRegexps(item.allowedOrigins)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), allowedOriginRegexps, nil, nil, "true")
server := httptest.NewServer(handler)
client := http.Client{}
request, err := http.NewRequest("GET", server.URL+"/version", nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
request.Header.Set("Origin", item.origin)
response, err := client.Do(request)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if item.allowed {
if !reflect.DeepEqual(item.origin, response.Header.Get("Access-Control-Allow-Origin")) {
t.Errorf("Expected %#v, Got %#v", item.origin, response.Header.Get("Access-Control-Allow-Origin"))
}
if response.Header.Get("Access-Control-Allow-Credentials") == "" {
t.Errorf("Expected Access-Control-Allow-Credentials header to be set")
}
if response.Header.Get("Access-Control-Allow-Headers") == "" {
t.Errorf("Expected Access-Control-Allow-Headers header to be set")
}
if response.Header.Get("Access-Control-Allow-Methods") == "" {
t.Errorf("Expected Access-Control-Allow-Methods header to be set")
}
} else {
if response.Header.Get("Access-Control-Allow-Origin") != "" {
t.Errorf("Expected Access-Control-Allow-Origin header to not be set")
}
if response.Header.Get("Access-Control-Allow-Credentials") != "" {
t.Errorf("Expected Access-Control-Allow-Credentials header to not be set")
}
if response.Header.Get("Access-Control-Allow-Headers") != "" {
t.Errorf("Expected Access-Control-Allow-Headers header to not be set")
}
if response.Header.Get("Access-Control-Allow-Methods") != "" {
t.Errorf("Expected Access-Control-Allow-Methods header to not be set")
}
}
}
}

94
pkg/apiserver/handlers.go Normal file
View File

@ -0,0 +1,94 @@
/*
Copyright 2014 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package apiserver
import (
"fmt"
"net/http"
"regexp"
"runtime/debug"
"strings"
"github.com/GoogleCloudPlatform/kubernetes/pkg/httplog"
"github.com/golang/glog"
)
// RecoverPanics wraps an http Handler to recover and log panics.
func RecoverPanics(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
defer func() {
if x := recover(); x != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, "apis panic. Look in log for details.")
glog.Infof("APIServer panic'd on %v %v: %#v\n%s\n", req.Method, req.RequestURI, x, debug.Stack())
}
}()
defer httplog.NewLogged(req, &w).StacktraceWhen(
httplog.StatusIsNot(
http.StatusOK,
http.StatusAccepted,
http.StatusMovedPermanently,
http.StatusTemporaryRedirect,
http.StatusConflict,
http.StatusNotFound,
StatusUnprocessableEntity,
),
).Log()
// Dispatch to the internal handler
handler.ServeHTTP(w, req)
})
}
// Simple CORS implementation that wraps an http Handler
// For a more detailed implementation use https://github.com/martini-contrib/cors
// or implement CORS at your proxy layer
// Pass nil for allowedMethods and allowedHeaders to use the defaults
func CORS(handler http.Handler, allowedOriginPatterns []*regexp.Regexp, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
origin := req.Header.Get("Origin")
if origin != "" {
allowed := false
for _, pattern := range allowedOriginPatterns {
if allowed = pattern.MatchString(origin); allowed {
break
}
}
if allowed {
w.Header().Set("Access-Control-Allow-Origin", origin)
// Set defaults for methods and headers if nothing was passed
if allowedMethods == nil {
allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE"}
}
if allowedHeaders == nil {
allowedHeaders = []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "X-Requested-With", "If-Modified-Since"}
}
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", "))
w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", "))
w.Header().Set("Access-Control-Allow-Credentials", allowCredentials)
// Stop here if its a preflight OPTIONS request
if req.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
}
}
// Dispatch to the next handler
handler.ServeHTTP(w, req)
})
}

View File

@ -89,14 +89,14 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
storage, ok := r.storage[resourceName]
if !ok {
httplog.LogOf(w).Addf("'%v' has no storage object", resourceName)
httplog.LogOf(req, w).Addf("'%v' has no storage object", resourceName)
notFound(w, req)
return
}
redirector, ok := storage.(Redirector)
if !ok {
httplog.LogOf(w).Addf("'%v' is not a redirector", resourceName)
httplog.LogOf(req, w).Addf("'%v' is not a redirector", resourceName)
notFound(w, req)
return
}

View File

@ -38,14 +38,14 @@ func (r *RedirectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
id := parts[1]
storage, ok := r.storage[resourceName]
if !ok {
httplog.LogOf(w).Addf("'%v' has no storage object", resourceName)
httplog.LogOf(req, w).Addf("'%v' has no storage object", resourceName)
notFound(w, req)
return
}
redirector, ok := storage.(Redirector)
if !ok {
httplog.LogOf(w).Addf("'%v' is not a redirector", resourceName)
httplog.LogOf(req, w).Addf("'%v' is not a redirector", resourceName)
notFound(w, req)
return
}

View File

@ -42,7 +42,7 @@ func (h *RESTHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
storage := h.storage[parts[0]]
if storage == nil {
httplog.LogOf(w).Addf("'%v' has no storage object", parts[0])
httplog.LogOf(req, w).Addf("'%v' has no storage object", parts[0])
notFound(w, req)
return
}
@ -114,7 +114,7 @@ func (h *RESTHandler) handleRESTStorage(parts []string, req *http.Request, w htt
return
}
op := h.createOperation(out, sync, timeout)
h.finishReq(op, w)
h.finishReq(op, req, w)
case "DELETE":
if len(parts) != 2 {
@ -127,7 +127,7 @@ func (h *RESTHandler) handleRESTStorage(parts []string, req *http.Request, w htt
return
}
op := h.createOperation(out, sync, timeout)
h.finishReq(op, w)
h.finishReq(op, req, w)
case "PUT":
if len(parts) != 2 {
@ -151,7 +151,7 @@ func (h *RESTHandler) handleRESTStorage(parts []string, req *http.Request, w htt
return
}
op := h.createOperation(out, sync, timeout)
h.finishReq(op, w)
h.finishReq(op, req, w)
default:
notFound(w, req)
@ -171,7 +171,7 @@ func (h *RESTHandler) createOperation(out <-chan runtime.Object, sync bool, time
// finishReq finishes up a request, waiting until the operation finishes or, after a timeout, creating an
// Operation to receive the result and returning its ID down the writer.
func (h *RESTHandler) finishReq(op *Operation, w http.ResponseWriter) {
func (h *RESTHandler) finishReq(op *Operation, req *http.Request, w http.ResponseWriter) {
obj, complete := op.StatusOrResult()
if complete {
status := http.StatusOK

View File

@ -127,24 +127,24 @@ func (w *WatchServer) HandleWS(ws *websocket.Conn) {
// ServeHTTP serves a series of JSON encoded events via straight HTTP with
// Transfer-Encoding: chunked.
func (self *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
loggedW := httplog.LogOf(w)
loggedW := httplog.LogOf(req, w)
w = httplog.Unlogged(w)
cn, ok := w.(http.CloseNotifier)
if !ok {
loggedW.Addf("unable to get CloseNotifier")
http.NotFound(loggedW, req)
http.NotFound(w, req)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
loggedW.Addf("unable to get Flusher")
http.NotFound(loggedW, req)
http.NotFound(w, req)
return
}
loggedW.Header().Set("Transfer-Encoding", "chunked")
loggedW.WriteHeader(http.StatusOK)
w.Header().Set("Transfer-Encoding", "chunked")
w.WriteHeader(http.StatusOK)
flusher.Flush()
encoder := json.NewEncoder(w)

View File

@ -40,6 +40,10 @@ func Handler(delegate http.Handler, pred StacktracePred) http.Handler {
// StacktracePred returns true if a stacktrace should be logged for this status.
type StacktracePred func(httpStatus int) (logStacktrace bool)
type logger interface {
Addf(format string, data ...interface{})
}
// Add a layer on top of ResponseWriter, so we can track latency and error
// message sources.
type respLogger struct {
@ -54,6 +58,14 @@ type respLogger struct {
logStacktracePred StacktracePred
}
// Simple logger that logs immediately when Addf is called
type passthroughLogger struct{}
// Addf logs info immediately.
func (passthroughLogger) Addf(format string, data ...interface{}) {
glog.Infof(format, data...)
}
// DefaultStacktracePred is the default implementation of StacktracePred.
func DefaultStacktracePred(status int) bool {
return status < http.StatusOK || status >= http.StatusBadRequest
@ -86,13 +98,18 @@ func NewLogged(req *http.Request, w *http.ResponseWriter) *respLogger {
return rl
}
// LogOf returns the logger hiding in w. Panics if there isn't such a logger,
// because NewLogged() must have been previously called for the log to work.
func LogOf(w http.ResponseWriter) *respLogger {
// LogOf returns the logger hiding in w. If there is not an existing logger
// then a passthroughLogger will be created which will log to stdout immediately
// when Addf is called.
func LogOf(req *http.Request, w http.ResponseWriter) logger {
if _, exists := w.(*respLogger); !exists {
pl := &passthroughLogger{}
return pl
}
if rl, ok := w.(*respLogger); ok {
return rl
}
panic("Logger not installed yet!")
panic("Unable to find or create the logger!")
}
// Unlogged returns the original ResponseWriter, or w if it is not our inserted logger.

View File

@ -91,17 +91,14 @@ func TestLogOf(t *testing.T) {
t.Errorf("Unexpected error: %v", err)
}
handler := func(w http.ResponseWriter, r *http.Request) {
var want *respLogger
var want string
if makeLogger {
want = NewLogged(req, &w)
NewLogged(req, &w)
want = "*httplog.respLogger"
} else {
defer func() {
if r := recover(); r == nil {
t.Errorf("Expected LogOf to panic")
}
}()
want = "*httplog.passthroughLogger"
}
got := LogOf(w)
got := reflect.TypeOf(LogOf(r, w)).String()
if want != got {
t.Errorf("Expected %v, got %v", want, got)
}

View File

@ -19,6 +19,7 @@ package util
import (
"encoding/json"
"fmt"
"regexp"
"runtime"
"time"
@ -162,3 +163,16 @@ func StringDiff(a, b string) string {
out = append(out, []byte("\n\n")...)
return string(out)
}
// Takes a list of strings and compiles them into a list of regular expressions
func CompileRegexps(regexpStrings StringList) ([]*regexp.Regexp, error) {
regexps := []*regexp.Regexp{}
for _, regexpStr := range regexpStrings {
r, err := regexp.Compile(regexpStr)
if err != nil {
return []*regexp.Regexp{}, err
}
regexps = append(regexps, r)
}
return regexps, nil
}