Merge pull request #2916 from lavalamp/fix3

Fix taking addr() of map value/key
pull/6/head
Clayton Coleman 2014-12-12 20:12:25 -05:00
commit 1edff23935
2 changed files with 82 additions and 10 deletions

View File

@ -300,7 +300,7 @@ func (c *Converter) Convert(src, dest interface{}, flags FieldMatchingFlags, met
if err != nil {
return err
}
if !dv.CanAddr() {
if !dv.CanAddr() && !dv.CanSet() {
return fmt.Errorf("can't write to dest")
}
sv, err := EnforcePtr(src)
@ -318,6 +318,35 @@ func (c *Converter) Convert(src, dest interface{}, flags FieldMatchingFlags, met
return c.convert(sv, dv, s)
}
// callCustom calls 'custom' with sv & dv. custom must be a conversion function.
func (c *Converter) callCustom(sv, dv, custom reflect.Value, scope *scope) error {
if !sv.CanAddr() {
sv2 := reflect.New(sv.Type())
sv2.Elem().Set(sv)
sv = sv2
} else {
sv = sv.Addr()
}
if !dv.CanAddr() {
if !dv.CanSet() {
return scope.error("can't addr or set dest.")
}
dvOrig := dv
dv := reflect.New(dvOrig.Type())
defer func() { dvOrig.Set(dv) }()
} else {
dv = dv.Addr()
}
args := []reflect.Value{sv, dv, reflect.ValueOf(scope)}
ret := custom.Call(args)[0].Interface()
// This convolution is necessary because nil interfaces won't convert
// to errors.
if ret == nil {
return nil
}
return ret.(error)
}
// convert recursively copies sv into dv, calling an appropriate conversion function if
// one is registered.
func (c *Converter) convert(sv, dv reflect.Value, scope *scope) error {
@ -326,14 +355,7 @@ func (c *Converter) convert(sv, dv reflect.Value, scope *scope) error {
if c.Debug != nil {
c.Debug.Logf("Calling custom conversion of '%v' to '%v'", st, dt)
}
args := []reflect.Value{sv.Addr(), dv.Addr(), reflect.ValueOf(scope)}
ret := fv.Call(args)[0].Interface()
// This convolution is necessary because nil interfaces won't convert
// to errors.
if ret == nil {
return nil
}
return ret.(error)
return c.callCustom(sv, dv, fv, scope)
}
if !scope.flags.IsSet(AllowDifferentFieldTypeNames) && c.NameFunc(dt) != c.NameFunc(st) {

View File

@ -19,6 +19,7 @@ package conversion
import (
"fmt"
"reflect"
"strconv"
"testing"
"github.com/google/gofuzz"
@ -134,6 +135,52 @@ func TestConverter_fuzz(t *testing.T) {
}
}
func TestConverter_MapElemAddr(t *testing.T) {
type Foo struct {
A map[int]int
}
type Bar struct {
A map[string]string
}
c := NewConverter()
c.Debug = t
err := c.Register(
func(in *int, out *string, s Scope) error {
*out = fmt.Sprintf("%v", *in)
return nil
},
)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = c.Register(
func(in *string, out *int, s Scope) error {
if str, err := strconv.Atoi(*in); err != nil {
return err
} else {
*out = str
return nil
}
},
)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
f := fuzz.New().NilChance(0).NumElements(3, 3)
first := Foo{}
second := Bar{}
f.Fuzz(&first)
err = c.Convert(&first, &second, AllowDifferentFieldTypeNames, nil)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
third := Foo{}
err = c.Convert(&second, &third, AllowDifferentFieldTypeNames, nil)
if e, a := first, third; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected diff: %v", objDiff(e, a))
}
}
func TestConverter_tags(t *testing.T) {
type Foo struct {
A string `test:"foo"`
@ -157,7 +204,10 @@ func TestConverter_tags(t *testing.T) {
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c.Convert(&Foo{}, &Bar{}, 0, nil)
err = c.Convert(&Foo{}, &Bar{}, AllowDifferentFieldTypeNames, nil)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
}
func TestConverter_meta(t *testing.T) {