From cd950364e5259659a771caf5b564de7a8319969b Mon Sep 17 00:00:00 2001 From: deads2k Date: Wed, 12 Apr 2017 13:26:00 -0400 Subject: [PATCH] add deregistration for paths --- .../apiserver/pkg/server/mux/pathrecorder.go | 73 +++++++++++++-- .../pkg/server/mux/pathrecorder_test.go | 37 ++++++++ .../pkg/apis/apiregistration/helpers.go | 9 ++ .../pkg/apiserver/apiserver.go | 30 +++--- .../pkg/apiserver/handler_apis.go | 23 +---- .../pkg/apiserver/handler_apis_test.go | 6 +- .../pkg/apiserver/handler_proxy.go | 11 ++- .../pkg/apiserver/handler_proxy_test.go | 91 ++++++++++--------- 8 files changed, 184 insertions(+), 96 deletions(-) diff --git a/staging/src/k8s.io/apiserver/pkg/server/mux/pathrecorder.go b/staging/src/k8s.io/apiserver/pkg/server/mux/pathrecorder.go index 40a9e75cfc..7a343b3697 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/mux/pathrecorder.go +++ b/staging/src/k8s.io/apiserver/pkg/server/mux/pathrecorder.go @@ -21,13 +21,21 @@ import ( "net/http" "runtime/debug" "sort" + "sync" + "sync/atomic" utilruntime "k8s.io/apimachinery/pkg/util/runtime" ) -// PathRecorderMux wraps a mux object and records the registered exposedPaths. It is _not_ go routine safe. +// PathRecorderMux wraps a mux object and records the registered exposedPaths. type PathRecorderMux struct { - mux *http.ServeMux + lock sync.Mutex + pathToHandler map[string]http.Handler + + // mux stores an *http.ServeMux and is used to handle the actual serving + mux atomic.Value + + // exposedPaths is the list of paths that should be shown at / exposedPaths []string // pathStacks holds the stacks of all registered paths. This allows us to show a more helpful message @@ -37,10 +45,15 @@ type PathRecorderMux struct { // NewPathRecorderMux creates a new PathRecorderMux with the given mux as the base mux. func NewPathRecorderMux() *PathRecorderMux { - return &PathRecorderMux{ - mux: http.NewServeMux(), - pathStacks: map[string]string{}, + ret := &PathRecorderMux{ + pathToHandler: map[string]http.Handler{}, + mux: atomic.Value{}, + exposedPaths: []string{}, + pathStacks: map[string]string{}, } + + ret.mux.Store(http.NewServeMux()) + return ret } // ListedPaths returns the registered handler exposedPaths. @@ -58,41 +71,81 @@ func (m *PathRecorderMux) trackCallers(path string) { m.pathStacks[path] = string(debug.Stack()) } +// refreshMuxLocked creates a new mux and must be called while locked. Otherwise the view of handlers may +// not be consistent +func (m *PathRecorderMux) refreshMuxLocked() { + mux := http.NewServeMux() + for path, handler := range m.pathToHandler { + mux.Handle(path, handler) + } + + m.mux.Store(mux) +} + +// Unregister removes a path from the mux. +func (m *PathRecorderMux) Unregister(path string) { + m.lock.Lock() + defer m.lock.Unlock() + + delete(m.pathToHandler, path) + delete(m.pathStacks, path) + for i := range m.exposedPaths { + if m.exposedPaths[i] == path { + m.exposedPaths = append(m.exposedPaths[:i], m.exposedPaths[i+1:]...) + break + } + } + + m.refreshMuxLocked() +} + // Handle registers the handler for the given pattern. // If a handler already exists for pattern, Handle panics. func (m *PathRecorderMux) Handle(path string, handler http.Handler) { + m.lock.Lock() + defer m.lock.Unlock() m.trackCallers(path) m.exposedPaths = append(m.exposedPaths, path) - m.mux.Handle(path, handler) + m.pathToHandler[path] = handler + m.refreshMuxLocked() } // HandleFunc registers the handler function for the given pattern. // If a handler already exists for pattern, Handle panics. func (m *PathRecorderMux) HandleFunc(path string, handler func(http.ResponseWriter, *http.Request)) { + m.lock.Lock() + defer m.lock.Unlock() m.trackCallers(path) m.exposedPaths = append(m.exposedPaths, path) - m.mux.HandleFunc(path, handler) + m.pathToHandler[path] = http.HandlerFunc(handler) + m.refreshMuxLocked() } // UnlistedHandle registers the handler for the given pattern, but doesn't list it. // If a handler already exists for pattern, Handle panics. func (m *PathRecorderMux) UnlistedHandle(path string, handler http.Handler) { + m.lock.Lock() + defer m.lock.Unlock() m.trackCallers(path) - m.mux.Handle(path, handler) + m.pathToHandler[path] = handler + m.refreshMuxLocked() } // UnlistedHandleFunc registers the handler function for the given pattern, but doesn't list it. // If a handler already exists for pattern, Handle panics. func (m *PathRecorderMux) UnlistedHandleFunc(path string, handler func(http.ResponseWriter, *http.Request)) { + m.lock.Lock() + defer m.lock.Unlock() m.trackCallers(path) - m.mux.HandleFunc(path, handler) + m.pathToHandler[path] = http.HandlerFunc(handler) + m.refreshMuxLocked() } // ServeHTTP makes it an http.Handler func (m *PathRecorderMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { - m.mux.ServeHTTP(w, r) + m.mux.Load().(*http.ServeMux).ServeHTTP(w, r) } diff --git a/staging/src/k8s.io/apiserver/pkg/server/mux/pathrecorder_test.go b/staging/src/k8s.io/apiserver/pkg/server/mux/pathrecorder_test.go index 3d7e6b6108..9bf64fe787 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/mux/pathrecorder_test.go +++ b/staging/src/k8s.io/apiserver/pkg/server/mux/pathrecorder_test.go @@ -18,6 +18,7 @@ package mux import ( "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -30,3 +31,39 @@ func TestSecretHandlers(t *testing.T) { assert.NotContains(t, c.ListedPaths(), "/secret") assert.Contains(t, c.ListedPaths(), "/nonswagger") } + +func TestUnregisterHandlers(t *testing.T) { + first := 0 + second := 0 + + c := NewPathRecorderMux() + s := httptest.NewServer(c) + defer s.Close() + + c.UnlistedHandleFunc("/secret", func(http.ResponseWriter, *http.Request) {}) + c.HandleFunc("/nonswagger", func(http.ResponseWriter, *http.Request) { + first = first + 1 + }) + assert.NotContains(t, c.ListedPaths(), "/secret") + assert.Contains(t, c.ListedPaths(), "/nonswagger") + + resp, _ := http.Get(s.URL + "/nonswagger") + assert.Equal(t, first, 1) + assert.Equal(t, resp.StatusCode, http.StatusOK) + + c.Unregister("/nonswagger") + assert.NotContains(t, c.ListedPaths(), "/nonswagger") + + resp, _ = http.Get(s.URL + "/nonswagger") + assert.Equal(t, first, 1) + assert.Equal(t, resp.StatusCode, http.StatusNotFound) + + c.HandleFunc("/nonswagger", func(http.ResponseWriter, *http.Request) { + second = second + 1 + }) + assert.Contains(t, c.ListedPaths(), "/nonswagger") + resp, _ = http.Get(s.URL + "/nonswagger") + assert.Equal(t, first, 1) + assert.Equal(t, second, 1) + assert.Equal(t, resp.StatusCode, http.StatusOK) +} diff --git a/staging/src/k8s.io/kube-aggregator/pkg/apis/apiregistration/helpers.go b/staging/src/k8s.io/kube-aggregator/pkg/apis/apiregistration/helpers.go index 64527bdf67..655a66e774 100644 --- a/staging/src/k8s.io/kube-aggregator/pkg/apis/apiregistration/helpers.go +++ b/staging/src/k8s.io/kube-aggregator/pkg/apis/apiregistration/helpers.go @@ -19,6 +19,8 @@ package apiregistration import ( "sort" "strings" + + "k8s.io/apimachinery/pkg/runtime/schema" ) func SortedByGroup(servers []*APIService) [][]*APIService { @@ -57,3 +59,10 @@ func (s ByPriority) Less(i, j int) bool { } return s[i].Spec.Priority < s[j].Spec.Priority } + +// APIServiceNameToGroupVersion returns the GroupVersion for a given apiServiceName. The name +// must be valid, but any object you get back from an informer will be valid. +func APIServiceNameToGroupVersion(apiServiceName string) schema.GroupVersion { + tokens := strings.SplitN(apiServiceName, ".", 2) + return schema.GroupVersion{Group: tokens[1], Version: tokens[0]} +} diff --git a/staging/src/k8s.io/kube-aggregator/pkg/apiserver/apiserver.go b/staging/src/k8s.io/kube-aggregator/pkg/apiserver/apiserver.go index c698deaebe..7378c5bb36 100644 --- a/staging/src/k8s.io/kube-aggregator/pkg/apiserver/apiserver.go +++ b/staging/src/k8s.io/kube-aggregator/pkg/apiserver/apiserver.go @@ -220,8 +220,8 @@ func (s *APIAggregator) AddAPIService(apiService *apiregistration.APIService, de } proxyHandler.updateAPIService(apiService, destinationHost) s.proxyHandlers[apiService.Name] = proxyHandler - s.GenericAPIServer.HandlerContainer.ServeMux.Handle(proxyPath, proxyHandler) - s.GenericAPIServer.HandlerContainer.ServeMux.Handle(proxyPath+"/", proxyHandler) + s.GenericAPIServer.FallThroughHandler.Handle(proxyPath, proxyHandler) + s.GenericAPIServer.FallThroughHandler.UnlistedHandle(proxyPath+"/", proxyHandler) // if we're dealing with the legacy group, we're done here if apiService.Name == legacyAPIServiceName { @@ -241,20 +241,28 @@ func (s *APIAggregator) AddAPIService(apiService *apiregistration.APIService, de lister: s.lister, serviceLister: s.serviceLister, endpointsLister: s.endpointsLister, - delegate: s.GenericAPIServer.FallThroughHandler, + delegate: s.delegateHandler, } // aggregation is protected - s.GenericAPIServer.HandlerContainer.ServeMux.Handle(groupPath, groupDiscoveryHandler) - s.GenericAPIServer.HandlerContainer.ServeMux.Handle(groupPath+"/", groupDiscoveryHandler) + s.GenericAPIServer.FallThroughHandler.Handle(groupPath, groupDiscoveryHandler) + s.GenericAPIServer.FallThroughHandler.UnlistedHandle(groupPath+"/", groupDiscoveryHandler) s.handledGroups.Insert(apiService.Spec.Group) } -// RemoveAPIService removes the APIService from being handled. Later on it will disable the proxy endpoint. -// Right now it does nothing because our handler has to properly 404 itself since muxes don't unregister +// RemoveAPIService removes the APIService from being handled. It is not thread-safe, so only call it on one thread at a time please. +// It's a slow moving API, so its ok to run the controller on a single thread. func (s *APIAggregator) RemoveAPIService(apiServiceName string) { - proxyHandler, exists := s.proxyHandlers[apiServiceName] - if !exists { - return + version := apiregistration.APIServiceNameToGroupVersion(apiServiceName) + + proxyPath := "/apis/" + version.Group + "/" + version.Version + // v1. is a special case for the legacy API. It proxies to a wider set of endpoints. + if apiServiceName == legacyAPIServiceName { + proxyPath = "/api" } - proxyHandler.removeAPIService() + s.GenericAPIServer.FallThroughHandler.Unregister(proxyPath) + s.GenericAPIServer.FallThroughHandler.Unregister(proxyPath + "/") + delete(s.proxyHandlers, apiServiceName) + + // TODO unregister group level discovery when there are no more versions for the group + // We don't need this right away because the handler properly delegates when no versions are present } diff --git a/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_apis.go b/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_apis.go index 7d3c60abda..5eafa6f4bf 100644 --- a/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_apis.go +++ b/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_apis.go @@ -18,7 +18,6 @@ package apiserver import ( "net/http" - "strings" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -30,24 +29,9 @@ import ( apiregistrationapi "k8s.io/kube-aggregator/pkg/apis/apiregistration" apiregistrationv1alpha1api "k8s.io/kube-aggregator/pkg/apis/apiregistration/v1alpha1" - informers "k8s.io/kube-aggregator/pkg/client/informers/internalversion/apiregistration/internalversion" listers "k8s.io/kube-aggregator/pkg/client/listers/apiregistration/internalversion" ) -// WithAPIs adds the handling for /apis and /apis/. -func WithAPIs(handler http.Handler, codecs serializer.CodecFactory, informer informers.APIServiceInformer, serviceLister v1listers.ServiceLister, endpointsLister v1listers.EndpointsLister) http.Handler { - apisHandler := &apisHandler{ - codecs: codecs, - lister: informer.Lister(), - delegate: handler, - serviceLister: serviceLister, - endpointsLister: endpointsLister, - } - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - apisHandler.ServeHTTP(w, req) - }) -} - // apisHandler serves the `/apis` endpoint. // This is registered as a filter so that it never collides with any explictly registered endpoints type apisHandler struct { @@ -75,11 +59,6 @@ var discoveryGroup = metav1.APIGroup{ } func (r *apisHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - // if the URL is for OUR api group, serve it normally - if strings.HasPrefix(req.URL.Path+"/", "/apis/"+apiregistrationapi.GroupName+"/") { - r.delegate.ServeHTTP(w, req) - return - } // don't handle URLs that aren't /apis if req.URL.Path != "/apis" && req.URL.Path != "/apis/" { r.delegate.ServeHTTP(w, req) @@ -210,7 +189,7 @@ func (r *apiGroupHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } if len(apiServicesForGroup) == 0 { - http.Error(w, "", http.StatusNotFound) + r.delegate.ServeHTTP(w, req) return } diff --git a/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_apis_test.go b/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_apis_test.go index f955fbf0ef..f5d6ffebed 100644 --- a/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_apis_test.go +++ b/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_apis_test.go @@ -338,13 +338,13 @@ func TestAPIGroupMissing(t *testing.T) { t.Fatalf("expected %v, got %v", http.StatusForbidden, resp.StatusCode) } - // groupName still has no api services for it (like it was deleted), it should 404 + // groupName still has no api services for it (like it was deleted), it should delegate resp, err = http.Get(server.URL + "/apis/groupName/") if err != nil { t.Fatal(err) } - if resp.StatusCode != http.StatusNotFound { - t.Fatalf("expected %v, got %v", http.StatusNotFound, resp.StatusCode) + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("expected %v, got %v", http.StatusForbidden, resp.StatusCode) } // missing group should delegate still has no api services for it (like it was deleted) diff --git a/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy.go b/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy.go index d3fdc92a50..b925a4ef0a 100644 --- a/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy.go +++ b/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy.go @@ -64,7 +64,12 @@ type proxyHandlingInfo struct { } func (r *proxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - handlingInfo := r.handlingInfo.Load().(proxyHandlingInfo) + value := r.handlingInfo.Load() + if value == nil { + r.localDelegate.ServeHTTP(w, req) + return + } + handlingInfo := value.(proxyHandlingInfo) if handlingInfo.local { r.localDelegate.ServeHTTP(w, req) return @@ -197,7 +202,3 @@ func (r *proxyHandler) updateAPIService(apiService *apiregistrationapi.APIServic newInfo.proxyRoundTripper, newInfo.transportBuildingError = restclient.TransportFor(newInfo.restConfig) r.handlingInfo.Store(newInfo) } - -func (r *proxyHandler) removeAPIService() { - r.handlingInfo.Store(proxyHandlingInfo{}) -} diff --git a/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy_test.go b/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy_test.go index e819246d53..9c24dd4b02 100644 --- a/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy_test.go +++ b/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy_test.go @@ -83,13 +83,6 @@ func TestProxyHandler(t *testing.T) { targetServer := httptest.NewTLSServer(target) defer targetServer.Close() - handler := &proxyHandler{ - localDelegate: http.NewServeMux(), - } - - server := httptest.NewServer(handler) - defer server.Close() - tests := map[string]struct { user user.Info path string @@ -161,45 +154,53 @@ func TestProxyHandler(t *testing.T) { for name, tc := range tests { target.Reset() - handler.contextMapper = &fakeRequestContextMapper{user: tc.user} - handler.removeAPIService() - if tc.apiService != nil { - handler.updateAPIService(tc.apiService, tc.apiService.Spec.Service.Name+"."+tc.apiService.Spec.Service.Namespace+".svc") - curr := handler.handlingInfo.Load().(proxyHandlingInfo) - curr.destinationHost = targetServer.Listener.Addr().String() - handler.handlingInfo.Store(curr) - } - resp, err := http.Get(server.URL + tc.path) - if err != nil { - t.Errorf("%s: %v", name, err) - continue - } - if e, a := tc.expectedStatusCode, resp.StatusCode; e != a { - body, _ := httputil.DumpResponse(resp, true) - t.Logf("%s: %v", name, string(body)) - t.Errorf("%s: expected %v, got %v", name, e, a) - continue - } - bytes, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("%s: %v", name, err) - continue - } - if !strings.Contains(string(bytes), tc.expectedBody) { - t.Errorf("%s: expected %q, got %q", name, tc.expectedBody, string(bytes)) - continue - } + func() { + handler := &proxyHandler{ + localDelegate: http.NewServeMux(), + } + handler.contextMapper = &fakeRequestContextMapper{user: tc.user} + server := httptest.NewServer(handler) + defer server.Close() - if e, a := tc.expectedCalled, target.called; e != a { - t.Errorf("%s: expected %v, got %v", name, e, a) - continue - } - // this varies every test - delete(target.headers, "X-Forwarded-Host") - if e, a := tc.expectedHeaders, target.headers; !reflect.DeepEqual(e, a) { - t.Errorf("%s: expected %v, got %v", name, e, a) - continue - } + if tc.apiService != nil { + handler.updateAPIService(tc.apiService, tc.apiService.Spec.Service.Name+"."+tc.apiService.Spec.Service.Namespace+".svc") + curr := handler.handlingInfo.Load().(proxyHandlingInfo) + curr.destinationHost = targetServer.Listener.Addr().String() + handler.handlingInfo.Store(curr) + } + + resp, err := http.Get(server.URL + tc.path) + if err != nil { + t.Errorf("%s: %v", name, err) + return + } + if e, a := tc.expectedStatusCode, resp.StatusCode; e != a { + body, _ := httputil.DumpResponse(resp, true) + t.Logf("%s: %v", name, string(body)) + t.Errorf("%s: expected %v, got %v", name, e, a) + return + } + bytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("%s: %v", name, err) + return + } + if !strings.Contains(string(bytes), tc.expectedBody) { + t.Errorf("%s: expected %q, got %q", name, tc.expectedBody, string(bytes)) + return + } + + if e, a := tc.expectedCalled, target.called; e != a { + t.Errorf("%s: expected %v, got %v", name, e, a) + return + } + // this varies every test + delete(target.headers, "X-Forwarded-Host") + if e, a := tc.expectedHeaders, target.headers; !reflect.DeepEqual(e, a) { + t.Errorf("%s: expected %v, got %v", name, e, a) + return + } + }() } }