Tests pass now.

pull/6/head
Daniel Smith 2014-06-12 16:46:07 -07:00
parent 601f6bb4ad
commit 2abfd95d6b
2 changed files with 26 additions and 12 deletions

View File

@ -24,18 +24,32 @@ func ToWireFormat(data []byte, storage string) ([]byte, error) {
return nil, fmt.Errorf("unknown storage type: %v", storage)
}
// Try parsing as json and yaml
parsed_json := reflect.New(prototypeType).Interface()
json_err := json.Unmarshal(data, parsed_json)
parsed_yaml := reflect.New(prototypeType).Interface()
yaml_err := yaml.Unmarshal(data, parsed_yaml)
// Try parsing json
json_out, json_err := tryJSON(data, reflect.New(prototypeType).Interface())
if json_err == nil {
return json_out, json_err
}
if json_err != nil && yaml_err != nil {
// Try parsing yaml
yaml_out, yaml_err := tryYAML(data, reflect.New(prototypeType).Interface())
if yaml_err != nil {
return nil, fmt.Errorf("Could not parse input as json (error: %v) or yaml (error: %v", json_err, yaml_err)
}
if json_err != nil {
return json.Marshal(parsed_json)
}
return json.Marshal(parsed_yaml)
return yaml_out, yaml_err
}
func tryJSON(data []byte, obj interface{}) ([]byte, error) {
err := json.Unmarshal(data, obj)
if err != nil {
return nil, err
}
return json.Marshal(obj)
}
func tryYAML(data []byte, obj interface{}) ([]byte, error) {
err := yaml.Unmarshal(data, obj)
if err != nil {
return nil, err
}
return json.Marshal(obj)
}

View File

@ -18,6 +18,7 @@ func TestParseBadStorage(t *testing.T) {
func DoParseTest(t *testing.T, storage string, obj interface{}) {
json_data, _ := json.Marshal(obj)
yaml_data, _ := yaml.Marshal(obj)
t.Logf("Intermediate yaml:\n%v\n", string(yaml_data))
json_got, json_err := ToWireFormat(json_data, storage)
yaml_got, yaml_err := ToWireFormat(yaml_data, storage)
@ -28,7 +29,6 @@ func DoParseTest(t *testing.T, storage string, obj interface{}) {
if yaml_err != nil {
t.Errorf("yaml err: %#v", yaml_err)
}
t.Logf("Intermediate yaml:\n%v\n", string(yaml_data))
if string(json_got) != string(json_data) {
t.Errorf("json output didn't match:\nGot:\n%v\n\nWanted:\n%v\n",
string(json_got), string(json_data))