mirror of https://github.com/k3s-io/k3s
commit
7c801f217c
|
@ -72,13 +72,35 @@ func ReadOnly(handler http.Handler) http.Handler {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type LongRunningRequestCheck func(r *http.Request) bool
|
||||||
|
|
||||||
|
// BasicLongRunningRequestCheck pathRegex operates against the url path, the queryParams match is case insensitive.
|
||||||
|
// Any one match flags the request.
|
||||||
|
// TODO tighten this check to eliminate the abuse potential by malicious clients that start setting queryParameters
|
||||||
|
// to bypass the rate limitter. This could be done using a full parse and special casing the bits we need.
|
||||||
|
func BasicLongRunningRequestCheck(pathRegex *regexp.Regexp, queryParams map[string]string) LongRunningRequestCheck {
|
||||||
|
return func(r *http.Request) bool {
|
||||||
|
if pathRegex.MatchString(r.URL.Path) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, expectedValue := range queryParams {
|
||||||
|
if strings.ToLower(expectedValue) == strings.ToLower(r.URL.Query().Get(key)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// MaxInFlight limits the number of in-flight requests to buffer size of the passed in channel.
|
// MaxInFlight limits the number of in-flight requests to buffer size of the passed in channel.
|
||||||
func MaxInFlightLimit(c chan bool, longRunningRequestRE *regexp.Regexp, handler http.Handler) http.Handler {
|
func MaxInFlightLimit(c chan bool, longRunningRequestCheck LongRunningRequestCheck, handler http.Handler) http.Handler {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if longRunningRequestRE.MatchString(r.URL.Path) {
|
if longRunningRequestCheck(r) {
|
||||||
// Skip tracking long running events.
|
// Skip tracking long running events.
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
|
|
|
@ -77,6 +77,7 @@ func TestMaxInFlight(t *testing.T) {
|
||||||
// notAccountedPathsRegexp specifies paths requests to which we don't account into
|
// notAccountedPathsRegexp specifies paths requests to which we don't account into
|
||||||
// requests in flight.
|
// requests in flight.
|
||||||
notAccountedPathsRegexp := regexp.MustCompile(".*\\/watch")
|
notAccountedPathsRegexp := regexp.MustCompile(".*\\/watch")
|
||||||
|
longRunningRequestCheck := BasicLongRunningRequestCheck(notAccountedPathsRegexp, map[string]string{"watch": "true"})
|
||||||
|
|
||||||
// Calls is used to wait until all server calls are received. We are sending
|
// Calls is used to wait until all server calls are received. We are sending
|
||||||
// AllowedInflightRequestsNo of 'long' not-accounted requests and the same number of
|
// AllowedInflightRequestsNo of 'long' not-accounted requests and the same number of
|
||||||
|
@ -98,7 +99,7 @@ func TestMaxInFlight(t *testing.T) {
|
||||||
server := httptest.NewServer(
|
server := httptest.NewServer(
|
||||||
MaxInFlightLimit(
|
MaxInFlightLimit(
|
||||||
inflightRequestsChannel,
|
inflightRequestsChannel,
|
||||||
notAccountedPathsRegexp,
|
longRunningRequestCheck,
|
||||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// A short, accounted request that does not wait for block WaitGroup.
|
// A short, accounted request that does not wait for block WaitGroup.
|
||||||
if strings.Contains(r.URL.Path, "dontwait") {
|
if strings.Contains(r.URL.Path, "dontwait") {
|
||||||
|
@ -114,11 +115,11 @@ func TestMaxInFlight(t *testing.T) {
|
||||||
// TODO: Uncomment when fix #19254
|
// TODO: Uncomment when fix #19254
|
||||||
// defer server.Close()
|
// defer server.Close()
|
||||||
|
|
||||||
// These should hang, but not affect accounting.
|
// These should hang, but not affect accounting. use a query param match
|
||||||
for i := 0; i < AllowedInflightRequestsNo; i++ {
|
for i := 0; i < AllowedInflightRequestsNo; i++ {
|
||||||
// These should hang waiting on block...
|
// These should hang waiting on block...
|
||||||
go func() {
|
go func() {
|
||||||
if err := expectHTTP(server.URL+"/foo/bar/watch", http.StatusOK); err != nil {
|
if err := expectHTTP(server.URL+"/foo/bar?watch=true", http.StatusOK); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
responses.Done()
|
responses.Done()
|
||||||
|
@ -150,7 +151,7 @@ func TestMaxInFlight(t *testing.T) {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Validate that non-accounted URLs still work
|
// Validate that non-accounted URLs still work. use a path regex match
|
||||||
if err := expectHTTP(server.URL+"/dontwait/watch", http.StatusOK); err != nil {
|
if err := expectHTTP(server.URL+"/dontwait/watch", http.StatusOK); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -650,9 +650,10 @@ func (s *GenericAPIServer) Run(options *ServerRunOptions) {
|
||||||
}
|
}
|
||||||
|
|
||||||
longRunningRE := regexp.MustCompile(options.LongRunningRequestRE)
|
longRunningRE := regexp.MustCompile(options.LongRunningRequestRE)
|
||||||
|
longRunningRequestCheck := apiserver.BasicLongRunningRequestCheck(longRunningRE, map[string]string{"watch": "true"})
|
||||||
longRunningTimeout := func(req *http.Request) (<-chan time.Time, string) {
|
longRunningTimeout := func(req *http.Request) (<-chan time.Time, string) {
|
||||||
// TODO unify this with apiserver.MaxInFlightLimit
|
// TODO unify this with apiserver.MaxInFlightLimit
|
||||||
if longRunningRE.MatchString(req.URL.Path) || req.URL.Query().Get("watch") == "true" {
|
if longRunningRequestCheck(req) {
|
||||||
return nil, ""
|
return nil, ""
|
||||||
}
|
}
|
||||||
return time.After(time.Minute), ""
|
return time.After(time.Minute), ""
|
||||||
|
@ -662,7 +663,7 @@ func (s *GenericAPIServer) Run(options *ServerRunOptions) {
|
||||||
handler := apiserver.TimeoutHandler(s.Handler, longRunningTimeout)
|
handler := apiserver.TimeoutHandler(s.Handler, longRunningTimeout)
|
||||||
secureServer := &http.Server{
|
secureServer := &http.Server{
|
||||||
Addr: secureLocation,
|
Addr: secureLocation,
|
||||||
Handler: apiserver.MaxInFlightLimit(sem, longRunningRE, apiserver.RecoverPanics(handler)),
|
Handler: apiserver.MaxInFlightLimit(sem, longRunningRequestCheck, apiserver.RecoverPanics(handler)),
|
||||||
MaxHeaderBytes: 1 << 20,
|
MaxHeaderBytes: 1 << 20,
|
||||||
TLSConfig: &tls.Config{
|
TLSConfig: &tls.Config{
|
||||||
// Change default from SSLv3 to TLSv1.0 (because of POODLE vulnerability)
|
// Change default from SSLv3 to TLSv1.0 (because of POODLE vulnerability)
|
||||||
|
|
Loading…
Reference in New Issue