From f3148e996513fa9634393f5b1ee68b0ad9202491 Mon Sep 17 00:00:00 2001 From: Daniel Smith Date: Fri, 12 Dec 2014 15:26:17 -0800 Subject: [PATCH] Fix taking addr() of map value/key --- pkg/conversion/converter.go | 40 ++++++++++++++++++------ pkg/conversion/converter_test.go | 52 +++++++++++++++++++++++++++++++- 2 files changed, 82 insertions(+), 10 deletions(-) diff --git a/pkg/conversion/converter.go b/pkg/conversion/converter.go index 519ab179ed..9b2f6a2a1f 100644 --- a/pkg/conversion/converter.go +++ b/pkg/conversion/converter.go @@ -300,7 +300,7 @@ func (c *Converter) Convert(src, dest interface{}, flags FieldMatchingFlags, met if err != nil { return err } - if !dv.CanAddr() { + if !dv.CanAddr() && !dv.CanSet() { return fmt.Errorf("can't write to dest") } sv, err := EnforcePtr(src) @@ -318,6 +318,35 @@ func (c *Converter) Convert(src, dest interface{}, flags FieldMatchingFlags, met return c.convert(sv, dv, s) } +// callCustom calls 'custom' with sv & dv. custom must be a conversion function. +func (c *Converter) callCustom(sv, dv, custom reflect.Value, scope *scope) error { + if !sv.CanAddr() { + sv2 := reflect.New(sv.Type()) + sv2.Elem().Set(sv) + sv = sv2 + } else { + sv = sv.Addr() + } + if !dv.CanAddr() { + if !dv.CanSet() { + return scope.error("can't addr or set dest.") + } + dvOrig := dv + dv := reflect.New(dvOrig.Type()) + defer func() { dvOrig.Set(dv) }() + } else { + dv = dv.Addr() + } + args := []reflect.Value{sv, dv, reflect.ValueOf(scope)} + ret := custom.Call(args)[0].Interface() + // This convolution is necessary because nil interfaces won't convert + // to errors. + if ret == nil { + return nil + } + return ret.(error) +} + // convert recursively copies sv into dv, calling an appropriate conversion function if // one is registered. func (c *Converter) convert(sv, dv reflect.Value, scope *scope) error { @@ -326,14 +355,7 @@ func (c *Converter) convert(sv, dv reflect.Value, scope *scope) error { if c.Debug != nil { c.Debug.Logf("Calling custom conversion of '%v' to '%v'", st, dt) } - args := []reflect.Value{sv.Addr(), dv.Addr(), reflect.ValueOf(scope)} - ret := fv.Call(args)[0].Interface() - // This convolution is necessary because nil interfaces won't convert - // to errors. - if ret == nil { - return nil - } - return ret.(error) + return c.callCustom(sv, dv, fv, scope) } if !scope.flags.IsSet(AllowDifferentFieldTypeNames) && c.NameFunc(dt) != c.NameFunc(st) { diff --git a/pkg/conversion/converter_test.go b/pkg/conversion/converter_test.go index 8fd45a66c3..a05a15dfa5 100644 --- a/pkg/conversion/converter_test.go +++ b/pkg/conversion/converter_test.go @@ -19,6 +19,7 @@ package conversion import ( "fmt" "reflect" + "strconv" "testing" "github.com/google/gofuzz" @@ -134,6 +135,52 @@ func TestConverter_fuzz(t *testing.T) { } } +func TestConverter_MapElemAddr(t *testing.T) { + type Foo struct { + A map[int]int + } + type Bar struct { + A map[string]string + } + c := NewConverter() + c.Debug = t + err := c.Register( + func(in *int, out *string, s Scope) error { + *out = fmt.Sprintf("%v", *in) + return nil + }, + ) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + err = c.Register( + func(in *string, out *int, s Scope) error { + if str, err := strconv.Atoi(*in); err != nil { + return err + } else { + *out = str + return nil + } + }, + ) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + f := fuzz.New().NilChance(0).NumElements(3, 3) + first := Foo{} + second := Bar{} + f.Fuzz(&first) + err = c.Convert(&first, &second, AllowDifferentFieldTypeNames, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + third := Foo{} + err = c.Convert(&second, &third, AllowDifferentFieldTypeNames, nil) + if e, a := first, third; !reflect.DeepEqual(e, a) { + t.Errorf("Unexpected diff: %v", objDiff(e, a)) + } +} + func TestConverter_tags(t *testing.T) { type Foo struct { A string `test:"foo"` @@ -157,7 +204,10 @@ func TestConverter_tags(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c.Convert(&Foo{}, &Bar{}, 0, nil) + err = c.Convert(&Foo{}, &Bar{}, AllowDifferentFieldTypeNames, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } } func TestConverter_meta(t *testing.T) {