consul/internal/tools/protoc-gen-consul-rate-limit/postprocess/main.go

196 lines
4.5 KiB
Go

package main
import (
"bytes"
"encoding/json"
"errors"
"flag"
"fmt"
"go/format"
"os"
"path/filepath"
"sort"
"strings"
)
const (
usage = "Usage: %s -input=/proto-dir-1 -input=/proto-dir-2 -output=/mappings.go\n"
fileHeader = `// generated by protoc-gen-consul-rate-limit; DO NOT EDIT.
package middleware
import "github.com/hashicorp/consul/agent/consul/rate"
`
entTags = `//go:build consulent
// +build consulent`
)
func main() {
var (
inputPaths sliceFlags
outputPath string
)
flag.Var(&inputPaths, "input", "")
flag.StringVar(&outputPath, "output", "", "")
flag.Parse()
if len(inputPaths) == 0 || outputPath == "" {
fmt.Fprintf(os.Stderr, usage, os.Args[0])
os.Exit(1)
}
if err := run(inputPaths, outputPath); err != nil {
fmt.Fprintf(os.Stderr, "ERROR: %v\n", err)
os.Exit(1)
}
}
func run(inputPaths []string, outputPath string) error {
if !strings.HasSuffix(outputPath, ".go") {
return errors.New("-output path must end in .go")
}
oss, ent, err := collectSpecs(inputPaths)
if err != nil {
return err
}
ossSource, err := generateOSS(oss)
if err != nil {
return err
}
if err := os.WriteFile(outputPath, ossSource, 0666); err != nil {
return fmt.Errorf("failed to write output file: %s - %w", outputPath, err)
}
// ent should only be non-zero in the enterprise repository.
if len(ent) > 0 {
entSource, err := generateENT(ent)
if err != nil {
return err
}
if err := os.WriteFile(enterpriseFileName(outputPath), entSource, 0666); err != nil {
return fmt.Errorf("failed to write output file: %s - %w", outputPath, err)
}
}
return nil
}
// enterpriseFileName adds the _ent filename suffix before the extension.
//
// Example:
// enterpriseFileName("bar/baz/foo.gen.go") => "bar/baz/foo_ent.gen.go"
func enterpriseFileName(filename string) string {
fileName := filepath.Base(filename)
extStart := strings.Index(fileName, ".")
return filepath.Join(
filepath.Dir(filename),
fileName[0:extStart]+"_ent"+fileName[extStart:],
)
}
type spec struct {
MethodName string
OperationType string
Enterprise bool
}
func (s spec) GoOperationType() string {
switch s.OperationType {
case "OPERATION_TYPE_WRITE":
return "rate.OperationTypeWrite"
case "OPERATION_TYPE_READ":
return "rate.OperationTypeRead"
case "OPERATION_TYPE_EXEMPT":
return "rate.OperationTypeExempt"
}
panic(fmt.Sprintf("unknown rate limit operation type: %s", s.OperationType))
}
func collectSpecs(inputPaths []string) ([]spec, []spec, error) {
var specs []spec
for _, protoPath := range inputPaths {
specFiles, err := filepath.Glob(filepath.Join(protoPath, "*", ".ratelimit.tmp"))
if err != nil {
return nil, nil, fmt.Errorf("failed to glob directory: %s - %s", protoPath, err)
}
for _, file := range specFiles {
b, err := os.ReadFile(file)
if err != nil {
return nil, nil, fmt.Errorf("failed to read ratelimit file: %w", err)
}
var fileSpecs []spec
if err := json.Unmarshal(b, &fileSpecs); err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal ratelimit file %s - %w", file, err)
}
specs = append(specs, fileSpecs...)
}
}
sort.Slice(specs, func(a, b int) bool {
return specs[a].MethodName < specs[b].MethodName
})
var oss, ent []spec
for _, spec := range specs {
if spec.Enterprise {
ent = append(ent, spec)
} else {
oss = append(oss, spec)
}
}
return oss, ent, nil
}
func generateOSS(specs []spec) ([]byte, error) {
var output bytes.Buffer
output.WriteString(fileHeader)
fmt.Fprintln(&output, `var rpcRateLimitSpecs = map[string]rate.OperationType{`)
for _, spec := range specs {
fmt.Fprintf(&output, `"%s": %s,`, spec.MethodName, spec.GoOperationType())
output.WriteString("\n")
}
output.WriteString("}")
formatted, err := format.Source(output.Bytes())
if err != nil {
return nil, fmt.Errorf("failed to format source: %w", err)
}
return formatted, nil
}
func generateENT(specs []spec) ([]byte, error) {
var output bytes.Buffer
output.WriteString(entTags)
output.WriteString(fileHeader)
output.WriteString("func init() {\n")
for _, spec := range specs {
fmt.Fprintf(&output, `rpcRateLimitSpecs["%s"] = %s`, spec.MethodName, spec.GoOperationType())
output.WriteString("\n")
}
output.WriteString("}")
formatted, err := format.Source(output.Bytes())
if err != nil {
return nil, fmt.Errorf("failed to format source: %w", err)
}
return formatted, nil
}
type sliceFlags []string
func (i *sliceFlags) Set(value string) error {
*i = append(*i, value)
return nil
}
func (i *sliceFlags) String() string { return "" }