diff --git a/pkg/capabilities/BUILD b/pkg/capabilities/BUILD index b94c142770..84d296ecf6 100644 --- a/pkg/capabilities/BUILD +++ b/pkg/capabilities/BUILD @@ -5,6 +5,7 @@ licenses(["notice"]) load( "@io_bazel_rules_go//go:def.bzl", "go_library", + "go_test", ) go_library( @@ -16,6 +17,13 @@ go_library( tags = ["automanaged"], ) +go_test( + name = "go_default_test", + srcs = ["capabilities_test.go"], + library = ":go_default_library", + tags = ["automanaged"], +) + filegroup( name = "package-srcs", srcs = glob(["**"]), diff --git a/pkg/capabilities/capabilities.go b/pkg/capabilities/capabilities.go index be721a7855..0da7b9c8b5 100644 --- a/pkg/capabilities/capabilities.go +++ b/pkg/capabilities/capabilities.go @@ -46,16 +46,17 @@ type PrivilegedSources struct { HostIPCSources []string } -// TODO: Clean these up into a singleton -var once sync.Once -var lock sync.Mutex -var capabilities *Capabilities +var capInstance struct { + once sync.Once + lock sync.Mutex + capabilities *Capabilities +} // Initialize the capability set. This can only be done once per binary, subsequent calls are ignored. func Initialize(c Capabilities) { // Only do this once - once.Do(func() { - capabilities = &c + capInstance.once.Do(func() { + capInstance.capabilities = &c }) } @@ -70,17 +71,17 @@ func Setup(allowPrivileged bool, privilegedSources PrivilegedSources, perConnect // SetForTests sets capabilities for tests. Convenience method for testing. This should only be called from tests. func SetForTests(c Capabilities) { - lock.Lock() - defer lock.Unlock() - capabilities = &c + capInstance.lock.Lock() + defer capInstance.lock.Unlock() + capInstance.capabilities = &c } // Returns a read-only copy of the system capabilities. func Get() Capabilities { - lock.Lock() - defer lock.Unlock() + capInstance.lock.Lock() + defer capInstance.lock.Unlock() // This check prevents clobbering of capabilities that might've been set via SetForTests - if capabilities == nil { + if capInstance.capabilities == nil { Initialize(Capabilities{ AllowPrivileged: false, PrivilegedSources: PrivilegedSources{ @@ -90,5 +91,5 @@ func Get() Capabilities { }, }) } - return *capabilities + return *capInstance.capabilities } diff --git a/pkg/capabilities/capabilities_test.go b/pkg/capabilities/capabilities_test.go new file mode 100644 index 0000000000..ea4434d061 --- /dev/null +++ b/pkg/capabilities/capabilities_test.go @@ -0,0 +1,50 @@ +/* +Copyright 2014 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package capabilities + +import ( + "reflect" + "testing" +) + +func TestGet(t *testing.T) { + defaultCap := Capabilities{ + AllowPrivileged: false, + PrivilegedSources: PrivilegedSources{ + HostNetworkSources: []string{}, + HostPIDSources: []string{}, + HostIPCSources: []string{}, + }, + } + + res := Get() + if !reflect.DeepEqual(defaultCap, res) { + t.Fatalf("expected Capabilities: %#v, got a non-default: %#v", defaultCap, res) + } + + cap := Capabilities{ + PrivilegedSources: PrivilegedSources{ + HostNetworkSources: []string{"A", "B"}, + }, + } + SetForTests(cap) + + res = Get() + if !reflect.DeepEqual(cap, res) { + t.Fatalf("expected Capabilities: %#v , got a different: %#v", cap, res) + } +}