mirror of https://github.com/k3s-io/k3s
add deregistration for paths
parent
5ba21e83b9
commit
cd950364e5
|
@ -21,13 +21,21 @@ import (
|
|||
"net/http"
|
||||
"runtime/debug"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
|
||||
)
|
||||
|
||||
// PathRecorderMux wraps a mux object and records the registered exposedPaths. It is _not_ go routine safe.
|
||||
// PathRecorderMux wraps a mux object and records the registered exposedPaths.
|
||||
type PathRecorderMux struct {
|
||||
mux *http.ServeMux
|
||||
lock sync.Mutex
|
||||
pathToHandler map[string]http.Handler
|
||||
|
||||
// mux stores an *http.ServeMux and is used to handle the actual serving
|
||||
mux atomic.Value
|
||||
|
||||
// exposedPaths is the list of paths that should be shown at /
|
||||
exposedPaths []string
|
||||
|
||||
// pathStacks holds the stacks of all registered paths. This allows us to show a more helpful message
|
||||
|
@ -37,10 +45,15 @@ type PathRecorderMux struct {
|
|||
|
||||
// NewPathRecorderMux creates a new PathRecorderMux with the given mux as the base mux.
|
||||
func NewPathRecorderMux() *PathRecorderMux {
|
||||
return &PathRecorderMux{
|
||||
mux: http.NewServeMux(),
|
||||
pathStacks: map[string]string{},
|
||||
ret := &PathRecorderMux{
|
||||
pathToHandler: map[string]http.Handler{},
|
||||
mux: atomic.Value{},
|
||||
exposedPaths: []string{},
|
||||
pathStacks: map[string]string{},
|
||||
}
|
||||
|
||||
ret.mux.Store(http.NewServeMux())
|
||||
return ret
|
||||
}
|
||||
|
||||
// ListedPaths returns the registered handler exposedPaths.
|
||||
|
@ -58,41 +71,81 @@ func (m *PathRecorderMux) trackCallers(path string) {
|
|||
m.pathStacks[path] = string(debug.Stack())
|
||||
}
|
||||
|
||||
// refreshMuxLocked creates a new mux and must be called while locked. Otherwise the view of handlers may
|
||||
// not be consistent
|
||||
func (m *PathRecorderMux) refreshMuxLocked() {
|
||||
mux := http.NewServeMux()
|
||||
for path, handler := range m.pathToHandler {
|
||||
mux.Handle(path, handler)
|
||||
}
|
||||
|
||||
m.mux.Store(mux)
|
||||
}
|
||||
|
||||
// Unregister removes a path from the mux.
|
||||
func (m *PathRecorderMux) Unregister(path string) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
delete(m.pathToHandler, path)
|
||||
delete(m.pathStacks, path)
|
||||
for i := range m.exposedPaths {
|
||||
if m.exposedPaths[i] == path {
|
||||
m.exposedPaths = append(m.exposedPaths[:i], m.exposedPaths[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
m.refreshMuxLocked()
|
||||
}
|
||||
|
||||
// Handle registers the handler for the given pattern.
|
||||
// If a handler already exists for pattern, Handle panics.
|
||||
func (m *PathRecorderMux) Handle(path string, handler http.Handler) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
m.trackCallers(path)
|
||||
|
||||
m.exposedPaths = append(m.exposedPaths, path)
|
||||
m.mux.Handle(path, handler)
|
||||
m.pathToHandler[path] = handler
|
||||
m.refreshMuxLocked()
|
||||
}
|
||||
|
||||
// HandleFunc registers the handler function for the given pattern.
|
||||
// If a handler already exists for pattern, Handle panics.
|
||||
func (m *PathRecorderMux) HandleFunc(path string, handler func(http.ResponseWriter, *http.Request)) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
m.trackCallers(path)
|
||||
|
||||
m.exposedPaths = append(m.exposedPaths, path)
|
||||
m.mux.HandleFunc(path, handler)
|
||||
m.pathToHandler[path] = http.HandlerFunc(handler)
|
||||
m.refreshMuxLocked()
|
||||
}
|
||||
|
||||
// UnlistedHandle registers the handler for the given pattern, but doesn't list it.
|
||||
// If a handler already exists for pattern, Handle panics.
|
||||
func (m *PathRecorderMux) UnlistedHandle(path string, handler http.Handler) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
m.trackCallers(path)
|
||||
|
||||
m.mux.Handle(path, handler)
|
||||
m.pathToHandler[path] = handler
|
||||
m.refreshMuxLocked()
|
||||
}
|
||||
|
||||
// UnlistedHandleFunc registers the handler function for the given pattern, but doesn't list it.
|
||||
// If a handler already exists for pattern, Handle panics.
|
||||
func (m *PathRecorderMux) UnlistedHandleFunc(path string, handler func(http.ResponseWriter, *http.Request)) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
m.trackCallers(path)
|
||||
|
||||
m.mux.HandleFunc(path, handler)
|
||||
m.pathToHandler[path] = http.HandlerFunc(handler)
|
||||
m.refreshMuxLocked()
|
||||
}
|
||||
|
||||
// ServeHTTP makes it an http.Handler
|
||||
func (m *PathRecorderMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
m.mux.ServeHTTP(w, r)
|
||||
m.mux.Load().(*http.ServeMux).ServeHTTP(w, r)
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package mux
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -30,3 +31,39 @@ func TestSecretHandlers(t *testing.T) {
|
|||
assert.NotContains(t, c.ListedPaths(), "/secret")
|
||||
assert.Contains(t, c.ListedPaths(), "/nonswagger")
|
||||
}
|
||||
|
||||
func TestUnregisterHandlers(t *testing.T) {
|
||||
first := 0
|
||||
second := 0
|
||||
|
||||
c := NewPathRecorderMux()
|
||||
s := httptest.NewServer(c)
|
||||
defer s.Close()
|
||||
|
||||
c.UnlistedHandleFunc("/secret", func(http.ResponseWriter, *http.Request) {})
|
||||
c.HandleFunc("/nonswagger", func(http.ResponseWriter, *http.Request) {
|
||||
first = first + 1
|
||||
})
|
||||
assert.NotContains(t, c.ListedPaths(), "/secret")
|
||||
assert.Contains(t, c.ListedPaths(), "/nonswagger")
|
||||
|
||||
resp, _ := http.Get(s.URL + "/nonswagger")
|
||||
assert.Equal(t, first, 1)
|
||||
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
||||
|
||||
c.Unregister("/nonswagger")
|
||||
assert.NotContains(t, c.ListedPaths(), "/nonswagger")
|
||||
|
||||
resp, _ = http.Get(s.URL + "/nonswagger")
|
||||
assert.Equal(t, first, 1)
|
||||
assert.Equal(t, resp.StatusCode, http.StatusNotFound)
|
||||
|
||||
c.HandleFunc("/nonswagger", func(http.ResponseWriter, *http.Request) {
|
||||
second = second + 1
|
||||
})
|
||||
assert.Contains(t, c.ListedPaths(), "/nonswagger")
|
||||
resp, _ = http.Get(s.URL + "/nonswagger")
|
||||
assert.Equal(t, first, 1)
|
||||
assert.Equal(t, second, 1)
|
||||
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
|
|
@ -19,6 +19,8 @@ package apiregistration
|
|||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"k8s.io/apimachinery/pkg/runtime/schema"
|
||||
)
|
||||
|
||||
func SortedByGroup(servers []*APIService) [][]*APIService {
|
||||
|
@ -57,3 +59,10 @@ func (s ByPriority) Less(i, j int) bool {
|
|||
}
|
||||
return s[i].Spec.Priority < s[j].Spec.Priority
|
||||
}
|
||||
|
||||
// APIServiceNameToGroupVersion returns the GroupVersion for a given apiServiceName. The name
|
||||
// must be valid, but any object you get back from an informer will be valid.
|
||||
func APIServiceNameToGroupVersion(apiServiceName string) schema.GroupVersion {
|
||||
tokens := strings.SplitN(apiServiceName, ".", 2)
|
||||
return schema.GroupVersion{Group: tokens[1], Version: tokens[0]}
|
||||
}
|
||||
|
|
|
@ -220,8 +220,8 @@ func (s *APIAggregator) AddAPIService(apiService *apiregistration.APIService, de
|
|||
}
|
||||
proxyHandler.updateAPIService(apiService, destinationHost)
|
||||
s.proxyHandlers[apiService.Name] = proxyHandler
|
||||
s.GenericAPIServer.HandlerContainer.ServeMux.Handle(proxyPath, proxyHandler)
|
||||
s.GenericAPIServer.HandlerContainer.ServeMux.Handle(proxyPath+"/", proxyHandler)
|
||||
s.GenericAPIServer.FallThroughHandler.Handle(proxyPath, proxyHandler)
|
||||
s.GenericAPIServer.FallThroughHandler.UnlistedHandle(proxyPath+"/", proxyHandler)
|
||||
|
||||
// if we're dealing with the legacy group, we're done here
|
||||
if apiService.Name == legacyAPIServiceName {
|
||||
|
@ -241,20 +241,28 @@ func (s *APIAggregator) AddAPIService(apiService *apiregistration.APIService, de
|
|||
lister: s.lister,
|
||||
serviceLister: s.serviceLister,
|
||||
endpointsLister: s.endpointsLister,
|
||||
delegate: s.GenericAPIServer.FallThroughHandler,
|
||||
delegate: s.delegateHandler,
|
||||
}
|
||||
// aggregation is protected
|
||||
s.GenericAPIServer.HandlerContainer.ServeMux.Handle(groupPath, groupDiscoveryHandler)
|
||||
s.GenericAPIServer.HandlerContainer.ServeMux.Handle(groupPath+"/", groupDiscoveryHandler)
|
||||
s.GenericAPIServer.FallThroughHandler.Handle(groupPath, groupDiscoveryHandler)
|
||||
s.GenericAPIServer.FallThroughHandler.UnlistedHandle(groupPath+"/", groupDiscoveryHandler)
|
||||
s.handledGroups.Insert(apiService.Spec.Group)
|
||||
}
|
||||
|
||||
// RemoveAPIService removes the APIService from being handled. Later on it will disable the proxy endpoint.
|
||||
// Right now it does nothing because our handler has to properly 404 itself since muxes don't unregister
|
||||
// RemoveAPIService removes the APIService from being handled. It is not thread-safe, so only call it on one thread at a time please.
|
||||
// It's a slow moving API, so its ok to run the controller on a single thread.
|
||||
func (s *APIAggregator) RemoveAPIService(apiServiceName string) {
|
||||
proxyHandler, exists := s.proxyHandlers[apiServiceName]
|
||||
if !exists {
|
||||
return
|
||||
version := apiregistration.APIServiceNameToGroupVersion(apiServiceName)
|
||||
|
||||
proxyPath := "/apis/" + version.Group + "/" + version.Version
|
||||
// v1. is a special case for the legacy API. It proxies to a wider set of endpoints.
|
||||
if apiServiceName == legacyAPIServiceName {
|
||||
proxyPath = "/api"
|
||||
}
|
||||
proxyHandler.removeAPIService()
|
||||
s.GenericAPIServer.FallThroughHandler.Unregister(proxyPath)
|
||||
s.GenericAPIServer.FallThroughHandler.Unregister(proxyPath + "/")
|
||||
delete(s.proxyHandlers, apiServiceName)
|
||||
|
||||
// TODO unregister group level discovery when there are no more versions for the group
|
||||
// We don't need this right away because the handler properly delegates when no versions are present
|
||||
}
|
||||
|
|
|
@ -18,7 +18,6 @@ package apiserver
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
apierrors "k8s.io/apimachinery/pkg/api/errors"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
|
@ -30,24 +29,9 @@ import (
|
|||
|
||||
apiregistrationapi "k8s.io/kube-aggregator/pkg/apis/apiregistration"
|
||||
apiregistrationv1alpha1api "k8s.io/kube-aggregator/pkg/apis/apiregistration/v1alpha1"
|
||||
informers "k8s.io/kube-aggregator/pkg/client/informers/internalversion/apiregistration/internalversion"
|
||||
listers "k8s.io/kube-aggregator/pkg/client/listers/apiregistration/internalversion"
|
||||
)
|
||||
|
||||
// WithAPIs adds the handling for /apis and /apis/<group: -apiregistration.k8s.io>.
|
||||
func WithAPIs(handler http.Handler, codecs serializer.CodecFactory, informer informers.APIServiceInformer, serviceLister v1listers.ServiceLister, endpointsLister v1listers.EndpointsLister) http.Handler {
|
||||
apisHandler := &apisHandler{
|
||||
codecs: codecs,
|
||||
lister: informer.Lister(),
|
||||
delegate: handler,
|
||||
serviceLister: serviceLister,
|
||||
endpointsLister: endpointsLister,
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
apisHandler.ServeHTTP(w, req)
|
||||
})
|
||||
}
|
||||
|
||||
// apisHandler serves the `/apis` endpoint.
|
||||
// This is registered as a filter so that it never collides with any explictly registered endpoints
|
||||
type apisHandler struct {
|
||||
|
@ -75,11 +59,6 @@ var discoveryGroup = metav1.APIGroup{
|
|||
}
|
||||
|
||||
func (r *apisHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
// if the URL is for OUR api group, serve it normally
|
||||
if strings.HasPrefix(req.URL.Path+"/", "/apis/"+apiregistrationapi.GroupName+"/") {
|
||||
r.delegate.ServeHTTP(w, req)
|
||||
return
|
||||
}
|
||||
// don't handle URLs that aren't /apis
|
||||
if req.URL.Path != "/apis" && req.URL.Path != "/apis/" {
|
||||
r.delegate.ServeHTTP(w, req)
|
||||
|
@ -210,7 +189,7 @@ func (r *apiGroupHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||
}
|
||||
|
||||
if len(apiServicesForGroup) == 0 {
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
r.delegate.ServeHTTP(w, req)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -338,13 +338,13 @@ func TestAPIGroupMissing(t *testing.T) {
|
|||
t.Fatalf("expected %v, got %v", http.StatusForbidden, resp.StatusCode)
|
||||
}
|
||||
|
||||
// groupName still has no api services for it (like it was deleted), it should 404
|
||||
// groupName still has no api services for it (like it was deleted), it should delegate
|
||||
resp, err = http.Get(server.URL + "/apis/groupName/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Fatalf("expected %v, got %v", http.StatusNotFound, resp.StatusCode)
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("expected %v, got %v", http.StatusForbidden, resp.StatusCode)
|
||||
}
|
||||
|
||||
// missing group should delegate still has no api services for it (like it was deleted)
|
||||
|
|
|
@ -64,7 +64,12 @@ type proxyHandlingInfo struct {
|
|||
}
|
||||
|
||||
func (r *proxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
handlingInfo := r.handlingInfo.Load().(proxyHandlingInfo)
|
||||
value := r.handlingInfo.Load()
|
||||
if value == nil {
|
||||
r.localDelegate.ServeHTTP(w, req)
|
||||
return
|
||||
}
|
||||
handlingInfo := value.(proxyHandlingInfo)
|
||||
if handlingInfo.local {
|
||||
r.localDelegate.ServeHTTP(w, req)
|
||||
return
|
||||
|
@ -197,7 +202,3 @@ func (r *proxyHandler) updateAPIService(apiService *apiregistrationapi.APIServic
|
|||
newInfo.proxyRoundTripper, newInfo.transportBuildingError = restclient.TransportFor(newInfo.restConfig)
|
||||
r.handlingInfo.Store(newInfo)
|
||||
}
|
||||
|
||||
func (r *proxyHandler) removeAPIService() {
|
||||
r.handlingInfo.Store(proxyHandlingInfo{})
|
||||
}
|
||||
|
|
|
@ -83,13 +83,6 @@ func TestProxyHandler(t *testing.T) {
|
|||
targetServer := httptest.NewTLSServer(target)
|
||||
defer targetServer.Close()
|
||||
|
||||
handler := &proxyHandler{
|
||||
localDelegate: http.NewServeMux(),
|
||||
}
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
|
||||
tests := map[string]struct {
|
||||
user user.Info
|
||||
path string
|
||||
|
@ -161,45 +154,53 @@ func TestProxyHandler(t *testing.T) {
|
|||
|
||||
for name, tc := range tests {
|
||||
target.Reset()
|
||||
handler.contextMapper = &fakeRequestContextMapper{user: tc.user}
|
||||
handler.removeAPIService()
|
||||
if tc.apiService != nil {
|
||||
handler.updateAPIService(tc.apiService, tc.apiService.Spec.Service.Name+"."+tc.apiService.Spec.Service.Namespace+".svc")
|
||||
curr := handler.handlingInfo.Load().(proxyHandlingInfo)
|
||||
curr.destinationHost = targetServer.Listener.Addr().String()
|
||||
handler.handlingInfo.Store(curr)
|
||||
}
|
||||
|
||||
resp, err := http.Get(server.URL + tc.path)
|
||||
if err != nil {
|
||||
t.Errorf("%s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
if e, a := tc.expectedStatusCode, resp.StatusCode; e != a {
|
||||
body, _ := httputil.DumpResponse(resp, true)
|
||||
t.Logf("%s: %v", name, string(body))
|
||||
t.Errorf("%s: expected %v, got %v", name, e, a)
|
||||
continue
|
||||
}
|
||||
bytes, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("%s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(string(bytes), tc.expectedBody) {
|
||||
t.Errorf("%s: expected %q, got %q", name, tc.expectedBody, string(bytes))
|
||||
continue
|
||||
}
|
||||
func() {
|
||||
handler := &proxyHandler{
|
||||
localDelegate: http.NewServeMux(),
|
||||
}
|
||||
handler.contextMapper = &fakeRequestContextMapper{user: tc.user}
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
|
||||
if e, a := tc.expectedCalled, target.called; e != a {
|
||||
t.Errorf("%s: expected %v, got %v", name, e, a)
|
||||
continue
|
||||
}
|
||||
// this varies every test
|
||||
delete(target.headers, "X-Forwarded-Host")
|
||||
if e, a := tc.expectedHeaders, target.headers; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("%s: expected %v, got %v", name, e, a)
|
||||
continue
|
||||
}
|
||||
if tc.apiService != nil {
|
||||
handler.updateAPIService(tc.apiService, tc.apiService.Spec.Service.Name+"."+tc.apiService.Spec.Service.Namespace+".svc")
|
||||
curr := handler.handlingInfo.Load().(proxyHandlingInfo)
|
||||
curr.destinationHost = targetServer.Listener.Addr().String()
|
||||
handler.handlingInfo.Store(curr)
|
||||
}
|
||||
|
||||
resp, err := http.Get(server.URL + tc.path)
|
||||
if err != nil {
|
||||
t.Errorf("%s: %v", name, err)
|
||||
return
|
||||
}
|
||||
if e, a := tc.expectedStatusCode, resp.StatusCode; e != a {
|
||||
body, _ := httputil.DumpResponse(resp, true)
|
||||
t.Logf("%s: %v", name, string(body))
|
||||
t.Errorf("%s: expected %v, got %v", name, e, a)
|
||||
return
|
||||
}
|
||||
bytes, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("%s: %v", name, err)
|
||||
return
|
||||
}
|
||||
if !strings.Contains(string(bytes), tc.expectedBody) {
|
||||
t.Errorf("%s: expected %q, got %q", name, tc.expectedBody, string(bytes))
|
||||
return
|
||||
}
|
||||
|
||||
if e, a := tc.expectedCalled, target.called; e != a {
|
||||
t.Errorf("%s: expected %v, got %v", name, e, a)
|
||||
return
|
||||
}
|
||||
// this varies every test
|
||||
delete(target.headers, "X-Forwarded-Host")
|
||||
if e, a := tc.expectedHeaders, target.headers; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("%s: expected %v, got %v", name, e, a)
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue