mirror of https://github.com/k3s-io/k3s
update long running to handle recommend watch mechanism
parent
8f3d7110d5
commit
357aebc89c
|
@ -72,13 +72,35 @@ func ReadOnly(handler http.Handler) http.Handler {
|
|||
})
|
||||
}
|
||||
|
||||
type LongRunningRequestCheck func(r *http.Request) bool
|
||||
|
||||
// BasicLongRunningRequestCheck pathRegex operates against the url path, the queryParams match is case insensitive.
|
||||
// Any one match flags the request.
|
||||
// TODO tighten this check to eliminate the abuse potential by malicious clients that start setting queryParameters
|
||||
// to bypass the rate limitter. This could be done using a full parse and special casing the bits we need.
|
||||
func BasicLongRunningRequestCheck(pathRegex *regexp.Regexp, queryParams map[string]string) LongRunningRequestCheck {
|
||||
return func(r *http.Request) bool {
|
||||
if pathRegex.MatchString(r.URL.Path) {
|
||||
return true
|
||||
}
|
||||
|
||||
for key, expectedValue := range queryParams {
|
||||
if strings.ToLower(expectedValue) == strings.ToLower(r.URL.Query().Get(key)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// MaxInFlight limits the number of in-flight requests to buffer size of the passed in channel.
|
||||
func MaxInFlightLimit(c chan bool, longRunningRequestRE *regexp.Regexp, handler http.Handler) http.Handler {
|
||||
func MaxInFlightLimit(c chan bool, longRunningRequestCheck LongRunningRequestCheck, handler http.Handler) http.Handler {
|
||||
if c == nil {
|
||||
return handler
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if longRunningRequestRE.MatchString(r.URL.Path) {
|
||||
if longRunningRequestCheck(r) {
|
||||
// Skip tracking long running events.
|
||||
handler.ServeHTTP(w, r)
|
||||
return
|
||||
|
|
|
@ -77,6 +77,7 @@ func TestMaxInFlight(t *testing.T) {
|
|||
// notAccountedPathsRegexp specifies paths requests to which we don't account into
|
||||
// requests in flight.
|
||||
notAccountedPathsRegexp := regexp.MustCompile(".*\\/watch")
|
||||
longRunningRequestCheck := BasicLongRunningRequestCheck(notAccountedPathsRegexp, map[string]string{"watch": "true"})
|
||||
|
||||
// Calls is used to wait until all server calls are received. We are sending
|
||||
// AllowedInflightRequestsNo of 'long' not-accounted requests and the same number of
|
||||
|
@ -98,7 +99,7 @@ func TestMaxInFlight(t *testing.T) {
|
|||
server := httptest.NewServer(
|
||||
MaxInFlightLimit(
|
||||
inflightRequestsChannel,
|
||||
notAccountedPathsRegexp,
|
||||
longRunningRequestCheck,
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// A short, accounted request that does not wait for block WaitGroup.
|
||||
if strings.Contains(r.URL.Path, "dontwait") {
|
||||
|
@ -114,11 +115,11 @@ func TestMaxInFlight(t *testing.T) {
|
|||
// TODO: Uncomment when fix #19254
|
||||
// defer server.Close()
|
||||
|
||||
// These should hang, but not affect accounting.
|
||||
// These should hang, but not affect accounting. use a query param match
|
||||
for i := 0; i < AllowedInflightRequestsNo; i++ {
|
||||
// These should hang waiting on block...
|
||||
go func() {
|
||||
if err := expectHTTP(server.URL+"/foo/bar/watch", http.StatusOK); err != nil {
|
||||
if err := expectHTTP(server.URL+"/foo/bar?watch=true", http.StatusOK); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
responses.Done()
|
||||
|
@ -150,7 +151,7 @@ func TestMaxInFlight(t *testing.T) {
|
|||
t.Error(err)
|
||||
}
|
||||
}
|
||||
// Validate that non-accounted URLs still work
|
||||
// Validate that non-accounted URLs still work. use a path regex match
|
||||
if err := expectHTTP(server.URL+"/dontwait/watch", http.StatusOK); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
|
|
@ -650,9 +650,10 @@ func (s *GenericAPIServer) Run(options *ServerRunOptions) {
|
|||
}
|
||||
|
||||
longRunningRE := regexp.MustCompile(options.LongRunningRequestRE)
|
||||
longRunningRequestCheck := apiserver.BasicLongRunningRequestCheck(longRunningRE, map[string]string{"watch": "true"})
|
||||
longRunningTimeout := func(req *http.Request) (<-chan time.Time, string) {
|
||||
// TODO unify this with apiserver.MaxInFlightLimit
|
||||
if longRunningRE.MatchString(req.URL.Path) || req.URL.Query().Get("watch") == "true" {
|
||||
if longRunningRequestCheck(req) {
|
||||
return nil, ""
|
||||
}
|
||||
return time.After(time.Minute), ""
|
||||
|
@ -662,7 +663,7 @@ func (s *GenericAPIServer) Run(options *ServerRunOptions) {
|
|||
handler := apiserver.TimeoutHandler(s.Handler, longRunningTimeout)
|
||||
secureServer := &http.Server{
|
||||
Addr: secureLocation,
|
||||
Handler: apiserver.MaxInFlightLimit(sem, longRunningRE, apiserver.RecoverPanics(handler)),
|
||||
Handler: apiserver.MaxInFlightLimit(sem, longRunningRequestCheck, apiserver.RecoverPanics(handler)),
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
TLSConfig: &tls.Config{
|
||||
// Change default from SSLv3 to TLSv1.0 (because of POODLE vulnerability)
|
||||
|
|
Loading…
Reference in New Issue