mirror of https://github.com/hashicorp/consul
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
581 lines
13 KiB
581 lines
13 KiB
// Copyright (c) HashiCorp, Inc. |
|
// SPDX-License-Identifier: BUSL-1.1 |
|
|
|
package main |
|
|
|
import ( |
|
"bytes" |
|
"errors" |
|
"flag" |
|
"fmt" |
|
"go/ast" |
|
"go/parser" |
|
"go/token" |
|
"log" |
|
"os" |
|
"os/exec" |
|
"strings" |
|
) |
|
|
|
var ( |
|
flagPath = flag.String("path", "", "path of file to load") |
|
verbose = flag.Bool("v", false, "verbose output") |
|
) |
|
|
|
const ( |
|
annotationPrefix = "@consul-rpc-glue:" |
|
outputFileSuffix = ".rpcglue.pb.go" |
|
) |
|
|
|
func main() { |
|
flag.Parse() |
|
|
|
log.SetFlags(0) |
|
|
|
if *flagPath == "" { |
|
log.Fatal("missing required -path argument") |
|
} |
|
|
|
if err := run(*flagPath); err != nil { |
|
log.Fatal(err) |
|
} |
|
} |
|
|
|
func run(path string) error { |
|
fi, err := os.Stat(path) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
if fi.IsDir() { |
|
return fmt.Errorf("argument must be a file: %s", path) |
|
} |
|
|
|
if !strings.HasSuffix(path, ".pb.go") { |
|
return fmt.Errorf("file must end with .pb.go: %s", path) |
|
} |
|
|
|
if err := processFile(path); err != nil { |
|
return fmt.Errorf("error processing file %q: %v", path, err) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func processFile(path string) error { |
|
if *verbose { |
|
log.Printf("visiting file %q", path) |
|
} |
|
|
|
fset := token.NewFileSet() |
|
tree, err := parser.ParseFile(fset, path, nil, parser.ParseComments) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
v := visitor{} |
|
ast.Walk(&v, tree) |
|
if err := v.Err(); err != nil { |
|
return err |
|
} |
|
|
|
if len(v.Types) == 0 { |
|
return nil |
|
} |
|
|
|
if *verbose { |
|
log.Printf("Package: %s", v.Package) |
|
log.Printf("BuildTags: %v", v.BuildTags) |
|
log.Println() |
|
for _, typ := range v.Types { |
|
log.Printf("Type: %s", typ.Name) |
|
ann := typ.Annotation |
|
if ann.ReadRequest != "" { |
|
log.Printf(" ReadRequest from %s", ann.ReadRequest) |
|
} |
|
if ann.WriteRequest != "" { |
|
log.Printf(" WriteRequest from %s", ann.WriteRequest) |
|
} |
|
if ann.TargetDatacenter != "" { |
|
log.Printf(" TargetDatacenter from %s", ann.TargetDatacenter) |
|
} |
|
if ann.QueryOptions != "" { |
|
log.Printf(" QueryOptions from %s", ann.QueryOptions) |
|
} |
|
if ann.QueryMeta != "" { |
|
log.Printf(" QueryMeta from %s", ann.QueryMeta) |
|
} |
|
if ann.Datacenter != "" { |
|
log.Printf(" Datacenter from %s", ann.Datacenter) |
|
} |
|
} |
|
} |
|
|
|
// generate output |
|
|
|
var buf bytes.Buffer |
|
|
|
if len(v.BuildTags) > 0 { |
|
for _, line := range v.BuildTags { |
|
buf.WriteString(line + "\n") |
|
} |
|
buf.WriteString("\n") |
|
} |
|
buf.WriteString("// Code generated by proto-gen-rpc-glue. DO NOT EDIT.\n\n") |
|
buf.WriteString("package " + v.Package + "\n") |
|
buf.WriteString(` |
|
import ( |
|
"time" |
|
|
|
"github.com/hashicorp/consul/agent/structs" |
|
) |
|
|
|
// Reference imports to suppress errors if they are not otherwise used. |
|
var _ structs.RPCInfo |
|
var _ time.Month |
|
|
|
`) |
|
for _, typ := range v.Types { |
|
if typ.Annotation.WriteRequest != "" { |
|
buf.WriteString(fmt.Sprintf(tmplWriteRequest, typ.Name, typ.Annotation.WriteRequest)) |
|
} |
|
if typ.Annotation.ReadRequest != "" { |
|
buf.WriteString(fmt.Sprintf(tmplReadRequest, typ.Name, typ.Annotation.ReadRequest)) |
|
} |
|
if typ.Annotation.TargetDatacenter != "" { |
|
buf.WriteString(fmt.Sprintf(tmplTargetDatacenter, typ.Name, typ.Annotation.TargetDatacenter)) |
|
} |
|
if typ.Annotation.QueryOptions != "" { |
|
buf.WriteString(fmt.Sprintf(tmplQueryOptions, typ.Name, typ.Annotation.QueryOptions)) |
|
} |
|
if typ.Annotation.QueryMeta != "" { |
|
buf.WriteString(fmt.Sprintf(tmplQueryMeta, typ.Name, typ.Annotation.QueryMeta)) |
|
} |
|
if typ.Annotation.Datacenter != "" { |
|
buf.WriteString(fmt.Sprintf(tmplDatacenter, typ.Name, typ.Annotation.Datacenter)) |
|
} |
|
} |
|
|
|
// write to disk |
|
outFile := strings.TrimSuffix(path, ".pb.go") + outputFileSuffix |
|
if err := os.WriteFile(outFile, buf.Bytes(), 0644); err != nil { |
|
return err |
|
} |
|
|
|
// clean up |
|
cmd := exec.Command("gofmt", "-s", "-w", outFile) |
|
cmd.Stdout = nil |
|
cmd.Stderr = os.Stderr |
|
cmd.Stdin = nil |
|
if err := cmd.Run(); err != nil { |
|
return fmt.Errorf("error running 'gofmt -s -w %q': %v", outFile, err) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
type TypeInfo struct { |
|
Name string |
|
Annotation Annotation |
|
} |
|
|
|
type visitor struct { |
|
Package string |
|
BuildTags []string |
|
Types []TypeInfo |
|
Errs []error |
|
} |
|
|
|
func (v *visitor) Err() error { |
|
switch len(v.Errs) { |
|
case 0: |
|
return nil |
|
case 1: |
|
return v.Errs[0] |
|
default: |
|
// |
|
var s []string |
|
for _, e := range v.Errs { |
|
s = append(s, e.Error()) |
|
} |
|
return errors.New(strings.Join(s, "; ")) |
|
} |
|
} |
|
|
|
var _ ast.Visitor = (*visitor)(nil) |
|
|
|
func (v *visitor) Visit(node ast.Node) ast.Visitor { |
|
if node == nil { |
|
return v |
|
} |
|
|
|
switch x := node.(type) { |
|
case *ast.File: |
|
v.Package = x.Name.Name |
|
v.BuildTags = getRawBuildTags(x) |
|
for _, d := range x.Decls { |
|
gd, ok := d.(*ast.GenDecl) |
|
if !ok { |
|
continue |
|
} |
|
|
|
if gd.Doc == nil { |
|
continue |
|
} else if len(gd.Specs) != 1 { |
|
continue |
|
} |
|
spec := gd.Specs[0] |
|
|
|
typeSpec, ok := spec.(*ast.TypeSpec) |
|
if !ok { |
|
continue |
|
} |
|
|
|
ann, err := getAnnotation(gd.Doc.List) |
|
if err != nil { |
|
v.Errs = append(v.Errs, err) |
|
continue |
|
} else if ann.IsZero() { |
|
continue |
|
} |
|
|
|
v.Types = append(v.Types, TypeInfo{ |
|
Name: typeSpec.Name.Name, |
|
Annotation: ann, |
|
}) |
|
|
|
} |
|
} |
|
return v |
|
} |
|
|
|
type Annotation struct { |
|
QueryMeta string |
|
QueryOptions string |
|
ReadRequest string |
|
WriteRequest string |
|
TargetDatacenter string |
|
Datacenter string |
|
ReadTODO string |
|
LeaderReadTODO string |
|
WriteTODO string |
|
} |
|
|
|
func (a Annotation) IsZero() bool { |
|
return a == Annotation{} |
|
} |
|
|
|
func getAnnotation(doc []*ast.Comment) (Annotation, error) { |
|
raw, ok := getRawStructAnnotation(doc) |
|
if !ok { |
|
return Annotation{}, nil |
|
} |
|
|
|
var ann Annotation |
|
|
|
parts := strings.Split(raw, ",") |
|
for _, part := range parts { |
|
part = strings.TrimSpace(part) |
|
switch { |
|
case part == "ReadRequest": |
|
ann.ReadRequest = "ReadRequest" |
|
case strings.HasPrefix(part, "ReadRequest="): |
|
ann.ReadRequest = strings.TrimPrefix(part, "ReadRequest=") |
|
|
|
case part == "WriteRequest": |
|
ann.WriteRequest = "WriteRequest" |
|
case strings.HasPrefix(part, "WriteRequest="): |
|
ann.WriteRequest = strings.TrimPrefix(part, "WriteRequest=") |
|
|
|
case part == "TargetDatacenter": |
|
ann.TargetDatacenter = "TargetDatacenter" |
|
case strings.HasPrefix(part, "TargetDatacenter="): |
|
ann.TargetDatacenter = strings.TrimPrefix(part, "TargetDatacenter=") |
|
|
|
case part == "QueryOptions": |
|
ann.QueryOptions = "QueryOptions" |
|
case strings.HasPrefix(part, "QueryOptions="): |
|
ann.QueryOptions = strings.TrimPrefix(part, "QueryOptions=") |
|
|
|
case part == "QueryMeta": |
|
ann.QueryMeta = "QueryMeta" |
|
case strings.HasPrefix(part, "QueryMeta="): |
|
ann.QueryMeta = strings.TrimPrefix(part, "QueryMeta=") |
|
|
|
case part == "Datacenter": |
|
ann.Datacenter = "Datacenter" |
|
case strings.HasPrefix(part, "Datacenter="): |
|
ann.Datacenter = strings.TrimPrefix(part, "Datacenter=") |
|
|
|
default: |
|
return Annotation{}, fmt.Errorf("unexpected annotation part: %s", part) |
|
} |
|
} |
|
|
|
return ann, nil |
|
} |
|
|
|
func getRawStructAnnotation(doc []*ast.Comment) (string, bool) { |
|
for _, line := range doc { |
|
text := strings.TrimSpace(strings.TrimLeft(line.Text, "/")) |
|
|
|
ann := strings.TrimSpace(strings.TrimPrefix(text, annotationPrefix)) |
|
|
|
if text != ann { |
|
return ann, true |
|
} |
|
} |
|
return "", false |
|
} |
|
|
|
func getRawBuildTags(file *ast.File) []string { |
|
// build tags are always the first group, at the very top |
|
if len(file.Comments) == 0 { |
|
return nil |
|
} |
|
cg := file.Comments[0] |
|
|
|
var out []string |
|
for _, line := range cg.List { |
|
text := strings.TrimSpace(strings.TrimLeft(line.Text, "/")) |
|
|
|
if !strings.HasPrefix(text, "go:build ") && !strings.HasPrefix(text, "+build") { |
|
break // stop at first non-build-tag |
|
} |
|
|
|
out = append(out, line.Text) |
|
} |
|
|
|
return out |
|
} |
|
|
|
const tmplWriteRequest = ` |
|
// AllowStaleRead implements structs.RPCInfo |
|
func (msg *%[1]s) AllowStaleRead() bool { |
|
return false |
|
} |
|
|
|
// HasTimedOut implements structs.RPCInfo |
|
func (msg *%[1]s) HasTimedOut(start time.Time, rpcHoldTimeout time.Duration, a time.Duration, b time.Duration) (bool, error) { |
|
if msg == nil || msg.%[2]s == nil { |
|
return false, nil |
|
} |
|
return msg.%[2]s.HasTimedOut(start, rpcHoldTimeout, a, b) |
|
} |
|
|
|
// IsRead implements structs.RPCInfo |
|
func (msg *%[1]s) IsRead() bool { |
|
return false |
|
} |
|
|
|
// SetTokenSecret implements structs.RPCInfo |
|
func (msg *%[1]s) SetTokenSecret(s string) { |
|
// TODO: initialize if nil |
|
msg.%[2]s.SetTokenSecret(s) |
|
} |
|
|
|
// TokenSecret implements structs.RPCInfo |
|
func (msg *%[1]s) TokenSecret() string { |
|
if msg == nil || msg.%[2]s == nil { |
|
return "" |
|
} |
|
return msg.%[2]s.TokenSecret() |
|
} |
|
|
|
// Token implements structs.RPCInfo |
|
func (msg *%[1]s) Token() string { |
|
if msg.%[2]s == nil { |
|
return "" |
|
} |
|
return msg.%[2]s.Token |
|
} |
|
` |
|
|
|
const tmplReadRequest = ` |
|
// IsRead implements structs.RPCInfo |
|
func (msg *%[1]s) IsRead() bool { |
|
return true |
|
} |
|
|
|
// AllowStaleRead implements structs.RPCInfo |
|
func (msg *%[1]s) AllowStaleRead() bool { |
|
// TODO: initialize if nil |
|
return msg.%[2]s.AllowStaleRead() |
|
} |
|
|
|
// HasTimedOut implements structs.RPCInfo |
|
func (msg *%[1]s) HasTimedOut(start time.Time, rpcHoldTimeout time.Duration, a time.Duration, b time.Duration) (bool, error) { |
|
if msg == nil || msg.%[2]s == nil { |
|
return false, nil |
|
} |
|
return msg.%[2]s.HasTimedOut(start, rpcHoldTimeout, a, b) |
|
} |
|
|
|
// SetTokenSecret implements structs.RPCInfo |
|
func (msg *%[1]s) SetTokenSecret(s string) { |
|
// TODO: initialize if nil |
|
msg.%[2]s.SetTokenSecret(s) |
|
} |
|
|
|
// TokenSecret implements structs.RPCInfo |
|
func (msg *%[1]s) TokenSecret() string { |
|
if msg == nil || msg.%[2]s == nil { |
|
return "" |
|
} |
|
return msg.%[2]s.TokenSecret() |
|
} |
|
|
|
// Token implements structs.RPCInfo |
|
func (msg *%[1]s) Token() string { |
|
if msg.%[2]s == nil { |
|
return "" |
|
} |
|
return msg.%[2]s.Token |
|
} |
|
` |
|
|
|
const tmplTargetDatacenter = ` |
|
// RequestDatacenter implements structs.RPCInfo |
|
func (msg *%[1]s) RequestDatacenter() string { |
|
if msg == nil || msg.%[2]s == nil { |
|
return "" |
|
} |
|
return msg.%[2]s.GetDatacenter() |
|
} |
|
` |
|
|
|
const tmplDatacenter = ` |
|
// RequestDatacenter implements structs.RPCInfo |
|
func (msg *%[1]s) RequestDatacenter() string { |
|
if msg == nil { |
|
return "" |
|
} |
|
return msg.Datacenter |
|
} |
|
` |
|
|
|
const tmplQueryOptions = ` |
|
// IsRead implements structs.RPCInfo |
|
func (msg *%[1]s) IsRead() bool { |
|
return true |
|
} |
|
|
|
// AllowStaleRead implements structs.RPCInfo |
|
func (msg *%[1]s) AllowStaleRead() bool { |
|
return msg.%[2]s.AllowStaleRead() |
|
} |
|
|
|
// BlockingTimeout implements pool.BlockableQuery |
|
func (msg *%[1]s) BlockingTimeout(maxQueryTime, defaultQueryTime time.Duration) time.Duration { |
|
maxTime := structs.DurationFromProto(msg.%[2]s.GetMaxQueryTime()) |
|
o := structs.QueryOptions{ |
|
MaxQueryTime: maxTime, |
|
MinQueryIndex: msg.%[2]s.GetMinQueryIndex(), |
|
} |
|
return o.BlockingTimeout(maxQueryTime, defaultQueryTime) |
|
} |
|
|
|
// HasTimedOut implements structs.RPCInfo |
|
func (msg *%[1]s) HasTimedOut(start time.Time, rpcHoldTimeout time.Duration, a time.Duration, b time.Duration) (bool, error) { |
|
if msg == nil || msg.%[2]s == nil { |
|
return false, nil |
|
} |
|
return msg.%[2]s.HasTimedOut(start, rpcHoldTimeout, a, b) |
|
} |
|
|
|
// SetTokenSecret implements structs.RPCInfo |
|
func (msg *%[1]s) SetTokenSecret(s string) { |
|
// TODO: initialize if nil |
|
msg.%[2]s.SetTokenSecret(s) |
|
} |
|
|
|
// TokenSecret implements structs.RPCInfo |
|
func (msg *%[1]s) TokenSecret() string { |
|
if msg == nil || msg.%[2]s == nil { |
|
return "" |
|
} |
|
return msg.%[2]s.TokenSecret() |
|
} |
|
|
|
// Token implements structs.RPCInfo |
|
func (msg *%[1]s) Token() string { |
|
if msg.%[2]s == nil { |
|
return "" |
|
} |
|
return msg.%[2]s.Token |
|
} |
|
// GetToken is required to implement blockingQueryOptions |
|
func (msg *%[1]s) GetToken() string { |
|
if msg == nil || msg.%[2]s == nil { |
|
return "" |
|
} |
|
|
|
return msg.%[2]s.GetToken() |
|
} |
|
// GetMinQueryIndex is required to implement blockingQueryOptions |
|
func (msg *%[1]s) GetMinQueryIndex() uint64 { |
|
if msg == nil || msg.%[2]s == nil { |
|
return 0 |
|
} |
|
|
|
return msg.%[2]s.GetMinQueryIndex() |
|
} |
|
// GetMaxQueryTime is required to implement blockingQueryOptions |
|
func (msg *%[1]s) GetMaxQueryTime() (time.Duration, error) { |
|
if msg == nil || msg.%[2]s == nil { |
|
return 0, nil |
|
} |
|
|
|
return structs.DurationFromProto(msg.%[2]s.GetMaxQueryTime()), nil |
|
} |
|
|
|
// GetRequireConsistent is required to implement blockingQueryOptions |
|
func (msg *%[1]s) GetRequireConsistent() bool { |
|
if msg == nil || msg.%[2]s == nil { |
|
return false |
|
} |
|
return msg.%[2]s.RequireConsistent |
|
} |
|
` |
|
|
|
const tmplQueryMeta = ` |
|
// SetLastContact is required to implement blockingQueryResponseMeta |
|
func (msg *%[1]s) SetLastContact(d time.Duration) { |
|
if msg == nil || msg.%[2]s == nil { |
|
return |
|
} |
|
msg.%[2]s.SetLastContact(d) |
|
} |
|
|
|
// SetKnownLeader is required to implement blockingQueryResponseMeta |
|
func (msg *%[1]s) SetKnownLeader(b bool) { |
|
if msg == nil || msg.%[2]s == nil { |
|
return |
|
} |
|
msg.%[2]s.SetKnownLeader(b) |
|
} |
|
|
|
// GetIndex is required to implement blockingQueryResponseMeta |
|
func (msg *%[1]s) GetIndex() uint64 { |
|
if msg == nil || msg.%[2]s == nil { |
|
return 0 |
|
} |
|
return msg.%[2]s.GetIndex() |
|
} |
|
|
|
// SetIndex is required to implement blockingQueryResponseMeta |
|
func (msg *%[1]s) SetIndex(i uint64) { |
|
if msg == nil || msg.%[2]s == nil { |
|
return |
|
} |
|
msg.%[2]s.SetIndex(i) |
|
} |
|
|
|
// SetResultsFilteredByACLs is required to implement blockingQueryResponseMeta |
|
func (msg *%[1]s) SetResultsFilteredByACLs(b bool) { |
|
if msg == nil || msg.%[2]s == nil { |
|
return |
|
} |
|
msg.%[2]s.SetResultsFilteredByACLs(b) |
|
} |
|
`
|
|
|