package propertyoverride import ( "fmt" "strings" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/known/wrapperspb" ) // PatchStruct patches the given ProtoMessage according to the documented behavior of Patch and returns the result. // If an error is returned, the returned K (which may be a partially modified copy) should be ignored. func PatchStruct[K proto.Message](k K, patch Patch, debug bool) (result K, e error) { // Don't panic due to misconfiguration. defer func() { if err := recover(); err != nil { e = fmt.Errorf("unexpected panic: %v", err) } }() protoM := k.ProtoReflect() if patch.Path == "" || patch.Path == "/" { // Return error with all possible fields at root of target. return k, fmt.Errorf("non-empty, non-root Path is required;\n%s", fieldListStr(protoM.Descriptor(), debug)) } parsedPath := parsePath(patch.Path) targetM, fieldDesc, err := findTargetMessageAndField(protoM, parsedPath, patch, debug) if err != nil { return k, err } if err := patchField(targetM, fieldDesc, patch, debug); err != nil { return k, err } // ProtoReflect returns a reflective _view_ of the underlying K, so // we don't need to convert back explictly when we return k here. return k, nil } // parsePath returns the path tokens from a JSON Pointer (https://datatracker.ietf.org/doc/html/rfc6901/) string as // expected by JSON Patch (e.g. "/foo/bar/baz"). func parsePath(path string) []string { return strings.Split(strings.TrimLeft(path, "/"), "/") } // findTargetMessageAndField takes a root-level protoreflect.Message and slice of path elements and returns the // protoreflect.FieldDescriptor that corresponds to the full path, along with the parent protoreflect.Message of // that field. parsedPath must be non-empty. // If any field in parsedPath cannot be matched while traversing the tree of messages, an error is returned. func findTargetMessageAndField(m protoreflect.Message, parsedPath []string, patch Patch, debug bool) (protoreflect.Message, protoreflect.FieldDescriptor, error) { if len(parsedPath) == 0 { return nil, nil, fmt.Errorf("unexpected error: non-empty path is required") } // Iterate until we've matched the (potentially nested) target field. for { fieldName := parsedPath[0] if fieldName == "" { return nil, nil, fmt.Errorf("empty field name in path") } fieldDesc, err := childFieldDescriptor(m.Descriptor(), fieldName, debug) if err != nil { return nil, nil, err } parsedPath = parsedPath[1:] if len(parsedPath) == 0 { // We've reached the end of the path, return current message and target field. return m, fieldDesc, nil } // Check whether we have a non-terminal (parent) field in the path for which we // don't support child lookup. switch { case fieldDesc.IsList(): return nil, nil, fmt.Errorf("path contains member of repeated field '%s'; repeated field member access is not supported", fieldName) case fieldDesc.IsMap(): return nil, nil, fmt.Errorf("path contains member of map field '%s'; map field member access is not supported", fieldName) } fieldM := m.Get(fieldDesc).Message() if !fieldM.IsValid() && patch.Op == OpAdd { // Init this message field to a valid, empty value so that we can keep walking // the path and set inner fields. Only do this for add, as remove should not // initialize fields on its own. m.Set(fieldDesc, protoreflect.ValueOfMessage(fieldM.New())) fieldM = m.Get(fieldDesc).Message() } // Advance our parent message "pointer" to the next field. m = fieldM } } // patchField applies the given patch op to the target field on given parent message. func patchField(parentM protoreflect.Message, fieldDesc protoreflect.FieldDescriptor, patch Patch, debug bool) error { switch patch.Op { case OpAdd: return applyAdd(parentM, fieldDesc, patch, debug) case OpRemove: // Ignore Value if provided, per JSON Patch: "Members that are not explicitly defined for the // operation in question MUST be ignored (i.e., the operation will complete as if the undefined // member did not appear in the object)." return applyRemove(parentM, fieldDesc) } return fmt.Errorf("unexpected error: no op implementation found") } // removeField clears the target field on the given message. func applyRemove(parentM protoreflect.Message, fieldDesc protoreflect.FieldDescriptor) error { // Check whether the parent has this field, as clearing a field on an unset parent may panic. if parentM.Has(fieldDesc) { parentM.Clear(fieldDesc) } return nil } // applyAdd updates the target field(s) on the given message based on the content of patch. // If the patch value is a scalar, scalar wrapper, or scalar array, we set the target field directly with that value. // If the patch value is a map, we set the indicated child fields on a new (empty) message matching the target field. // Regardless, the target field is replaced entirely. This conforms to the PUT-style semantics of the JSON Patch "add" // operation for objects (https://www.rfc-editor.org/rfc/rfc6902#section-4.1). func applyAdd(parentM protoreflect.Message, fieldDesc protoreflect.FieldDescriptor, patch Patch, debug bool) error { if patch.Value == nil { return fmt.Errorf("non-nil Value is required; use an empty map to reset all fields on a message or the 'remove' op to unset fields") } mapValue, isMapValue := patch.Value.(map[string]interface{}) // If the field is a proto map type, we'll treat it as a "single" field for error handling purposes. // If we support proto map targets in the future, it will still likely be treated as a single field, // similar to a list (repeated field). This map handling is specific to _our_ patch semantics for // updating multiple message fields at once. if isMapValue && !fieldDesc.IsMap() { // Get a fresh copy of the target field's message, then set the children indicated by the patch. fieldM := parentM.Get(fieldDesc).Message().New() for k, v := range mapValue { targetFieldDesc, err := childFieldDescriptor(fieldDesc.Message(), k, debug) if err != nil { return err } val, err := toProtoValue(fieldM, targetFieldDesc, v) if err != nil { return err } fieldM.Set(targetFieldDesc, val) } parentM.Set(fieldDesc, protoreflect.ValueOf(fieldM)) } else { // Just set the field directly, as our patch value is not a map. val, err := toProtoValue(parentM, fieldDesc, patch.Value) if err != nil { return err } parentM.Set(fieldDesc, val) } return nil } func childFieldDescriptor(parentDesc protoreflect.MessageDescriptor, fieldName string, debug bool) (protoreflect.FieldDescriptor, error) { if childFieldDesc := parentDesc.Fields().ByName(protoreflect.Name(fieldName)); childFieldDesc != nil { return childFieldDesc, nil } return nil, fmt.Errorf("no match for field '%s'!\n%s", fieldName, fieldListStr(parentDesc, debug)) } // fieldListStr prints all possible fields (debug) or the first 10 fields of a given MessageDescriptor. func fieldListStr(messageDesc protoreflect.MessageDescriptor, debug bool) string { // Future: it might be nice to use something like https://github.com/agnivade/levenshtein // or https://github.com/schollz/closestmatch to inform these choices. fields := messageDesc.Fields() // If we're not in debug mode and there's > 10 possible fields, only print the first 10. printCount := fields.Len() truncateFields := false if !debug && printCount > 10 { truncateFields = true printCount = 10 } msg := strings.Builder{} msg.WriteString("available ") msg.WriteString(string(messageDesc.FullName())) msg.WriteString(" fields:\n") for i := 0; i < printCount; i++ { msg.WriteString(fields.Get(i).TextName()) msg.WriteString("\n") } if truncateFields { msg.WriteString("First 10 fields for this message included, configure with `Debug = true` to print all.") } return msg.String() } func toProtoValue(parentM protoreflect.Message, fieldDesc protoreflect.FieldDescriptor, patchValue interface{}) (v protoreflect.Value, e error) { // Repeated fields. Check for these first, so we can use Kind below for single value // fields (repeated fields have a Kind corresponding to their members). // We have to do special handling for int types and strings since they could be enums. if fieldDesc.IsList() { list := parentM.NewField(fieldDesc).List() switch val := patchValue.(type) { case []int: return toProtoIntOrEnumList(val, list, fieldDesc) case []int32: return toProtoIntOrEnumList(val, list, fieldDesc) case []int64: return toProtoIntOrEnumList(val, list, fieldDesc) case []uint: return toProtoIntOrEnumList(val, list, fieldDesc) case []uint32: return toProtoIntOrEnumList(val, list, fieldDesc) case []uint64: return toProtoIntOrEnumList(val, list, fieldDesc) case []float32: return toProtoNumericList(val, list, fieldDesc) case []float64: return toProtoNumericList(val, list, fieldDesc) case []bool: return toProtoList(val, list) case []string: if fieldDesc.Kind() == protoreflect.EnumKind { return toProtoEnumList(val, list, fieldDesc) } return toProtoList(val, list) default: if fieldDesc.Kind() == protoreflect.MessageKind || fieldDesc.Kind() == protoreflect.GroupKind || fieldDesc.Kind() == protoreflect.BytesKind { return unsupportedTargetTypeErr(fieldDesc) } return typeMismatchErr(fieldDesc, val) } } switch fieldDesc.Kind() { case protoreflect.MessageKind: // google.protobuf wrapper types are used for detecting presence of scalars. If the // target field is a message, and the patch value is not a map, we assume the user // is targeting a wrapper type or has misconfigured the path. return toProtoWrapperValue(fieldDesc, patchValue) case protoreflect.EnumKind: return toProtoEnumValue(fieldDesc, patchValue) case protoreflect.Int32Kind, protoreflect.Int64Kind, protoreflect.Sint32Kind, protoreflect.Sint64Kind, protoreflect.Uint32Kind, protoreflect.Uint64Kind, protoreflect.Fixed32Kind, protoreflect.Fixed64Kind, protoreflect.Sfixed32Kind, protoreflect.Sfixed64Kind, protoreflect.FloatKind, protoreflect.DoubleKind: // We have to be careful specifically obtaining the correct proto field type here, // since conversion checking by protoreflect is stringent and will not accept e.g. // int->uint32 mismatches, or float->int downcasts. switch val := patchValue.(type) { case int: return toProtoNumericValue(fieldDesc, val) case int32: return toProtoNumericValue(fieldDesc, val) case int64: return toProtoNumericValue(fieldDesc, val) case uint: return toProtoNumericValue(fieldDesc, val) case uint32: return toProtoNumericValue(fieldDesc, val) case uint64: return toProtoNumericValue(fieldDesc, val) case float32: return toProtoNumericValue(fieldDesc, val) case float64: return toProtoNumericValue(fieldDesc, val) } } // Fall back to protoreflect.ValueOf, which may panic if an unexpected type is passed. defer func() { if err := recover(); err != nil { _, e = typeMismatchErr(fieldDesc, patchValue) } }() return protoreflect.ValueOf(patchValue), nil } func toProtoList[V float32 | float64 | bool | string](vs []V, l protoreflect.List) (protoreflect.Value, error) { for _, v := range vs { l.Append(protoreflect.ValueOf(v)) } return protoreflect.ValueOfList(l), nil } // toProtoIntOrEnumList takes a slice of some integer type V and returns a protoreflect.Value List of either the // corresponding proto integer type or enum values, depending on the type of the target field. func toProtoIntOrEnumList[V int | int32 | int64 | uint | uint32 | uint64](vs []V, l protoreflect.List, fieldDesc protoreflect.FieldDescriptor) (protoreflect.Value, error) { if fieldDesc.Kind() == protoreflect.EnumKind { return toProtoEnumList(vs, l, fieldDesc) } return toProtoNumericList(vs, l, fieldDesc) } // toProtoNumericList takes a slice of some numeric type V and returns a protoreflect.Value List of the corresponding // proto integer or float values, depending on the type of the target field. func toProtoNumericList[V int | int32 | int64 | uint | uint32 | uint64 | float32 | float64](vs []V, l protoreflect.List, fieldDesc protoreflect.FieldDescriptor) (protoreflect.Value, error) { for _, v := range vs { i, err := toProtoNumericValue(fieldDesc, v) if err != nil { return protoreflect.Value{}, err } l.Append(i) } return protoreflect.ValueOfList(l), nil } func toProtoEnumList[V int | int32 | int64 | uint | uint32 | uint64 | string](vs []V, l protoreflect.List, fieldDesc protoreflect.FieldDescriptor) (protoreflect.Value, error) { for _, v := range vs { e, err := toProtoEnumValue(fieldDesc, v) if err != nil { return protoreflect.Value{}, err } l.Append(e) } return protoreflect.ValueOfList(l), nil } // toProtoNumericValue aids converting from a Go numeric type to a specific protoreflect.Value type, casting to match // the target field type. // // It supports converting numeric Go slices to List values, and scalar numeric values, in a protoreflect // validation-compatible manner by replacing the internal conversion logic of protoreflect.ValueOf with a more lenient // "blind" cast, with the assumption that inputs to structpatcher may have been deserialized in such a way as to make // strict type checking infeasible, even if the actual value and target type are compatible. This function does _not_ // attempt to handle overflows, only type conversion. // // See https://protobuf.dev/programming-guides/proto3/#scalar for canonical proto3 type mappings, reflected here. func toProtoNumericValue[V int | int32 | int64 | uint | uint32 | uint64 | float32 | float64](fieldDesc protoreflect.FieldDescriptor, v V) (protoreflect.Value, error) { switch fieldDesc.Kind() { case protoreflect.FloatKind: return protoreflect.ValueOfFloat32(float32(v)), nil case protoreflect.DoubleKind: return protoreflect.ValueOfFloat64(float64(v)), nil case protoreflect.Int32Kind: return protoreflect.ValueOfInt32(int32(v)), nil case protoreflect.Int64Kind: return protoreflect.ValueOfInt64(int64(v)), nil case protoreflect.Uint32Kind: return protoreflect.ValueOfUint32(uint32(v)), nil case protoreflect.Uint64Kind: return protoreflect.ValueOfUint64(uint64(v)), nil case protoreflect.Sint32Kind: return protoreflect.ValueOfInt32(int32(v)), nil case protoreflect.Sint64Kind: return protoreflect.ValueOfInt64(int64(v)), nil case protoreflect.Fixed32Kind: return protoreflect.ValueOfUint32(uint32(v)), nil case protoreflect.Fixed64Kind: return protoreflect.ValueOfUint64(uint64(v)), nil case protoreflect.Sfixed32Kind: return protoreflect.ValueOfInt32(int32(v)), nil case protoreflect.Sfixed64Kind: return protoreflect.ValueOfInt64(int64(v)), nil default: // Fall back to protoreflect.ValueOf, which may panic if an unexpected type is passed. return protoreflect.ValueOf(v), nil } } func toProtoEnumValue(fieldDesc protoreflect.FieldDescriptor, patchValue any) (protoreflect.Value, error) { // For enums, we accept the field number or a string representation of the name // (both supported by protojson). EnumNumber is a type alias for int32, but we // may be dealing with a 64-bit number in Go depending on how it was obtained. var enumValue protoreflect.EnumValueDescriptor switch val := patchValue.(type) { case string: enumValue = fieldDesc.Enum().Values().ByName(protoreflect.Name(val)) case int: enumValue = fieldDesc.Enum().Values().ByNumber(protoreflect.EnumNumber(val)) case int32: enumValue = fieldDesc.Enum().Values().ByNumber(protoreflect.EnumNumber(val)) case int64: enumValue = fieldDesc.Enum().Values().ByNumber(protoreflect.EnumNumber(val)) case uint: enumValue = fieldDesc.Enum().Values().ByNumber(protoreflect.EnumNumber(val)) case uint32: enumValue = fieldDesc.Enum().Values().ByNumber(protoreflect.EnumNumber(val)) case uint64: enumValue = fieldDesc.Enum().Values().ByNumber(protoreflect.EnumNumber(val)) } if enumValue != nil { return protoreflect.ValueOfEnum(enumValue.Number()), nil } return typeMismatchErr(fieldDesc, patchValue) } // toProtoWrapperValue converts possible runtime Go types (per https://protobuf.dev/programming-guides/proto3/#scalar) // to the appropriate proto wrapper target type based on the name of the given FieldDescriptor. // // This function does not attempt to handle overflows, only type conversion. func toProtoWrapperValue(fieldDesc protoreflect.FieldDescriptor, patchValue any) (protoreflect.Value, error) { fullName := string(fieldDesc.Message().FullName()) if !strings.HasPrefix(fullName, "google.protobuf.") || !strings.HasSuffix(fullName, "Value") { return unsupportedTargetTypeErr(fieldDesc) } switch val := patchValue.(type) { case int: return toProtoIntWrapperValue(fieldDesc, val) case int32: return toProtoIntWrapperValue(fieldDesc, val) case int64: return toProtoIntWrapperValue(fieldDesc, val) case uint: return toProtoIntWrapperValue(fieldDesc, val) case uint32: return toProtoIntWrapperValue(fieldDesc, val) case uint64: return toProtoIntWrapperValue(fieldDesc, val) case float32: switch fieldDesc.Message().FullName() { case "google.protobuf.FloatValue": v := wrapperspb.Float(val) return protoreflect.ValueOf(v.ProtoReflect()), nil case "google.protobuf.DoubleValue": v := wrapperspb.Double(float64(val)) return protoreflect.ValueOf(v.ProtoReflect()), nil default: // Fall back to int wrapper, since we may actually be targeting this instead. // Failure will result in a typical type mismatch error. return toProtoIntWrapperValue(fieldDesc, val) } case float64: switch fieldDesc.Message().FullName() { case "google.protobuf.FloatValue": v := wrapperspb.Float(float32(val)) return protoreflect.ValueOf(v.ProtoReflect()), nil case "google.protobuf.DoubleValue": v := wrapperspb.Double(val) return protoreflect.ValueOf(v.ProtoReflect()), nil default: // Fall back to int wrapper, since we may actually be targeting this instead. // Failure will result in a typical type mismatch error. return toProtoIntWrapperValue(fieldDesc, val) } case bool: switch fieldDesc.Message().FullName() { case "google.protobuf.BoolValue": v := wrapperspb.Bool(val) return protoreflect.ValueOf(v.ProtoReflect()), nil } case string: switch fieldDesc.Message().FullName() { case "google.protobuf.StringValue": v := wrapperspb.String(val) return protoreflect.ValueOf(v.ProtoReflect()), nil } } return typeMismatchErr(fieldDesc, patchValue) } func toProtoIntWrapperValue[V int | int32 | int64 | uint | uint32 | uint64 | float32 | float64](fieldDesc protoreflect.FieldDescriptor, v V) (protoreflect.Value, error) { switch fieldDesc.Message().FullName() { case "google.protobuf.UInt32Value": v := wrapperspb.UInt32(uint32(v)) return protoreflect.ValueOf(v.ProtoReflect()), nil case "google.protobuf.UInt64Value": v := wrapperspb.UInt64(uint64(v)) return protoreflect.ValueOf(v.ProtoReflect()), nil case "google.protobuf.Int32Value": v := wrapperspb.Int32(int32(v)) return protoreflect.ValueOf(v.ProtoReflect()), nil case "google.protobuf.Int64Value": v := wrapperspb.Int64(int64(v)) return protoreflect.ValueOf(v.ProtoReflect()), nil default: return typeMismatchErr(fieldDesc, v) } } func typeMismatchErr(fieldDesc protoreflect.FieldDescriptor, v interface{}) (protoreflect.Value, error) { return protoreflect.Value{}, fmt.Errorf("patch value type %T could not be applied to target field type '%s'", v, fieldType(fieldDesc)) } func unsupportedTargetTypeErr(fieldDesc protoreflect.FieldDescriptor) (protoreflect.Value, error) { return protoreflect.Value{}, fmt.Errorf("unsupported target field type '%s'", fieldType(fieldDesc)) } func fieldType(fieldDesc protoreflect.FieldDescriptor) string { v := fieldDesc.Kind().String() // Scalars have a useful Kind string, but complex fields should have their full name for clarity. if fieldDesc.Kind() == protoreflect.MessageKind { v = string(fieldDesc.Message().FullName()) } if fieldDesc.IsList() { v = "repeated " + v // Kind reflects the type of repeated elements in repeated fields. } if fieldDesc.IsMap() { v = "map" // maps are difficult to get types from, but we don't support them, so return a simple value for now. } return v }