diff --git a/cmd/libs/go2idl/go-to-protobuf/protobuf/cmd.go b/cmd/libs/go2idl/go-to-protobuf/protobuf/cmd.go index cfc19d0935..147dd5453b 100644 --- a/cmd/libs/go2idl/go-to-protobuf/protobuf/cmd.go +++ b/cmd/libs/go2idl/go-to-protobuf/protobuf/cmd.go @@ -43,6 +43,7 @@ type Generator struct { Conditional string Clean bool OnlyIDL bool + KeepGogoproto bool SkipGeneratedRewrite bool DropEmbeddedFields string } @@ -77,6 +78,7 @@ func (g *Generator) BindFlags(flag *flag.FlagSet) { flag.StringVar(&g.Conditional, "conditional", g.Conditional, "An optional Golang build tag condition to add to the generated Go code") flag.BoolVar(&g.Clean, "clean", g.Clean, "If true, remove all generated files for the specified Packages.") flag.BoolVar(&g.OnlyIDL, "only-idl", g.OnlyIDL, "If true, only generate the IDL for each package.") + flag.BoolVar(&g.KeepGogoproto, "keep-gogoproto", g.KeepGogoproto, "If true, the generated IDL will contain gogoprotobuf extensions which are normally removed") flag.BoolVar(&g.SkipGeneratedRewrite, "skip-generated-rewrite", g.SkipGeneratedRewrite, "If true, skip fixing up the generated.pb.go file (debugging only).") flag.StringVar(&g.DropEmbeddedFields, "drop-embedded-fields", g.DropEmbeddedFields, "Comma-delimited list of embedded Go types to omit from generated protobufs") } @@ -206,8 +208,11 @@ func Run(g *Generator) { for _, outputPackage := range outputPackages { p := outputPackage.(*protobufPackage) + path := filepath.Join(g.OutputBase, p.ImportPath()) outputPath := filepath.Join(g.OutputBase, p.OutputPath()) + + // generate the gogoprotobuf protoc cmd := exec.Command("protoc", append(args, path)...) out, err := cmd.CombinedOutput() if len(out) > 0 { @@ -217,29 +222,74 @@ func Run(g *Generator) { log.Println(strings.Join(cmd.Args, " ")) log.Fatalf("Unable to generate protoc on %s: %v", p.PackageName, err) } - if !g.SkipGeneratedRewrite { - if err := RewriteGeneratedGogoProtobufFile(outputPath, p.GoPackageName(), p.HasGoType, buf.Bytes()); err != nil { - log.Fatalf("Unable to rewrite generated %s: %v", outputPath, err) - } - cmd := exec.Command("goimports", "-w", outputPath) - out, err := cmd.CombinedOutput() - if len(out) > 0 { - log.Printf(string(out)) - } - if err != nil { - log.Println(strings.Join(cmd.Args, " ")) - log.Fatalf("Unable to rewrite imports for %s: %v", p.PackageName, err) - } + if g.SkipGeneratedRewrite { + continue + } - cmd = exec.Command("gofmt", "-s", "-w", outputPath) - out, err = cmd.CombinedOutput() - if len(out) > 0 { - log.Printf(string(out)) + // alter the generated protobuf file to remove the generated types (but leave the serializers) and rewrite the + // package statement to match the desired package name + if err := RewriteGeneratedGogoProtobufFile(outputPath, p.GoPackageName(), p.ExtractGeneratedType, buf.Bytes()); err != nil { + log.Fatalf("Unable to rewrite generated %s: %v", outputPath, err) + } + + // sort imports + cmd = exec.Command("goimports", "-w", outputPath) + out, err = cmd.CombinedOutput() + if len(out) > 0 { + log.Printf(string(out)) + } + if err != nil { + log.Println(strings.Join(cmd.Args, " ")) + log.Fatalf("Unable to rewrite imports for %s: %v", p.PackageName, err) + } + + // format and simplify the generated file + cmd = exec.Command("gofmt", "-s", "-w", outputPath) + out, err = cmd.CombinedOutput() + if len(out) > 0 { + log.Printf(string(out)) + } + if err != nil { + log.Println(strings.Join(cmd.Args, " ")) + log.Fatalf("Unable to apply gofmt for %s: %v", p.PackageName, err) + } + } + + if g.SkipGeneratedRewrite { + return + } + + if !g.KeepGogoproto { + // generate, but do so without gogoprotobuf extensions + for _, outputPackage := range outputPackages { + p := outputPackage.(*protobufPackage) + p.OmitGogo = true + } + if err := c.ExecutePackages(g.OutputBase, outputPackages); err != nil { + log.Fatalf("Failed executing generator: %v", err) + } + } + + for _, outputPackage := range outputPackages { + p := outputPackage.(*protobufPackage) + + if len(p.StructTags) == 0 { + continue + } + + pattern := filepath.Join(g.OutputBase, p.PackagePath, "*.go") + files, err := filepath.Glob(pattern) + if err != nil { + log.Fatalf("Can't glob pattern %q: %v", pattern, err) + } + + for _, s := range files { + if strings.HasSuffix(s, "_test.go") { + continue } - if err != nil { - log.Println(strings.Join(cmd.Args, " ")) - log.Fatalf("Unable to rewrite imports for %s: %v", p.PackageName, err) + if err := RewriteTypesWithProtobufStructTags(s, p.StructTags); err != nil { + log.Fatalf("Unable to rewrite with struct tags %s: %v", s, err) } } } diff --git a/cmd/libs/go2idl/go-to-protobuf/protobuf/generator.go b/cmd/libs/go2idl/go-to-protobuf/protobuf/generator.go index f63307e444..f75efba700 100644 --- a/cmd/libs/go2idl/go-to-protobuf/protobuf/generator.go +++ b/cmd/libs/go2idl/go-to-protobuf/protobuf/generator.go @@ -42,10 +42,16 @@ type genProtoIDL struct { imports *ImportTracker generateAll bool + omitGogo bool omitFieldTypes map[types.Name]struct{} } func (g *genProtoIDL) PackageVars(c *generator.Context) []string { + if g.omitGogo { + return []string{ + fmt.Sprintf("option go_package = %q;", g.localGoPackage.Name), + } + } return []string{ "option (gogoproto.marshaler_all) = true;", "option (gogoproto.sizer_all) = true;", @@ -117,7 +123,15 @@ func isProtoable(seen map[*types.Type]bool, t *types.Type) bool { } func (g *genProtoIDL) Imports(c *generator.Context) (imports []string) { - return g.imports.ImportLines() + lines := []string{} + // TODO: this could be expressed more cleanly + for _, line := range g.imports.ImportLines() { + if g.omitGogo && line == "github.com/gogo/protobuf/gogoproto/gogo.proto" { + continue + } + lines = append(lines, line) + } + return lines } // GenerateType makes the body of a file implementing a set for type t. @@ -130,7 +144,9 @@ func (g *genProtoIDL) GenerateType(c *generator.Context, t *types.Type, w io.Wri localGoPackage: g.localGoPackage.Package, }, - localPackage: g.localPackage, + localPackage: g.localPackage, + + omitGogo: g.omitGogo, omitFieldTypes: g.omitFieldTypes, t: t, @@ -201,6 +217,7 @@ func (p protobufLocator) ProtoTypeFor(t *types.Type) (*types.Type, error) { type bodyGen struct { locator ProtobufLocator localPackage types.Name + omitGogo bool omitFieldTypes map[types.Name]struct{} t *types.Type @@ -228,14 +245,18 @@ func (b bodyGen) doStruct(sw *generator.SnippetWriter) error { switch key { case "marshal": if v == "false" { - options = append(options, - "(gogoproto.marshaler) = false", - "(gogoproto.unmarshaler) = false", - "(gogoproto.sizer) = false", - ) + if !b.omitGogo { + options = append(options, + "(gogoproto.marshaler) = false", + "(gogoproto.unmarshaler) = false", + "(gogoproto.sizer) = false", + ) + } } default: - options = append(options, fmt.Sprintf("%s = %s", key, v)) + if !b.omitGogo || !strings.HasPrefix(key, "(gogoproto.") { + options = append(options, fmt.Sprintf("%s = %s", key, v)) + } } case k == "protobuf.embed": fields = []protoField{ @@ -289,14 +310,19 @@ func (b bodyGen) doStruct(sw *generator.SnippetWriter) error { } sw.Do(`$.Type|local$ $.Name$ = $.Tag$`, field) if len(field.Extras) > 0 { - fmt.Fprintf(out, " [") extras := []string{} for k, v := range field.Extras { + if b.omitGogo && strings.HasPrefix(k, "(gogoproto.") { + continue + } extras = append(extras, fmt.Sprintf("%s = %s", k, v)) } sort.Sort(sort.StringSlice(extras)) - fmt.Fprint(out, strings.Join(extras, ", ")) - fmt.Fprintf(out, "]") + if len(extras) > 0 { + fmt.Fprintf(out, " [") + fmt.Fprint(out, strings.Join(extras, ", ")) + fmt.Fprintf(out, "]") + } } fmt.Fprintf(out, ";\n") if i != len(fields)-1 { @@ -459,24 +485,19 @@ func protobufTagToField(tag string, field *protoField, m types.Member, t *types. Kind: typesKindProtobuf, } } else { - field.Type = &types.Type{ - Name: types.Name{ - Name: parts[0], - Package: localPackage.Package, - Path: localPackage.Path, - }, - Kind: typesKindProtobuf, + switch parts[0] { + case "varint", "bytes", "fixed64": + default: + field.Type = &types.Type{ + Name: types.Name{ + Name: parts[0], + Package: localPackage.Package, + Path: localPackage.Path, + }, + Kind: typesKindProtobuf, + } } } - switch parts[2] { - case "rep": - field.Repeated = true - case "opt": - field.Optional = true - case "req": - default: - return fmt.Errorf("member %q of %q malformed 'protobuf' tag, field mode is %q not recognized\n", m.Name, t.Name, parts[2]) - } field.OptionalSet = true protoExtra := make(map[string]string) @@ -485,7 +506,11 @@ func protobufTagToField(tag string, field *protoField, m types.Member, t *types. if len(parts) != 2 { return fmt.Errorf("member %q of %q malformed 'protobuf' tag, tag %d should be key=value, got %q\n", m.Name, t.Name, i+4, extra) } - protoExtra[parts[0]] = parts[1] + switch parts[0] { + case "casttype": + parts[0] = fmt.Sprintf("(gogoproto.%s)", parts[0]) + protoExtra[parts[0]] = parts[1] + } } field.Extras = protoExtra @@ -526,7 +551,7 @@ func membersToFields(locator ProtobufLocator, t *types.Type, localPackage types. if len(field.Name) == 0 && len(parts[0]) != 0 { field.Name = parts[0] } - if field.Name == "-" { + if field.Tag == -1 && field.Name == "-" { continue } } diff --git a/cmd/libs/go2idl/go-to-protobuf/protobuf/package.go b/cmd/libs/go2idl/go-to-protobuf/protobuf/package.go index 6d9111c5a6..1fae676416 100644 --- a/cmd/libs/go2idl/go-to-protobuf/protobuf/package.go +++ b/cmd/libs/go2idl/go-to-protobuf/protobuf/package.go @@ -18,8 +18,13 @@ package protobuf import ( "fmt" + "log" "os" "path/filepath" + "reflect" + "strings" + + "k8s.io/kubernetes/third_party/golang/go/ast" "k8s.io/kubernetes/cmd/libs/go2idl/generator" "k8s.io/kubernetes/cmd/libs/go2idl/types" @@ -68,12 +73,18 @@ type protobufPackage struct { // A list of types to filter to; if not specified all types will be included. FilterTypes map[types.Name]struct{} + // If true, omit any gogoprotobuf extensions not defined as types. + OmitGogo bool + // A list of field types that will be excluded from the output struct OmitFieldTypes map[types.Name]struct{} // A list of names that this package exports LocalNames map[string]struct{} + // A list of struct tags to generate onto named struct fields + StructTags map[string]map[string]string + // An import tracker for this package Imports *ImportTracker } @@ -127,6 +138,43 @@ func (p *protobufPackage) HasGoType(name string) bool { return ok } +func (p *protobufPackage) ExtractGeneratedType(t *ast.TypeSpec) bool { + if !p.HasGoType(t.Name.Name) { + return false + } + + switch s := t.Type.(type) { + case *ast.StructType: + for i, f := range s.Fields.List { + if len(f.Tag.Value) == 0 { + continue + } + tag := strings.Trim(f.Tag.Value, "`") + protobufTag := reflect.StructTag(tag).Get("protobuf") + if len(protobufTag) == 0 { + continue + } + if len(f.Names) > 1 { + log.Printf("WARNING: struct %s field %d %s: defined multiple names but single protobuf tag", t.Name.Name, i, f.Names[0].Name) + // TODO hard error? + } + if p.StructTags == nil { + p.StructTags = make(map[string]map[string]string) + } + m := p.StructTags[t.Name.Name] + if m == nil { + m = make(map[string]string) + p.StructTags[t.Name.Name] = m + } + m[f.Names[0].Name] = tag + } + default: + log.Printf("WARNING: unexpected Go AST type definition: %#v", t) + } + + return true +} + func (p *protobufPackage) Generators(c *generator.Context) []generator.Generator { generators := []generator.Generator{} @@ -140,6 +188,7 @@ func (p *protobufPackage) Generators(c *generator.Context) []generator.Generator localGoPackage: types.Name{Package: p.PackagePath, Name: p.GoPackageName()}, imports: p.Imports, generateAll: p.GenerateAll, + omitGogo: p.OmitGogo, omitFieldTypes: p.OmitFieldTypes, }) return generators diff --git a/cmd/libs/go2idl/go-to-protobuf/protobuf/parser.go b/cmd/libs/go2idl/go-to-protobuf/protobuf/parser.go index 9a1cbfbddc..29c6a7f1c9 100644 --- a/cmd/libs/go2idl/go-to-protobuf/protobuf/parser.go +++ b/cmd/libs/go2idl/go-to-protobuf/protobuf/parser.go @@ -18,17 +18,26 @@ package protobuf import ( "bytes" + "errors" + "fmt" "go/format" "io/ioutil" "os" + "reflect" + "strings" "k8s.io/kubernetes/third_party/golang/go/ast" "k8s.io/kubernetes/third_party/golang/go/parser" "k8s.io/kubernetes/third_party/golang/go/printer" "k8s.io/kubernetes/third_party/golang/go/token" + customreflect "k8s.io/kubernetes/third_party/golang/reflect" ) -func RewriteGeneratedGogoProtobufFile(name string, packageName string, typeExistsFn func(string) bool, header []byte) error { +// ExtractFunc extracts information from the provided TypeSpec and returns true if the type should be +// removed from the destination file. +type ExtractFunc func(*ast.TypeSpec) bool + +func RewriteGeneratedGogoProtobufFile(name string, packageName string, extractFn ExtractFunc, header []byte) error { fset := token.NewFileSet() src, err := ioutil.ReadFile(name) if err != nil { @@ -43,7 +52,7 @@ func RewriteGeneratedGogoProtobufFile(name string, packageName string, typeExist // remove types that are already declared decls := []ast.Decl{} for _, d := range file.Decls { - if !dropExistingTypeDeclarations(d, typeExistsFn) { + if !dropExistingTypeDeclarations(d, extractFn) { decls = append(decls, d) } } @@ -74,7 +83,7 @@ func RewriteGeneratedGogoProtobufFile(name string, packageName string, typeExist return f.Close() } -func dropExistingTypeDeclarations(decl ast.Decl, existsFn func(string) bool) bool { +func dropExistingTypeDeclarations(decl ast.Decl, extractFn ExtractFunc) bool { switch t := decl.(type) { case *ast.GenDecl: if t.Tok != token.TYPE { @@ -84,7 +93,7 @@ func dropExistingTypeDeclarations(decl ast.Decl, existsFn func(string) bool) boo for _, s := range t.Specs { switch spec := s.(type) { case *ast.TypeSpec: - if existsFn(spec.Name.Name) { + if extractFn(spec) { continue } specs = append(specs, spec) @@ -97,3 +106,128 @@ func dropExistingTypeDeclarations(decl ast.Decl, existsFn func(string) bool) boo } return false } + +func RewriteTypesWithProtobufStructTags(name string, structTags map[string]map[string]string) error { + fset := token.NewFileSet() + src, err := ioutil.ReadFile(name) + if err != nil { + return err + } + file, err := parser.ParseFile(fset, name, src, parser.DeclarationErrors|parser.ParseComments) + if err != nil { + return err + } + + allErrs := []error{} + + // set any new struct tags + for _, d := range file.Decls { + if errs := updateStructTags(d, structTags, []string{"protobuf"}); len(errs) > 0 { + allErrs = append(allErrs, errs...) + } + } + + if len(allErrs) > 0 { + var s string + for _, err := range allErrs { + s += err.Error() + "\n" + } + return errors.New(s) + } + + b := &bytes.Buffer{} + if err := printer.Fprint(b, fset, file); err != nil { + return err + } + + body, err := format.Source(b.Bytes()) + if err != nil { + return fmt.Errorf("%s\n---\nunable to format %q: %v", b, name, err) + } + + f, err := os.OpenFile(name, os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return err + } + defer f.Close() + if _, err := f.Write(body); err != nil { + return err + } + return f.Close() +} + +func updateStructTags(decl ast.Decl, structTags map[string]map[string]string, toCopy []string) []error { + var errs []error + t, ok := decl.(*ast.GenDecl) + if !ok { + return nil + } + if t.Tok != token.TYPE { + return nil + } + + for _, s := range t.Specs { + spec, ok := s.(*ast.TypeSpec) + if !ok { + continue + } + typeName := spec.Name.Name + fieldTags, ok := structTags[typeName] + if !ok { + continue + } + st, ok := spec.Type.(*ast.StructType) + if !ok { + continue + } + + for i := range st.Fields.List { + f := st.Fields.List[i] + var name string + if len(f.Names) == 0 { + switch t := f.Type.(type) { + case *ast.Ident: + name = t.Name + case *ast.SelectorExpr: + name = t.Sel.Name + default: + errs = append(errs, fmt.Errorf("unable to get name for tag from struct %q, field %#v", spec.Name.Name, t)) + continue + } + } else { + name = f.Names[0].Name + } + value, ok := fieldTags[name] + if !ok { + continue + } + var tags customreflect.StructTags + if f.Tag != nil { + oldTags, err := customreflect.ParseStructTags(strings.Trim(f.Tag.Value, "`")) + if err != nil { + errs = append(errs, fmt.Errorf("unable to read struct tag from struct %q, field %q: %v", spec.Name.Name, name, err)) + continue + } + tags = oldTags + } + for _, name := range toCopy { + // don't overwrite existing tags + if tags.Has(name) { + continue + } + // append new tags + if v := reflect.StructTag(value).Get(name); len(v) > 0 { + tags = append(tags, customreflect.StructTag{Name: name, Value: v}) + } + } + if len(tags) == 0 { + continue + } + if f.Tag == nil { + f.Tag = &ast.BasicLit{} + } + f.Tag.Value = tags.String() + } + } + return errs +} diff --git a/hack/verify-flags/known-flags.txt b/hack/verify-flags/known-flags.txt index 315b7ff69e..f45fd31ae7 100644 --- a/hack/verify-flags/known-flags.txt +++ b/hack/verify-flags/known-flags.txt @@ -153,6 +153,7 @@ ir-user jenkins-host jenkins-jobs k8s-build-output +keep-gogoproto km-path kube-api-burst kube-api-qps diff --git a/third_party/golang/reflect/type.go b/third_party/golang/reflect/type.go new file mode 100644 index 0000000000..67957ee33e --- /dev/null +++ b/third_party/golang/reflect/type.go @@ -0,0 +1,91 @@ +//This package is copied from Go library reflect/type.go. +//The struct tag library provides no way to extract the list of struct tags, only +//a specific tag +package reflect + +import ( + "fmt" + + "strconv" + "strings" +) + +type StructTag struct { + Name string + Value string +} + +func (t StructTag) String() string { + return fmt.Sprintf("%s:%q", t.Name, t.Value) +} + +type StructTags []StructTag + +func (tags StructTags) String() string { + s := make([]string, 0, len(tags)) + for _, tag := range tags { + s = append(s, tag.String()) + } + return "`" + strings.Join(s, " ") + "`" +} + +func (tags StructTags) Has(name string) bool { + for i := range tags { + if tags[i].Name == name { + return true + } + } + return false +} + +// ParseStructTags returns the full set of fields in a struct tag in the order they appear in +// the struct tag. +func ParseStructTags(tag string) (StructTags, error) { + tags := StructTags{} + for tag != "" { + // Skip leading space. + i := 0 + for i < len(tag) && tag[i] == ' ' { + i++ + } + tag = tag[i:] + if tag == "" { + break + } + + // Scan to colon. A space, a quote or a control character is a syntax error. + // Strictly speaking, control chars include the range [0x7f, 0x9f], not just + // [0x00, 0x1f], but in practice, we ignore the multi-byte control characters + // as it is simpler to inspect the tag's bytes than the tag's runes. + i = 0 + for i < len(tag) && tag[i] > ' ' && tag[i] != ':' && tag[i] != '"' && tag[i] != 0x7f { + i++ + } + if i == 0 || i+1 >= len(tag) || tag[i] != ':' || tag[i+1] != '"' { + break + } + name := string(tag[:i]) + tag = tag[i+1:] + + // Scan quoted string to find value. + i = 1 + for i < len(tag) && tag[i] != '"' { + if tag[i] == '\\' { + i++ + } + i++ + } + if i >= len(tag) { + break + } + qvalue := string(tag[:i+1]) + tag = tag[i+1:] + + value, err := strconv.Unquote(qvalue) + if err != nil { + return nil, err + } + tags = append(tags, StructTag{Name: name, Value: value}) + } + return tags, nil +}