mirror of https://github.com/k3s-io/k3s
Move RecoverPanics to be the top level wrapped handler. Add new method to be sure a logger has been generated instead of assuming one has. Move regexp list compilation into a utility and pass regexp list into CORS.
parent
8b4ca9c2a7
commit
becf6ca4e7
|
@ -139,12 +139,16 @@ func main() {
|
||||||
|
|
||||||
handler := apiserver.Handle(storage, codec, *apiPrefix)
|
handler := apiserver.Handle(storage, codec, *apiPrefix)
|
||||||
if len(corsAllowedOriginList) > 0 {
|
if len(corsAllowedOriginList) > 0 {
|
||||||
handler = apiserver.CORS(handler, corsAllowedOriginList, nil, nil, "true")
|
allowedOriginRegexps, err := util.CompileRegexps(corsAllowedOriginList)
|
||||||
|
if err != nil {
|
||||||
|
glog.Fatalf("Invalid CORS allowed origin: %v", err)
|
||||||
|
}
|
||||||
|
handler = apiserver.CORS(handler, allowedOriginRegexps, nil, nil, "true")
|
||||||
}
|
}
|
||||||
|
|
||||||
s := &http.Server{
|
s := &http.Server{
|
||||||
Addr: net.JoinHostPort(*address, strconv.Itoa(int(*port))),
|
Addr: net.JoinHostPort(*address, strconv.Itoa(int(*port))),
|
||||||
Handler: handler,
|
Handler: apiserver.RecoverPanics(handler),
|
||||||
ReadTimeout: 5 * time.Minute,
|
ReadTimeout: 5 * time.Minute,
|
||||||
WriteTimeout: 5 * time.Minute,
|
WriteTimeout: 5 * time.Minute,
|
||||||
MaxHeaderBytes: 1 << 20,
|
MaxHeaderBytes: 1 << 20,
|
||||||
|
|
|
@ -55,7 +55,7 @@ func Handle(storage map[string]RESTStorage, codec runtime.Codec, prefix string)
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
group.InstallREST(mux, prefix)
|
group.InstallREST(mux, prefix)
|
||||||
InstallSupport(mux)
|
InstallSupport(mux)
|
||||||
return &defaultAPIServer{RecoverPanics(mux), group}
|
return &defaultAPIServer{mux, group}
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIGroup is a http.Handler that exposes multiple RESTStorage objects
|
// APIGroup is a http.Handler that exposes multiple RESTStorage objects
|
||||||
|
|
|
@ -25,6 +25,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -733,7 +734,7 @@ func TestSyncCreateTimeout(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCORSAllowedOrigin(t *testing.T) {
|
func TestCORSAllowedOrigin(t *testing.T) {
|
||||||
handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []string{"example.com"}, nil, nil, "true")
|
handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []*regexp.Regexp{regexp.MustCompile("example.com")}, nil, nil, "true")
|
||||||
server := httptest.NewServer(handler)
|
server := httptest.NewServer(handler)
|
||||||
client := http.Client{}
|
client := http.Client{}
|
||||||
|
|
||||||
|
@ -766,7 +767,7 @@ func TestCORSAllowedOrigin(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCORSUnallowedOrigin(t *testing.T) {
|
func TestCORSUnallowedOrigin(t *testing.T) {
|
||||||
handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []string{"example.com"}, nil, nil, "true")
|
handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []*regexp.Regexp{regexp.MustCompile("example.com")}, nil, nil, "true")
|
||||||
server := httptest.NewServer(handler)
|
server := httptest.NewServer(handler)
|
||||||
client := http.Client{}
|
client := http.Client{}
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,6 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/httplog"
|
"github.com/GoogleCloudPlatform/kubernetes/pkg/httplog"
|
||||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
|
|
||||||
"github.com/golang/glog"
|
"github.com/golang/glog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -59,22 +58,12 @@ func RecoverPanics(handler http.Handler) http.Handler {
|
||||||
// For a more detailed implementation use https://github.com/martini-contrib/cors
|
// For a more detailed implementation use https://github.com/martini-contrib/cors
|
||||||
// or implement CORS at your proxy layer
|
// or implement CORS at your proxy layer
|
||||||
// Pass nil for allowedMethods and allowedHeaders to use the defaults
|
// Pass nil for allowedMethods and allowedHeaders to use the defaults
|
||||||
func CORS(handler http.Handler, allowedOriginPatterns util.StringList, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler {
|
func CORS(handler http.Handler, allowedOriginPatterns []*regexp.Regexp, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler {
|
||||||
// Compile the regular expressions once upfront
|
|
||||||
allowedOriginRegexps := []*regexp.Regexp{}
|
|
||||||
for _, allowedOrigin := range allowedOriginPatterns {
|
|
||||||
allowedOriginRegexp, err := regexp.Compile(allowedOrigin)
|
|
||||||
if err != nil {
|
|
||||||
glog.Fatalf("Invalid CORS allowed origin regexp: %v", err)
|
|
||||||
}
|
|
||||||
allowedOriginRegexps = append(allowedOriginRegexps, allowedOriginRegexp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
origin := req.Header.Get("Origin")
|
origin := req.Header.Get("Origin")
|
||||||
if origin != "" {
|
if origin != "" {
|
||||||
allowed := false
|
allowed := false
|
||||||
for _, pattern := range allowedOriginRegexps {
|
for _, pattern := range allowedOriginPatterns {
|
||||||
if allowed = pattern.MatchString(origin); allowed {
|
if allowed = pattern.MatchString(origin); allowed {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -86,7 +75,7 @@ func CORS(handler http.Handler, allowedOriginPatterns util.StringList, allowedMe
|
||||||
allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE"}
|
allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE"}
|
||||||
}
|
}
|
||||||
if allowedHeaders == nil {
|
if allowedHeaders == nil {
|
||||||
allowedHeaders = []string{"Accept", "Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization"}
|
allowedHeaders = []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "X-Requested-With", "If-Modified-Since"}
|
||||||
}
|
}
|
||||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", "))
|
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", "))
|
||||||
w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", "))
|
w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", "))
|
||||||
|
|
|
@ -42,7 +42,7 @@ func (h *RESTHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
}
|
}
|
||||||
storage := h.storage[parts[0]]
|
storage := h.storage[parts[0]]
|
||||||
if storage == nil {
|
if storage == nil {
|
||||||
httplog.LogOf(w).Addf("'%v' has no storage object", parts[0])
|
httplog.FindOrCreateLogOf(req, &w).Addf("'%v' has no storage object", parts[0])
|
||||||
notFound(w, req)
|
notFound(w, req)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -127,7 +127,7 @@ func (w *WatchServer) HandleWS(ws *websocket.Conn) {
|
||||||
// ServeHTTP serves a series of JSON encoded events via straight HTTP with
|
// ServeHTTP serves a series of JSON encoded events via straight HTTP with
|
||||||
// Transfer-Encoding: chunked.
|
// Transfer-Encoding: chunked.
|
||||||
func (self *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
func (self *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
loggedW := httplog.LogOf(w)
|
loggedW := httplog.FindOrCreateLogOf(req, &w)
|
||||||
w = httplog.Unlogged(w)
|
w = httplog.Unlogged(w)
|
||||||
|
|
||||||
cn, ok := w.(http.CloseNotifier)
|
cn, ok := w.(http.CloseNotifier)
|
||||||
|
|
|
@ -95,6 +95,15 @@ func LogOf(w http.ResponseWriter) *respLogger {
|
||||||
panic("Logger not installed yet!")
|
panic("Logger not installed yet!")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the existing logger hiding in w. If there is not an existing logger
|
||||||
|
// then one will be created.
|
||||||
|
func FindOrCreateLogOf(req *http.Request, w *http.ResponseWriter) *respLogger {
|
||||||
|
if _, exists := (*w).(*respLogger); !exists {
|
||||||
|
NewLogged(req, w)
|
||||||
|
}
|
||||||
|
return LogOf(*w)
|
||||||
|
}
|
||||||
|
|
||||||
// Unlogged returns the original ResponseWriter, or w if it is not our inserted logger.
|
// Unlogged returns the original ResponseWriter, or w if it is not our inserted logger.
|
||||||
func Unlogged(w http.ResponseWriter) http.ResponseWriter {
|
func Unlogged(w http.ResponseWriter) http.ResponseWriter {
|
||||||
if rl, ok := w.(*respLogger); ok {
|
if rl, ok := w.(*respLogger); ok {
|
||||||
|
|
|
@ -19,6 +19,7 @@ package util
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -162,3 +163,16 @@ func StringDiff(a, b string) string {
|
||||||
out = append(out, []byte("\n\n")...)
|
out = append(out, []byte("\n\n")...)
|
||||||
return string(out)
|
return string(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Takes a list of strings and compiles them into a list of regular expressions
|
||||||
|
func CompileRegexps(regexpStrings StringList) ([]*regexp.Regexp, error) {
|
||||||
|
regexps := []*regexp.Regexp{}
|
||||||
|
for _, regexpStr := range regexpStrings {
|
||||||
|
r, err := regexp.Compile(regexpStr)
|
||||||
|
if err != nil {
|
||||||
|
return []*regexp.Regexp{}, err
|
||||||
|
}
|
||||||
|
regexps = append(regexps, r)
|
||||||
|
}
|
||||||
|
return regexps, nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue