Add option to enable a simple CORS implementation for the api server

pull/6/head
Jessica Forrester 2014-09-02 12:28:24 -04:00
parent a4e3a4e351
commit 8723eece49
11 changed files with 161 additions and 59 deletions

View File

@ -39,6 +39,7 @@ var (
port = flag.Uint("port", 8080, "The port to listen on. Default 8080.")
address = flag.String("address", "127.0.0.1", "The address on the local server to listen to. Default 127.0.0.1")
apiPrefix = flag.String("api_prefix", "/api/v1beta1", "The prefix for API requests on the server. Default '/api/v1beta1'")
enableCORS = flag.Bool("enable_cors", false, "If true, the basic CORS implementation will be enabled. [default false]")
cloudProvider = flag.String("cloud_provider", "", "The provider for cloud services. Empty string for no provider.")
cloudConfigFile = flag.String("cloud_config", "", "The path to the cloud provider configuration file. Empty string for no configuration file.")
minionRegexp = flag.String("minion_regexp", "", "If non empty, and -cloud_provider is specified, a regular expression for matching minion VMs")
@ -134,7 +135,7 @@ func main() {
storage, codec := m.API_v1beta1()
s := &http.Server{
Addr: net.JoinHostPort(*address, strconv.Itoa(int(*port))),
Handler: apiserver.Handle(storage, codec, *apiPrefix),
Handler: apiserver.Handle(storage, codec, *apiPrefix, *enableCORS),
ReadTimeout: 5 * time.Minute,
WriteTimeout: 5 * time.Minute,
MaxHeaderBytes: 1 << 20,

View File

@ -116,7 +116,7 @@ func startComponents(manifestURL string) (apiServerURL string) {
PodInfoGetter: fakePodInfoGetter{},
})
storage, codec := m.API_v1beta1()
handler.delegate = apiserver.Handle(storage, codec, "/api/v1beta1")
handler.delegate = apiserver.Handle(storage, codec, "/api/v1beta1", false)
// Scheduler
scheduler.New((&factory.ConfigFactory{cl}).Create()).Run()

View File

@ -39,6 +39,7 @@ set +e
API_PORT=${API_PORT:-8080}
API_HOST=${API_HOST:-127.0.0.1}
API_ENABLE_CORS=${API_ENABLE_CORS:-false}
KUBELET_PORT=${KUBELET_PORT:-10250}
GO_OUT=$(dirname $0)/../_output/go/bin
@ -48,7 +49,8 @@ APISERVER_LOG=/tmp/apiserver.log
--address="${API_HOST}" \
--port="${API_PORT}" \
--etcd_servers="http://127.0.0.1:4001" \
--machines="127.0.0.1" >"${APISERVER_LOG}" 2>&1 &
--machines="127.0.0.1" \
--enable_cors="${API_ENABLE_CORS}" >"${APISERVER_LOG}" 2>&1 &
APISERVER_PID=$!
CTLRMGR_LOG=/tmp/controller-manager.log

View File

@ -18,10 +18,8 @@ package apiserver
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"runtime/debug"
"strings"
"time"
@ -57,8 +55,11 @@ func Handle(storage map[string]RESTStorage, codec runtime.Codec, prefix string)
mux := http.NewServeMux()
group.InstallREST(mux, prefix)
InstallSupport(mux)
return &defaultAPIServer{RecoverPanics(mux), group}
handler := RecoverPanics(mux)
if enableCORS {
handler = CORS(handler, []string{".*"}, nil, nil, "true")
}
return &defaultAPIServer{handler, group}
}
// APIGroup is a http.Handler that exposes multiple RESTStorage objects
@ -116,33 +117,6 @@ func InstallSupport(mux mux) {
mux.HandleFunc("/", handleIndex)
}
// RecoverPanics wraps an http Handler to recover and log panics.
func RecoverPanics(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
defer func() {
if x := recover(); x != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, "apis panic. Look in log for details.")
glog.Infof("APIServer panic'd on %v %v: %#v\n%s\n", req.Method, req.RequestURI, x, debug.Stack())
}
}()
defer httplog.NewLogged(req, &w).StacktraceWhen(
httplog.StatusIsNot(
http.StatusOK,
http.StatusAccepted,
http.StatusMovedPermanently,
http.StatusTemporaryRedirect,
http.StatusConflict,
http.StatusNotFound,
StatusUnprocessableEntity,
),
).Log()
// Dispatch to the internal handler
handler.ServeHTTP(w, req)
})
}
// handleVersion writes the server's version information.
func handleVersion(w http.ResponseWriter, req *http.Request) {
writeRawJSON(http.StatusOK, version.Get(), w)

View File

@ -191,7 +191,7 @@ func TestNotFound(t *testing.T) {
}
handler := Handle(map[string]RESTStorage{
"foo": &SimpleRESTStorage{},
}, codec, "/prefix/version")
}, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
client := http.Client{}
for k, v := range cases {
@ -212,7 +212,7 @@ func TestNotFound(t *testing.T) {
}
func TestVersion(t *testing.T) {
handler := Handle(map[string]RESTStorage{}, codec, "/prefix/version")
handler := Handle(map[string]RESTStorage{}, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
client := http.Client{}
@ -241,7 +241,7 @@ func TestSimpleList(t *testing.T) {
storage := map[string]RESTStorage{}
simpleStorage := SimpleRESTStorage{}
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version")
handler := Handle(storage, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
resp, err := http.Get(server.URL + "/prefix/version/simple")
@ -260,7 +260,7 @@ func TestErrorList(t *testing.T) {
errors: map[string]error{"list": fmt.Errorf("test Error")},
}
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version")
handler := Handle(storage, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
resp, err := http.Get(server.URL + "/prefix/version/simple")
@ -284,7 +284,7 @@ func TestNonEmptyList(t *testing.T) {
},
}
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version")
handler := Handle(storage, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
resp, err := http.Get(server.URL + "/prefix/version/simple")
@ -319,7 +319,7 @@ func TestGet(t *testing.T) {
},
}
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version")
handler := Handle(storage, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
resp, err := http.Get(server.URL + "/prefix/version/simple/id")
@ -340,7 +340,7 @@ func TestGetMissing(t *testing.T) {
errors: map[string]error{"get": apierrs.NewNotFound("simple", "id")},
}
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version")
handler := Handle(storage, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
resp, err := http.Get(server.URL + "/prefix/version/simple/id")
@ -358,7 +358,7 @@ func TestDelete(t *testing.T) {
simpleStorage := SimpleRESTStorage{}
ID := "id"
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version")
handler := Handle(storage, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
client := http.Client{}
@ -380,7 +380,7 @@ func TestDeleteMissing(t *testing.T) {
errors: map[string]error{"delete": apierrs.NewNotFound("simple", ID)},
}
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version")
handler := Handle(storage, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
client := http.Client{}
@ -400,7 +400,7 @@ func TestUpdate(t *testing.T) {
simpleStorage := SimpleRESTStorage{}
ID := "id"
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version")
handler := Handle(storage, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
item := &Simple{
@ -430,7 +430,7 @@ func TestUpdateMissing(t *testing.T) {
errors: map[string]error{"update": apierrs.NewNotFound("simple", ID)},
}
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version")
handler := Handle(storage, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
item := &Simple{
@ -457,7 +457,7 @@ func TestCreate(t *testing.T) {
simpleStorage := &SimpleRESTStorage{}
handler := Handle(map[string]RESTStorage{
"foo": simpleStorage,
}, codec, "/prefix/version")
}, codec, "/prefix/version", false)
handler.(*defaultAPIServer).group.handler.asyncOpWait = 0
server := httptest.NewServer(handler)
client := http.Client{}
@ -498,7 +498,7 @@ func TestCreateNotFound(t *testing.T) {
// See https://github.com/GoogleCloudPlatform/kubernetes/pull/486#discussion_r15037092.
errors: map[string]error{"create": apierrs.NewNotFound("simple", "id")},
},
}, codec, "/prefix/version")
}, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
client := http.Client{}
@ -540,7 +540,7 @@ func TestSyncCreate(t *testing.T) {
}
handler := Handle(map[string]RESTStorage{
"foo": &storage,
}, codec, "/prefix/version")
}, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
client := http.Client{}
@ -609,7 +609,7 @@ func TestAsyncDelayReturnsError(t *testing.T) {
return nil, apierrs.NewAlreadyExists("foo", "bar")
},
}
handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version")
handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version", false)
handler.(*defaultAPIServer).group.handler.asyncOpWait = time.Millisecond / 2
server := httptest.NewServer(handler)
@ -627,7 +627,7 @@ func TestAsyncCreateError(t *testing.T) {
return nil, apierrs.NewAlreadyExists("foo", "bar")
},
}
handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version")
handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version", false)
handler.(*defaultAPIServer).group.handler.asyncOpWait = 0
server := httptest.NewServer(handler)
@ -721,7 +721,7 @@ func TestSyncCreateTimeout(t *testing.T) {
}
handler := Handle(map[string]RESTStorage{
"foo": &storage,
}, codec, "/prefix/version")
}, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
simple := &Simple{Name: "foo"}
@ -731,3 +731,36 @@ func TestSyncCreateTimeout(t *testing.T) {
t.Errorf("Unexpected status %#v", itemOut)
}
}
func TestEnableCORS(t *testing.T) {
handler := Handle(map[string]RESTStorage{}, codec, "/prefix/version", true)
server := httptest.NewServer(handler)
client := http.Client{}
request, err := http.NewRequest("GET", server.URL+"/version", nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
request.Header.Set("Origin", "example.com")
response, err := client.Do(request)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if !reflect.DeepEqual(response.Header.Get("Access-Control-Allow-Origin"), "example.com") {
t.Errorf("Expected %#v, Got %#v", response.Header.Get("Access-Control-Allow-Origin"), "example.com")
}
if response.Header.Get("Access-Control-Allow-Credentials") == "" {
t.Errorf("Expected Access-Control-Allow-Credentials header to be set")
}
if response.Header.Get("Access-Control-Allow-Headers") == "" {
t.Errorf("Expected Access-Control-Allow-Headers header to be set")
}
if response.Header.Get("Access-Control-Allow-Methods") == "" {
t.Errorf("Expected Access-Control-Allow-Methods header to be set")
}
}

92
pkg/apiserver/handlers.go Normal file
View File

@ -0,0 +1,92 @@
/*
Copyright 2014 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package apiserver
import (
"fmt"
"net/http"
"regexp"
"runtime/debug"
"strings"
"github.com/GoogleCloudPlatform/kubernetes/pkg/httplog"
"github.com/golang/glog"
)
// RecoverPanics wraps an http Handler to recover and log panics.
func RecoverPanics(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
defer func() {
if x := recover(); x != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, "apis panic. Look in log for details.")
glog.Infof("APIServer panic'd on %v %v: %#v\n%s\n", req.Method, req.RequestURI, x, debug.Stack())
}
}()
defer httplog.NewLogged(req, &w).StacktraceWhen(
httplog.StatusIsNot(
http.StatusOK,
http.StatusAccepted,
http.StatusMovedPermanently,
http.StatusTemporaryRedirect,
http.StatusConflict,
http.StatusNotFound,
StatusUnprocessableEntity,
),
).Log()
// Dispatch to the internal handler
handler.ServeHTTP(w, req)
})
}
// Simple CORS implementation that wraps an http Handler
// For a more detailed implementation use https://github.com/martini-contrib/cors
// or implement CORS at your proxy layer
// Pass nil for allowedMethods and allowedHeaders to use the defaults
func CORS(handler http.Handler, allowedOriginPatterns []string, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
origin := req.Header.Get("Origin")
if origin != "" {
allowed := false
for _, pattern := range allowedOriginPatterns {
allowed, _ = regexp.MatchString(pattern, origin)
}
if allowed {
w.Header().Set("Access-Control-Allow-Origin", origin)
// Set defaults for methods and headers if nothing was passed
if allowedMethods == nil {
allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE"}
}
if allowedHeaders == nil {
allowedHeaders = []string{"Accept", "Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization"}
}
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", "))
w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", "))
w.Header().Set("Access-Control-Allow-Credentials", allowCredentials)
// Stop here if its a preflight OPTIONS request
if req.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
}
}
// Dispatch to the next handler
handler.ServeHTTP(w, req)
})
}

View File

@ -127,7 +127,7 @@ func TestApiServerMinionProxy(t *testing.T) {
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte(req.URL.Path))
}))
server := httptest.NewServer(Handle(nil, nil, "/prefix"))
server := httptest.NewServer(Handle(nil, nil, "/prefix", false))
proxy, _ := url.Parse(proxyServer.URL)
resp, err := http.Get(fmt.Sprintf("%s/proxy/minion/%s%s", server.URL, proxy.Host, "/test"))
if err != nil {

View File

@ -107,7 +107,7 @@ func TestOperationsList(t *testing.T) {
}
handler := Handle(map[string]RESTStorage{
"foo": simpleStorage,
}, codec, "/prefix/version")
}, codec, "/prefix/version", false)
handler.(*defaultAPIServer).group.handler.asyncOpWait = 0
server := httptest.NewServer(handler)
client := http.Client{}
@ -163,7 +163,7 @@ func TestOpGet(t *testing.T) {
}
handler := Handle(map[string]RESTStorage{
"foo": simpleStorage,
}, codec, "/prefix/version")
}, codec, "/prefix/version", false)
handler.(*defaultAPIServer).group.handler.asyncOpWait = 0
server := httptest.NewServer(handler)
client := http.Client{}

View File

@ -30,7 +30,7 @@ func TestRedirect(t *testing.T) {
}
handler := Handle(map[string]RESTStorage{
"foo": simpleStorage,
}, codec, "/prefix/version")
}, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
dontFollow := errors.New("don't follow")

View File

@ -44,7 +44,7 @@ func TestWatchWebsocket(t *testing.T) {
_ = ResourceWatcher(simpleStorage) // Give compile error if this doesn't work.
handler := Handle(map[string]RESTStorage{
"foo": simpleStorage,
}, codec, "/prefix/version")
}, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
dest, _ := url.Parse(server.URL)
@ -90,7 +90,7 @@ func TestWatchHTTP(t *testing.T) {
simpleStorage := &SimpleRESTStorage{}
handler := Handle(map[string]RESTStorage{
"foo": simpleStorage,
}, codec, "/prefix/version")
}, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
client := http.Client{}
@ -147,7 +147,7 @@ func TestWatchParamParsing(t *testing.T) {
simpleStorage := &SimpleRESTStorage{}
handler := Handle(map[string]RESTStorage{
"foo": simpleStorage,
}, codec, "/prefix/version")
}, codec, "/prefix/version", false)
server := httptest.NewServer(handler)
dest, _ := url.Parse(server.URL)

View File

@ -41,7 +41,7 @@ func TestClient(t *testing.T) {
})
storage, codec := m.API_v1beta1()
s := httptest.NewServer(apiserver.Handle(storage, codec, "/api/v1beta1/"))
s := httptest.NewServer(apiserver.Handle(storage, codec, "/api/v1beta1/", false))
client := client.NewOrDie(s.URL, nil)