Remove request context mapper

pull/8/head
Jordan Liggitt 2018-04-18 11:12:15 -04:00
parent 1ee2ac07c1
commit 8ea88a5092
No known key found for this signature in database
GPG Key ID: 39928704103C7229
70 changed files with 293 additions and 927 deletions

View File

@ -34,14 +34,12 @@ import (
// BuildHandlerChain builds a handler chain with a base handler and CompletedConfig.
func BuildHandlerChain(apiHandler http.Handler, c *CompletedConfig) http.Handler {
requestContextMapper := apirequest.NewRequestContextMapper()
requestInfoResolver := &apirequest.RequestInfoFactory{}
failedHandler := genericapifilters.Unauthorized(requestContextMapper, legacyscheme.Codecs, false)
failedHandler := genericapifilters.Unauthorized(legacyscheme.Codecs, false)
handler := genericapifilters.WithAuthorization(apiHandler, requestContextMapper, c.Authorization.Authorizer, legacyscheme.Codecs)
handler = genericapifilters.WithAuthentication(handler, requestContextMapper, c.Authentication.Authenticator, failedHandler)
handler = genericapifilters.WithRequestInfo(handler, requestInfoResolver, requestContextMapper)
handler = apirequest.WithRequestContext(handler, requestContextMapper)
handler := genericapifilters.WithAuthorization(apiHandler, c.Authorization.Authorizer, legacyscheme.Codecs)
handler = genericapifilters.WithAuthentication(handler, c.Authentication.Authenticator, failedHandler)
handler = genericapifilters.WithRequestInfo(handler, requestInfoResolver)
handler = genericfilters.WithPanicRecovery(handler)
return handler

View File

@ -183,7 +183,7 @@ func CreateServerChain(completedOptions completedServerRunOptions, stopCh <-chan
// just start the API server as is because clients don't get built correctly when you do this
if len(os.Getenv("KUBE_API_VERSIONS")) > 0 {
if insecureServingOptions != nil {
insecureHandlerChain := kubeserver.BuildInsecureHandlerChain(kubeAPIServer.GenericAPIServer.UnprotectedHandler(), kubeAPIServerConfig.GenericConfig, kubeAPIServer.GenericAPIServer.RequestContextMapper())
insecureHandlerChain := kubeserver.BuildInsecureHandlerChain(kubeAPIServer.GenericAPIServer.UnprotectedHandler(), kubeAPIServerConfig.GenericConfig)
if err := kubeserver.NonBlockingRun(insecureServingOptions, insecureHandlerChain, kubeAPIServerConfig.GenericConfig.RequestTimeout, stopCh); err != nil {
return nil, err
}
@ -211,7 +211,7 @@ func CreateServerChain(completedOptions completedServerRunOptions, stopCh <-chan
}
if insecureServingOptions != nil {
insecureHandlerChain := kubeserver.BuildInsecureHandlerChain(aggregatorServer.GenericAPIServer.UnprotectedHandler(), kubeAPIServerConfig.GenericConfig, aggregatorServer.GenericAPIServer.RequestContextMapper())
insecureHandlerChain := kubeserver.BuildInsecureHandlerChain(aggregatorServer.GenericAPIServer.UnprotectedHandler(), kubeAPIServerConfig.GenericConfig)
if err := kubeserver.NonBlockingRun(insecureServingOptions, insecureHandlerChain, kubeAPIServerConfig.GenericConfig.RequestTimeout, stopCh); err != nil {
return nil, err
}

View File

@ -13,7 +13,6 @@ go_library(
"//vendor/github.com/golang/glog:go_default_library",
"//vendor/k8s.io/apiserver/pkg/authentication/user:go_default_library",
"//vendor/k8s.io/apiserver/pkg/endpoints/filters:go_default_library",
"//vendor/k8s.io/apiserver/pkg/endpoints/request:go_default_library",
"//vendor/k8s.io/apiserver/pkg/features:go_default_library",
"//vendor/k8s.io/apiserver/pkg/server:go_default_library",
"//vendor/k8s.io/apiserver/pkg/server/filters:go_default_library",

View File

@ -25,7 +25,6 @@ import (
"k8s.io/apiserver/pkg/authentication/user"
genericapifilters "k8s.io/apiserver/pkg/endpoints/filters"
apirequest "k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/features"
"k8s.io/apiserver/pkg/server"
genericfilters "k8s.io/apiserver/pkg/server/filters"
@ -38,20 +37,19 @@ import (
// You shouldn't be using this. It makes sig-auth sad.
// InsecureServingInfo *ServingInfo
func BuildInsecureHandlerChain(apiHandler http.Handler, c *server.Config, contextMapper apirequest.RequestContextMapper) http.Handler {
func BuildInsecureHandlerChain(apiHandler http.Handler, c *server.Config) http.Handler {
handler := apiHandler
if utilfeature.DefaultFeatureGate.Enabled(features.AdvancedAuditing) {
handler = genericapifilters.WithAudit(handler, contextMapper, c.AuditBackend, c.AuditPolicyChecker, c.LongRunningFunc)
handler = genericapifilters.WithAudit(handler, c.AuditBackend, c.AuditPolicyChecker, c.LongRunningFunc)
} else {
handler = genericapifilters.WithLegacyAudit(handler, contextMapper, c.LegacyAuditWriter)
handler = genericapifilters.WithLegacyAudit(handler, c.LegacyAuditWriter)
}
handler = genericapifilters.WithAuthentication(handler, contextMapper, insecureSuperuser{}, nil)
handler = genericapifilters.WithAuthentication(handler, insecureSuperuser{}, nil)
handler = genericfilters.WithCORS(handler, c.CorsAllowedOriginList, nil, nil, nil, "true")
handler = genericfilters.WithTimeoutForNonLongRunningRequests(handler, contextMapper, c.LongRunningFunc, c.RequestTimeout)
handler = genericfilters.WithMaxInFlightLimit(handler, c.MaxRequestsInFlight, c.MaxMutatingRequestsInFlight, contextMapper, c.LongRunningFunc)
handler = genericfilters.WithWaitGroup(handler, contextMapper, c.LongRunningFunc, c.HandlerChainWaitGroup)
handler = genericapifilters.WithRequestInfo(handler, server.NewRequestInfoResolver(c), contextMapper)
handler = apirequest.WithRequestContext(handler, contextMapper)
handler = genericfilters.WithTimeoutForNonLongRunningRequests(handler, c.LongRunningFunc, c.RequestTimeout)
handler = genericfilters.WithMaxInFlightLimit(handler, c.MaxRequestsInFlight, c.MaxMutatingRequestsInFlight, c.LongRunningFunc)
handler = genericfilters.WithWaitGroup(handler, c.LongRunningFunc, c.HandlerChainWaitGroup)
handler = genericapifilters.WithRequestInfo(handler, server.NewRequestInfoResolver(c))
handler = genericfilters.WithPanicRecovery(handler)
return handler

View File

@ -159,7 +159,6 @@ go_test(
"//vendor/k8s.io/apimachinery/pkg/util/net:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/sets:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/version:go_default_library",
"//vendor/k8s.io/apiserver/pkg/endpoints/request:go_default_library",
"//vendor/k8s.io/apiserver/pkg/server:go_default_library",
"//vendor/k8s.io/apiserver/pkg/server/options:go_default_library",
"//vendor/k8s.io/apiserver/pkg/server/storage:go_default_library",

View File

@ -27,7 +27,6 @@ import (
"net/http/httptest"
"testing"
apirequest "k8s.io/apiserver/pkg/endpoints/request"
genericapiserver "k8s.io/apiserver/pkg/server"
"k8s.io/kubernetes/pkg/api/legacyscheme"
openapigen "k8s.io/kubernetes/pkg/generated/openapi"
@ -60,7 +59,7 @@ func TestValidOpenAPISpec(t *testing.T) {
}
// make sure swagger.json is not registered before calling PrepareRun.
server := httptest.NewServer(apirequest.WithRequestContext(master.GenericAPIServer.Handler.Director, master.GenericAPIServer.RequestContextMapper()))
server := httptest.NewServer(master.GenericAPIServer.Handler.Director)
defer server.Close()
resp, err := http.Get(server.URL + "/swagger.json")
if !assert.NoError(err) {

View File

@ -35,7 +35,6 @@ import (
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/version"
genericapirequest "k8s.io/apiserver/pkg/endpoints/request"
genericapiserver "k8s.io/apiserver/pkg/server"
"k8s.io/apiserver/pkg/server/options"
serverstorage "k8s.io/apiserver/pkg/server/storage"
@ -277,7 +276,7 @@ func TestAPIVersionOfDiscoveryEndpoints(t *testing.T) {
master, etcdserver, _, assert := newMaster(t)
defer etcdserver.Terminate(t)
server := httptest.NewServer(genericapirequest.WithRequestContext(master.GenericAPIServer.Handler.GoRestfulContainer.ServeMux, master.GenericAPIServer.RequestContextMapper()))
server := httptest.NewServer(master.GenericAPIServer.Handler.GoRestfulContainer.ServeMux)
// /api exists in release-1.1
resp, err := http.Get(server.URL + "/api")

View File

@ -183,7 +183,6 @@ func (c completedConfig) New(delegationTarget genericapiserver.DelegationTarget)
crdHandler := NewCustomResourceDefinitionHandler(
versionDiscoveryHandler,
groupDiscoveryHandler,
s.GenericAPIServer.RequestContextMapper(),
s.Informers.Apiextensions().InternalVersion().CustomResourceDefinitions(),
delegateHandler,
c.ExtraConfig.CRDRESTOptionsGetter,
@ -197,7 +196,7 @@ func (c completedConfig) New(delegationTarget genericapiserver.DelegationTarget)
return s, nil
}
crdController := NewDiscoveryController(s.Informers.Apiextensions().InternalVersion().CustomResourceDefinitions(), versionDiscoveryHandler, groupDiscoveryHandler, delegationTarget.RequestContextMapper())
crdController := NewDiscoveryController(s.Informers.Apiextensions().InternalVersion().CustomResourceDefinitions(), versionDiscoveryHandler, groupDiscoveryHandler)
namingController := status.NewNamingConditionController(s.Informers.Apiextensions().InternalVersion().CustomResourceDefinitions(), crdClient.Apiextensions())
finalizingController := finalizer.NewCRDFinalizer(
s.Informers.Apiextensions().InternalVersion().CustomResourceDefinitions(),

View File

@ -29,7 +29,6 @@ import (
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/apiserver/pkg/endpoints/discovery"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/client-go/tools/cache"
"k8s.io/client-go/util/workqueue"
@ -41,7 +40,6 @@ import (
type DiscoveryController struct {
versionHandler *versionDiscoveryHandler
groupHandler *groupDiscoveryHandler
contextMapper request.RequestContextMapper
crdLister listers.CustomResourceDefinitionLister
crdsSynced cache.InformerSynced
@ -52,13 +50,12 @@ type DiscoveryController struct {
queue workqueue.RateLimitingInterface
}
func NewDiscoveryController(crdInformer informers.CustomResourceDefinitionInformer, versionHandler *versionDiscoveryHandler, groupHandler *groupDiscoveryHandler, contextMapper request.RequestContextMapper) *DiscoveryController {
func NewDiscoveryController(crdInformer informers.CustomResourceDefinitionInformer, versionHandler *versionDiscoveryHandler, groupHandler *groupDiscoveryHandler) *DiscoveryController {
c := &DiscoveryController{
versionHandler: versionHandler,
groupHandler: groupHandler,
crdLister: crdInformer.Lister(),
crdsSynced: crdInformer.Informer().HasSynced,
contextMapper: contextMapper,
queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "DiscoveryController"),
}
@ -153,7 +150,7 @@ func (c *DiscoveryController) sync(version schema.GroupVersion) error {
// the preferred versions for a group is arbitrary since there cannot be duplicate resources
PreferredVersion: apiVersionsForDiscovery[0],
}
c.groupHandler.setDiscovery(version.Group, discovery.NewAPIGroupHandler(Codecs, apiGroup, c.contextMapper))
c.groupHandler.setDiscovery(version.Group, discovery.NewAPIGroupHandler(Codecs, apiGroup))
if !foundVersion {
c.versionHandler.unsetDiscovery(version)
@ -161,7 +158,7 @@ func (c *DiscoveryController) sync(version schema.GroupVersion) error {
}
c.versionHandler.setDiscovery(version, discovery.NewAPIVersionHandler(Codecs, version, discovery.APIResourceListerFunc(func() []metav1.APIResource {
return apiResourcesForDiscovery
}), c.contextMapper))
})))
return nil
}

View File

@ -80,8 +80,6 @@ type crdHandler struct {
// which is suited for most read and rarely write cases
customStorage atomic.Value
requestContextMapper apirequest.RequestContextMapper
crdLister listers.CustomResourceDefinitionLister
delegate http.Handler
@ -109,7 +107,6 @@ type crdStorageMap map[types.UID]*crdInfo
func NewCustomResourceDefinitionHandler(
versionDiscoveryHandler *versionDiscoveryHandler,
groupDiscoveryHandler *groupDiscoveryHandler,
requestContextMapper apirequest.RequestContextMapper,
crdInformer informers.CustomResourceDefinitionInformer,
delegate http.Handler,
restOptionsGetter generic.RESTOptionsGetter,
@ -118,7 +115,6 @@ func NewCustomResourceDefinitionHandler(
versionDiscoveryHandler: versionDiscoveryHandler,
groupDiscoveryHandler: groupDiscoveryHandler,
customStorage: atomic.Value{},
requestContextMapper: requestContextMapper,
crdLister: crdInformer.Lister(),
delegate: delegate,
restOptionsGetter: restOptionsGetter,
@ -138,11 +134,7 @@ func NewCustomResourceDefinitionHandler(
}
func (r *crdHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
ctx, ok := r.requestContextMapper.Get(req)
if !ok {
responsewriters.InternalError(w, req, fmt.Errorf("no context found for request"))
return
}
ctx := req.Context()
requestInfo, ok := apirequest.RequestInfoFrom(ctx)
if !ok {
responsewriters.InternalError(w, req, fmt.Errorf("no RequestInfo found in the context"))
@ -457,23 +449,12 @@ func (r *crdHandler) getOrCreateServingInfoFor(crd *apiextensions.CustomResource
clusterScoped := crd.Spec.Scope == apiextensions.ClusterScoped
var ctxFn handlers.ContextFunc
ctxFn = func(req *http.Request) apirequest.Context {
ret, _ := r.requestContextMapper.Get(req)
return ret
}
requestScope := handlers.RequestScope{
Namer: handlers.ContextBasedNaming{
GetContext: ctxFn,
SelfLinker: meta.NewAccessor(),
ClusterScoped: clusterScoped,
SelfLinkPathPrefix: selfLinkPrefix,
},
ContextFunc: func(req *http.Request) apirequest.Context {
ret, _ := r.requestContextMapper.Get(req)
return ret
},
Serializer: unstructuredNegotiatedSerializer{typer: typer, creator: creator},
ParameterCodec: parameterCodec,
@ -511,7 +492,6 @@ func (r *crdHandler) getOrCreateServingInfoFor(crd *apiextensions.CustomResource
ret.scaleRequestScope.Serializer = serializer.NewCodecFactory(scaleConverter.Scheme())
ret.scaleRequestScope.Kind = autoscalingv1.SchemeGroupVersion.WithKind("Scale")
ret.scaleRequestScope.Namer = handlers.ContextBasedNaming{
GetContext: ctxFn,
SelfLinker: meta.NewAccessor(),
ClusterScoped: clusterScoped,
SelfLinkPathPrefix: selfLinkPrefix,
@ -521,7 +501,6 @@ func (r *crdHandler) getOrCreateServingInfoFor(crd *apiextensions.CustomResource
// override status subresource values
ret.statusRequestScope.Subresource = "status"
ret.statusRequestScope.Namer = handlers.ContextBasedNaming{
GetContext: ctxFn,
SelfLinker: meta.NewAccessor(),
ClusterScoped: clusterScoped,
SelfLinkPathPrefix: selfLinkPrefix,

View File

@ -76,7 +76,6 @@ go_library(
"//vendor/k8s.io/apiserver/pkg/endpoints/handlers:go_default_library",
"//vendor/k8s.io/apiserver/pkg/endpoints/handlers/negotiation:go_default_library",
"//vendor/k8s.io/apiserver/pkg/endpoints/metrics:go_default_library",
"//vendor/k8s.io/apiserver/pkg/endpoints/request:go_default_library",
"//vendor/k8s.io/apiserver/pkg/registry/rest:go_default_library",
"//vendor/k8s.io/apiserver/pkg/server/filters:go_default_library",
],

View File

@ -130,7 +130,6 @@ var accessor = meta.NewAccessor()
var selfLinker runtime.SelfLinker = accessor
var mapper, namespaceMapper meta.RESTMapper // The mappers with namespace and with legacy namespace scopes.
var admissionControl admission.Interface
var requestContextMapper request.RequestContextMapper
func init() {
metav1.AddToGroupVersion(scheme, metav1.SchemeGroupVersion)
@ -238,7 +237,6 @@ func init() {
mapper = nsMapper
namespaceMapper = nsMapper
admissionControl = alwaysAdmit{}
requestContextMapper = request.NewRequestContextMapper()
scheme.AddFieldLabelConversionFunc(grouplessGroupVersion.String(), "Simple",
func(label, value string) (string, string, error) {
@ -295,8 +293,7 @@ func handleInternal(storage map[string]rest.Storage, admissionControl admission.
ParameterCodec: parameterCodec,
Admit: admissionControl,
Context: requestContextMapper,
Admit: admissionControl,
}
// groupless v1 version
@ -334,13 +331,11 @@ func handleInternal(storage map[string]rest.Storage, admissionControl admission.
panic(fmt.Sprintf("unable to install container %s: %v", group.GroupVersion, err))
}
}
handler := genericapifilters.WithAudit(mux, requestContextMapper, auditSink, auditpolicy.FakeChecker(auditinternal.LevelRequestResponse, nil), func(r *http.Request, requestInfo *request.RequestInfo) bool {
handler := genericapifilters.WithAudit(mux, auditSink, auditpolicy.FakeChecker(auditinternal.LevelRequestResponse, nil), func(r *http.Request, requestInfo *request.RequestInfo) bool {
// simplified long-running check
return requestInfo.Verb == "watch" || requestInfo.Verb == "proxy"
})
handler = genericapifilters.WithRequestInfo(handler, testRequestInfoResolver(), requestContextMapper)
handler = request.WithRequestContext(handler, requestContextMapper)
handler = genericapifilters.WithRequestInfo(handler, testRequestInfoResolver())
return &defaultAPIServer{handler, container}
}
@ -1225,11 +1220,8 @@ func TestListCompression(t *testing.T) {
}
var handler = handleInternal(storage, admissionControl, selfLinker, nil)
requestContextMapper = request.NewRequestContextMapper()
handler = filters.WithCompression(handler, requestContextMapper)
handler = genericapifilters.WithRequestInfo(handler, newTestRequestInfoResolver(), requestContextMapper)
handler = request.WithRequestContext(handler, requestContextMapper)
handler = filters.WithCompression(handler)
handler = genericapifilters.WithRequestInfo(handler, newTestRequestInfoResolver())
server := httptest.NewServer(handler)
@ -1635,13 +1627,10 @@ func TestGetCompression(t *testing.T) {
namespace: "default",
}
requestContextMapper = request.NewRequestContextMapper()
storage["simple"] = &simpleStorage
handler := handleLinker(storage, selfLinker)
handler = filters.WithCompression(handler, requestContextMapper)
handler = genericapifilters.WithRequestInfo(handler, newTestRequestInfoResolver(), requestContextMapper)
handler = request.WithRequestContext(handler, requestContextMapper)
handler = filters.WithCompression(handler)
handler = genericapifilters.WithRequestInfo(handler, newTestRequestInfoResolver())
server := httptest.NewServer(handler)
defer server.Close()
@ -3297,9 +3286,8 @@ func TestParentResourceIsRequired(t *testing.T) {
Typer: scheme,
Linker: selfLinker,
Admit: admissionControl,
Context: requestContextMapper,
Mapper: namespaceMapper,
Admit: admissionControl,
Mapper: namespaceMapper,
GroupVersion: newGroupVersion,
OptionsExternalVersion: &newGroupVersion,
@ -3328,9 +3316,8 @@ func TestParentResourceIsRequired(t *testing.T) {
Typer: scheme,
Linker: selfLinker,
Admit: admissionControl,
Context: requestContextMapper,
Mapper: namespaceMapper,
Admit: admissionControl,
Mapper: namespaceMapper,
GroupVersion: newGroupVersion,
OptionsExternalVersion: &newGroupVersion,
@ -3343,8 +3330,7 @@ func TestParentResourceIsRequired(t *testing.T) {
t.Fatal(err)
}
handler := genericapifilters.WithRequestInfo(container, newTestRequestInfoResolver(), requestContextMapper)
handler = request.WithRequestContext(handler, requestContextMapper)
handler := genericapifilters.WithRequestInfo(container, newTestRequestInfoResolver())
// resource is NOT registered in the root scope
w := httptest.NewRecorder()
@ -3744,7 +3730,7 @@ func (obj *UnregisteredAPIObject) DeepCopyObject() runtime.Object {
func TestWriteJSONDecodeError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
responsewriters.WriteObjectNegotiated(request.NewContext(), codecs, newGroupVersion, w, req, http.StatusOK, &UnregisteredAPIObject{"Undecodable"})
responsewriters.WriteObjectNegotiated(codecs, newGroupVersion, w, req, http.StatusOK, &UnregisteredAPIObject{"Undecodable"})
}))
defer server.Close()
// We send a 200 status code before we encode the object, so we expect OK, but there will
@ -3954,8 +3940,7 @@ func TestXGSubresource(t *testing.T) {
ParameterCodec: parameterCodec,
Admit: admissionControl,
Context: requestContextMapper,
Admit: admissionControl,
Root: "/" + prefix,
GroupVersion: testGroupVersion,
@ -4058,8 +4043,7 @@ func BenchmarkUpdateProtobuf(b *testing.B) {
}
func newTestServer(handler http.Handler) *httptest.Server {
handler = genericapifilters.WithRequestInfo(handler, newTestRequestInfoResolver(), requestContextMapper)
handler = request.WithRequestContext(handler, requestContextMapper)
handler = genericapifilters.WithRequestInfo(handler, newTestRequestInfoResolver())
return httptest.NewServer(handler)
}

View File

@ -22,6 +22,7 @@ go_test(
"//vendor/k8s.io/apimachinery/pkg/runtime/schema:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/runtime/serializer:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/net:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/sets:go_default_library",
"//vendor/k8s.io/apiserver/pkg/endpoints/request:go_default_library",
],
)
@ -45,7 +46,6 @@ go_library(
"//vendor/k8s.io/apimachinery/pkg/util/net:go_default_library",
"//vendor/k8s.io/apiserver/pkg/endpoints/handlers/negotiation:go_default_library",
"//vendor/k8s.io/apiserver/pkg/endpoints/handlers/responsewriters:go_default_library",
"//vendor/k8s.io/apiserver/pkg/endpoints/request:go_default_library",
],
)

View File

@ -17,7 +17,6 @@ limitations under the License.
package discovery
import (
"errors"
"net/http"
"github.com/emicklei/go-restful"
@ -27,18 +26,16 @@ import (
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apiserver/pkg/endpoints/handlers/negotiation"
"k8s.io/apiserver/pkg/endpoints/handlers/responsewriters"
"k8s.io/apiserver/pkg/endpoints/request"
)
// APIGroupHandler creates a webservice serving the supported versions, preferred version, and name
// of a group. E.g., such a web service will be registered at /apis/extensions.
type APIGroupHandler struct {
serializer runtime.NegotiatedSerializer
contextMapper request.RequestContextMapper
group metav1.APIGroup
serializer runtime.NegotiatedSerializer
group metav1.APIGroup
}
func NewAPIGroupHandler(serializer runtime.NegotiatedSerializer, group metav1.APIGroup, contextMapper request.RequestContextMapper) *APIGroupHandler {
func NewAPIGroupHandler(serializer runtime.NegotiatedSerializer, group metav1.APIGroup) *APIGroupHandler {
if keepUnversioned(group.Name) {
// Because in release 1.1, /apis/extensions returns response with empty
// APIVersion, we use stripVersionNegotiatedSerializer to keep the
@ -47,9 +44,8 @@ func NewAPIGroupHandler(serializer runtime.NegotiatedSerializer, group metav1.AP
}
return &APIGroupHandler{
serializer: serializer,
contextMapper: contextMapper,
group: group,
serializer: serializer,
group: group,
}
}
@ -73,10 +69,5 @@ func (s *APIGroupHandler) handle(req *restful.Request, resp *restful.Response) {
}
func (s *APIGroupHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
ctx, ok := s.contextMapper.Get(req)
if !ok {
responsewriters.InternalError(w, req, errors.New("no context found for request"))
return
}
responsewriters.WriteObjectNegotiated(ctx, s.serializer, schema.GroupVersion{}, w, req, http.StatusOK, &s.group)
responsewriters.WriteObjectNegotiated(s.serializer, schema.GroupVersion{}, w, req, http.StatusOK, &s.group)
}

View File

@ -17,7 +17,6 @@ limitations under the License.
package discovery
import (
"errors"
"net/http"
"github.com/emicklei/go-restful"
@ -28,31 +27,28 @@ import (
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apiserver/pkg/endpoints/handlers/negotiation"
"k8s.io/apiserver/pkg/endpoints/handlers/responsewriters"
"k8s.io/apiserver/pkg/endpoints/request"
)
// legacyRootAPIHandler creates a webservice serving api group discovery.
type legacyRootAPIHandler struct {
// addresses is used to build cluster IPs for discovery.
addresses Addresses
apiPrefix string
serializer runtime.NegotiatedSerializer
apiVersions []string
contextMapper request.RequestContextMapper
addresses Addresses
apiPrefix string
serializer runtime.NegotiatedSerializer
apiVersions []string
}
func NewLegacyRootAPIHandler(addresses Addresses, serializer runtime.NegotiatedSerializer, apiPrefix string, apiVersions []string, contextMapper request.RequestContextMapper) *legacyRootAPIHandler {
func NewLegacyRootAPIHandler(addresses Addresses, serializer runtime.NegotiatedSerializer, apiPrefix string, apiVersions []string) *legacyRootAPIHandler {
// Because in release 1.1, /apis returns response with empty APIVersion, we
// use stripVersionNegotiatedSerializer to keep the response backwards
// compatible.
serializer = stripVersionNegotiatedSerializer{serializer}
return &legacyRootAPIHandler{
addresses: addresses,
apiPrefix: apiPrefix,
serializer: serializer,
apiVersions: apiVersions,
contextMapper: contextMapper,
addresses: addresses,
apiPrefix: apiPrefix,
serializer: serializer,
apiVersions: apiVersions,
}
}
@ -72,17 +68,11 @@ func (s *legacyRootAPIHandler) WebService() *restful.WebService {
}
func (s *legacyRootAPIHandler) handle(req *restful.Request, resp *restful.Response) {
ctx, ok := s.contextMapper.Get(req.Request)
if !ok {
responsewriters.InternalError(resp.ResponseWriter, req.Request, errors.New("no context found for request"))
return
}
clientIP := utilnet.GetClientIP(req.Request)
apiVersions := &metav1.APIVersions{
ServerAddressByClientCIDRs: s.addresses.ServerAddressByClientCIDRs(clientIP),
Versions: s.apiVersions,
}
responsewriters.WriteObjectNegotiated(ctx, s.serializer, schema.GroupVersion{}, resp.ResponseWriter, req.Request, http.StatusOK, apiVersions)
responsewriters.WriteObjectNegotiated(s.serializer, schema.GroupVersion{}, resp.ResponseWriter, req.Request, http.StatusOK, apiVersions)
}

View File

@ -17,7 +17,6 @@ limitations under the License.
package discovery
import (
"errors"
"net/http"
"sync"
@ -29,7 +28,6 @@ import (
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apiserver/pkg/endpoints/handlers/negotiation"
"k8s.io/apiserver/pkg/endpoints/handlers/responsewriters"
"k8s.io/apiserver/pkg/endpoints/request"
)
// GroupManager is an interface that allows dynamic mutation of the existing webservice to handle
@ -48,8 +46,7 @@ type rootAPIsHandler struct {
// addresses is used to build cluster IPs for discovery.
addresses Addresses
serializer runtime.NegotiatedSerializer
contextMapper request.RequestContextMapper
serializer runtime.NegotiatedSerializer
// Map storing information about all groups to be exposed in discovery response.
// The map is from name to the group.
@ -59,17 +56,16 @@ type rootAPIsHandler struct {
apiGroupNames []string
}
func NewRootAPIsHandler(addresses Addresses, serializer runtime.NegotiatedSerializer, contextMapper request.RequestContextMapper) *rootAPIsHandler {
func NewRootAPIsHandler(addresses Addresses, serializer runtime.NegotiatedSerializer) *rootAPIsHandler {
// Because in release 1.1, /apis returns response with empty APIVersion, we
// use stripVersionNegotiatedSerializer to keep the response backwards
// compatible.
serializer = stripVersionNegotiatedSerializer{serializer}
return &rootAPIsHandler{
addresses: addresses,
serializer: serializer,
apiGroups: map[string]metav1.APIGroup{},
contextMapper: contextMapper,
addresses: addresses,
serializer: serializer,
apiGroups: map[string]metav1.APIGroup{},
}
}
@ -99,12 +95,6 @@ func (s *rootAPIsHandler) RemoveGroup(groupName string) {
}
func (s *rootAPIsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
ctx, ok := s.contextMapper.Get(req)
if !ok {
responsewriters.InternalError(resp, req, errors.New("no context found for request"))
return
}
s.lock.RLock()
defer s.lock.RUnlock()
@ -121,7 +111,7 @@ func (s *rootAPIsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request)
groups[i].ServerAddressByClientCIDRs = serverCIDR
}
responsewriters.WriteObjectNegotiated(ctx, s.serializer, schema.GroupVersion{}, resp, req, http.StatusOK, &metav1.APIGroupList{Groups: groups})
responsewriters.WriteObjectNegotiated(s.serializer, schema.GroupVersion{}, resp, req, http.StatusOK, &metav1.APIGroupList{Groups: groups})
}
func (s *rootAPIsHandler) restfulHandle(req *restful.Request, resp *restful.Response) {

View File

@ -33,6 +33,7 @@ import (
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/runtime/serializer"
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apiserver/pkg/endpoints/request"
)
@ -83,11 +84,27 @@ func getGroupList(t *testing.T, server *httptest.Server) (*metav1.APIGroupList,
return &groupList, err
}
func TestDiscoveryAtAPIS(t *testing.T) {
mapper := request.NewRequestContextMapper()
handler := NewRootAPIsHandler(DefaultAddresses{DefaultAddress: "192.168.1.1"}, codecs, mapper)
func contextHandler(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
resolver := &request.RequestInfoFactory{
APIPrefixes: sets.NewString("api", "apis"),
GrouplessAPIPrefixes: sets.NewString("api"),
}
info, err := resolver.NewRequestInfo(req)
if err == nil {
ctx = request.WithRequestInfo(ctx, info)
}
req = req.WithContext(ctx)
handler.ServeHTTP(w, req)
})
}
func TestDiscoveryAtAPIS(t *testing.T) {
handler := NewRootAPIsHandler(DefaultAddresses{DefaultAddress: "192.168.1.1"}, codecs)
server := httptest.NewServer(contextHandler(handler))
server := httptest.NewServer(request.WithRequestContext(handler, mapper))
groupList, err := getGroupList(t, server)
if err != nil {
t.Fatalf("unexpected error: %v", err)
@ -135,10 +152,9 @@ func TestDiscoveryAtAPIS(t *testing.T) {
}
func TestDiscoveryOrdering(t *testing.T) {
mapper := request.NewRequestContextMapper()
handler := NewRootAPIsHandler(DefaultAddresses{DefaultAddress: "192.168.1.1"}, codecs, mapper)
handler := NewRootAPIsHandler(DefaultAddresses{DefaultAddress: "192.168.1.1"}, codecs)
server := httptest.NewServer(request.WithRequestContext(handler, mapper))
server := httptest.NewServer(handler)
groupList, err := getGroupList(t, server)
if err != nil {
t.Fatalf("unexpected error: %v", err)

View File

@ -17,7 +17,6 @@ limitations under the License.
package discovery
import (
"errors"
"net/http"
restful "github.com/emicklei/go-restful"
@ -27,7 +26,6 @@ import (
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apiserver/pkg/endpoints/handlers/negotiation"
"k8s.io/apiserver/pkg/endpoints/handlers/responsewriters"
"k8s.io/apiserver/pkg/endpoints/request"
)
type APIResourceLister interface {
@ -43,14 +41,13 @@ func (f APIResourceListerFunc) ListAPIResources() []metav1.APIResource {
// APIVersionHandler creates a webservice serving the supported resources for the version
// E.g., such a web service will be registered at /apis/extensions/v1beta1.
type APIVersionHandler struct {
serializer runtime.NegotiatedSerializer
contextMapper request.RequestContextMapper
serializer runtime.NegotiatedSerializer
groupVersion schema.GroupVersion
apiResourceLister APIResourceLister
}
func NewAPIVersionHandler(serializer runtime.NegotiatedSerializer, groupVersion schema.GroupVersion, apiResourceLister APIResourceLister, contextMapper request.RequestContextMapper) *APIVersionHandler {
func NewAPIVersionHandler(serializer runtime.NegotiatedSerializer, groupVersion schema.GroupVersion, apiResourceLister APIResourceLister) *APIVersionHandler {
if keepUnversioned(groupVersion.Group) {
// Because in release 1.1, /apis/extensions returns response with empty
// APIVersion, we use stripVersionNegotiatedSerializer to keep the
@ -62,7 +59,6 @@ func NewAPIVersionHandler(serializer runtime.NegotiatedSerializer, groupVersion
serializer: serializer,
groupVersion: groupVersion,
apiResourceLister: apiResourceLister,
contextMapper: contextMapper,
}
}
@ -82,12 +78,6 @@ func (s *APIVersionHandler) handle(req *restful.Request, resp *restful.Response)
}
func (s *APIVersionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
ctx, ok := s.contextMapper.Get(req)
if !ok {
responsewriters.InternalError(w, req, errors.New("no context found for request"))
return
}
responsewriters.WriteObjectNegotiated(ctx, s.serializer, schema.GroupVersion{}, w, req, http.StatusOK,
responsewriters.WriteObjectNegotiated(s.serializer, schema.GroupVersion{}, w, req, http.StatusOK,
&metav1.APIResourceList{GroupVersion: s.groupVersion.String(), APIResources: s.apiResourceLister.ListAPIResources()})
}

View File

@ -34,7 +34,6 @@ go_test(
"//vendor/k8s.io/apiserver/pkg/authentication/authenticator:go_default_library",
"//vendor/k8s.io/apiserver/pkg/authentication/user:go_default_library",
"//vendor/k8s.io/apiserver/pkg/authorization/authorizer:go_default_library",
"//vendor/k8s.io/apiserver/pkg/endpoints/handlers/responsewriters:go_default_library",
"//vendor/k8s.io/apiserver/pkg/endpoints/request:go_default_library",
],
)

View File

@ -38,17 +38,18 @@ import (
// requests coming to the server. Audit level is decided according to requests'
// attributes and audit policy. Logs are emitted to the audit sink to
// process events. If sink or audit policy is nil, no decoration takes place.
func WithAudit(handler http.Handler, requestContextMapper request.RequestContextMapper, sink audit.Sink, policy policy.Checker, longRunningCheck request.LongRunningRequestCheck) http.Handler {
func WithAudit(handler http.Handler, sink audit.Sink, policy policy.Checker, longRunningCheck request.LongRunningRequestCheck) http.Handler {
if sink == nil || policy == nil {
return handler
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, ev, omitStages, err := createAuditEventAndAttachToContext(requestContextMapper, req, policy)
req, ev, omitStages, err := createAuditEventAndAttachToContext(req, policy)
if err != nil {
utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err))
responsewriters.InternalError(w, req, errors.New("failed to create audit event"))
return
}
ctx := req.Context()
if ev == nil || ctx == nil {
handler.ServeHTTP(w, req)
return
@ -111,35 +112,29 @@ func WithAudit(handler http.Handler, requestContextMapper request.RequestContext
// - context with audit event attached to it
// - created audit event
// - error if anything bad happened
func createAuditEventAndAttachToContext(requestContextMapper request.RequestContextMapper, req *http.Request, policy policy.Checker) (request.Context, *auditinternal.Event, []auditinternal.Stage, error) {
ctx, ok := requestContextMapper.Get(req)
if !ok {
return nil, nil, nil, fmt.Errorf("no context found for request")
}
func createAuditEventAndAttachToContext(req *http.Request, policy policy.Checker) (*http.Request, *auditinternal.Event, []auditinternal.Stage, error) {
ctx := req.Context()
attribs, err := GetAuthorizerAttributes(ctx)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to GetAuthorizerAttributes: %v", err)
return req, nil, nil, fmt.Errorf("failed to GetAuthorizerAttributes: %v", err)
}
level, omitStages := policy.LevelAndStages(attribs)
audit.ObservePolicyLevel(level)
if level == auditinternal.LevelNone {
// Don't audit.
return nil, nil, nil, nil
return req, nil, nil, nil
}
ev, err := audit.NewEventFromRequest(req, level, attribs)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to complete audit event from request: %v", err)
return req, nil, nil, fmt.Errorf("failed to complete audit event from request: %v", err)
}
ctx = request.WithAuditEvent(ctx, ev)
if err := requestContextMapper.Update(req, ctx); err != nil {
return nil, nil, nil, fmt.Errorf("failed to attach audit event to context: %v", err)
}
req = req.WithContext(request.WithAuditEvent(ctx, ev))
return ctx, ev, omitStages, nil
return req, ev, omitStages, nil
}
func processAuditEvent(sink audit.Sink, ev *auditinternal.Event, omitStages []auditinternal.Stage) {

View File

@ -667,14 +667,13 @@ func TestAudit(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
sink := &fakeAuditSink{}
policyChecker := policy.FakeChecker(auditinternal.LevelRequestResponse, test.omitStages)
handler := WithAudit(http.HandlerFunc(test.handler), &fakeRequestContextMapper{
user: &user.DefaultInfo{Name: "admin"},
}, sink, policyChecker, func(r *http.Request, ri *request.RequestInfo) bool {
handler := WithAudit(http.HandlerFunc(test.handler), sink, policyChecker, func(r *http.Request, ri *request.RequestInfo) bool {
// simplified long-running check
return ri.Verb == "watch"
})
req, _ := http.NewRequest(test.verb, test.path, nil)
req = withTestContext(req, &user.DefaultInfo{Name: "admin"}, nil)
if test.auditID != "" {
req.Header.Add("Audit-ID", test.auditID)
}
@ -735,37 +734,11 @@ func TestAudit(t *testing.T) {
}
}
type fakeRequestContextMapper struct {
user *user.DefaultInfo
audit *auditinternal.Event
}
func (m *fakeRequestContextMapper) Get(req *http.Request) (request.Context, bool) {
ctx := request.NewContext()
if m.user != nil {
ctx = request.WithUser(ctx, m.user)
}
if m.audit != nil {
ctx = request.WithAuditEvent(ctx, m.audit)
}
resolver := newTestRequestInfoResolver()
info, err := resolver.NewRequestInfo(req)
if err == nil {
ctx = request.WithRequestInfo(ctx, info)
}
return ctx, true
}
func (*fakeRequestContextMapper) Update(req *http.Request, context request.Context) error {
return nil
}
func TestAuditNoPanicOnNilUser(t *testing.T) {
policyChecker := policy.FakeChecker(auditinternal.LevelRequestResponse, nil)
handler := WithAudit(&fakeHTTPHandler{}, &fakeRequestContextMapper{}, &fakeAuditSink{}, policyChecker, nil)
handler := WithAudit(&fakeHTTPHandler{}, &fakeAuditSink{}, policyChecker, nil)
req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil)
req = withTestContext(req, nil, nil)
req.RemoteAddr = "127.0.0.1"
handler.ServeHTTP(httptest.NewRecorder(), req)
}
@ -777,12 +750,11 @@ func TestAuditLevelNone(t *testing.T) {
w.WriteHeader(200)
})
policyChecker := policy.FakeChecker(auditinternal.LevelNone, nil)
handler = WithAudit(handler, &fakeRequestContextMapper{
user: &user.DefaultInfo{Name: "admin"},
}, sink, policyChecker, nil)
handler = WithAudit(handler, sink, policyChecker, nil)
req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil)
req.RemoteAddr = "127.0.0.1"
req = withTestContext(req, &user.DefaultInfo{Name: "admin"}, nil)
handler.ServeHTTP(httptest.NewRecorder(), req)
if len(sink.events) > 0 {
@ -828,12 +800,11 @@ func TestAuditIDHttpHeader(t *testing.T) {
w.WriteHeader(200)
})
policyChecker := policy.FakeChecker(test.level, nil)
handler = WithAudit(handler, &fakeRequestContextMapper{
user: &user.DefaultInfo{Name: "admin"},
}, sink, policyChecker, nil)
handler = WithAudit(handler, sink, policyChecker, nil)
req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil)
req.RemoteAddr = "127.0.0.1"
req = withTestContext(req, &user.DefaultInfo{Name: "admin"}, nil)
if test.requestHeader != "" {
req.Header.Add("Audit-ID", test.requestHeader)
}
@ -857,3 +828,17 @@ func TestAuditIDHttpHeader(t *testing.T) {
}
}
}
func withTestContext(req *http.Request, user user.Info, audit *auditinternal.Event) *http.Request {
ctx := req.Context()
if user != nil {
ctx = request.WithUser(ctx, user)
}
if audit != nil {
ctx = request.WithAuditEvent(ctx, audit)
}
if info, err := newTestRequestInfoResolver().NewRequestInfo(req); err == nil {
ctx = request.WithRequestInfo(ctx, info)
}
return req.WithContext(ctx)
}

View File

@ -50,47 +50,38 @@ func init() {
// stores any such user found onto the provided context for the request. If authentication fails or returns an error
// the failed handler is used. On success, "Authorization" header is removed from the request and handler
// is invoked to serve the request.
func WithAuthentication(handler http.Handler, mapper genericapirequest.RequestContextMapper, auth authenticator.Request, failed http.Handler) http.Handler {
func WithAuthentication(handler http.Handler, auth authenticator.Request, failed http.Handler) http.Handler {
if auth == nil {
glog.Warningf("Authentication is disabled")
return handler
}
return genericapirequest.WithRequestContext(
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
user, ok, err := auth.AuthenticateRequest(req)
if err != nil || !ok {
if err != nil {
glog.Errorf("Unable to authenticate the request due to an error: %v", err)
}
failed.ServeHTTP(w, req)
return
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
user, ok, err := auth.AuthenticateRequest(req)
if err != nil || !ok {
if err != nil {
glog.Errorf("Unable to authenticate the request due to an error: %v", err)
}
failed.ServeHTTP(w, req)
return
}
// authorization header is not required anymore in case of a successful authentication.
req.Header.Del("Authorization")
// authorization header is not required anymore in case of a successful authentication.
req.Header.Del("Authorization")
if ctx, ok := mapper.Get(req); ok {
mapper.Update(req, genericapirequest.WithUser(ctx, user))
}
req = req.WithContext(genericapirequest.WithUser(req.Context(), user))
authenticatedUserCounter.WithLabelValues(compressUsername(user.GetName())).Inc()
authenticatedUserCounter.WithLabelValues(compressUsername(user.GetName())).Inc()
handler.ServeHTTP(w, req)
}),
mapper,
)
handler.ServeHTTP(w, req)
})
}
func Unauthorized(requestContextMapper genericapirequest.RequestContextMapper, s runtime.NegotiatedSerializer, supportsBasicAuth bool) http.Handler {
func Unauthorized(s runtime.NegotiatedSerializer, supportsBasicAuth bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if supportsBasicAuth {
w.Header().Set("WWW-Authenticate", `Basic realm="kubernetes-master"`)
}
ctx, ok := requestContextMapper.Get(req)
if !ok {
responsewriters.InternalError(w, req, errors.New("no context found for request"))
return
}
ctx := req.Context()
requestInfo, found := genericapirequest.RequestInfoFrom(ctx)
if !found {
responsewriters.InternalError(w, req, errors.New("no RequestInfo found in the context"))
@ -98,7 +89,7 @@ func Unauthorized(requestContextMapper genericapirequest.RequestContextMapper, s
}
gv := schema.GroupVersion{Group: requestInfo.APIGroup, Version: requestInfo.APIVersion}
responsewriters.ErrorNegotiated(ctx, apierrors.NewUnauthorized("Unauthorized"), s, gv, w, req)
responsewriters.ErrorNegotiated(apierrors.NewUnauthorized("Unauthorized"), s, gv, w, req)
})
}

View File

@ -29,13 +29,9 @@ import (
func TestAuthenticateRequest(t *testing.T) {
success := make(chan struct{})
contextMapper := genericapirequest.NewRequestContextMapper()
auth := WithAuthentication(
http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
ctx, ok := contextMapper.Get(req)
if ctx == nil || !ok {
t.Errorf("no context stored on contextMapper: %#v", contextMapper)
}
ctx := req.Context()
user, ok := genericapirequest.UserFrom(ctx)
if user == nil || !ok {
t.Errorf("no user stored in context: %#v", ctx)
@ -45,7 +41,6 @@ func TestAuthenticateRequest(t *testing.T) {
}
close(success)
}),
contextMapper,
authenticator.RequestFunc(func(req *http.Request) (user.Info, bool, error) {
if req.Header.Get("Authorization") == "Something" {
return &user.DefaultInfo{Name: "user"}, true, nil
@ -60,23 +55,14 @@ func TestAuthenticateRequest(t *testing.T) {
auth.ServeHTTP(httptest.NewRecorder(), &http.Request{Header: map[string][]string{"Authorization": {"Something"}}})
<-success
empty, err := genericapirequest.IsEmpty(contextMapper)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !empty {
t.Fatalf("contextMapper should have no stored requests: %v", contextMapper)
}
}
func TestAuthenticateRequestFailed(t *testing.T) {
failed := make(chan struct{})
contextMapper := genericapirequest.NewRequestContextMapper()
auth := WithAuthentication(
http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
t.Errorf("unexpected call to handler")
}),
contextMapper,
authenticator.RequestFunc(func(req *http.Request) (user.Info, bool, error) {
return nil, false, nil
}),
@ -88,23 +74,14 @@ func TestAuthenticateRequestFailed(t *testing.T) {
auth.ServeHTTP(httptest.NewRecorder(), &http.Request{})
<-failed
empty, err := genericapirequest.IsEmpty(contextMapper)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !empty {
t.Fatalf("contextMapper should have no stored requests: %v", contextMapper)
}
}
func TestAuthenticateRequestError(t *testing.T) {
failed := make(chan struct{})
contextMapper := genericapirequest.NewRequestContextMapper()
auth := WithAuthentication(
http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
t.Errorf("unexpected call to handler")
}),
contextMapper,
authenticator.RequestFunc(func(req *http.Request) (user.Info, bool, error) {
return nil, false, errors.New("failure")
}),
@ -116,11 +93,4 @@ func TestAuthenticateRequestError(t *testing.T) {
auth.ServeHTTP(httptest.NewRecorder(), &http.Request{})
<-failed
empty, err := genericapirequest.IsEmpty(contextMapper)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !empty {
t.Fatalf("contextMapper should have no stored requests: %v", contextMapper)
}
}

View File

@ -28,17 +28,16 @@ import (
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/audit/policy"
"k8s.io/apiserver/pkg/endpoints/handlers/responsewriters"
"k8s.io/apiserver/pkg/endpoints/request"
)
// WithFailedAuthenticationAudit decorates a failed http.Handler used in WithAuthentication handler.
// It is meant to log only failed authentication requests.
func WithFailedAuthenticationAudit(failedHandler http.Handler, requestContextMapper request.RequestContextMapper, sink audit.Sink, policy policy.Checker) http.Handler {
func WithFailedAuthenticationAudit(failedHandler http.Handler, sink audit.Sink, policy policy.Checker) http.Handler {
if sink == nil || policy == nil {
return failedHandler
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
_, ev, omitStages, err := createAuditEventAndAttachToContext(requestContextMapper, req, policy)
req, ev, omitStages, err := createAuditEventAndAttachToContext(req, policy)
if err != nil {
utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err))
responsewriters.InternalError(w, req, errors.New("failed to create audit event"))

View File

@ -35,9 +35,10 @@ func TestFailedAuthnAudit(t *testing.T) {
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}),
&fakeRequestContextMapper{}, sink, policyChecker)
sink, policyChecker)
req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil)
req.RemoteAddr = "127.0.0.1"
req = withTestContext(req, nil, nil)
req.SetBasicAuth("username", "password")
handler.ServeHTTP(httptest.NewRecorder(), req)
@ -66,9 +67,10 @@ func TestFailedMultipleAuthnAudit(t *testing.T) {
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}),
&fakeRequestContextMapper{}, sink, policyChecker)
sink, policyChecker)
req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil)
req.RemoteAddr = "127.0.0.1"
req = withTestContext(req, nil, nil)
req.SetBasicAuth("username", "password")
req.TLS = &tls.ConnectionState{PeerCertificates: []*x509.Certificate{{}}}
handler.ServeHTTP(httptest.NewRecorder(), req)
@ -98,9 +100,10 @@ func TestFailedAuthnAuditWithoutAuthorization(t *testing.T) {
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}),
&fakeRequestContextMapper{}, sink, policyChecker)
sink, policyChecker)
req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil)
req.RemoteAddr = "127.0.0.1"
req = withTestContext(req, nil, nil)
handler.ServeHTTP(httptest.NewRecorder(), req)
if len(sink.events) != 1 {
@ -128,9 +131,10 @@ func TestFailedAuthnAuditOmitted(t *testing.T) {
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}),
&fakeRequestContextMapper{}, sink, policyChecker)
sink, policyChecker)
req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil)
req.RemoteAddr = "127.0.0.1"
req = withTestContext(req, nil, nil)
handler.ServeHTTP(httptest.NewRecorder(), req)
if len(sink.events) != 0 {

View File

@ -41,17 +41,13 @@ const (
)
// WithAuthorizationCheck passes all authorized requests on to handler, and returns a forbidden error otherwise.
func WithAuthorization(handler http.Handler, requestContextMapper request.RequestContextMapper, a authorizer.Authorizer, s runtime.NegotiatedSerializer) http.Handler {
func WithAuthorization(handler http.Handler, a authorizer.Authorizer, s runtime.NegotiatedSerializer) http.Handler {
if a == nil {
glog.Warningf("Authorization is disabled")
return handler
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, ok := requestContextMapper.Get(req)
if !ok {
responsewriters.InternalError(w, req, errors.New("no context found for request"))
return
}
ctx := req.Context()
ae := request.AuditEventFrom(ctx)
attributes, err := GetAuthorizerAttributes(ctx)

View File

@ -29,13 +29,9 @@ import (
"k8s.io/apimachinery/pkg/runtime/serializer"
auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/authorization/authorizer"
"k8s.io/apiserver/pkg/endpoints/handlers/responsewriters"
"k8s.io/apiserver/pkg/endpoints/request"
)
func TestGetAuthorizerAttributes(t *testing.T) {
mapper := request.NewRequestContextMapper()
testcases := map[string]struct {
Verb string
Path string
@ -113,15 +109,10 @@ func TestGetAuthorizerAttributes(t *testing.T) {
var attribs authorizer.Attributes
var err error
var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, ok := mapper.Get(req)
if !ok {
responsewriters.InternalError(w, req, errors.New("no context found for request"))
return
}
ctx := req.Context()
attribs, err = GetAuthorizerAttributes(ctx)
})
handler = WithRequestInfo(handler, newTestRequestInfoResolver(), mapper)
handler = request.WithRequestContext(handler, mapper)
handler = WithRequestInfo(handler, newTestRequestInfoResolver())
handler.ServeHTTP(httptest.NewRecorder(), req)
if err != nil {
@ -181,11 +172,11 @@ func TestAuditAnnotation(t *testing.T) {
negotiatedSerializer := serializer.DirectCodecFactory{CodecFactory: serializer.NewCodecFactory(scheme)}
for k, tc := range testcases {
audit := &auditinternal.Event{Level: auditinternal.LevelMetadata}
handler := WithAuthorization(&fakeHTTPHandler{}, &fakeRequestContextMapper{
audit: audit,
}, tc.authorizer, negotiatedSerializer)
handler := WithAuthorization(&fakeHTTPHandler{}, tc.authorizer, negotiatedSerializer)
// TODO: fake audit injector
req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil)
req = withTestContext(req, nil, audit)
req.RemoteAddr = "127.0.0.1"
handler.ServeHTTP(httptest.NewRecorder(), req)
assert.Equal(t, tc.decisionAnnotation, audit.Annotations[decisionAnnotationKey], k+": unexpected decision annotation")

View File

@ -37,7 +37,7 @@ import (
)
// WithImpersonation is a filter that will inspect and check requests that attempt to change the user.Info for their requests
func WithImpersonation(handler http.Handler, requestContextMapper request.RequestContextMapper, a authorizer.Authorizer, s runtime.NegotiatedSerializer) http.Handler {
func WithImpersonation(handler http.Handler, a authorizer.Authorizer, s runtime.NegotiatedSerializer) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
impersonationRequests, err := buildImpersonationRequests(req.Header)
if err != nil {
@ -50,11 +50,7 @@ func WithImpersonation(handler http.Handler, requestContextMapper request.Reques
return
}
ctx, exists := requestContextMapper.Get(req)
if !exists {
responsewriters.InternalError(w, req, errors.New("no context found for request"))
return
}
ctx := req.Context()
requestor, exists := request.UserFrom(ctx)
if !exists {
responsewriters.InternalError(w, req, errors.New("no user found for request"))
@ -129,7 +125,7 @@ func WithImpersonation(handler http.Handler, requestContextMapper request.Reques
Groups: groups,
Extra: userExtra,
}
requestContextMapper.Update(req, request.WithUser(ctx, newUser))
req = req.WithContext(request.WithUser(ctx, newUser))
oldUser, _ := request.UserFrom(ctx)
httplog.LogOf(req, w).Addf("%v is acting as %v", oldUser, newUser)

View File

@ -308,13 +308,12 @@ func TestImpersonationFilter(t *testing.T) {
},
}
requestContextMapper := request.NewRequestContextMapper()
var ctx request.Context
var actualUser user.Info
var lock sync.Mutex
doNothingHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
currentCtx, _ := requestContextMapper.Get(req)
currentCtx := req.Context()
user, exists := request.UserFrom(currentCtx)
if !exists {
actualUser = nil
@ -345,8 +344,8 @@ func TestImpersonationFilter(t *testing.T) {
}()
lock.Lock()
defer lock.Unlock()
requestContextMapper.Update(req, ctx)
currentCtx, _ := requestContextMapper.Get(req)
req = req.WithContext(ctx)
currentCtx := req.Context()
user, exists := request.UserFrom(currentCtx)
if !exists {
@ -358,8 +357,7 @@ func TestImpersonationFilter(t *testing.T) {
delegate.ServeHTTP(w, req)
})
}(WithImpersonation(doNothingHandler, requestContextMapper, impersonateAuthorizer{}, serializer.NewCodecFactory(runtime.NewScheme())))
handler = request.WithRequestContext(handler, requestContextMapper)
}(WithImpersonation(doNothingHandler, impersonateAuthorizer{}, serializer.NewCodecFactory(runtime.NewScheme())))
server := httptest.NewServer(handler)
defer server.Close()

View File

@ -18,7 +18,6 @@ package filters
import (
"bufio"
"errors"
"fmt"
"io"
"net"
@ -32,7 +31,6 @@ import (
authenticationapi "k8s.io/api/authentication/v1"
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apiserver/pkg/endpoints/handlers/responsewriters"
"k8s.io/apiserver/pkg/endpoints/request"
)
var _ http.ResponseWriter = &legacyAuditResponseWriter{}
@ -96,16 +94,12 @@ var _ http.Hijacker = &fancyLegacyResponseWriterDelegator{}
// 2. the response line containing:
// - the unique id from 1
// - response code
func WithLegacyAudit(handler http.Handler, requestContextMapper request.RequestContextMapper, out io.Writer) http.Handler {
func WithLegacyAudit(handler http.Handler, out io.Writer) http.Handler {
if out == nil {
return handler
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, ok := requestContextMapper.Get(req)
if !ok {
responsewriters.InternalError(w, req, errors.New("no context found for request"))
return
}
ctx := req.Context()
attribs, err := GetAuthorizerAttributes(ctx)
if err != nil {
responsewriters.InternalError(w, req, err)

View File

@ -48,12 +48,11 @@ func TestLegacyConstructResponseWriter(t *testing.T) {
func TestLegacyAudit(t *testing.T) {
var buf bytes.Buffer
handler := WithLegacyAudit(&fakeHTTPHandler{}, &fakeRequestContextMapper{
user: &user.DefaultInfo{Name: "admin"},
}, &buf)
handler := WithLegacyAudit(&fakeHTTPHandler{}, &buf)
req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil)
req.RemoteAddr = "127.0.0.1"
req = withTestContext(req, &user.DefaultInfo{Name: "admin"}, nil)
handler.ServeHTTP(httptest.NewRecorder(), req)
line := strings.Split(strings.TrimSpace(buf.String()), "\n")
if len(line) != 2 {
@ -78,10 +77,11 @@ func TestLegacyAudit(t *testing.T) {
func TestLegacyAuditNoPanicOnNilUser(t *testing.T) {
var buf bytes.Buffer
handler := WithLegacyAudit(&fakeHTTPHandler{}, &fakeRequestContextMapper{}, &buf)
handler := WithLegacyAudit(&fakeHTTPHandler{}, &buf)
req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil)
req.RemoteAddr = "127.0.0.1"
req = withTestContext(req, nil, nil)
handler.ServeHTTP(httptest.NewRecorder(), req)
line := strings.Split(strings.TrimSpace(buf.String()), "\n")
if len(line) != 2 {

View File

@ -17,7 +17,6 @@ limitations under the License.
package filters
import (
"errors"
"fmt"
"net/http"
@ -26,21 +25,16 @@ import (
)
// WithRequestInfo attaches a RequestInfo to the context.
func WithRequestInfo(handler http.Handler, resolver request.RequestInfoResolver, requestContextMapper request.RequestContextMapper) http.Handler {
func WithRequestInfo(handler http.Handler, resolver request.RequestInfoResolver) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, ok := requestContextMapper.Get(req)
if !ok {
responsewriters.InternalError(w, req, errors.New("no context found for request"))
return
}
ctx := req.Context()
info, err := resolver.NewRequestInfo(req)
if err != nil {
responsewriters.InternalError(w, req, fmt.Errorf("failed to create RequestInfo: %v", err))
return
}
requestContextMapper.Update(req, request.WithRequestInfo(ctx, info))
req = req.WithContext(request.WithRequestInfo(ctx, info))
handler.ServeHTTP(w, req)
})

View File

@ -29,7 +29,6 @@ import (
utilerrors "k8s.io/apimachinery/pkg/util/errors"
"k8s.io/apiserver/pkg/admission"
"k8s.io/apiserver/pkg/endpoints/discovery"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/registry/rest"
)
@ -70,8 +69,7 @@ type APIGroupVersion struct {
Linker runtime.SelfLinker
UnsafeConvertor runtime.ObjectConvertor
Admit admission.Interface
Context request.RequestContextMapper
Admit admission.Interface
MinRequestTimeout time.Duration
@ -93,7 +91,7 @@ func (g *APIGroupVersion) InstallREST(container *restful.Container) error {
}
apiResources, ws, registrationErrors := installer.Install()
versionDiscoveryHandler := discovery.NewAPIVersionHandler(g.Serializer, g.GroupVersion, staticLister{apiResources}, g.Context)
versionDiscoveryHandler := discovery.NewAPIVersionHandler(g.Serializer, g.GroupVersion, staticLister{apiResources})
versionDiscoveryHandler.AddToWebService(ws)
container.Add(ws)
return utilerrors.NewAggregate(registrationErrors)

View File

@ -57,7 +57,7 @@ func createHandler(r rest.NamedCreater, scope RequestScope, typer runtime.Object
return
}
ctx := scope.ContextFunc(req)
ctx := req.Context()
ctx = request.WithNamespace(ctx, namespace)
gv := scope.Kind.GroupVersion()

View File

@ -49,7 +49,7 @@ func DeleteResource(r rest.GracefulDeleter, allowsOptions bool, scope RequestSco
scope.err(err, w, req)
return
}
ctx := scope.ContextFunc(req)
ctx := req.Context()
ctx = request.WithNamespace(ctx, namespace)
options := &metav1.DeleteOptions{}
@ -176,7 +176,7 @@ func DeleteCollection(r rest.CollectionDeleter, checkBody bool, scope RequestSco
return
}
ctx := scope.ContextFunc(req)
ctx := req.Context()
ctx = request.WithNamespace(ctx, namespace)
if mutatingAdmission, ok := admit.(admission.MutationInterface); ok && mutatingAdmission.Handles(admission.Delete) {

View File

@ -54,7 +54,7 @@ func getResourceHandler(scope RequestScope, getter getterFunc) http.HandlerFunc
scope.err(err, w, req)
return
}
ctx := scope.ContextFunc(req)
ctx := req.Context()
ctx = request.WithNamespace(ctx, namespace)
result, err := getter(ctx, name, req, trace)
@ -137,7 +137,7 @@ func getRequestOptions(req *http.Request, scope RequestScope, into runtime.Objec
newQuery[k] = v
}
ctx := scope.ContextFunc(req)
ctx := req.Context()
requestInfo, _ := request.RequestInfoFrom(ctx)
startingIndex := 2
if isSubresource {
@ -181,7 +181,7 @@ func ListResource(r rest.Lister, rw rest.Watcher, scope RequestScope, forceWatch
hasName = false
}
ctx := scope.ContextFunc(req)
ctx := req.Context()
ctx = request.WithNamespace(ctx, namespace)
opts := metainternalversion.ListOptions{}

View File

@ -26,9 +26,6 @@ import (
"k8s.io/apiserver/pkg/endpoints/request"
)
// ContextFunc returns a Context given a request - a context must be returned
type ContextFunc func(req *http.Request) request.Context
// ScopeNamer handles accessing names from requests and objects
type ScopeNamer interface {
// Namespace returns the appropriate namespace value from the request (may be empty) or an
@ -51,7 +48,6 @@ type ScopeNamer interface {
}
type ContextBasedNaming struct {
GetContext ContextFunc
SelfLinker runtime.SelfLinker
ClusterScoped bool
@ -67,7 +63,7 @@ func (n ContextBasedNaming) SetSelfLink(obj runtime.Object, url string) error {
}
func (n ContextBasedNaming) Namespace(req *http.Request) (namespace string, err error) {
requestInfo, ok := request.RequestInfoFrom(n.GetContext(req))
requestInfo, ok := request.RequestInfoFrom(req.Context())
if !ok {
return "", fmt.Errorf("missing requestInfo")
}
@ -75,7 +71,7 @@ func (n ContextBasedNaming) Namespace(req *http.Request) (namespace string, err
}
func (n ContextBasedNaming) Name(req *http.Request) (namespace, name string, err error) {
requestInfo, ok := request.RequestInfoFrom(n.GetContext(req))
requestInfo, ok := request.RequestInfoFrom(req.Context())
if !ok {
return "", "", fmt.Errorf("missing requestInfo")
}

View File

@ -76,7 +76,7 @@ func PatchResource(r rest.Patcher, scope RequestScope, admit admission.Interface
return
}
ctx := scope.ContextFunc(req)
ctx := req.Context()
ctx = request.WithNamespace(ctx, namespace)
versionedObj, err := converter.ConvertToVersion(r.New(), scope.Kind.GroupVersion())

View File

@ -169,7 +169,7 @@ func transformResponseObject(ctx request.Context, scope RequestScope, req *http.
}
}
responsewriters.WriteObject(ctx, statusCode, scope.Kind.GroupVersion(), scope.Serializer, result, w, req)
responsewriters.WriteObject(statusCode, scope.Kind.GroupVersion(), scope.Serializer, result, w, req)
}
// errNotAcceptable indicates Accept negotiation has failed

View File

@ -53,7 +53,7 @@ func Forbidden(ctx request.Context, attributes authorizer.Attributes, w http.Res
}
gv := schema.GroupVersion{Group: attributes.GetAPIGroup(), Version: attributes.GetAPIVersion()}
gr := schema.GroupResource{Group: attributes.GetAPIGroup(), Resource: attributes.GetResource()}
ErrorNegotiated(ctx, apierrors.NewForbidden(gr, attributes.GetName(), fmt.Errorf(errMsg)), s, gv, w, req)
ErrorNegotiated(apierrors.NewForbidden(gr, attributes.GetName(), fmt.Errorf(errMsg)), s, gv, w, req)
}
func forbiddenMessage(attributes authorizer.Attributes) string {

View File

@ -40,25 +40,25 @@ import (
// response. The Accept header and current API version will be passed in, and the output will be copied
// directly to the response body. If content type is returned it is used, otherwise the content type will
// be "application/octet-stream". All other objects are sent to standard JSON serialization.
func WriteObject(ctx request.Context, statusCode int, gv schema.GroupVersion, s runtime.NegotiatedSerializer, object runtime.Object, w http.ResponseWriter, req *http.Request) {
func WriteObject(statusCode int, gv schema.GroupVersion, s runtime.NegotiatedSerializer, object runtime.Object, w http.ResponseWriter, req *http.Request) {
stream, ok := object.(rest.ResourceStreamer)
if ok {
requestInfo, _ := request.RequestInfoFrom(ctx)
requestInfo, _ := request.RequestInfoFrom(req.Context())
metrics.RecordLongRunning(req, requestInfo, func() {
StreamObject(ctx, statusCode, gv, s, stream, w, req)
StreamObject(statusCode, gv, s, stream, w, req)
})
return
}
WriteObjectNegotiated(ctx, s, gv, w, req, statusCode, object)
WriteObjectNegotiated(s, gv, w, req, statusCode, object)
}
// StreamObject performs input stream negotiation from a ResourceStreamer and writes that to the response.
// If the client requests a websocket upgrade, negotiate for a websocket reader protocol (because many
// browser clients cannot easily handle binary streaming protocols).
func StreamObject(ctx request.Context, statusCode int, gv schema.GroupVersion, s runtime.NegotiatedSerializer, stream rest.ResourceStreamer, w http.ResponseWriter, req *http.Request) {
func StreamObject(statusCode int, gv schema.GroupVersion, s runtime.NegotiatedSerializer, stream rest.ResourceStreamer, w http.ResponseWriter, req *http.Request) {
out, flush, contentType, err := stream.InputStream(gv.String(), req.Header.Get("Accept"))
if err != nil {
ErrorNegotiated(ctx, err, s, gv, w, req)
ErrorNegotiated(err, s, gv, w, req)
return
}
if out == nil {
@ -101,7 +101,7 @@ func SerializeObject(mediaType string, encoder runtime.Encoder, w http.ResponseW
// WriteObjectNegotiated renders an object in the content type negotiated by the client.
// The context is optional and can be nil.
func WriteObjectNegotiated(ctx request.Context, s runtime.NegotiatedSerializer, gv schema.GroupVersion, w http.ResponseWriter, req *http.Request, statusCode int, object runtime.Object) {
func WriteObjectNegotiated(s runtime.NegotiatedSerializer, gv schema.GroupVersion, w http.ResponseWriter, req *http.Request, statusCode int, object runtime.Object) {
serializer, err := negotiation.NegotiateOutputSerializer(req, s)
if err != nil {
// if original statusCode was not successful we need to return the original error
@ -115,7 +115,7 @@ func WriteObjectNegotiated(ctx request.Context, s runtime.NegotiatedSerializer,
return
}
if ae := request.AuditEventFrom(ctx); ae != nil {
if ae := request.AuditEventFrom(req.Context()); ae != nil {
audit.LogResponseObject(ae, object, gv, s)
}
@ -125,7 +125,7 @@ func WriteObjectNegotiated(ctx request.Context, s runtime.NegotiatedSerializer,
// ErrorNegotiated renders an error to the response. Returns the HTTP status code of the error.
// The context is optional and may be nil.
func ErrorNegotiated(ctx request.Context, err error, s runtime.NegotiatedSerializer, gv schema.GroupVersion, w http.ResponseWriter, req *http.Request) int {
func ErrorNegotiated(err error, s runtime.NegotiatedSerializer, gv schema.GroupVersion, w http.ResponseWriter, req *http.Request) int {
status := ErrorToAPIStatus(err)
code := int(status.Code)
// when writing an error, check to see if the status indicates a retry after period
@ -139,7 +139,7 @@ func ErrorNegotiated(ctx request.Context, err error, s runtime.NegotiatedSeriali
return code
}
WriteObjectNegotiated(ctx, s, gv, w, req, code, status)
WriteObjectNegotiated(s, gv, w, req, code, status)
return code
}

View File

@ -42,7 +42,6 @@ import (
// RequestScope encapsulates common fields across all RESTful handler methods.
type RequestScope struct {
Namer ScopeNamer
ContextFunc
Serializer runtime.NegotiatedSerializer
runtime.ParameterCodec
@ -63,8 +62,7 @@ type RequestScope struct {
}
func (scope *RequestScope) err(err error, w http.ResponseWriter, req *http.Request) {
ctx := scope.ContextFunc(req)
responsewriters.ErrorNegotiated(ctx, err, scope.Serializer, scope.Kind.GroupVersion(), w, req)
responsewriters.ErrorNegotiated(err, scope.Serializer, scope.Kind.GroupVersion(), w, req)
}
func (scope *RequestScope) AllowsConversion(gvk schema.GroupVersionKind) bool {
@ -102,7 +100,7 @@ func ConnectResource(connecter rest.Connecter, scope RequestScope, admit admissi
scope.err(err, w, req)
return
}
ctx := scope.ContextFunc(req)
ctx := req.Context()
ctx = request.WithNamespace(ctx, namespace)
opts, subpath, subpathKey := connecter.NewConnectOptions()
if err := getRequestOptions(req, scope, opts, subpath, subpathKey, isSubresource); err != nil {
@ -153,8 +151,7 @@ type responder struct {
}
func (r *responder) Object(statusCode int, obj runtime.Object) {
ctx := r.scope.ContextFunc(r.req)
responsewriters.WriteObject(ctx, statusCode, r.scope.Kind.GroupVersion(), r.scope.Serializer, obj, r.w, r.req)
responsewriters.WriteObject(statusCode, r.scope.Kind.GroupVersion(), r.scope.Serializer, obj, r.w, r.req)
}
func (r *responder) Error(err error) {

View File

@ -47,7 +47,7 @@ func UpdateResource(r rest.Updater, scope RequestScope, typer runtime.ObjectType
scope.err(err, w, req)
return
}
ctx := scope.ContextFunc(req)
ctx := req.Context()
ctx = request.WithNamespace(ctx, namespace)
body, err := readBody(req)

View File

@ -89,7 +89,7 @@ func serveWatch(watcher watch.Interface, scope RequestScope, req *http.Request,
mediaType += ";stream=watch"
}
ctx := scope.ContextFunc(req)
ctx := req.Context()
requestInfo, ok := request.RequestInfoFrom(ctx)
if !ok {
scope.err(fmt.Errorf("missing requestInfo"), w, req)

View File

@ -38,7 +38,6 @@ import (
"k8s.io/apiserver/pkg/endpoints/handlers"
"k8s.io/apiserver/pkg/endpoints/handlers/negotiation"
"k8s.io/apiserver/pkg/endpoints/metrics"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/registry/rest"
genericfilters "k8s.io/apiserver/pkg/server/filters"
)
@ -188,10 +187,6 @@ func (a *APIInstaller) restMapping(resource string) (*meta.RESTMapping, error) {
func (a *APIInstaller) registerResourceHandlers(path string, storage rest.Storage, ws *restful.WebService) (*metav1.APIResource, error) {
admit := a.group.Admit
context := a.group.Context
if context == nil {
return nil, fmt.Errorf("%v missing Context", a.group.GroupVersion)
}
optionsExternalVersion := a.group.GroupVersion
if a.group.OptionsExternalVersion != nil {
@ -342,14 +337,6 @@ func (a *APIInstaller) registerResourceHandlers(path string, storage rest.Storag
}
}
var ctxFn handlers.ContextFunc
ctxFn = func(req *http.Request) request.Context {
if ctx, ok := context.Get(req); ok {
return request.WithUserAgent(ctx, req.Header.Get("User-Agent"))
}
return request.WithUserAgent(request.NewContext(), req.Header.Get("User-Agent"))
}
allowWatchList := isWatcher && isLister // watching on lists is allowed only for kinds that support both watch and list.
scope := mapping.Scope
nameParam := ws.PathParameter("name", "name of the "+kind).DataType("string")
@ -389,7 +376,6 @@ func (a *APIInstaller) registerResourceHandlers(path string, storage rest.Storag
apiResource.Namespaced = false
apiResource.Kind = resourceKind
namer := handlers.ContextBasedNaming{
GetContext: ctxFn,
SelfLinker: a.group.Linker,
ClusterScoped: true,
SelfLinkPathPrefix: gpath.Join(a.prefix, resource) + "/",
@ -438,7 +424,6 @@ func (a *APIInstaller) registerResourceHandlers(path string, storage rest.Storag
apiResource.Namespaced = true
apiResource.Kind = resourceKind
namer := handlers.ContextBasedNaming{
GetContext: ctxFn,
SelfLinker: a.group.Linker,
ClusterScoped: false,
SelfLinkPathPrefix: gpath.Join(a.prefix, scope.ParamName()) + "/",
@ -497,7 +482,6 @@ func (a *APIInstaller) registerResourceHandlers(path string, storage rest.Storag
kubeVerbs := map[string]struct{}{}
reqScope := handlers.RequestScope{
ContextFunc: ctxFn,
Serializer: a.group.Serializer,
ParameterCodec: a.group.ParameterCodec,
Creater: a.group.Creater,
@ -581,7 +565,7 @@ func (a *APIInstaller) registerResourceHandlers(path string, storage rest.Storag
}
if a.enableAPIResponseCompression {
handler = genericfilters.RestfulWithCompression(handler, a.group.Context)
handler = genericfilters.RestfulWithCompression(handler)
}
doc := "read the specified " + kind
if hasSubresource {
@ -613,7 +597,7 @@ func (a *APIInstaller) registerResourceHandlers(path string, storage rest.Storag
}
handler := metrics.InstrumentRouteFunc(action.Verb, resource, subresource, requestScope, restfulListResource(lister, watcher, reqScope, false, a.minRequestTimeout))
if a.enableAPIResponseCompression {
handler = genericfilters.RestfulWithCompression(handler, a.group.Context)
handler = genericfilters.RestfulWithCompression(handler)
}
route := ws.GET(action.Path).To(handler).
Doc(doc).

View File

@ -10,7 +10,6 @@ go_test(
name = "go_default_test",
srcs = [
"context_test.go",
"requestcontext_test.go",
"requestinfo_test.go",
],
embed = [":go_default_library"],
@ -27,12 +26,10 @@ go_library(
srcs = [
"context.go",
"doc.go",
"requestcontext.go",
"requestinfo.go",
],
importpath = "k8s.io/apiserver/pkg/endpoints/request",
deps = [
"//vendor/github.com/golang/glog:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/apis/meta/v1:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/types:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/sets:go_default_library",

View File

@ -136,17 +136,6 @@ func UIDFrom(ctx Context) (types.UID, bool) {
return uid, ok
}
// WithUserAgent returns a copy of parent in which the user value is set
func WithUserAgent(parent Context, userAgent string) Context {
return WithValue(parent, userAgentKey, userAgent)
}
// UserAgentFrom returns the value of the userAgent key on the ctx
func UserAgentFrom(ctx Context) (string, bool) {
userAgent, ok := ctx.Value(userAgentKey).(string)
return userAgent, ok
}
// WithAuditEvent returns set audit event struct.
func WithAuditEvent(parent Context, ev *audit.Event) Context {
return WithValue(parent, auditKey, ev)

View File

@ -109,25 +109,3 @@ func TestUIDContext(t *testing.T) {
t.Fatalf("Error getting UID")
}
}
//TestUserAgentContext validates that a useragent can be get/set on a context object
func TestUserAgentContext(t *testing.T) {
ctx := NewContext()
_, ok := UserAgentFrom(ctx)
if ok {
t.Fatalf("Should not be ok because there is no UserAgent on the context")
}
ctx = WithUserAgent(
ctx,
"TestUserAgent",
)
result, ok := UserAgentFrom(ctx)
if !ok {
t.Fatalf("Error getting UserAgent")
}
expectedResult := "TestUserAgent"
if result != expectedResult {
t.Fatalf("Get user agent error, Expected: %s, Actual: %s", expectedResult, result)
}
}

View File

@ -1,149 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package request
import (
"errors"
"net/http"
"sync"
"sync/atomic"
"github.com/golang/glog"
)
// LongRunningRequestCheck is a predicate which is true for long-running http requests.
type LongRunningRequestCheck func(r *http.Request, requestInfo *RequestInfo) bool
// RequestContextMapper keeps track of the context associated with a particular request
type RequestContextMapper interface {
// Get returns the context associated with the given request (if any), and true if the request has an associated context, and false if it does not.
Get(req *http.Request) (Context, bool)
// Update maps the request to the given context. If no context was previously associated with the request, an error is returned.
// Update should only be called with a descendant context of the previously associated context.
// Updating to an unrelated context may return an error in the future.
// The context associated with a request should only be updated by a limited set of callers.
// Valid examples include the authentication layer, or an audit/tracing layer.
Update(req *http.Request, context Context) error
}
type requestContextMap struct {
// contexts contains a request Context map
// atomic.Value has a very good read performance compared to sync.RWMutex
// almost all requests have 3-4 context updates associated with them,
// and they can use only read lock to protect updating context, which is of higher performance with higher burst.
contexts map[*http.Request]*atomic.Value
lock sync.RWMutex
}
// NewRequestContextMapper returns a new RequestContextMapper.
// The returned mapper must be added as a request filter using NewRequestContextFilter.
func NewRequestContextMapper() RequestContextMapper {
return &requestContextMap{
contexts: make(map[*http.Request]*atomic.Value),
}
}
func (c *requestContextMap) getValue(req *http.Request) (*atomic.Value, bool) {
c.lock.RLock()
defer c.lock.RUnlock()
value, ok := c.contexts[req]
return value, ok
}
// contextWrap is a wrapper of Context to prevent atomic.Value to be copied
type contextWrap struct {
Context
}
// Get returns the context associated with the given request (if any), and true if the request has an associated context, and false if it does not.
// Get will only return a valid context when called from inside the filter chain set up by NewRequestContextFilter()
func (c *requestContextMap) Get(req *http.Request) (Context, bool) {
value, ok := c.getValue(req)
if !ok {
return nil, false
}
if context, ok := value.Load().(contextWrap); ok {
return context.Context, ok
}
return nil, false
}
// Update maps the request to the given context.
// If no context was previously associated with the request, an error is returned and the context is ignored.
func (c *requestContextMap) Update(req *http.Request, context Context) error {
value, ok := c.getValue(req)
if !ok {
return errors.New("no context associated")
}
wrapper, ok := value.Load().(contextWrap)
if !ok {
return errors.New("value type does not match")
}
wrapper.Context = context
value.Store(wrapper)
return nil
}
// init maps the request to the given context and returns true if there was no context associated with the request already.
// if a context was already associated with the request, it ignores the given context and returns false.
// init is intentionally unexported to ensure that all init calls are paired with a remove after a request is handled
func (c *requestContextMap) init(req *http.Request, context Context) bool {
c.lock.Lock()
defer c.lock.Unlock()
if _, exists := c.contexts[req]; exists {
return false
}
value := &atomic.Value{}
value.Store(contextWrap{context})
c.contexts[req] = value
return true
}
// remove is intentionally unexported to ensure that the context is not removed until a request is handled
func (c *requestContextMap) remove(req *http.Request) {
c.lock.Lock()
defer c.lock.Unlock()
delete(c.contexts, req)
}
// WithRequestContext ensures there is a Context object associated with the request before calling the passed handler.
// After the passed handler runs, the context is cleaned up.
func WithRequestContext(handler http.Handler, mapper RequestContextMapper) http.Handler {
rcMap, ok := mapper.(*requestContextMap)
if !ok {
glog.Fatal("Unknown RequestContextMapper implementation.")
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if rcMap.init(req, NewContext()) {
// If we were the ones to successfully initialize, pair with a remove
defer rcMap.remove(req)
}
handler.ServeHTTP(w, req)
})
}
// IsEmpty returns true if there are no contexts registered, or an error if it could not be determined. Intended for use by tests.
func IsEmpty(requestsToContexts RequestContextMapper) (bool, error) {
if requestsToContexts, ok := requestsToContexts.(*requestContextMap); ok {
return len(requestsToContexts.contexts) == 0, nil
}
return true, errors.New("Unknown RequestContextMapper implementation")
}

View File

@ -1,154 +0,0 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package request
import (
"net/http"
"sync"
"testing"
)
func TestRequestContextMapperGet(t *testing.T) {
mapper := NewRequestContextMapper()
context := NewContext()
req, _ := http.NewRequest("GET", "/api/version/resource", nil)
// empty mapper
if _, ok := mapper.Get(req); ok {
t.Fatalf("got unexpected context")
}
// init mapper
mapper.(*requestContextMap).init(req, context)
if _, ok := mapper.Get(req); !ok {
t.Fatalf("got no context")
}
// remove request context
mapper.(*requestContextMap).remove(req)
if _, ok := mapper.Get(req); ok {
t.Fatalf("got unexpected context")
}
}
func TestRequestContextMapperUpdate(t *testing.T) {
mapper := NewRequestContextMapper()
context := NewContext()
req, _ := http.NewRequest("GET", "/api/version/resource", nil)
// empty mapper
if err := mapper.Update(req, context); err == nil {
t.Fatalf("got no error")
}
// init mapper
if !mapper.(*requestContextMap).init(req, context) {
t.Fatalf("unexpected error, should init mapper")
}
context = WithNamespace(context, "default")
if err := mapper.Update(req, context); err != nil {
t.Fatalf("unexpected error")
}
if context, ok := mapper.Get(req); !ok {
t.Fatalf("go no context")
} else {
if ns, _ := NamespaceFrom(context); ns != "default" {
t.Fatalf("unexpected namespace %s", ns)
}
}
}
func TestRequestContextMapperConcurrent(t *testing.T) {
mapper := NewRequestContextMapper()
testCases := []struct{ url, namespace string }{
{"/api/version/resource1", "ns1"},
{"/api/version/resource2", "ns2"},
{"/api/version/resource3", "ns3"},
{"/api/version/resource4", "ns4"},
{"/api/version/resource5", "ns5"},
}
wg := sync.WaitGroup{}
for _, testcase := range testCases {
wg.Add(1)
go func(testcase struct{ url, namespace string }) {
defer wg.Done()
context := NewContext()
req, _ := http.NewRequest("GET", testcase.url, nil)
if !mapper.(*requestContextMap).init(req, context) {
t.Errorf("unexpected init error")
return
}
if _, ok := mapper.Get(req); !ok {
t.Errorf("got no context")
return
}
context2 := WithNamespace(context, testcase.namespace)
if err := mapper.Update(req, context2); err != nil {
t.Errorf("unexpected update error")
return
}
if context, ok := mapper.Get(req); !ok {
t.Errorf("got no context")
return
} else {
if ns, _ := NamespaceFrom(context); ns != testcase.namespace {
t.Errorf("unexpected namespace %s", ns)
return
}
}
}(testcase)
}
wg.Wait()
}
func BenchmarkRequestContextMapper(b *testing.B) {
mapper := NewRequestContextMapper()
b.SetParallelism(500)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
context := NewContext()
req, _ := http.NewRequest("GET", "/api/version/resource", nil)
// 1 init
mapper.(*requestContextMap).init(req, context)
// 5 Get + 4 Update
mapper.Get(req)
context = WithNamespace(context, "default1")
mapper.Update(req, context)
mapper.Get(req)
context = WithNamespace(context, "default2")
mapper.Update(req, context)
mapper.Get(req)
context = WithNamespace(context, "default3")
mapper.Update(req, context)
mapper.Get(req)
context = WithNamespace(context, "default4")
mapper.Update(req, context)
mapper.Get(req)
// 1 remove
mapper.(*requestContextMap).remove(req)
}
})
}

View File

@ -25,6 +25,9 @@ import (
"k8s.io/apimachinery/pkg/util/sets"
)
// LongRunningRequestCheck is a predicate which is true for long-running http requests.
type LongRunningRequestCheck func(r *http.Request, requestInfo *RequestInfo) bool
type RequestInfoResolver interface {
NewRequestInfo(req *http.Request) (*RequestInfo, error)
}

View File

@ -127,7 +127,7 @@ type Config struct {
//===========================================================================
// BuildHandlerChainFunc allows you to build custom handler chains by decorating the apiHandler.
BuildHandlerChainFunc func(apiHandler http.Handler, c *Config, contextMapper apirequest.RequestContextMapper) (secure http.Handler)
BuildHandlerChainFunc func(apiHandler http.Handler, c *Config) (secure http.Handler)
// HandlerChainWaitGroup allows you to wait for all chain handlers exit after the server shutdown.
HandlerChainWaitGroup *utilwaitgroup.SafeWaitGroup
// DiscoveryAddresses is used to build the IPs pass to discovery. If nil, the ExternalAddress is
@ -452,11 +452,10 @@ func (c completedConfig) New(name string, delegationTarget DelegationTarget) (*G
return nil, fmt.Errorf("Genericapiserver.New() called with config.LoopbackClientConfig == nil")
}
contextMapper := delegationTarget.RequestContextMapper()
handlerChainBuilder := func(handler http.Handler) http.Handler {
return c.BuildHandlerChainFunc(handler, c.Config, contextMapper)
return c.BuildHandlerChainFunc(handler, c.Config)
}
apiServerHandler := NewAPIServerHandler(name, contextMapper, c.Serializer, handlerChainBuilder, delegationTarget.UnprotectedHandler())
apiServerHandler := NewAPIServerHandler(name, c.Serializer, handlerChainBuilder, delegationTarget.UnprotectedHandler())
s := &GenericAPIServer{
discoveryAddresses: c.DiscoveryAddresses,
@ -487,7 +486,7 @@ func (c completedConfig) New(name string, delegationTarget DelegationTarget) (*G
healthzChecks: c.HealthzChecks,
DiscoveryGroupManager: discovery.NewRootAPIsHandler(c.DiscoveryAddresses, c.Serializer, contextMapper),
DiscoveryGroupManager: discovery.NewRootAPIsHandler(c.DiscoveryAddresses, c.Serializer),
enableAPIResponseCompression: c.EnableAPIResponseCompression,
}
@ -542,25 +541,24 @@ func (c completedConfig) New(name string, delegationTarget DelegationTarget) (*G
return s, nil
}
func DefaultBuildHandlerChain(apiHandler http.Handler, c *Config, contextMapper apirequest.RequestContextMapper) http.Handler {
handler := genericapifilters.WithAuthorization(apiHandler, contextMapper, c.Authorization.Authorizer, c.Serializer)
handler = genericfilters.WithMaxInFlightLimit(handler, c.MaxRequestsInFlight, c.MaxMutatingRequestsInFlight, contextMapper, c.LongRunningFunc)
handler = genericapifilters.WithImpersonation(handler, contextMapper, c.Authorization.Authorizer, c.Serializer)
func DefaultBuildHandlerChain(apiHandler http.Handler, c *Config) http.Handler {
handler := genericapifilters.WithAuthorization(apiHandler, c.Authorization.Authorizer, c.Serializer)
handler = genericfilters.WithMaxInFlightLimit(handler, c.MaxRequestsInFlight, c.MaxMutatingRequestsInFlight, c.LongRunningFunc)
handler = genericapifilters.WithImpersonation(handler, c.Authorization.Authorizer, c.Serializer)
if utilfeature.DefaultFeatureGate.Enabled(features.AdvancedAuditing) {
handler = genericapifilters.WithAudit(handler, contextMapper, c.AuditBackend, c.AuditPolicyChecker, c.LongRunningFunc)
handler = genericapifilters.WithAudit(handler, c.AuditBackend, c.AuditPolicyChecker, c.LongRunningFunc)
} else {
handler = genericapifilters.WithLegacyAudit(handler, contextMapper, c.LegacyAuditWriter)
handler = genericapifilters.WithLegacyAudit(handler, c.LegacyAuditWriter)
}
failedHandler := genericapifilters.Unauthorized(contextMapper, c.Serializer, c.Authentication.SupportsBasicAuth)
failedHandler := genericapifilters.Unauthorized(c.Serializer, c.Authentication.SupportsBasicAuth)
if utilfeature.DefaultFeatureGate.Enabled(features.AdvancedAuditing) {
failedHandler = genericapifilters.WithFailedAuthenticationAudit(failedHandler, contextMapper, c.AuditBackend, c.AuditPolicyChecker)
failedHandler = genericapifilters.WithFailedAuthenticationAudit(failedHandler, c.AuditBackend, c.AuditPolicyChecker)
}
handler = genericapifilters.WithAuthentication(handler, contextMapper, c.Authentication.Authenticator, failedHandler)
handler = genericapifilters.WithAuthentication(handler, c.Authentication.Authenticator, failedHandler)
handler = genericfilters.WithCORS(handler, c.CorsAllowedOriginList, nil, nil, nil, "true")
handler = genericfilters.WithTimeoutForNonLongRunningRequests(handler, contextMapper, c.LongRunningFunc, c.RequestTimeout)
handler = genericfilters.WithWaitGroup(handler, contextMapper, c.LongRunningFunc, c.HandlerChainWaitGroup)
handler = genericapifilters.WithRequestInfo(handler, c.RequestInfoResolver, contextMapper)
handler = apirequest.WithRequestContext(handler, contextMapper)
handler = genericfilters.WithTimeoutForNonLongRunningRequests(handler, c.LongRunningFunc, c.RequestTimeout)
handler = genericfilters.WithWaitGroup(handler, c.LongRunningFunc, c.HandlerChainWaitGroup)
handler = genericapifilters.WithRequestInfo(handler, c.RequestInfoResolver)
handler = genericfilters.WithPanicRecovery(handler)
return handler
}

View File

@ -46,9 +46,9 @@ const (
)
// WithCompression wraps an http.Handler with the Compression Handler
func WithCompression(handler http.Handler, ctxMapper request.RequestContextMapper) http.Handler {
func WithCompression(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
wantsCompression, encoding := wantsCompressedResponse(req, ctxMapper)
wantsCompression, encoding := wantsCompressedResponse(req)
w.Header().Set("Vary", "Accept-Encoding")
if wantsCompression {
compressionWriter, err := NewCompressionResponseWriter(w, encoding)
@ -67,12 +67,9 @@ func WithCompression(handler http.Handler, ctxMapper request.RequestContextMappe
}
// wantsCompressedResponse reads the Accept-Encoding header to see if and which encoding is requested.
func wantsCompressedResponse(req *http.Request, ctxMapper request.RequestContextMapper) (bool, string) {
func wantsCompressedResponse(req *http.Request) (bool, string) {
// don't compress watches
ctx, ok := ctxMapper.Get(req)
if !ok {
return false, ""
}
ctx := req.Context()
info, ok := request.RequestInfoFrom(ctx)
if !ok {
return false, ""
@ -172,13 +169,13 @@ func (c *compressionResponseWriter) compressorClosed() bool {
}
// RestfulWithCompression wraps WithCompression to be compatible with go-restful
func RestfulWithCompression(function restful.RouteFunction, ctxMapper request.RequestContextMapper) restful.RouteFunction {
func RestfulWithCompression(function restful.RouteFunction) restful.RouteFunction {
return restful.RouteFunction(func(request *restful.Request, response *restful.Response) {
handler := WithCompression(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
response.ResponseWriter = w
request.Request = req
function(request, response)
}), ctxMapper)
}))
handler.ServeHTTP(response.ResponseWriter, request.Request)
})
}

View File

@ -42,17 +42,13 @@ func TestCompression(t *testing.T) {
responseData := []byte("1234")
requestContextMapper := request.NewRequestContextMapper()
for _, test := range tests {
handler := WithCompression(
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write(responseData)
}),
requestContextMapper,
)
handler = filters.WithRequestInfo(handler, newTestRequestInfoResolver(), requestContextMapper)
handler = request.WithRequestContext(handler, requestContextMapper)
handler = filters.WithRequestInfo(handler, newTestRequestInfoResolver())
server := httptest.NewServer(handler)
defer server.Close()
client := http.Client{

View File

@ -98,7 +98,6 @@ func WithMaxInFlightLimit(
handler http.Handler,
nonMutatingLimit int,
mutatingLimit int,
requestContextMapper apirequest.RequestContextMapper,
longRunningRequestCheck apirequest.LongRunningRequestCheck,
) http.Handler {
startOnce.Do(startRecordingUsage)
@ -115,11 +114,7 @@ func WithMaxInFlightLimit(
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, ok := requestContextMapper.Get(r)
if !ok {
handleError(w, r, fmt.Errorf("no context found for request, handler chain must be wrong"))
return
}
ctx := r.Context()
requestInfo, ok := apirequest.RequestInfoFrom(ctx)
if !ok {
handleError(w, r, fmt.Errorf("no RequestInfo found in context, handler chain must be wrong"))

View File

@ -33,7 +33,6 @@ import (
func createMaxInflightServer(callsWg, blockWg *sync.WaitGroup, disableCallsWg *bool, disableCallsWgMutex *sync.Mutex, nonMutating, mutating int) *httptest.Server {
longRunningRequestCheck := BasicLongRunningRequestCheck(sets.NewString("watch"), sets.NewString("proxy"))
requestContextMapper := apirequest.NewRequestContextMapper()
requestInfoFactory := &apirequest.RequestInfoFactory{APIPrefixes: sets.NewString("apis", "api"), GrouplessAPIPrefixes: sets.NewString("api")}
handler := WithMaxInFlightLimit(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -51,26 +50,18 @@ func createMaxInflightServer(callsWg, blockWg *sync.WaitGroup, disableCallsWg *b
}),
nonMutating,
mutating,
requestContextMapper,
longRunningRequestCheck,
)
handler = withFakeUser(handler, requestContextMapper)
handler = apifilters.WithRequestInfo(handler, requestInfoFactory, requestContextMapper)
handler = apirequest.WithRequestContext(handler, requestContextMapper)
handler = withFakeUser(handler)
handler = apifilters.WithRequestInfo(handler, requestInfoFactory)
return httptest.NewServer(handler)
}
func withFakeUser(handler http.Handler, requestContextMapper apirequest.RequestContextMapper) http.Handler {
func withFakeUser(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, ok := requestContextMapper.Get(r)
if !ok {
handleError(w, r, fmt.Errorf("no context found for request, handler chain must be wrong"))
return
}
if len(r.Header["Groups"]) > 0 {
requestContextMapper.Update(r, apirequest.WithUser(ctx, &user.DefaultInfo{
r = r.WithContext(apirequest.WithUser(r.Context(), &user.DefaultInfo{
Groups: r.Header["Groups"],
}))
}

View File

@ -34,43 +34,37 @@ import (
var errConnKilled = fmt.Errorf("killing connection/stream because serving request timed out and response had been started")
// WithTimeoutForNonLongRunningRequests times out non-long-running requests after the time given by timeout.
func WithTimeoutForNonLongRunningRequests(handler http.Handler, requestContextMapper apirequest.RequestContextMapper, longRunning apirequest.LongRunningRequestCheck, timeout time.Duration) http.Handler {
func WithTimeoutForNonLongRunningRequests(handler http.Handler, longRunning apirequest.LongRunningRequestCheck, timeout time.Duration) http.Handler {
if longRunning == nil {
return handler
}
timeoutFunc := func(req *http.Request) (<-chan time.Time, func(), *apierrors.StatusError) {
timeoutFunc := func(req *http.Request) (*http.Request, <-chan time.Time, func(), *apierrors.StatusError) {
// TODO unify this with apiserver.MaxInFlightLimit
ctx, ok := requestContextMapper.Get(req)
if !ok {
// if this happens, the handler chain isn't setup correctly because there is no context mapper
return time.After(timeout), func() {}, apierrors.NewInternalError(fmt.Errorf("no context found for request during timeout"))
}
ctx := req.Context()
requestInfo, ok := apirequest.RequestInfoFrom(ctx)
if !ok {
// if this happens, the handler chain isn't setup correctly because there is no request info
return time.After(timeout), func() {}, apierrors.NewInternalError(fmt.Errorf("no request info found for request during timeout"))
return req, time.After(timeout), func() {}, apierrors.NewInternalError(fmt.Errorf("no request info found for request during timeout"))
}
if longRunning(req, requestInfo) {
return nil, nil, nil
return req, nil, nil, nil
}
ctx, cancel := context.WithCancel(ctx)
if err := requestContextMapper.Update(req, ctx); err != nil {
return time.After(timeout), func() {}, apierrors.NewInternalError(fmt.Errorf("failed to update context during timeout"))
}
req = req.WithContext(ctx)
postTimeoutFn := func() {
cancel()
metrics.Record(req, requestInfo, "", http.StatusGatewayTimeout, 0, 0)
}
return time.After(timeout), postTimeoutFn, apierrors.NewTimeoutError(fmt.Sprintf("request did not complete within %s", timeout), 0)
return req, time.After(timeout), postTimeoutFn, apierrors.NewTimeoutError(fmt.Sprintf("request did not complete within %s", timeout), 0)
}
return WithTimeout(handler, timeoutFunc)
}
type timeoutFunc = func(*http.Request) (timeout <-chan time.Time, postTimeoutFunc func(), err *apierrors.StatusError)
type timeoutFunc = func(*http.Request) (req *http.Request, timeout <-chan time.Time, postTimeoutFunc func(), err *apierrors.StatusError)
// WithTimeout returns an http.Handler that runs h with a timeout
// determined by timeoutFunc. The new http.Handler calls h.ServeHTTP to handle
@ -91,7 +85,7 @@ type timeoutHandler struct {
}
func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
after, postTimeoutFn, err := t.timeout(r)
r, after, postTimeoutFn, err := t.timeout(r)
if after == nil {
t.handler.ServeHTTP(w, r)
return

View File

@ -63,8 +63,8 @@ func TestTimeout(t *testing.T) {
_, err := w.Write([]byte(resp))
writeErrors <- err
}),
func(*http.Request) (<-chan time.Time, func(), *apierrors.StatusError) {
return timeout, record.Record, timeoutErr
func(req *http.Request) (*http.Request, <-chan time.Time, func(), *apierrors.StatusError) {
return req, timeout, record.Record, timeoutErr
}))
defer ts.Close()

View File

@ -24,15 +24,9 @@ import (
)
// WithWaitGroup adds all non long-running requests to wait group, which is used for graceful shutdown.
func WithWaitGroup(handler http.Handler, requestContextMapper apirequest.RequestContextMapper, longRunning apirequest.LongRunningRequestCheck, wg *utilwaitgroup.SafeWaitGroup) http.Handler {
func WithWaitGroup(handler http.Handler, longRunning apirequest.LongRunningRequestCheck, wg *utilwaitgroup.SafeWaitGroup) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, ok := requestContextMapper.Get(req)
if !ok {
// if this happens, the handler chain isn't setup correctly because there is no context mapper
handler.ServeHTTP(w, req)
return
}
ctx := req.Context()
requestInfo, ok := apirequest.RequestInfoFrom(ctx)
if !ok {
// if this happens, the handler chain isn't setup correctly because there is no request info

View File

@ -39,7 +39,6 @@ import (
"k8s.io/apiserver/pkg/audit"
genericapi "k8s.io/apiserver/pkg/endpoints"
"k8s.io/apiserver/pkg/endpoints/discovery"
apirequest "k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/registry/rest"
"k8s.io/apiserver/pkg/server/healthz"
"k8s.io/apiserver/pkg/server/routes"
@ -157,10 +156,6 @@ type DelegationTarget interface {
// UnprotectedHandler returns a handler that is NOT protected by a normal chain
UnprotectedHandler() http.Handler
// RequestContextMapper returns the existing RequestContextMapper. Because we cannot rewire all existing
// uses of this function, this will be used in any delegating API server
RequestContextMapper() apirequest.RequestContextMapper
// PostStartHooks returns the post-start hooks that need to be combined
PostStartHooks() map[string]postStartHookEntry
@ -199,13 +194,10 @@ func (s *GenericAPIServer) NextDelegate() DelegationTarget {
}
type emptyDelegate struct {
requestContextMapper apirequest.RequestContextMapper
}
func NewEmptyDelegate() DelegationTarget {
return emptyDelegate{
requestContextMapper: apirequest.NewRequestContextMapper(),
}
return emptyDelegate{}
}
func (s emptyDelegate) UnprotectedHandler() http.Handler {
@ -223,17 +215,10 @@ func (s emptyDelegate) HealthzChecks() []healthz.HealthzChecker {
func (s emptyDelegate) ListedPaths() []string {
return []string{}
}
func (s emptyDelegate) RequestContextMapper() apirequest.RequestContextMapper {
return s.requestContextMapper
}
func (s emptyDelegate) NextDelegate() DelegationTarget {
return nil
}
func (s *GenericAPIServer) RequestContextMapper() apirequest.RequestContextMapper {
return s.delegationTarget.RequestContextMapper()
}
// preparedGenericAPIServer is a private wrapper that enforces a call of PrepareRun() before Run can be invoked.
type preparedGenericAPIServer struct {
*GenericAPIServer
@ -364,7 +349,7 @@ func (s *GenericAPIServer) InstallLegacyAPIGroup(apiPrefix string, apiGroupInfo
}
// Install the version handler.
// Add a handler at /<apiPrefix> to enumerate the supported api versions.
s.Handler.GoRestfulContainer.Add(discovery.NewLegacyRootAPIHandler(s.discoveryAddresses, s.Serializer, apiPrefix, apiVersions, s.delegationTarget.RequestContextMapper()).WebService())
s.Handler.GoRestfulContainer.Add(discovery.NewLegacyRootAPIHandler(s.discoveryAddresses, s.Serializer, apiPrefix, apiVersions).WebService())
return nil
}
@ -409,7 +394,7 @@ func (s *GenericAPIServer) InstallAPIGroup(apiGroupInfo *APIGroupInfo) error {
}
s.DiscoveryGroupManager.AddGroup(apiGroup)
s.Handler.GoRestfulContainer.Add(discovery.NewAPIGroupHandler(s.Serializer, apiGroup, s.delegationTarget.RequestContextMapper()).WebService())
s.Handler.GoRestfulContainer.Add(discovery.NewAPIGroupHandler(s.Serializer, apiGroup).WebService())
return nil
}
@ -441,7 +426,6 @@ func (s *GenericAPIServer) newAPIGroupVersion(apiGroupInfo *APIGroupInfo, groupV
Mapper: apiGroupInfo.GroupMeta.RESTMapper,
Admit: s.admissionControl,
Context: s.RequestContextMapper(),
MinRequestTimeout: s.minRequestTimeout,
EnableAPIResponseCompression: s.enableAPIResponseCompression,
}

View File

@ -331,7 +331,7 @@ func TestCustomHandlerChain(t *testing.T) {
var protected, called bool
config.BuildHandlerChainFunc = func(apiHandler http.Handler, c *Config, contextMapper apirequest.RequestContextMapper) http.Handler {
config.BuildHandlerChainFunc = func(apiHandler http.Handler, c *Config) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
protected = true
apiHandler.ServeHTTP(w, req)
@ -507,10 +507,9 @@ func TestGracefulShutdown(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(1)
config.BuildHandlerChainFunc = func(apiHandler http.Handler, c *Config, contextMapper apirequest.RequestContextMapper) http.Handler {
handler := genericfilters.WithWaitGroup(apiHandler, contextMapper, c.LongRunningFunc, c.HandlerChainWaitGroup)
handler = genericapifilters.WithRequestInfo(handler, c.RequestInfoResolver, contextMapper)
handler = apirequest.WithRequestContext(handler, contextMapper)
config.BuildHandlerChainFunc = func(apiHandler http.Handler, c *Config) http.Handler {
handler := genericfilters.WithWaitGroup(apiHandler, c.LongRunningFunc, c.HandlerChainWaitGroup)
handler = genericapifilters.WithRequestInfo(handler, c.RequestInfoResolver)
return handler
}

View File

@ -18,7 +18,6 @@ package server
import (
"bytes"
"errors"
"fmt"
"net/http"
rt "runtime"
@ -32,7 +31,6 @@ import (
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apiserver/pkg/endpoints/handlers/responsewriters"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/server/mux"
)
@ -72,7 +70,7 @@ type APIServerHandler struct {
// It is normally used to apply filtering like authentication and authorization
type HandlerChainBuilderFn func(apiHandler http.Handler) http.Handler
func NewAPIServerHandler(name string, contextMapper request.RequestContextMapper, s runtime.NegotiatedSerializer, handlerChainBuilder HandlerChainBuilderFn, notFoundHandler http.Handler) *APIServerHandler {
func NewAPIServerHandler(name string, s runtime.NegotiatedSerializer, handlerChainBuilder HandlerChainBuilderFn, notFoundHandler http.Handler) *APIServerHandler {
nonGoRestfulMux := mux.NewPathRecorderMux(name)
if notFoundHandler != nil {
nonGoRestfulMux.NotFoundHandler(notFoundHandler)
@ -85,11 +83,7 @@ func NewAPIServerHandler(name string, contextMapper request.RequestContextMapper
logStackOnRecover(s, panicReason, httpWriter)
})
gorestfulContainer.ServiceErrorHandler(func(serviceErr restful.ServiceError, request *restful.Request, response *restful.Response) {
ctx, ok := contextMapper.Get(request.Request)
if !ok {
responsewriters.InternalError(response.ResponseWriter, request.Request, errors.New("no context found for request"))
}
serviceErrorHandler(ctx, s, serviceErr, request, response)
serviceErrorHandler(s, serviceErr, request, response)
})
director := director{
@ -177,13 +171,11 @@ func logStackOnRecover(s runtime.NegotiatedSerializer, panicReason interface{},
if ct := w.Header().Get("Content-Type"); len(ct) > 0 {
headers.Set("Accept", ct)
}
emptyContext := request.NewContext() // best we can do here: we don't know the request
responsewriters.ErrorNegotiated(emptyContext, apierrors.NewGenericServerResponse(http.StatusInternalServerError, "", schema.GroupResource{}, "", "", 0, false), s, schema.GroupVersion{}, w, &http.Request{Header: headers})
responsewriters.ErrorNegotiated(apierrors.NewGenericServerResponse(http.StatusInternalServerError, "", schema.GroupResource{}, "", "", 0, false), s, schema.GroupVersion{}, w, &http.Request{Header: headers})
}
func serviceErrorHandler(ctx request.Context, s runtime.NegotiatedSerializer, serviceErr restful.ServiceError, request *restful.Request, resp *restful.Response) {
func serviceErrorHandler(s runtime.NegotiatedSerializer, serviceErr restful.ServiceError, request *restful.Request, resp *restful.Response) {
responsewriters.ErrorNegotiated(
ctx,
apierrors.NewGenericServerResponse(serviceErr.Code, "", schema.GroupResource{}, "", serviceErr.Message, 0, false),
s,
schema.GroupVersion{},

View File

@ -23,7 +23,6 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/sets"
genericapirequest "k8s.io/apiserver/pkg/endpoints/request"
genericapiserver "k8s.io/apiserver/pkg/server"
serverstorage "k8s.io/apiserver/pkg/server/storage"
"k8s.io/client-go/pkg/version"
@ -92,8 +91,6 @@ type APIAggregator struct {
delegateHandler http.Handler
contextMapper genericapirequest.RequestContextMapper
// proxyClientCert/Key are the client cert used to identify this proxy. Backing APIServices use
// this to confirm the proxy's identity
proxyClientCert []byte
@ -158,7 +155,6 @@ func (c completedConfig) NewWithDelegate(delegationTarget genericapiserver.Deleg
s := &APIAggregator{
GenericAPIServer: genericServer,
delegateHandler: delegationTarget.UnprotectedHandler(),
contextMapper: genericServer.RequestContextMapper(),
proxyClientCert: c.ExtraConfig.ProxyClientCert,
proxyClientKey: c.ExtraConfig.ProxyClientKey,
proxyTransport: c.ExtraConfig.ProxyTransport,
@ -177,7 +173,6 @@ func (c completedConfig) NewWithDelegate(delegationTarget genericapiserver.Deleg
apisHandler := &apisHandler{
codecs: aggregatorscheme.Codecs,
lister: s.lister,
mapper: s.contextMapper,
}
s.GenericAPIServer.Handler.NonGoRestfulMux.Handle("/apis", apisHandler)
s.GenericAPIServer.Handler.NonGoRestfulMux.UnlistedHandle("/apis/", apisHandler)
@ -208,7 +203,7 @@ func (c completedConfig) NewWithDelegate(delegationTarget genericapiserver.Deleg
})
if openApiConfig != nil {
specDownloader := openapicontroller.NewDownloader(s.contextMapper)
specDownloader := openapicontroller.NewDownloader()
openAPIAggregator, err := openapicontroller.BuildAndRegisterAggregator(
&specDownloader,
delegationTarget,
@ -250,7 +245,6 @@ func (s *APIAggregator) AddAPIService(apiService *apiregistration.APIService) er
// register the proxy handler
proxyHandler := &proxyHandler{
contextMapper: s.contextMapper,
localDelegate: s.delegateHandler,
proxyClientCert: s.proxyClientCert,
proxyClientKey: s.proxyClientKey,
@ -278,11 +272,10 @@ func (s *APIAggregator) AddAPIService(apiService *apiregistration.APIService) er
// it's time to register the group aggregation endpoint
groupPath := "/apis/" + apiService.Spec.Group
groupDiscoveryHandler := &apiGroupHandler{
codecs: aggregatorscheme.Codecs,
groupName: apiService.Spec.Group,
lister: s.lister,
delegate: s.delegateHandler,
contextMapper: s.contextMapper,
codecs: aggregatorscheme.Codecs,
groupName: apiService.Spec.Group,
lister: s.lister,
delegate: s.delegateHandler,
}
// aggregation is protected
s.GenericAPIServer.Handler.NonGoRestfulMux.Handle(groupPath, groupDiscoveryHandler)

View File

@ -17,7 +17,6 @@ limitations under the License.
package apiserver
import (
"errors"
"net/http"
apierrors "k8s.io/apimachinery/pkg/api/errors"
@ -26,7 +25,6 @@ import (
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/runtime/serializer"
"k8s.io/apiserver/pkg/endpoints/handlers/responsewriters"
"k8s.io/apiserver/pkg/endpoints/request"
apiregistrationapi "k8s.io/kube-aggregator/pkg/apis/apiregistration"
apiregistrationv1api "k8s.io/kube-aggregator/pkg/apis/apiregistration/v1"
@ -39,7 +37,6 @@ import (
type apisHandler struct {
codecs serializer.CodecFactory
lister listers.APIServiceLister
mapper request.RequestContextMapper
}
var discoveryGroup = metav1.APIGroup{
@ -61,12 +58,6 @@ var discoveryGroup = metav1.APIGroup{
}
func (r *apisHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
ctx, ok := r.mapper.Get(req)
if !ok {
responsewriters.InternalError(w, req, errors.New("no context found for request"))
return
}
discoveryGroupList := &metav1.APIGroupList{
// always add OUR api group to the list first. Since we'll never have a registered APIService for it
// and since this is the crux of the API, having this first will give our names priority. It's good to be king.
@ -90,7 +81,7 @@ func (r *apisHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
}
responsewriters.WriteObjectNegotiated(ctx, r.codecs, schema.GroupVersion{}, w, req, http.StatusOK, discoveryGroupList)
responsewriters.WriteObjectNegotiated(r.codecs, schema.GroupVersion{}, w, req, http.StatusOK, discoveryGroupList)
}
// convertToDiscoveryAPIGroup takes apiservices in a single group and returns a discovery compatible object.
@ -129,9 +120,8 @@ func convertToDiscoveryAPIGroup(apiServices []*apiregistrationapi.APIService) *m
// apiGroupHandler serves the `/apis/<group>` endpoint.
type apiGroupHandler struct {
codecs serializer.CodecFactory
groupName string
contextMapper request.RequestContextMapper
codecs serializer.CodecFactory
groupName string
lister listers.APIServiceLister
@ -139,12 +129,6 @@ type apiGroupHandler struct {
}
func (r *apiGroupHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
ctx, ok := r.contextMapper.Get(req)
if !ok {
responsewriters.InternalError(w, req, errors.New("no context found for request"))
return
}
apiServices, err := r.lister.List(labels.Everything())
if statusErr, ok := err.(*apierrors.StatusError); ok && err != nil {
responsewriters.WriteRawJSON(int(statusErr.Status().Code), statusErr.Status(), w)
@ -172,5 +156,5 @@ func (r *apiGroupHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
http.Error(w, "", http.StatusNotFound)
return
}
responsewriters.WriteObjectNegotiated(ctx, r.codecs, schema.GroupVersion{}, w, req, http.StatusOK, discoveryGroup)
responsewriters.WriteObjectNegotiated(r.codecs, schema.GroupVersion{}, w, req, http.StatusOK, discoveryGroup)
}

View File

@ -27,7 +27,6 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/diff"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/client-go/tools/cache"
"k8s.io/kube-aggregator/pkg/apis/apiregistration"
@ -240,18 +239,16 @@ func TestAPIs(t *testing.T) {
}
for _, tc := range tests {
mapper := request.NewRequestContextMapper()
indexer := cache.NewIndexer(cache.MetaNamespaceKeyFunc, cache.Indexers{cache.NamespaceIndex: cache.MetaNamespaceIndexFunc})
handler := &apisHandler{
codecs: aggregatorscheme.Codecs,
lister: listers.NewAPIServiceLister(indexer),
mapper: mapper,
}
for _, o := range tc.apiservices {
indexer.Add(o)
}
server := httptest.NewServer(request.WithRequestContext(handler, mapper))
server := httptest.NewServer(handler)
defer server.Close()
resp, err := http.Get(server.URL + "/apis")
@ -278,7 +275,6 @@ func TestAPIs(t *testing.T) {
}
func TestAPIGroupMissing(t *testing.T) {
mapper := request.NewRequestContextMapper()
indexer := cache.NewIndexer(cache.MetaNamespaceKeyFunc, cache.Indexers{cache.NamespaceIndex: cache.MetaNamespaceIndexFunc})
handler := &apiGroupHandler{
codecs: aggregatorscheme.Codecs,
@ -287,10 +283,9 @@ func TestAPIGroupMissing(t *testing.T) {
delegate: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
}),
contextMapper: mapper,
}
server := httptest.NewServer(request.WithRequestContext(handler, mapper))
server := httptest.NewServer(handler)
defer server.Close()
// this call should delegate
@ -425,19 +420,17 @@ func TestAPIGroup(t *testing.T) {
}
for _, tc := range tests {
mapper := request.NewRequestContextMapper()
indexer := cache.NewIndexer(cache.MetaNamespaceKeyFunc, cache.Indexers{cache.NamespaceIndex: cache.MetaNamespaceIndexFunc})
handler := &apiGroupHandler{
codecs: aggregatorscheme.Codecs,
lister: listers.NewAPIServiceLister(indexer),
groupName: "foo",
contextMapper: mapper,
codecs: aggregatorscheme.Codecs,
lister: listers.NewAPIServiceLister(indexer),
groupName: "foo",
}
for _, o := range tc.apiservices {
indexer.Add(o)
}
server := httptest.NewServer(request.WithRequestContext(handler, mapper))
server := httptest.NewServer(handler)
defer server.Close()
resp, err := http.Get(server.URL + "/apis/" + tc.group)

View File

@ -42,8 +42,6 @@ import (
// proxyHandler provides a http.Handler which will proxy traffic to locations
// specified by items implementing Redirector.
type proxyHandler struct {
contextMapper genericapirequest.RequestContextMapper
// localDelegate is used to satisfy local APIServices
localDelegate http.Handler
@ -104,12 +102,7 @@ func (r *proxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return
}
ctx, ok := r.contextMapper.Get(req)
if !ok {
http.Error(w, "missing context", http.StatusInternalServerError)
return
}
user, ok := genericapirequest.UserFrom(ctx)
user, ok := genericapirequest.UserFrom(req.Context())
if !ok {
http.Error(w, "missing user", http.StatusInternalServerError)
return

View File

@ -57,30 +57,23 @@ func (d *targetHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
type fakeRequestContextMapper struct {
user user.Info
}
func (m *fakeRequestContextMapper) Get(req *http.Request) (genericapirequest.Context, bool) {
ctx := genericapirequest.NewContext()
if m.user != nil {
ctx = genericapirequest.WithUser(ctx, m.user)
}
resolver := &genericapirequest.RequestInfoFactory{
APIPrefixes: sets.NewString("api", "apis"),
GrouplessAPIPrefixes: sets.NewString("api"),
}
info, err := resolver.NewRequestInfo(req)
if err == nil {
ctx = genericapirequest.WithRequestInfo(ctx, info)
}
return ctx, true
}
func (*fakeRequestContextMapper) Update(req *http.Request, context genericapirequest.Context) error {
return nil
func contextHandler(handler http.Handler, user user.Info) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
if user != nil {
ctx = genericapirequest.WithUser(ctx, user)
}
resolver := &genericapirequest.RequestInfoFactory{
APIPrefixes: sets.NewString("api", "apis"),
GrouplessAPIPrefixes: sets.NewString("api"),
}
info, err := resolver.NewRequestInfo(req)
if err == nil {
ctx = genericapirequest.WithRequestInfo(ctx, info)
}
req = req.WithContext(ctx)
handler.ServeHTTP(w, req)
})
}
type mockedRouter struct {
@ -280,8 +273,7 @@ func TestProxyHandler(t *testing.T) {
serviceResolver: serviceResolver,
proxyTransport: &http.Transport{},
}
handler.contextMapper = &fakeRequestContextMapper{user: tc.user}
server := httptest.NewServer(handler)
server := httptest.NewServer(contextHandler(handler, tc.user))
defer server.Close()
if tc.apiService != nil {
@ -417,12 +409,11 @@ func TestProxyUpgrade(t *testing.T) {
serverURL, _ := url.Parse(backendServer.URL)
proxyHandler := &proxyHandler{
contextMapper: &fakeRequestContextMapper{user: &user.DefaultInfo{Name: "username"}},
serviceResolver: &mockedRouter{destinationHost: serverURL.Host},
proxyTransport: &http.Transport{},
}
proxyHandler.updateAPIService(tc.APIService)
aggregator := httptest.NewServer(proxyHandler)
aggregator := httptest.NewServer(contextHandler(proxyHandler, &user.DefaultInfo{Name: "username"}))
defer aggregator.Close()
ws, err := websocket.Dial("ws://"+aggregator.Listener.Addr().String()+path, "", "http://127.0.0.1/")

View File

@ -34,7 +34,6 @@ go_test(
deps = [
"//vendor/github.com/go-openapi/spec:go_default_library",
"//vendor/github.com/stretchr/testify/assert:go_default_library",
"//vendor/k8s.io/apiserver/pkg/endpoints/request:go_default_library",
"//vendor/k8s.io/kube-aggregator/pkg/apis/apiregistration:go_default_library",
],
)

View File

@ -25,7 +25,6 @@ import (
"github.com/go-openapi/spec"
"github.com/stretchr/testify/assert"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/kube-aggregator/pkg/apis/apiregistration"
)
@ -138,7 +137,7 @@ func assertDownloadedSpec(actualSpec *spec.Swagger, actualEtag string, err error
func TestDownloadOpenAPISpec(t *testing.T) {
s := Downloader{contextMapper: request.NewRequestContextMapper()}
s := Downloader{}
// Test with no eTag
actualSpec, actualEtag, _, err := s.Download(handlerTest{data: []byte("{\"id\": \"test\"}")}, "")

View File

@ -31,12 +31,11 @@ import (
// Downloader is the OpenAPI downloader type. It will try to download spec from /swagger.json endpoint.
type Downloader struct {
contextMapper request.RequestContextMapper
}
// NewDownloader creates a new OpenAPI Downloader.
func NewDownloader(contextMapper request.RequestContextMapper) Downloader {
return Downloader{contextMapper}
func NewDownloader() Downloader {
return Downloader{}
}
// inMemoryResponseWriter is a http.Writer that keep the response in memory.
@ -81,9 +80,7 @@ func (r *inMemoryResponseWriter) String() string {
func (s *Downloader) handlerWithUser(handler http.Handler, info user.Info) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if ctx, ok := s.contextMapper.Get(req); ok {
s.contextMapper.Update(req, request.WithUser(ctx, info))
}
req = req.WithContext(request.WithUser(req.Context(), info))
handler.ServeHTTP(w, req)
})
}
@ -96,7 +93,6 @@ func etagFor(data []byte) string {
// httpStatus is only valid if err == nil
func (s *Downloader) Download(handler http.Handler, etag string) (returnSpec *spec.Swagger, newEtag string, httpStatus int, err error) {
handler = s.handlerWithUser(handler, &user.DefaultInfo{Name: aggregatorUser})
handler = request.WithRequestContext(handler, s.contextMapper)
handler = http.TimeoutHandler(handler, specDownloadTimeout, "request timed out")
req, err := http.NewRequest("GET", "/openapi/v2", nil)