You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
consul/command/flags/config.go

314 lines
7.1 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package flags
import (
"fmt"
"os"
"path/filepath"
"reflect"
"sort"
"strconv"
"time"
"github.com/mitchellh/mapstructure"
)
// TODO (slackpad) - Trying out a different pattern here for config handling.
// These classes support the flag.Value interface but work in a manner where
// we can tell if they have been set. This lets us work with an all-pointer
// config structure and merge it in a clean-ish way. If this ends up being a
// good pattern we should pull this out into a reusable library.
// ConfigDecodeHook should be passed to mapstructure in order to decode into
// the *Value objects here.
var ConfigDecodeHook = mapstructure.ComposeDecodeHookFunc(
BoolToBoolValueFunc(),
StringToDurationValueFunc(),
StringToStringValueFunc(),
Float64ToUintValueFunc(),
)
// BoolValue provides a flag value that's aware if it has been set.
type BoolValue struct {
v *bool
}
// IsBoolFlag is an optional method of the flag.Value
// interface which marks this value as boolean when
// the return value is true. See flag.Value for details.
func (b *BoolValue) IsBoolFlag() bool {
return true
}
// Merge will overlay this value if it has been set.
func (b *BoolValue) Merge(onto *bool) {
if b.v != nil {
*onto = *(b.v)
}
}
// Set implements the flag.Value interface.
func (b *BoolValue) Set(v string) error {
if b.v == nil {
b.v = new(bool)
}
var err error
*(b.v), err = strconv.ParseBool(v)
return err
}
// String implements the flag.Value interface.
func (b *BoolValue) String() string {
var current bool
if b.v != nil {
current = *(b.v)
}
return fmt.Sprintf("%v", current)
}
// BoolToBoolValueFunc is a mapstructure hook that looks for an incoming bool
// mapped to a BoolValue and does the translation.
func BoolToBoolValueFunc() mapstructure.DecodeHookFunc {
return func(
f reflect.Type,
t reflect.Type,
data interface{}) (interface{}, error) {
if f.Kind() != reflect.Bool {
return data, nil
}
val := BoolValue{}
if t != reflect.TypeOf(val) {
return data, nil
}
val.v = new(bool)
*(val.v) = data.(bool)
return val, nil
}
}
// DurationValue provides a flag value that's aware if it has been set.
type DurationValue struct {
v *time.Duration
}
// Merge will overlay this value if it has been set.
func (d *DurationValue) Merge(onto *time.Duration) {
if d.v != nil {
*onto = *(d.v)
}
}
// Set implements the flag.Value interface.
func (d *DurationValue) Set(v string) error {
if d.v == nil {
d.v = new(time.Duration)
}
var err error
*(d.v), err = time.ParseDuration(v)
return err
}
// String implements the flag.Value interface.
func (d *DurationValue) String() string {
var current time.Duration
if d.v != nil {
current = *(d.v)
}
return current.String()
}
// StringToDurationValueFunc is a mapstructure hook that looks for an incoming
// string mapped to a DurationValue and does the translation.
func StringToDurationValueFunc() mapstructure.DecodeHookFunc {
return func(
f reflect.Type,
t reflect.Type,
data interface{}) (interface{}, error) {
if f.Kind() != reflect.String {
return data, nil
}
val := DurationValue{}
if t != reflect.TypeOf(val) {
return data, nil
}
if err := val.Set(data.(string)); err != nil {
return nil, err
}
return val, nil
}
}
// StringValue provides a flag value that's aware if it has been set.
type StringValue struct {
v *string
}
// Merge will overlay this value if it has been set.
func (s *StringValue) Merge(onto *string) {
if s.v != nil {
*onto = *(s.v)
}
}
// Set implements the flag.Value interface.
func (s *StringValue) Set(v string) error {
if s.v == nil {
s.v = new(string)
}
*(s.v) = v
return nil
}
// String implements the flag.Value interface.
func (s *StringValue) String() string {
var current string
if s.v != nil {
current = *(s.v)
}
return current
}
// StringToStringValueFunc is a mapstructure hook that looks for an incoming
// string mapped to a StringValue and does the translation.
func StringToStringValueFunc() mapstructure.DecodeHookFunc {
return func(
f reflect.Type,
t reflect.Type,
data interface{}) (interface{}, error) {
if f.Kind() != reflect.String {
return data, nil
}
val := StringValue{}
if t != reflect.TypeOf(val) {
return data, nil
}
val.v = new(string)
*(val.v) = data.(string)
return val, nil
}
}
// UintValue provides a flag value that's aware if it has been set.
type UintValue struct {
v *uint
}
// Merge will overlay this value if it has been set.
func (u *UintValue) Merge(onto *uint) {
if u.v != nil {
*onto = *(u.v)
}
}
// Set implements the flag.Value interface.
func (u *UintValue) Set(v string) error {
if u.v == nil {
u.v = new(uint)
}
parsed, err := strconv.ParseUint(v, 0, 64)
*(u.v) = (uint)(parsed)
return err
}
// String implements the flag.Value interface.
func (u *UintValue) String() string {
var current uint
if u.v != nil {
current = *(u.v)
}
return fmt.Sprintf("%v", current)
}
// Float64ToUintValueFunc is a mapstructure hook that looks for an incoming
// float64 mapped to a UintValue and does the translation.
func Float64ToUintValueFunc() mapstructure.DecodeHookFunc {
return func(
f reflect.Type,
t reflect.Type,
data interface{}) (interface{}, error) {
if f.Kind() != reflect.Float64 {
return data, nil
}
val := UintValue{}
if t != reflect.TypeOf(val) {
return data, nil
}
fv := data.(float64)
if fv < 0 {
return nil, fmt.Errorf("value cannot be negative")
}
// The standard guarantees at least this, and this is fine for
// values we expect to use in configs vs. being fancy with the
// machine's size for uint.
if fv > (1<<32 - 1) {
return nil, fmt.Errorf("value is too large")
}
val.v = new(uint)
*(val.v) = (uint)(fv)
return val, nil
}
}
// VisitFn is a callback that gets a chance to visit each file found during a
// traversal with visit().
type VisitFn func(path string) error
// Visit will call the visitor function on the path if it's a file, or for each
// file in the path if it's a directory. Directories will not be recursed into,
// and files in the directory will be visited in alphabetical order.
func Visit(path string, visitor VisitFn) error {
f, err := os.Open(path)
if err != nil {
return fmt.Errorf("error reading %q: %v", path, err)
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return fmt.Errorf("error checking %q: %v", path, err)
}
if !fi.IsDir() {
if err := visitor(path); err != nil {
return fmt.Errorf("error in %q: %v", path, err)
}
return nil
}
contents, err := f.Readdir(-1)
if err != nil {
return fmt.Errorf("error listing %q: %v", path, err)
}
sort.Sort(dirEnts(contents))
for _, fi := range contents {
if fi.IsDir() {
continue
}
fullPath := filepath.Join(path, fi.Name())
if err := visitor(fullPath); err != nil {
return fmt.Errorf("error in %q: %v", fullPath, err)
}
}
return nil
}
// dirEnts applies sort.Interface to directory entries for sorting by name.
type dirEnts []os.FileInfo
func (d dirEnts) Len() int { return len(d) }
func (d dirEnts) Less(i, j int) bool { return d[i].Name() < d[j].Name() }
func (d dirEnts) Swap(i, j int) { d[i], d[j] = d[j], d[i] }