Make unexported fields panic (informatively)

...Also fix some incorrect calls to semantic.DeepEqual, and a bug where
it returned true incorrectly.
pull/6/head
Daniel Smith 2015-03-05 16:52:43 -08:00
parent 2902028476
commit 3ef3777192
4 changed files with 44 additions and 10 deletions

View File

@ -60,6 +60,9 @@ var Semantic = conversion.EqualitiesOrDie(
}
return a.Amount.Cmp(b.Amount) == 0
},
func(a, b util.Time) bool {
return a.Unix() == b.Unix()
},
)
var standardResources = util.NewStringSet(

View File

@ -27,7 +27,7 @@ import (
"net/http/httptest"
"net/url"
"os"
// "reflect"
"reflect"
"strings"
"testing"
"time"
@ -64,7 +64,7 @@ func TestRequestWithErrorWontChange(t *testing.T) {
if changed != &r {
t.Errorf("returned request should point to the same object")
}
if !api.Semantic.DeepDerivative(changed, &original) {
if !reflect.DeepEqual(changed, &original) {
t.Errorf("expected %#v, got %#v", &original, changed)
}
}

View File

@ -19,6 +19,7 @@ package conversion
import (
"fmt"
"reflect"
"strings"
)
// Equalities is a map from type to a function comparing two values of
@ -99,10 +100,36 @@ type visit struct {
typ reflect.Type
}
// unexportedTypePanic is thrown when you use this DeepEqual on something that has an
// unexported type. It indicates a programmer error, so should not occur at runtime,
// which is why it's not public and thus impossible to catch.
type unexportedTypePanic []reflect.Type
func (u unexportedTypePanic) Error() string { return u.String() }
func (u unexportedTypePanic) String() string {
strs := make([]string, len(u))
for i, t := range u {
strs[i] = fmt.Sprintf("%v", t)
}
return "an unexported field was encountered, nested like this: " + strings.Join(strs, " -> ")
}
func makeUsefulPanic(v reflect.Value) {
if x := recover(); x != nil {
if u, ok := x.(unexportedTypePanic); ok {
u = append(unexportedTypePanic{v.Type()}, u...)
x = u
}
panic(x)
}
}
// Tests for deep equality using reflected types. The map argument tracks
// comparisons that have already been seen, which allows short circuiting on
// recursive types.
func (e Equalities) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
defer makeUsefulPanic(v1)
if !v1.IsValid() || !v2.IsValid() {
return v1.IsValid() == v2.IsValid()
}
@ -207,10 +234,10 @@ func (e Equalities) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool,
return false
default:
// Normal equality suffices
if v1.CanInterface() && v2.CanInterface() {
return v1.Interface() == v2.Interface()
if !v1.CanInterface() || !v2.CanInterface() {
panic(unexportedTypePanic{})
}
return v1.CanInterface() == v2.CanInterface()
return v1.Interface() == v2.Interface()
}
}
@ -221,7 +248,8 @@ func (e Equalities) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool,
//
// An empty slice *is* equal to a nil slice for our purposes; same for maps.
//
// Unexported field members are not compared.
// Unexported field members cannot be compared and will cause an imformative panic; you must add an Equality
// function for these types.
func (e Equalities) DeepEqual(a1, a2 interface{}) bool {
if a1 == nil || a2 == nil {
return a1 == a2
@ -235,6 +263,8 @@ func (e Equalities) DeepEqual(a1, a2 interface{}) bool {
}
func (e Equalities) deepValueDerive(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
defer makeUsefulPanic(v1)
if !v1.IsValid() || !v2.IsValid() {
return v1.IsValid() == v2.IsValid()
}
@ -347,10 +377,10 @@ func (e Equalities) deepValueDerive(v1, v2 reflect.Value, visited map[visit]bool
return false
default:
// Normal equality suffices
if v1.CanInterface() && v2.CanInterface() {
return v1.Interface() == v2.Interface()
if !v1.CanInterface() || !v2.CanInterface() {
panic(unexportedTypePanic{})
}
return v1.CanInterface() == v2.CanInterface()
return v1.Interface() == v2.Interface()
}
}

View File

@ -20,6 +20,7 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"reflect"
"strings"
"testing"
@ -378,7 +379,7 @@ func TestFillCurrentState(t *testing.T) {
if controller.Status.Replicas != 2 {
t.Errorf("expected 2, got: %d", controller.Status.Replicas)
}
if !api.Semantic.DeepEqual(fakeLister.s, labels.Set(controller.Spec.Selector).AsSelector()) {
if !reflect.DeepEqual(fakeLister.s, labels.Set(controller.Spec.Selector).AsSelector()) {
t.Errorf("unexpected output: %#v %#v", labels.Set(controller.Spec.Selector).AsSelector(), fakeLister.s)
}
}