mirror of https://github.com/hashicorp/consul
542 lines
12 KiB
Go
542 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/parser"
|
|
"go/token"
|
|
"log"
|
|
"os"
|
|
"os/exec"
|
|
"strings"
|
|
)
|
|
|
|
var (
|
|
flagPath = flag.String("path", "", "path of file to load")
|
|
verbose = flag.Bool("v", false, "verbose output")
|
|
)
|
|
|
|
const (
|
|
annotationPrefix = "@consul-rpc-glue:"
|
|
outputFileSuffix = ".rpcglue.pb.go"
|
|
)
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
|
|
log.SetFlags(0)
|
|
|
|
if *flagPath == "" {
|
|
log.Fatal("missing required -path argument")
|
|
}
|
|
|
|
if err := run(*flagPath); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func run(path string) error {
|
|
fi, err := os.Stat(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if fi.IsDir() {
|
|
return fmt.Errorf("argument must be a file: %s", path)
|
|
}
|
|
|
|
if !strings.HasSuffix(path, ".pb.go") {
|
|
return fmt.Errorf("file must end with .pb.go: %s", path)
|
|
}
|
|
|
|
if err := processFile(path); err != nil {
|
|
return fmt.Errorf("error processing file %q: %v", path, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func processFile(path string) error {
|
|
if *verbose {
|
|
log.Printf("visiting file %q", path)
|
|
}
|
|
|
|
fset := token.NewFileSet()
|
|
tree, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
v := visitor{}
|
|
ast.Walk(&v, tree)
|
|
if err := v.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(v.Types) == 0 {
|
|
return nil
|
|
}
|
|
|
|
if *verbose {
|
|
log.Printf("Package: %s", v.Package)
|
|
log.Printf("BuildTags: %v", v.BuildTags)
|
|
log.Println()
|
|
for _, typ := range v.Types {
|
|
log.Printf("Type: %s", typ.Name)
|
|
ann := typ.Annotation
|
|
if ann.ReadRequest != "" {
|
|
log.Printf(" ReadRequest from %s", ann.ReadRequest)
|
|
}
|
|
if ann.WriteRequest != "" {
|
|
log.Printf(" WriteRequest from %s", ann.WriteRequest)
|
|
}
|
|
if ann.TargetDatacenter != "" {
|
|
log.Printf(" TargetDatacenter from %s", ann.TargetDatacenter)
|
|
}
|
|
if ann.QueryOptions != "" {
|
|
log.Printf(" QueryOptions from %s", ann.QueryOptions)
|
|
}
|
|
if ann.QueryMeta != "" {
|
|
log.Printf(" QueryMeta from %s", ann.QueryMeta)
|
|
}
|
|
}
|
|
}
|
|
|
|
// generate output
|
|
|
|
var buf bytes.Buffer
|
|
|
|
if len(v.BuildTags) > 0 {
|
|
for _, line := range v.BuildTags {
|
|
buf.WriteString(line + "\n")
|
|
}
|
|
buf.WriteString("\n")
|
|
}
|
|
buf.WriteString("// Code generated by proto-gen-rpc-glue. DO NOT EDIT.\n\n")
|
|
buf.WriteString("package " + v.Package + "\n")
|
|
buf.WriteString(`
|
|
import (
|
|
"time"
|
|
|
|
"github.com/hashicorp/consul/agent/structs"
|
|
)
|
|
|
|
// Reference imports to suppress errors if they are not otherwise used.
|
|
var _ structs.RPCInfo
|
|
|
|
`)
|
|
for _, typ := range v.Types {
|
|
if typ.Annotation.WriteRequest != "" {
|
|
buf.WriteString(fmt.Sprintf(tmplWriteRequest, typ.Name, typ.Annotation.WriteRequest))
|
|
}
|
|
if typ.Annotation.ReadRequest != "" {
|
|
buf.WriteString(fmt.Sprintf(tmplReadRequest, typ.Name, typ.Annotation.ReadRequest))
|
|
}
|
|
if typ.Annotation.TargetDatacenter != "" {
|
|
buf.WriteString(fmt.Sprintf(tmplTargetDatacenter, typ.Name, typ.Annotation.TargetDatacenter))
|
|
}
|
|
if typ.Annotation.QueryOptions != "" {
|
|
buf.WriteString(fmt.Sprintf(tmplQueryOptions, typ.Name, typ.Annotation.QueryOptions))
|
|
}
|
|
if typ.Annotation.QueryMeta != "" {
|
|
buf.WriteString(fmt.Sprintf(tmplQueryMeta, typ.Name, typ.Annotation.QueryMeta))
|
|
}
|
|
}
|
|
|
|
// write to disk
|
|
outFile := strings.TrimSuffix(path, ".pb.go") + outputFileSuffix
|
|
if err := os.WriteFile(outFile, buf.Bytes(), 0644); err != nil {
|
|
return err
|
|
}
|
|
|
|
// clean up
|
|
cmd := exec.Command("gofmt", "-s", "-w", outFile)
|
|
cmd.Stdout = nil
|
|
cmd.Stderr = os.Stderr
|
|
cmd.Stdin = nil
|
|
if err := cmd.Run(); err != nil {
|
|
return fmt.Errorf("error running 'gofmt -s -w %q': %v", outFile, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type TypeInfo struct {
|
|
Name string
|
|
Annotation Annotation
|
|
}
|
|
|
|
type visitor struct {
|
|
Package string
|
|
BuildTags []string
|
|
Types []TypeInfo
|
|
Errs []error
|
|
}
|
|
|
|
func (v *visitor) Err() error {
|
|
switch len(v.Errs) {
|
|
case 0:
|
|
return nil
|
|
case 1:
|
|
return v.Errs[0]
|
|
default:
|
|
//
|
|
var s []string
|
|
for _, e := range v.Errs {
|
|
s = append(s, e.Error())
|
|
}
|
|
return errors.New(strings.Join(s, "; "))
|
|
}
|
|
}
|
|
|
|
var _ ast.Visitor = (*visitor)(nil)
|
|
|
|
func (v *visitor) Visit(node ast.Node) ast.Visitor {
|
|
if node == nil {
|
|
return v
|
|
}
|
|
|
|
switch x := node.(type) {
|
|
case *ast.File:
|
|
v.Package = x.Name.Name
|
|
v.BuildTags = getRawBuildTags(x)
|
|
for _, d := range x.Decls {
|
|
gd, ok := d.(*ast.GenDecl)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
if gd.Doc == nil {
|
|
continue
|
|
} else if len(gd.Specs) != 1 {
|
|
continue
|
|
}
|
|
spec := gd.Specs[0]
|
|
|
|
typeSpec, ok := spec.(*ast.TypeSpec)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
ann, err := getAnnotation(gd.Doc.List)
|
|
if err != nil {
|
|
v.Errs = append(v.Errs, err)
|
|
continue
|
|
} else if ann.IsZero() {
|
|
continue
|
|
}
|
|
|
|
v.Types = append(v.Types, TypeInfo{
|
|
Name: typeSpec.Name.Name,
|
|
Annotation: ann,
|
|
})
|
|
|
|
}
|
|
}
|
|
return v
|
|
}
|
|
|
|
type Annotation struct {
|
|
QueryMeta string
|
|
QueryOptions string
|
|
ReadRequest string
|
|
WriteRequest string
|
|
TargetDatacenter string
|
|
}
|
|
|
|
func (a Annotation) IsZero() bool {
|
|
return a == Annotation{}
|
|
}
|
|
|
|
func getAnnotation(doc []*ast.Comment) (Annotation, error) {
|
|
raw, ok := getRawStructAnnotation(doc)
|
|
if !ok {
|
|
return Annotation{}, nil
|
|
}
|
|
|
|
var ann Annotation
|
|
|
|
parts := strings.Split(raw, ",")
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
switch {
|
|
case part == "ReadRequest":
|
|
ann.ReadRequest = "ReadRequest"
|
|
case strings.HasPrefix(part, "ReadRequest="):
|
|
ann.ReadRequest = strings.TrimPrefix(part, "ReadRequest=")
|
|
|
|
case part == "WriteRequest":
|
|
ann.WriteRequest = "WriteRequest"
|
|
case strings.HasPrefix(part, "WriteRequest="):
|
|
ann.WriteRequest = strings.TrimPrefix(part, "WriteRequest=")
|
|
|
|
case part == "TargetDatacenter":
|
|
ann.TargetDatacenter = "TargetDatacenter"
|
|
case strings.HasPrefix(part, "TargetDatacenter="):
|
|
ann.TargetDatacenter = strings.TrimPrefix(part, "TargetDatacenter=")
|
|
|
|
case part == "QueryOptions":
|
|
ann.QueryOptions = "QueryOptions"
|
|
case strings.HasPrefix(part, "QueryOptions="):
|
|
ann.QueryOptions = strings.TrimPrefix(part, "QueryOptions=")
|
|
|
|
case part == "QueryMeta":
|
|
ann.QueryMeta = "QueryMeta"
|
|
case strings.HasPrefix(part, "QueryMeta="):
|
|
ann.QueryMeta = strings.TrimPrefix(part, "QueryMeta=")
|
|
|
|
default:
|
|
return Annotation{}, fmt.Errorf("unexpected annotation part: %s", part)
|
|
}
|
|
}
|
|
|
|
return ann, nil
|
|
}
|
|
|
|
func getRawStructAnnotation(doc []*ast.Comment) (string, bool) {
|
|
for _, line := range doc {
|
|
text := strings.TrimSpace(strings.TrimLeft(line.Text, "/"))
|
|
|
|
ann := strings.TrimSpace(strings.TrimPrefix(text, annotationPrefix))
|
|
|
|
if text != ann {
|
|
return ann, true
|
|
}
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func getRawBuildTags(file *ast.File) []string {
|
|
// build tags are always the first group, at the very top
|
|
if len(file.Comments) == 0 {
|
|
return nil
|
|
}
|
|
cg := file.Comments[0]
|
|
|
|
var out []string
|
|
for _, line := range cg.List {
|
|
text := strings.TrimSpace(strings.TrimLeft(line.Text, "/"))
|
|
|
|
if !strings.HasPrefix(text, "go:build ") && !strings.HasPrefix(text, "+build") {
|
|
break // stop at first non-build-tag
|
|
}
|
|
|
|
out = append(out, line.Text)
|
|
}
|
|
|
|
return out
|
|
}
|
|
|
|
const tmplWriteRequest = `
|
|
// AllowStaleRead implements structs.RPCInfo
|
|
func (msg *%[1]s) AllowStaleRead() bool {
|
|
return false
|
|
}
|
|
|
|
// HasTimedOut implements structs.RPCInfo
|
|
func (msg *%[1]s) HasTimedOut(start time.Time, rpcHoldTimeout time.Duration, a time.Duration, b time.Duration) (bool, error) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return false, nil
|
|
}
|
|
return msg.%[2]s.HasTimedOut(start, rpcHoldTimeout, a, b)
|
|
}
|
|
|
|
// IsRead implements structs.RPCInfo
|
|
func (msg *%[1]s) IsRead() bool {
|
|
return false
|
|
}
|
|
|
|
// SetTokenSecret implements structs.RPCInfo
|
|
func (msg *%[1]s) SetTokenSecret(s string) {
|
|
// TODO: initialize if nil
|
|
msg.%[2]s.SetTokenSecret(s)
|
|
}
|
|
|
|
// TokenSecret implements structs.RPCInfo
|
|
func (msg *%[1]s) TokenSecret() string {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
return msg.%[2]s.TokenSecret()
|
|
}
|
|
|
|
// Token implements structs.RPCInfo
|
|
func (msg *%[1]s) Token() string {
|
|
if msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
return msg.%[2]s.Token
|
|
}
|
|
`
|
|
|
|
const tmplReadRequest = `
|
|
// IsRead implements structs.RPCInfo
|
|
func (msg *%[1]s) IsRead() bool {
|
|
return true
|
|
}
|
|
|
|
// AllowStaleRead implements structs.RPCInfo
|
|
func (msg *%[1]s) AllowStaleRead() bool {
|
|
// TODO: initialize if nil
|
|
return msg.%[2]s.AllowStaleRead()
|
|
}
|
|
|
|
// HasTimedOut implements structs.RPCInfo
|
|
func (msg *%[1]s) HasTimedOut(start time.Time, rpcHoldTimeout time.Duration, a time.Duration, b time.Duration) (bool, error) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return false, nil
|
|
}
|
|
return msg.%[2]s.HasTimedOut(start, rpcHoldTimeout, a, b)
|
|
}
|
|
|
|
// SetTokenSecret implements structs.RPCInfo
|
|
func (msg *%[1]s) SetTokenSecret(s string) {
|
|
// TODO: initialize if nil
|
|
msg.%[2]s.SetTokenSecret(s)
|
|
}
|
|
|
|
// TokenSecret implements structs.RPCInfo
|
|
func (msg *%[1]s) TokenSecret() string {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
return msg.%[2]s.TokenSecret()
|
|
}
|
|
|
|
// Token implements structs.RPCInfo
|
|
func (msg *%[1]s) Token() string {
|
|
if msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
return msg.%[2]s.Token
|
|
}
|
|
`
|
|
|
|
const tmplTargetDatacenter = `
|
|
// RequestDatacenter implements structs.RPCInfo
|
|
func (msg *%[1]s) RequestDatacenter() string {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
return msg.%[2]s.GetDatacenter()
|
|
}
|
|
`
|
|
|
|
const tmplQueryOptions = `
|
|
// IsRead implements structs.RPCInfo
|
|
func (msg *%[1]s) IsRead() bool {
|
|
return true
|
|
}
|
|
|
|
// AllowStaleRead implements structs.RPCInfo
|
|
func (msg *%[1]s) AllowStaleRead() bool {
|
|
return msg.%[2]s.AllowStaleRead()
|
|
}
|
|
|
|
// HasTimedOut implements structs.RPCInfo
|
|
func (msg *%[1]s) HasTimedOut(start time.Time, rpcHoldTimeout time.Duration, a time.Duration, b time.Duration) (bool, error) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return false, nil
|
|
}
|
|
return msg.%[2]s.HasTimedOut(start, rpcHoldTimeout, a, b)
|
|
}
|
|
// SetTokenSecret implements structs.RPCInfo
|
|
func (msg *%[1]s) SetTokenSecret(s string) {
|
|
// TODO: initialize if nil
|
|
msg.%[2]s.SetTokenSecret(s)
|
|
}
|
|
|
|
// TokenSecret implements structs.RPCInfo
|
|
func (msg *%[1]s) TokenSecret() string {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
return msg.%[2]s.TokenSecret()
|
|
}
|
|
|
|
// Token implements structs.RPCInfo
|
|
func (msg *%[1]s) Token() string {
|
|
if msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
return msg.%[2]s.Token
|
|
}
|
|
// GetToken is required to implement blockingQueryOptions
|
|
func (msg *%[1]s) GetToken() string {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
|
|
return msg.%[2]s.GetToken()
|
|
}
|
|
// GetMinQueryIndex is required to implement blockingQueryOptions
|
|
func (msg *%[1]s) GetMinQueryIndex() uint64 {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return 0
|
|
}
|
|
|
|
return msg.%[2]s.GetMinQueryIndex()
|
|
}
|
|
// GetMaxQueryTime is required to implement blockingQueryOptions
|
|
func (msg *%[1]s) GetMaxQueryTime() (time.Duration, error) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return 0, nil
|
|
}
|
|
|
|
return structs.DurationFromProto(msg.%[2]s.GetMaxQueryTime()), nil
|
|
}
|
|
|
|
// GetRequireConsistent is required to implement blockingQueryOptions
|
|
func (msg *%[1]s) GetRequireConsistent() bool {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return false
|
|
}
|
|
return msg.%[2]s.RequireConsistent
|
|
}
|
|
`
|
|
|
|
const tmplQueryMeta = `
|
|
// SetLastContact is required to implement blockingQueryResponseMeta
|
|
func (msg *%[1]s) SetLastContact(d time.Duration) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return
|
|
}
|
|
msg.%[2]s.SetLastContact(d)
|
|
}
|
|
|
|
// SetKnownLeader is required to implement blockingQueryResponseMeta
|
|
func (msg *%[1]s) SetKnownLeader(b bool) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return
|
|
}
|
|
msg.%[2]s.SetKnownLeader(b)
|
|
}
|
|
|
|
// GetIndex is required to implement blockingQueryResponseMeta
|
|
func (msg *%[1]s) GetIndex() uint64 {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return 0
|
|
}
|
|
return msg.%[2]s.GetIndex()
|
|
}
|
|
|
|
// SetIndex is required to implement blockingQueryResponseMeta
|
|
func (msg *%[1]s) SetIndex(i uint64) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return
|
|
}
|
|
msg.%[2]s.SetIndex(i)
|
|
}
|
|
|
|
// SetResultsFilteredByACLs is required to implement blockingQueryResponseMeta
|
|
func (msg *%[1]s) SetResultsFilteredByACLs(b bool) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return
|
|
}
|
|
msg.%[2]s.SetResultsFilteredByACLs(b)
|
|
}
|
|
`
|