k3s/pkg/conversion/deep_equal.go

236 lines
6.0 KiB
Go

/*
Copyright 2014 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package conversion
import (
"fmt"
"reflect"
)
// Equalities is a map from type to a function comparing two values of
// that type.
type Equalities map[reflect.Type]reflect.Value
// For convenience, panics on errrors
func EqualitiesOrDie(funcs ...interface{}) Equalities {
e := Equalities{}
if err := e.AddFuncs(funcs...); err != nil {
panic(err)
}
return e
}
// AddFuncs is a shortcut for multiple calls to AddFunc.
func (e Equalities) AddFuncs(funcs ...interface{}) error {
for _, f := range funcs {
if err := e.AddFunc(f); err != nil {
return err
}
}
return nil
}
// AddFunc uses func as an equality function: it must take
// two parameters of the same type, and return a boolean.
func (e Equalities) AddFunc(eqFunc interface{}) error {
fv := reflect.ValueOf(eqFunc)
ft := fv.Type()
if ft.Kind() != reflect.Func {
return fmt.Errorf("expected func, got: %v", ft)
}
if ft.NumIn() != 2 {
return fmt.Errorf("expected three 'in' params, got: %v", ft)
}
if ft.NumOut() != 1 {
return fmt.Errorf("expected one 'out' param, got: %v", ft)
}
if ft.In(0) != ft.In(1) {
return fmt.Errorf("expected arg 1 and 2 to have same type, but got %v", ft)
}
var forReturnType bool
boolType := reflect.TypeOf(forReturnType)
if ft.Out(0) != boolType {
return fmt.Errorf("expected bool return, got: %v", ft)
}
e[ft.In(0)] = fv
return nil
}
// Equal return true if a matching equality function thinks a == b;
// if there is no matching equality function, it calls the standard
// go == operator.
func (e Equalities) Equal(a, b interface{}) bool {
av, bv := reflect.ValueOf(a), reflect.ValueOf(a)
if !av.IsValid() || !av.IsValid() {
return av.IsValid() == av.IsValid()
}
if av.Type() != av.Type() {
return false
}
if fv, ok := e[av.Type()]; ok {
return fv.Call([]reflect.Value{av, bv})[0].Bool()
}
return av.Interface() == bv.Interface()
}
// Below here is forked from go's reflect/deepequal.go
// During deepValueEqual, must keep track of checks that are
// in progress. The comparison algorithm assumes that all
// checks in progress are true when it reencounters them.
// Visited comparisons are stored in a map indexed by visit.
type visit struct {
a1 uintptr
a2 uintptr
typ reflect.Type
}
// Tests for deep equality using reflected types. The map argument tracks
// comparisons that have already been seen, which allows short circuiting on
// recursive types.
func (e Equalities) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
if !v1.IsValid() || !v2.IsValid() {
return v1.IsValid() == v2.IsValid()
}
if v1.Type() != v2.Type() {
return false
}
if fv, ok := e[v1.Type()]; ok {
return fv.Call([]reflect.Value{v1, v2})[0].Bool()
}
hard := func(k reflect.Kind) bool {
switch k {
case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
return true
}
return false
}
if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
addr1 := v1.UnsafeAddr()
addr2 := v2.UnsafeAddr()
if addr1 > addr2 {
// Canonicalize order to reduce number of entries in visited.
addr1, addr2 = addr2, addr1
}
// Short circuit if references are identical ...
if addr1 == addr2 {
return true
}
// ... or already seen
typ := v1.Type()
v := visit{addr1, addr2, typ}
if visited[v] {
return true
}
// Remember for later.
visited[v] = true
}
switch v1.Kind() {
case reflect.Array:
for i := 0; i < v1.Len(); i++ {
if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
return false
}
}
return true
case reflect.Slice:
if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) {
return false
}
if v1.IsNil() || v1.Len() == 0 {
return true
}
if v1.Pointer() == v2.Pointer() {
return true
}
for i := 0; i < v1.Len(); i++ {
if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
return false
}
}
return true
case reflect.Interface:
if v1.IsNil() || v2.IsNil() {
return v1.IsNil() == v2.IsNil()
}
return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1)
case reflect.Ptr:
return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1)
case reflect.Struct:
for i, n := 0, v1.NumField(); i < n; i++ {
if !e.deepValueEqual(v1.Field(i), v2.Field(i), visited, depth+1) {
return false
}
}
return true
case reflect.Map:
if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) {
return false
}
if v1.IsNil() || v1.Len() == 0 {
return true
}
if v1.Pointer() == v2.Pointer() {
return true
}
for _, k := range v1.MapKeys() {
if !e.deepValueEqual(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) {
return false
}
}
return true
case reflect.Func:
if v1.IsNil() && v2.IsNil() {
return true
}
// Can't do better than this:
return false
default:
// Normal equality suffices
if v1.CanInterface() && v2.CanInterface() {
return v1.Interface() == v2.Interface()
}
return v1.CanInterface() == v2.CanInterface()
}
}
// DeepEqual is like reflect.DeepEqual, but focused on semantic equality
// instead of memory equality.
//
// It will use e's equality functions if it finds types that match.
//
// An empty slice *is* equal to a nil slice for our purposes; same for maps.
//
// Unexported field members are not compared.
func (e Equalities) DeepEqual(a1, a2 interface{}) bool {
if a1 == nil || a2 == nil {
return a1 == a2
}
v1 := reflect.ValueOf(a1)
v2 := reflect.ValueOf(a2)
if v1.Type() != v2.Type() {
return false
}
return e.deepValueEqual(v1, v2, make(map[visit]bool), 0)
}