Merge pull request #5949 from thockin/deepequal_len

Fix deep equal to not panic on bad slice lengths
pull/6/head
Daniel Smith 2015-03-26 15:28:47 -07:00
commit 557da59302
2 changed files with 101 additions and 7 deletions

View File

@ -174,6 +174,8 @@ func (e Equalities) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool,
switch v1.Kind() { switch v1.Kind() {
case reflect.Array: case reflect.Array:
// We don't need to check length here because length is part of
// an array's type, which has already been filtered for.
for i := 0; i < v1.Len(); i++ { for i := 0; i < v1.Len(); i++ {
if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) { if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
return false return false
@ -187,6 +189,9 @@ func (e Equalities) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool,
if v1.IsNil() || v1.Len() == 0 { if v1.IsNil() || v1.Len() == 0 {
return true return true
} }
if v1.Len() != v2.Len() {
return false
}
if v1.Pointer() == v2.Pointer() { if v1.Pointer() == v2.Pointer() {
return true return true
} }
@ -217,6 +222,9 @@ func (e Equalities) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool,
if v1.IsNil() || v1.Len() == 0 { if v1.IsNil() || v1.Len() == 0 {
return true return true
} }
if v1.Len() != v2.Len() {
return false
}
if v1.Pointer() == v2.Pointer() { if v1.Pointer() == v2.Pointer() {
return true return true
} }
@ -309,6 +317,8 @@ func (e Equalities) deepValueDerive(v1, v2 reflect.Value, visited map[visit]bool
switch v1.Kind() { switch v1.Kind() {
case reflect.Array: case reflect.Array:
// We don't need to check length here because length is part of
// an array's type, which has already been filtered for.
for i := 0; i < v1.Len(); i++ { for i := 0; i < v1.Len(); i++ {
if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) { if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) {
return false return false
@ -316,12 +326,12 @@ func (e Equalities) deepValueDerive(v1, v2 reflect.Value, visited map[visit]bool
} }
return true return true
case reflect.Slice: case reflect.Slice:
if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) {
return false
}
if v1.IsNil() || v1.Len() == 0 { if v1.IsNil() || v1.Len() == 0 {
return true return true
} }
if v1.Len() > v2.Len() {
return false
}
if v1.Pointer() == v2.Pointer() { if v1.Pointer() == v2.Pointer() {
return true return true
} }
@ -335,6 +345,9 @@ func (e Equalities) deepValueDerive(v1, v2 reflect.Value, visited map[visit]bool
if v1.Len() == 0 { if v1.Len() == 0 {
return true return true
} }
if v1.Len() > v2.Len() {
return false
}
return v1.String() == v2.String() return v1.String() == v2.String()
case reflect.Interface: case reflect.Interface:
if v1.IsNil() { if v1.IsNil() {
@ -354,12 +367,12 @@ func (e Equalities) deepValueDerive(v1, v2 reflect.Value, visited map[visit]bool
} }
return true return true
case reflect.Map: case reflect.Map:
if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) {
return false
}
if v1.IsNil() || v1.Len() == 0 { if v1.IsNil() || v1.Len() == 0 {
return true return true
} }
if v1.Len() > v2.Len() {
return false
}
if v1.Pointer() == v2.Pointer() { if v1.Pointer() == v2.Pointer() {
return true return true
} }

View File

@ -50,14 +50,31 @@ func TestEqualities(t *testing.T) {
}{ }{
{1, 2, true}, {1, 2, true},
{2, 1, false}, {2, 1, false},
{"foo", "fo", false},
{"foo", "foo", true}, {"foo", "foo", true},
{"foo", "foobar", false},
{Foo{1}, Foo{2}, true}, {Foo{1}, Foo{2}, true},
{Foo{2}, Foo{1}, false},
{Bar{1}, Bar{10}, true}, {Bar{1}, Bar{10}, true},
{&Bar{1}, &Bar{10}, true}, {&Bar{1}, &Bar{10}, true},
{Baz{Bar{1}}, Baz{Bar{10}}, true}, {Baz{Bar{1}}, Baz{Bar{10}}, true},
{[...]string{}, [...]string{"1", "2", "3"}, false},
{[...]string{"1"}, [...]string{"1", "2", "3"}, false},
{[...]string{"1", "2", "3"}, [...]string{}, false},
{[...]string{"1", "2", "3"}, [...]string{"1", "2", "3"}, true},
{map[string]int{"foo": 1}, map[string]int{}, false},
{map[string]int{"foo": 1}, map[string]int{"foo": 2}, true}, {map[string]int{"foo": 1}, map[string]int{"foo": 2}, true},
{map[string]int{"foo": 2}, map[string]int{"foo": 1}, false},
{map[string]int{"foo": 1}, map[string]int{"foo": 2, "bar": 6}, false},
{map[string]int{"foo": 1, "bar": 6}, map[string]int{"foo": 2}, false},
{map[string]int{}, map[string]int(nil), true}, {map[string]int{}, map[string]int(nil), true},
{[]int{}, []int(nil), true}, {[]string(nil), []string(nil), true},
{[]string{}, []string(nil), true},
{[]string(nil), []string{}, true},
{[]string{"1"}, []string(nil), false},
{[]string{}, []string{"1", "2", "3"}, false},
{[]string{"1"}, []string{"1", "2", "3"}, false},
{[]string{"1", "2", "3"}, []string{}, false},
} }
for _, item := range table { for _, item := range table {
@ -66,3 +83,67 @@ func TestEqualities(t *testing.T) {
} }
} }
} }
func TestDerivates(t *testing.T) {
e := Equalities{}
type Bar struct {
X int
}
type Baz struct {
Y Bar
}
err := e.AddFuncs(
func(a, b int) bool {
return a+1 == b
},
func(a, b Bar) bool {
return a.X*10 == b.X
},
)
if err != nil {
t.Fatalf("Unexpected: %v", err)
}
type Foo struct {
X int
}
table := []struct {
a, b interface{}
equal bool
}{
{1, 2, true},
{2, 1, false},
{"foo", "fo", false},
{"foo", "foo", true},
{"foo", "foobar", false},
{Foo{1}, Foo{2}, true},
{Foo{2}, Foo{1}, false},
{Bar{1}, Bar{10}, true},
{&Bar{1}, &Bar{10}, true},
{Baz{Bar{1}}, Baz{Bar{10}}, true},
{[...]string{}, [...]string{"1", "2", "3"}, false},
{[...]string{"1"}, [...]string{"1", "2", "3"}, false},
{[...]string{"1", "2", "3"}, [...]string{}, false},
{[...]string{"1", "2", "3"}, [...]string{"1", "2", "3"}, true},
{map[string]int{"foo": 1}, map[string]int{}, false},
{map[string]int{"foo": 1}, map[string]int{"foo": 2}, true},
{map[string]int{"foo": 2}, map[string]int{"foo": 1}, false},
{map[string]int{"foo": 1}, map[string]int{"foo": 2, "bar": 6}, true},
{map[string]int{"foo": 1, "bar": 6}, map[string]int{"foo": 2}, false},
{map[string]int{}, map[string]int(nil), true},
{[]string(nil), []string(nil), true},
{[]string{}, []string(nil), true},
{[]string(nil), []string{}, true},
{[]string{"1"}, []string(nil), false},
{[]string{}, []string{"1", "2", "3"}, true},
{[]string{"1"}, []string{"1", "2", "3"}, true},
{[]string{"1", "2", "3"}, []string{}, false},
}
for _, item := range table {
if e, a := item.equal, e.DeepDerivative(item.a, item.b); e != a {
t.Errorf("Expected (%+v ~ %+v) == %v, but got %v", item.a, item.b, e, a)
}
}
}