add deregistration for paths

pull/6/head
deads2k 2017-04-12 13:26:00 -04:00
parent 5ba21e83b9
commit cd950364e5
8 changed files with 184 additions and 96 deletions

View File

@ -21,13 +21,21 @@ import (
"net/http"
"runtime/debug"
"sort"
"sync"
"sync/atomic"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
)
// PathRecorderMux wraps a mux object and records the registered exposedPaths. It is _not_ go routine safe.
// PathRecorderMux wraps a mux object and records the registered exposedPaths.
type PathRecorderMux struct {
mux *http.ServeMux
lock sync.Mutex
pathToHandler map[string]http.Handler
// mux stores an *http.ServeMux and is used to handle the actual serving
mux atomic.Value
// exposedPaths is the list of paths that should be shown at /
exposedPaths []string
// pathStacks holds the stacks of all registered paths. This allows us to show a more helpful message
@ -37,10 +45,15 @@ type PathRecorderMux struct {
// NewPathRecorderMux creates a new PathRecorderMux with the given mux as the base mux.
func NewPathRecorderMux() *PathRecorderMux {
return &PathRecorderMux{
mux: http.NewServeMux(),
pathStacks: map[string]string{},
ret := &PathRecorderMux{
pathToHandler: map[string]http.Handler{},
mux: atomic.Value{},
exposedPaths: []string{},
pathStacks: map[string]string{},
}
ret.mux.Store(http.NewServeMux())
return ret
}
// ListedPaths returns the registered handler exposedPaths.
@ -58,41 +71,81 @@ func (m *PathRecorderMux) trackCallers(path string) {
m.pathStacks[path] = string(debug.Stack())
}
// refreshMuxLocked creates a new mux and must be called while locked. Otherwise the view of handlers may
// not be consistent
func (m *PathRecorderMux) refreshMuxLocked() {
mux := http.NewServeMux()
for path, handler := range m.pathToHandler {
mux.Handle(path, handler)
}
m.mux.Store(mux)
}
// Unregister removes a path from the mux.
func (m *PathRecorderMux) Unregister(path string) {
m.lock.Lock()
defer m.lock.Unlock()
delete(m.pathToHandler, path)
delete(m.pathStacks, path)
for i := range m.exposedPaths {
if m.exposedPaths[i] == path {
m.exposedPaths = append(m.exposedPaths[:i], m.exposedPaths[i+1:]...)
break
}
}
m.refreshMuxLocked()
}
// Handle registers the handler for the given pattern.
// If a handler already exists for pattern, Handle panics.
func (m *PathRecorderMux) Handle(path string, handler http.Handler) {
m.lock.Lock()
defer m.lock.Unlock()
m.trackCallers(path)
m.exposedPaths = append(m.exposedPaths, path)
m.mux.Handle(path, handler)
m.pathToHandler[path] = handler
m.refreshMuxLocked()
}
// HandleFunc registers the handler function for the given pattern.
// If a handler already exists for pattern, Handle panics.
func (m *PathRecorderMux) HandleFunc(path string, handler func(http.ResponseWriter, *http.Request)) {
m.lock.Lock()
defer m.lock.Unlock()
m.trackCallers(path)
m.exposedPaths = append(m.exposedPaths, path)
m.mux.HandleFunc(path, handler)
m.pathToHandler[path] = http.HandlerFunc(handler)
m.refreshMuxLocked()
}
// UnlistedHandle registers the handler for the given pattern, but doesn't list it.
// If a handler already exists for pattern, Handle panics.
func (m *PathRecorderMux) UnlistedHandle(path string, handler http.Handler) {
m.lock.Lock()
defer m.lock.Unlock()
m.trackCallers(path)
m.mux.Handle(path, handler)
m.pathToHandler[path] = handler
m.refreshMuxLocked()
}
// UnlistedHandleFunc registers the handler function for the given pattern, but doesn't list it.
// If a handler already exists for pattern, Handle panics.
func (m *PathRecorderMux) UnlistedHandleFunc(path string, handler func(http.ResponseWriter, *http.Request)) {
m.lock.Lock()
defer m.lock.Unlock()
m.trackCallers(path)
m.mux.HandleFunc(path, handler)
m.pathToHandler[path] = http.HandlerFunc(handler)
m.refreshMuxLocked()
}
// ServeHTTP makes it an http.Handler
func (m *PathRecorderMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.mux.ServeHTTP(w, r)
m.mux.Load().(*http.ServeMux).ServeHTTP(w, r)
}

View File

@ -18,6 +18,7 @@ package mux
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
@ -30,3 +31,39 @@ func TestSecretHandlers(t *testing.T) {
assert.NotContains(t, c.ListedPaths(), "/secret")
assert.Contains(t, c.ListedPaths(), "/nonswagger")
}
func TestUnregisterHandlers(t *testing.T) {
first := 0
second := 0
c := NewPathRecorderMux()
s := httptest.NewServer(c)
defer s.Close()
c.UnlistedHandleFunc("/secret", func(http.ResponseWriter, *http.Request) {})
c.HandleFunc("/nonswagger", func(http.ResponseWriter, *http.Request) {
first = first + 1
})
assert.NotContains(t, c.ListedPaths(), "/secret")
assert.Contains(t, c.ListedPaths(), "/nonswagger")
resp, _ := http.Get(s.URL + "/nonswagger")
assert.Equal(t, first, 1)
assert.Equal(t, resp.StatusCode, http.StatusOK)
c.Unregister("/nonswagger")
assert.NotContains(t, c.ListedPaths(), "/nonswagger")
resp, _ = http.Get(s.URL + "/nonswagger")
assert.Equal(t, first, 1)
assert.Equal(t, resp.StatusCode, http.StatusNotFound)
c.HandleFunc("/nonswagger", func(http.ResponseWriter, *http.Request) {
second = second + 1
})
assert.Contains(t, c.ListedPaths(), "/nonswagger")
resp, _ = http.Get(s.URL + "/nonswagger")
assert.Equal(t, first, 1)
assert.Equal(t, second, 1)
assert.Equal(t, resp.StatusCode, http.StatusOK)
}

View File

@ -19,6 +19,8 @@ package apiregistration
import (
"sort"
"strings"
"k8s.io/apimachinery/pkg/runtime/schema"
)
func SortedByGroup(servers []*APIService) [][]*APIService {
@ -57,3 +59,10 @@ func (s ByPriority) Less(i, j int) bool {
}
return s[i].Spec.Priority < s[j].Spec.Priority
}
// APIServiceNameToGroupVersion returns the GroupVersion for a given apiServiceName. The name
// must be valid, but any object you get back from an informer will be valid.
func APIServiceNameToGroupVersion(apiServiceName string) schema.GroupVersion {
tokens := strings.SplitN(apiServiceName, ".", 2)
return schema.GroupVersion{Group: tokens[1], Version: tokens[0]}
}

View File

@ -220,8 +220,8 @@ func (s *APIAggregator) AddAPIService(apiService *apiregistration.APIService, de
}
proxyHandler.updateAPIService(apiService, destinationHost)
s.proxyHandlers[apiService.Name] = proxyHandler
s.GenericAPIServer.HandlerContainer.ServeMux.Handle(proxyPath, proxyHandler)
s.GenericAPIServer.HandlerContainer.ServeMux.Handle(proxyPath+"/", proxyHandler)
s.GenericAPIServer.FallThroughHandler.Handle(proxyPath, proxyHandler)
s.GenericAPIServer.FallThroughHandler.UnlistedHandle(proxyPath+"/", proxyHandler)
// if we're dealing with the legacy group, we're done here
if apiService.Name == legacyAPIServiceName {
@ -241,20 +241,28 @@ func (s *APIAggregator) AddAPIService(apiService *apiregistration.APIService, de
lister: s.lister,
serviceLister: s.serviceLister,
endpointsLister: s.endpointsLister,
delegate: s.GenericAPIServer.FallThroughHandler,
delegate: s.delegateHandler,
}
// aggregation is protected
s.GenericAPIServer.HandlerContainer.ServeMux.Handle(groupPath, groupDiscoveryHandler)
s.GenericAPIServer.HandlerContainer.ServeMux.Handle(groupPath+"/", groupDiscoveryHandler)
s.GenericAPIServer.FallThroughHandler.Handle(groupPath, groupDiscoveryHandler)
s.GenericAPIServer.FallThroughHandler.UnlistedHandle(groupPath+"/", groupDiscoveryHandler)
s.handledGroups.Insert(apiService.Spec.Group)
}
// RemoveAPIService removes the APIService from being handled. Later on it will disable the proxy endpoint.
// Right now it does nothing because our handler has to properly 404 itself since muxes don't unregister
// RemoveAPIService removes the APIService from being handled. It is not thread-safe, so only call it on one thread at a time please.
// It's a slow moving API, so its ok to run the controller on a single thread.
func (s *APIAggregator) RemoveAPIService(apiServiceName string) {
proxyHandler, exists := s.proxyHandlers[apiServiceName]
if !exists {
return
version := apiregistration.APIServiceNameToGroupVersion(apiServiceName)
proxyPath := "/apis/" + version.Group + "/" + version.Version
// v1. is a special case for the legacy API. It proxies to a wider set of endpoints.
if apiServiceName == legacyAPIServiceName {
proxyPath = "/api"
}
proxyHandler.removeAPIService()
s.GenericAPIServer.FallThroughHandler.Unregister(proxyPath)
s.GenericAPIServer.FallThroughHandler.Unregister(proxyPath + "/")
delete(s.proxyHandlers, apiServiceName)
// TODO unregister group level discovery when there are no more versions for the group
// We don't need this right away because the handler properly delegates when no versions are present
}

View File

@ -18,7 +18,6 @@ package apiserver
import (
"net/http"
"strings"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -30,24 +29,9 @@ import (
apiregistrationapi "k8s.io/kube-aggregator/pkg/apis/apiregistration"
apiregistrationv1alpha1api "k8s.io/kube-aggregator/pkg/apis/apiregistration/v1alpha1"
informers "k8s.io/kube-aggregator/pkg/client/informers/internalversion/apiregistration/internalversion"
listers "k8s.io/kube-aggregator/pkg/client/listers/apiregistration/internalversion"
)
// WithAPIs adds the handling for /apis and /apis/<group: -apiregistration.k8s.io>.
func WithAPIs(handler http.Handler, codecs serializer.CodecFactory, informer informers.APIServiceInformer, serviceLister v1listers.ServiceLister, endpointsLister v1listers.EndpointsLister) http.Handler {
apisHandler := &apisHandler{
codecs: codecs,
lister: informer.Lister(),
delegate: handler,
serviceLister: serviceLister,
endpointsLister: endpointsLister,
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
apisHandler.ServeHTTP(w, req)
})
}
// apisHandler serves the `/apis` endpoint.
// This is registered as a filter so that it never collides with any explictly registered endpoints
type apisHandler struct {
@ -75,11 +59,6 @@ var discoveryGroup = metav1.APIGroup{
}
func (r *apisHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// if the URL is for OUR api group, serve it normally
if strings.HasPrefix(req.URL.Path+"/", "/apis/"+apiregistrationapi.GroupName+"/") {
r.delegate.ServeHTTP(w, req)
return
}
// don't handle URLs that aren't /apis
if req.URL.Path != "/apis" && req.URL.Path != "/apis/" {
r.delegate.ServeHTTP(w, req)
@ -210,7 +189,7 @@ func (r *apiGroupHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
if len(apiServicesForGroup) == 0 {
http.Error(w, "", http.StatusNotFound)
r.delegate.ServeHTTP(w, req)
return
}

View File

@ -338,13 +338,13 @@ func TestAPIGroupMissing(t *testing.T) {
t.Fatalf("expected %v, got %v", http.StatusForbidden, resp.StatusCode)
}
// groupName still has no api services for it (like it was deleted), it should 404
// groupName still has no api services for it (like it was deleted), it should delegate
resp, err = http.Get(server.URL + "/apis/groupName/")
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusNotFound {
t.Fatalf("expected %v, got %v", http.StatusNotFound, resp.StatusCode)
if resp.StatusCode != http.StatusForbidden {
t.Fatalf("expected %v, got %v", http.StatusForbidden, resp.StatusCode)
}
// missing group should delegate still has no api services for it (like it was deleted)

View File

@ -64,7 +64,12 @@ type proxyHandlingInfo struct {
}
func (r *proxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
handlingInfo := r.handlingInfo.Load().(proxyHandlingInfo)
value := r.handlingInfo.Load()
if value == nil {
r.localDelegate.ServeHTTP(w, req)
return
}
handlingInfo := value.(proxyHandlingInfo)
if handlingInfo.local {
r.localDelegate.ServeHTTP(w, req)
return
@ -197,7 +202,3 @@ func (r *proxyHandler) updateAPIService(apiService *apiregistrationapi.APIServic
newInfo.proxyRoundTripper, newInfo.transportBuildingError = restclient.TransportFor(newInfo.restConfig)
r.handlingInfo.Store(newInfo)
}
func (r *proxyHandler) removeAPIService() {
r.handlingInfo.Store(proxyHandlingInfo{})
}

View File

@ -83,13 +83,6 @@ func TestProxyHandler(t *testing.T) {
targetServer := httptest.NewTLSServer(target)
defer targetServer.Close()
handler := &proxyHandler{
localDelegate: http.NewServeMux(),
}
server := httptest.NewServer(handler)
defer server.Close()
tests := map[string]struct {
user user.Info
path string
@ -161,45 +154,53 @@ func TestProxyHandler(t *testing.T) {
for name, tc := range tests {
target.Reset()
handler.contextMapper = &fakeRequestContextMapper{user: tc.user}
handler.removeAPIService()
if tc.apiService != nil {
handler.updateAPIService(tc.apiService, tc.apiService.Spec.Service.Name+"."+tc.apiService.Spec.Service.Namespace+".svc")
curr := handler.handlingInfo.Load().(proxyHandlingInfo)
curr.destinationHost = targetServer.Listener.Addr().String()
handler.handlingInfo.Store(curr)
}
resp, err := http.Get(server.URL + tc.path)
if err != nil {
t.Errorf("%s: %v", name, err)
continue
}
if e, a := tc.expectedStatusCode, resp.StatusCode; e != a {
body, _ := httputil.DumpResponse(resp, true)
t.Logf("%s: %v", name, string(body))
t.Errorf("%s: expected %v, got %v", name, e, a)
continue
}
bytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("%s: %v", name, err)
continue
}
if !strings.Contains(string(bytes), tc.expectedBody) {
t.Errorf("%s: expected %q, got %q", name, tc.expectedBody, string(bytes))
continue
}
func() {
handler := &proxyHandler{
localDelegate: http.NewServeMux(),
}
handler.contextMapper = &fakeRequestContextMapper{user: tc.user}
server := httptest.NewServer(handler)
defer server.Close()
if e, a := tc.expectedCalled, target.called; e != a {
t.Errorf("%s: expected %v, got %v", name, e, a)
continue
}
// this varies every test
delete(target.headers, "X-Forwarded-Host")
if e, a := tc.expectedHeaders, target.headers; !reflect.DeepEqual(e, a) {
t.Errorf("%s: expected %v, got %v", name, e, a)
continue
}
if tc.apiService != nil {
handler.updateAPIService(tc.apiService, tc.apiService.Spec.Service.Name+"."+tc.apiService.Spec.Service.Namespace+".svc")
curr := handler.handlingInfo.Load().(proxyHandlingInfo)
curr.destinationHost = targetServer.Listener.Addr().String()
handler.handlingInfo.Store(curr)
}
resp, err := http.Get(server.URL + tc.path)
if err != nil {
t.Errorf("%s: %v", name, err)
return
}
if e, a := tc.expectedStatusCode, resp.StatusCode; e != a {
body, _ := httputil.DumpResponse(resp, true)
t.Logf("%s: %v", name, string(body))
t.Errorf("%s: expected %v, got %v", name, e, a)
return
}
bytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("%s: %v", name, err)
return
}
if !strings.Contains(string(bytes), tc.expectedBody) {
t.Errorf("%s: expected %q, got %q", name, tc.expectedBody, string(bytes))
return
}
if e, a := tc.expectedCalled, target.called; e != a {
t.Errorf("%s: expected %v, got %v", name, e, a)
return
}
// this varies every test
delete(target.headers, "X-Forwarded-Host")
if e, a := tc.expectedHeaders, target.headers; !reflect.DeepEqual(e, a) {
t.Errorf("%s: expected %v, got %v", name, e, a)
return
}
}()
}
}