Merge pull request #19426 from smarterclayton/rewrite_tags

Auto commit by PR queue bot
pull/6/head
k8s-merge-robot 2016-01-26 13:11:58 -08:00
commit 3254df3a7c
6 changed files with 403 additions and 53 deletions

View File

@ -43,6 +43,7 @@ type Generator struct {
Conditional string
Clean bool
OnlyIDL bool
KeepGogoproto bool
SkipGeneratedRewrite bool
DropEmbeddedFields string
}
@ -77,6 +78,7 @@ func (g *Generator) BindFlags(flag *flag.FlagSet) {
flag.StringVar(&g.Conditional, "conditional", g.Conditional, "An optional Golang build tag condition to add to the generated Go code")
flag.BoolVar(&g.Clean, "clean", g.Clean, "If true, remove all generated files for the specified Packages.")
flag.BoolVar(&g.OnlyIDL, "only-idl", g.OnlyIDL, "If true, only generate the IDL for each package.")
flag.BoolVar(&g.KeepGogoproto, "keep-gogoproto", g.KeepGogoproto, "If true, the generated IDL will contain gogoprotobuf extensions which are normally removed")
flag.BoolVar(&g.SkipGeneratedRewrite, "skip-generated-rewrite", g.SkipGeneratedRewrite, "If true, skip fixing up the generated.pb.go file (debugging only).")
flag.StringVar(&g.DropEmbeddedFields, "drop-embedded-fields", g.DropEmbeddedFields, "Comma-delimited list of embedded Go types to omit from generated protobufs")
}
@ -206,8 +208,11 @@ func Run(g *Generator) {
for _, outputPackage := range outputPackages {
p := outputPackage.(*protobufPackage)
path := filepath.Join(g.OutputBase, p.ImportPath())
outputPath := filepath.Join(g.OutputBase, p.OutputPath())
// generate the gogoprotobuf protoc
cmd := exec.Command("protoc", append(args, path)...)
out, err := cmd.CombinedOutput()
if len(out) > 0 {
@ -217,29 +222,74 @@ func Run(g *Generator) {
log.Println(strings.Join(cmd.Args, " "))
log.Fatalf("Unable to generate protoc on %s: %v", p.PackageName, err)
}
if !g.SkipGeneratedRewrite {
if err := RewriteGeneratedGogoProtobufFile(outputPath, p.GoPackageName(), p.HasGoType, buf.Bytes()); err != nil {
log.Fatalf("Unable to rewrite generated %s: %v", outputPath, err)
}
cmd := exec.Command("goimports", "-w", outputPath)
out, err := cmd.CombinedOutput()
if len(out) > 0 {
log.Printf(string(out))
}
if err != nil {
log.Println(strings.Join(cmd.Args, " "))
log.Fatalf("Unable to rewrite imports for %s: %v", p.PackageName, err)
}
if g.SkipGeneratedRewrite {
continue
}
cmd = exec.Command("gofmt", "-s", "-w", outputPath)
out, err = cmd.CombinedOutput()
if len(out) > 0 {
log.Printf(string(out))
// alter the generated protobuf file to remove the generated types (but leave the serializers) and rewrite the
// package statement to match the desired package name
if err := RewriteGeneratedGogoProtobufFile(outputPath, p.GoPackageName(), p.ExtractGeneratedType, buf.Bytes()); err != nil {
log.Fatalf("Unable to rewrite generated %s: %v", outputPath, err)
}
// sort imports
cmd = exec.Command("goimports", "-w", outputPath)
out, err = cmd.CombinedOutput()
if len(out) > 0 {
log.Printf(string(out))
}
if err != nil {
log.Println(strings.Join(cmd.Args, " "))
log.Fatalf("Unable to rewrite imports for %s: %v", p.PackageName, err)
}
// format and simplify the generated file
cmd = exec.Command("gofmt", "-s", "-w", outputPath)
out, err = cmd.CombinedOutput()
if len(out) > 0 {
log.Printf(string(out))
}
if err != nil {
log.Println(strings.Join(cmd.Args, " "))
log.Fatalf("Unable to apply gofmt for %s: %v", p.PackageName, err)
}
}
if g.SkipGeneratedRewrite {
return
}
if !g.KeepGogoproto {
// generate, but do so without gogoprotobuf extensions
for _, outputPackage := range outputPackages {
p := outputPackage.(*protobufPackage)
p.OmitGogo = true
}
if err := c.ExecutePackages(g.OutputBase, outputPackages); err != nil {
log.Fatalf("Failed executing generator: %v", err)
}
}
for _, outputPackage := range outputPackages {
p := outputPackage.(*protobufPackage)
if len(p.StructTags) == 0 {
continue
}
pattern := filepath.Join(g.OutputBase, p.PackagePath, "*.go")
files, err := filepath.Glob(pattern)
if err != nil {
log.Fatalf("Can't glob pattern %q: %v", pattern, err)
}
for _, s := range files {
if strings.HasSuffix(s, "_test.go") {
continue
}
if err != nil {
log.Println(strings.Join(cmd.Args, " "))
log.Fatalf("Unable to rewrite imports for %s: %v", p.PackageName, err)
if err := RewriteTypesWithProtobufStructTags(s, p.StructTags); err != nil {
log.Fatalf("Unable to rewrite with struct tags %s: %v", s, err)
}
}
}

View File

@ -42,10 +42,16 @@ type genProtoIDL struct {
imports *ImportTracker
generateAll bool
omitGogo bool
omitFieldTypes map[types.Name]struct{}
}
func (g *genProtoIDL) PackageVars(c *generator.Context) []string {
if g.omitGogo {
return []string{
fmt.Sprintf("option go_package = %q;", g.localGoPackage.Name),
}
}
return []string{
"option (gogoproto.marshaler_all) = true;",
"option (gogoproto.sizer_all) = true;",
@ -117,7 +123,15 @@ func isProtoable(seen map[*types.Type]bool, t *types.Type) bool {
}
func (g *genProtoIDL) Imports(c *generator.Context) (imports []string) {
return g.imports.ImportLines()
lines := []string{}
// TODO: this could be expressed more cleanly
for _, line := range g.imports.ImportLines() {
if g.omitGogo && line == "github.com/gogo/protobuf/gogoproto/gogo.proto" {
continue
}
lines = append(lines, line)
}
return lines
}
// GenerateType makes the body of a file implementing a set for type t.
@ -130,7 +144,9 @@ func (g *genProtoIDL) GenerateType(c *generator.Context, t *types.Type, w io.Wri
localGoPackage: g.localGoPackage.Package,
},
localPackage: g.localPackage,
localPackage: g.localPackage,
omitGogo: g.omitGogo,
omitFieldTypes: g.omitFieldTypes,
t: t,
@ -201,6 +217,7 @@ func (p protobufLocator) ProtoTypeFor(t *types.Type) (*types.Type, error) {
type bodyGen struct {
locator ProtobufLocator
localPackage types.Name
omitGogo bool
omitFieldTypes map[types.Name]struct{}
t *types.Type
@ -228,14 +245,18 @@ func (b bodyGen) doStruct(sw *generator.SnippetWriter) error {
switch key {
case "marshal":
if v == "false" {
options = append(options,
"(gogoproto.marshaler) = false",
"(gogoproto.unmarshaler) = false",
"(gogoproto.sizer) = false",
)
if !b.omitGogo {
options = append(options,
"(gogoproto.marshaler) = false",
"(gogoproto.unmarshaler) = false",
"(gogoproto.sizer) = false",
)
}
}
default:
options = append(options, fmt.Sprintf("%s = %s", key, v))
if !b.omitGogo || !strings.HasPrefix(key, "(gogoproto.") {
options = append(options, fmt.Sprintf("%s = %s", key, v))
}
}
case k == "protobuf.embed":
fields = []protoField{
@ -289,14 +310,19 @@ func (b bodyGen) doStruct(sw *generator.SnippetWriter) error {
}
sw.Do(`$.Type|local$ $.Name$ = $.Tag$`, field)
if len(field.Extras) > 0 {
fmt.Fprintf(out, " [")
extras := []string{}
for k, v := range field.Extras {
if b.omitGogo && strings.HasPrefix(k, "(gogoproto.") {
continue
}
extras = append(extras, fmt.Sprintf("%s = %s", k, v))
}
sort.Sort(sort.StringSlice(extras))
fmt.Fprint(out, strings.Join(extras, ", "))
fmt.Fprintf(out, "]")
if len(extras) > 0 {
fmt.Fprintf(out, " [")
fmt.Fprint(out, strings.Join(extras, ", "))
fmt.Fprintf(out, "]")
}
}
fmt.Fprintf(out, ";\n")
if i != len(fields)-1 {
@ -459,24 +485,19 @@ func protobufTagToField(tag string, field *protoField, m types.Member, t *types.
Kind: typesKindProtobuf,
}
} else {
field.Type = &types.Type{
Name: types.Name{
Name: parts[0],
Package: localPackage.Package,
Path: localPackage.Path,
},
Kind: typesKindProtobuf,
switch parts[0] {
case "varint", "bytes", "fixed64":
default:
field.Type = &types.Type{
Name: types.Name{
Name: parts[0],
Package: localPackage.Package,
Path: localPackage.Path,
},
Kind: typesKindProtobuf,
}
}
}
switch parts[2] {
case "rep":
field.Repeated = true
case "opt":
field.Optional = true
case "req":
default:
return fmt.Errorf("member %q of %q malformed 'protobuf' tag, field mode is %q not recognized\n", m.Name, t.Name, parts[2])
}
field.OptionalSet = true
protoExtra := make(map[string]string)
@ -485,7 +506,11 @@ func protobufTagToField(tag string, field *protoField, m types.Member, t *types.
if len(parts) != 2 {
return fmt.Errorf("member %q of %q malformed 'protobuf' tag, tag %d should be key=value, got %q\n", m.Name, t.Name, i+4, extra)
}
protoExtra[parts[0]] = parts[1]
switch parts[0] {
case "casttype":
parts[0] = fmt.Sprintf("(gogoproto.%s)", parts[0])
protoExtra[parts[0]] = parts[1]
}
}
field.Extras = protoExtra
@ -526,7 +551,7 @@ func membersToFields(locator ProtobufLocator, t *types.Type, localPackage types.
if len(field.Name) == 0 && len(parts[0]) != 0 {
field.Name = parts[0]
}
if field.Name == "-" {
if field.Tag == -1 && field.Name == "-" {
continue
}
}

View File

@ -18,8 +18,13 @@ package protobuf
import (
"fmt"
"log"
"os"
"path/filepath"
"reflect"
"strings"
"k8s.io/kubernetes/third_party/golang/go/ast"
"k8s.io/kubernetes/cmd/libs/go2idl/generator"
"k8s.io/kubernetes/cmd/libs/go2idl/types"
@ -68,12 +73,18 @@ type protobufPackage struct {
// A list of types to filter to; if not specified all types will be included.
FilterTypes map[types.Name]struct{}
// If true, omit any gogoprotobuf extensions not defined as types.
OmitGogo bool
// A list of field types that will be excluded from the output struct
OmitFieldTypes map[types.Name]struct{}
// A list of names that this package exports
LocalNames map[string]struct{}
// A list of struct tags to generate onto named struct fields
StructTags map[string]map[string]string
// An import tracker for this package
Imports *ImportTracker
}
@ -127,6 +138,43 @@ func (p *protobufPackage) HasGoType(name string) bool {
return ok
}
func (p *protobufPackage) ExtractGeneratedType(t *ast.TypeSpec) bool {
if !p.HasGoType(t.Name.Name) {
return false
}
switch s := t.Type.(type) {
case *ast.StructType:
for i, f := range s.Fields.List {
if len(f.Tag.Value) == 0 {
continue
}
tag := strings.Trim(f.Tag.Value, "`")
protobufTag := reflect.StructTag(tag).Get("protobuf")
if len(protobufTag) == 0 {
continue
}
if len(f.Names) > 1 {
log.Printf("WARNING: struct %s field %d %s: defined multiple names but single protobuf tag", t.Name.Name, i, f.Names[0].Name)
// TODO hard error?
}
if p.StructTags == nil {
p.StructTags = make(map[string]map[string]string)
}
m := p.StructTags[t.Name.Name]
if m == nil {
m = make(map[string]string)
p.StructTags[t.Name.Name] = m
}
m[f.Names[0].Name] = tag
}
default:
log.Printf("WARNING: unexpected Go AST type definition: %#v", t)
}
return true
}
func (p *protobufPackage) Generators(c *generator.Context) []generator.Generator {
generators := []generator.Generator{}
@ -140,6 +188,7 @@ func (p *protobufPackage) Generators(c *generator.Context) []generator.Generator
localGoPackage: types.Name{Package: p.PackagePath, Name: p.GoPackageName()},
imports: p.Imports,
generateAll: p.GenerateAll,
omitGogo: p.OmitGogo,
omitFieldTypes: p.OmitFieldTypes,
})
return generators

View File

@ -18,17 +18,26 @@ package protobuf
import (
"bytes"
"errors"
"fmt"
"go/format"
"io/ioutil"
"os"
"reflect"
"strings"
"k8s.io/kubernetes/third_party/golang/go/ast"
"k8s.io/kubernetes/third_party/golang/go/parser"
"k8s.io/kubernetes/third_party/golang/go/printer"
"k8s.io/kubernetes/third_party/golang/go/token"
customreflect "k8s.io/kubernetes/third_party/golang/reflect"
)
func RewriteGeneratedGogoProtobufFile(name string, packageName string, typeExistsFn func(string) bool, header []byte) error {
// ExtractFunc extracts information from the provided TypeSpec and returns true if the type should be
// removed from the destination file.
type ExtractFunc func(*ast.TypeSpec) bool
func RewriteGeneratedGogoProtobufFile(name string, packageName string, extractFn ExtractFunc, header []byte) error {
fset := token.NewFileSet()
src, err := ioutil.ReadFile(name)
if err != nil {
@ -43,7 +52,7 @@ func RewriteGeneratedGogoProtobufFile(name string, packageName string, typeExist
// remove types that are already declared
decls := []ast.Decl{}
for _, d := range file.Decls {
if !dropExistingTypeDeclarations(d, typeExistsFn) {
if !dropExistingTypeDeclarations(d, extractFn) {
decls = append(decls, d)
}
}
@ -74,7 +83,7 @@ func RewriteGeneratedGogoProtobufFile(name string, packageName string, typeExist
return f.Close()
}
func dropExistingTypeDeclarations(decl ast.Decl, existsFn func(string) bool) bool {
func dropExistingTypeDeclarations(decl ast.Decl, extractFn ExtractFunc) bool {
switch t := decl.(type) {
case *ast.GenDecl:
if t.Tok != token.TYPE {
@ -84,7 +93,7 @@ func dropExistingTypeDeclarations(decl ast.Decl, existsFn func(string) bool) boo
for _, s := range t.Specs {
switch spec := s.(type) {
case *ast.TypeSpec:
if existsFn(spec.Name.Name) {
if extractFn(spec) {
continue
}
specs = append(specs, spec)
@ -97,3 +106,128 @@ func dropExistingTypeDeclarations(decl ast.Decl, existsFn func(string) bool) boo
}
return false
}
func RewriteTypesWithProtobufStructTags(name string, structTags map[string]map[string]string) error {
fset := token.NewFileSet()
src, err := ioutil.ReadFile(name)
if err != nil {
return err
}
file, err := parser.ParseFile(fset, name, src, parser.DeclarationErrors|parser.ParseComments)
if err != nil {
return err
}
allErrs := []error{}
// set any new struct tags
for _, d := range file.Decls {
if errs := updateStructTags(d, structTags, []string{"protobuf"}); len(errs) > 0 {
allErrs = append(allErrs, errs...)
}
}
if len(allErrs) > 0 {
var s string
for _, err := range allErrs {
s += err.Error() + "\n"
}
return errors.New(s)
}
b := &bytes.Buffer{}
if err := printer.Fprint(b, fset, file); err != nil {
return err
}
body, err := format.Source(b.Bytes())
if err != nil {
return fmt.Errorf("%s\n---\nunable to format %q: %v", b, name, err)
}
f, err := os.OpenFile(name, os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return err
}
defer f.Close()
if _, err := f.Write(body); err != nil {
return err
}
return f.Close()
}
func updateStructTags(decl ast.Decl, structTags map[string]map[string]string, toCopy []string) []error {
var errs []error
t, ok := decl.(*ast.GenDecl)
if !ok {
return nil
}
if t.Tok != token.TYPE {
return nil
}
for _, s := range t.Specs {
spec, ok := s.(*ast.TypeSpec)
if !ok {
continue
}
typeName := spec.Name.Name
fieldTags, ok := structTags[typeName]
if !ok {
continue
}
st, ok := spec.Type.(*ast.StructType)
if !ok {
continue
}
for i := range st.Fields.List {
f := st.Fields.List[i]
var name string
if len(f.Names) == 0 {
switch t := f.Type.(type) {
case *ast.Ident:
name = t.Name
case *ast.SelectorExpr:
name = t.Sel.Name
default:
errs = append(errs, fmt.Errorf("unable to get name for tag from struct %q, field %#v", spec.Name.Name, t))
continue
}
} else {
name = f.Names[0].Name
}
value, ok := fieldTags[name]
if !ok {
continue
}
var tags customreflect.StructTags
if f.Tag != nil {
oldTags, err := customreflect.ParseStructTags(strings.Trim(f.Tag.Value, "`"))
if err != nil {
errs = append(errs, fmt.Errorf("unable to read struct tag from struct %q, field %q: %v", spec.Name.Name, name, err))
continue
}
tags = oldTags
}
for _, name := range toCopy {
// don't overwrite existing tags
if tags.Has(name) {
continue
}
// append new tags
if v := reflect.StructTag(value).Get(name); len(v) > 0 {
tags = append(tags, customreflect.StructTag{Name: name, Value: v})
}
}
if len(tags) == 0 {
continue
}
if f.Tag == nil {
f.Tag = &ast.BasicLit{}
}
f.Tag.Value = tags.String()
}
}
return errs
}

View File

@ -153,6 +153,7 @@ ir-user
jenkins-host
jenkins-jobs
k8s-build-output
keep-gogoproto
km-path
kube-api-burst
kube-api-qps

91
third_party/golang/reflect/type.go vendored Normal file
View File

@ -0,0 +1,91 @@
//This package is copied from Go library reflect/type.go.
//The struct tag library provides no way to extract the list of struct tags, only
//a specific tag
package reflect
import (
"fmt"
"strconv"
"strings"
)
type StructTag struct {
Name string
Value string
}
func (t StructTag) String() string {
return fmt.Sprintf("%s:%q", t.Name, t.Value)
}
type StructTags []StructTag
func (tags StructTags) String() string {
s := make([]string, 0, len(tags))
for _, tag := range tags {
s = append(s, tag.String())
}
return "`" + strings.Join(s, " ") + "`"
}
func (tags StructTags) Has(name string) bool {
for i := range tags {
if tags[i].Name == name {
return true
}
}
return false
}
// ParseStructTags returns the full set of fields in a struct tag in the order they appear in
// the struct tag.
func ParseStructTags(tag string) (StructTags, error) {
tags := StructTags{}
for tag != "" {
// Skip leading space.
i := 0
for i < len(tag) && tag[i] == ' ' {
i++
}
tag = tag[i:]
if tag == "" {
break
}
// Scan to colon. A space, a quote or a control character is a syntax error.
// Strictly speaking, control chars include the range [0x7f, 0x9f], not just
// [0x00, 0x1f], but in practice, we ignore the multi-byte control characters
// as it is simpler to inspect the tag's bytes than the tag's runes.
i = 0
for i < len(tag) && tag[i] > ' ' && tag[i] != ':' && tag[i] != '"' && tag[i] != 0x7f {
i++
}
if i == 0 || i+1 >= len(tag) || tag[i] != ':' || tag[i+1] != '"' {
break
}
name := string(tag[:i])
tag = tag[i+1:]
// Scan quoted string to find value.
i = 1
for i < len(tag) && tag[i] != '"' {
if tag[i] == '\\' {
i++
}
i++
}
if i >= len(tag) {
break
}
qvalue := string(tag[:i+1])
tag = tag[i+1:]
value, err := strconv.Unquote(qvalue)
if err != nil {
return nil, err
}
tags = append(tags, StructTag{Name: name, Value: value})
}
return tags, nil
}