diff --git a/command/agent/config.go b/command/agent/config.go index 6a80ea46f4..4d038df892 100644 --- a/command/agent/config.go +++ b/command/agent/config.go @@ -217,6 +217,9 @@ func DefaultConfig() *Config { SerfWan: consul.DefaultWANSerfPort, Server: 8300, }, + DNSConfig: DNSConfig{ + MaxStale: 5 * time.Second, + }, Protocol: consul.ProtocolVersionMax, AEInterval: time.Minute, } @@ -276,6 +279,36 @@ func DecodeConfig(r io.Reader) (*Config, error) { return nil, err } + // Handle time conversions + if raw := result.DNSConfig.NodeTTLRaw; raw != "" { + dur, err := time.ParseDuration(raw) + if err != nil { + return nil, fmt.Errorf("NodeTTL invalid: %v", err) + } + result.DNSConfig.NodeTTL = dur + } + + if raw := result.DNSConfig.MaxStaleRaw; raw != "" { + dur, err := time.ParseDuration(raw) + if err != nil { + return nil, fmt.Errorf("MaxStale invalid: %v", err) + } + result.DNSConfig.MaxStale = dur + } + + if len(result.DNSConfig.ServiceTTLRaw) != 0 { + if result.DNSConfig.ServiceTTL == nil { + result.DNSConfig.ServiceTTL = make(map[string]time.Duration) + } + for service, raw := range result.DNSConfig.ServiceTTLRaw { + dur, err := time.ParseDuration(raw) + if err != nil { + return nil, fmt.Errorf("ServiceTTL %s invalid: %v", service, err) + } + result.DNSConfig.ServiceTTL[service] = dur + } + } + return &result, nil } @@ -486,6 +519,23 @@ func MergeConfig(a, b *Config) *Config { if b.RejoinAfterLeave { result.RejoinAfterLeave = true } + if b.DNSConfig.NodeTTL != 0 { + result.DNSConfig.NodeTTL = b.DNSConfig.NodeTTL + } + if len(b.DNSConfig.ServiceTTL) != 0 { + if result.DNSConfig.ServiceTTL == nil { + result.DNSConfig.ServiceTTL = make(map[string]time.Duration) + } + for service, dur := range b.DNSConfig.ServiceTTL { + result.DNSConfig.ServiceTTL[service] = dur + } + } + if b.DNSConfig.AllowStale { + result.DNSConfig.AllowStale = true + } + if b.DNSConfig.MaxStale != 0 { + result.DNSConfig.MaxStale = b.DNSConfig.MaxStale + } // Copy the start join addresses result.StartJoin = make([]string, 0, len(a.StartJoin)+len(b.StartJoin)) diff --git a/command/agent/config_test.go b/command/agent/config_test.go index 288e6fd196..b4ab10bbcd 100644 --- a/command/agent/config_test.go +++ b/command/agent/config_test.go @@ -290,6 +290,40 @@ func TestDecodeConfig(t *testing.T) { if !config.RejoinAfterLeave { t.Fatalf("bad: %#v", config) } + + // DNS node ttl, max stale + input = `{"dns_config": {"node_ttl": "5s", "max_stale": "15s", "allow_stale": true}}` + config, err = DecodeConfig(bytes.NewReader([]byte(input))) + if err != nil { + t.Fatalf("err: %s", err) + } + + if config.DNSConfig.NodeTTL != 5*time.Second { + t.Fatalf("bad: %#v", config) + } + if config.DNSConfig.MaxStale != 15*time.Second { + t.Fatalf("bad: %#v", config) + } + if !config.DNSConfig.AllowStale { + t.Fatalf("bad: %#v", config) + } + + // DNS service ttl + input = `{"dns_config": {"service_ttl": {"*": "1s", "api": "10s", "web": "30s"}}}` + config, err = DecodeConfig(bytes.NewReader([]byte(input))) + if err != nil { + t.Fatalf("err: %s", err) + } + + if config.DNSConfig.ServiceTTL["*"] != time.Second { + t.Fatalf("bad: %#v", config) + } + if config.DNSConfig.ServiceTTL["api"] != 10*time.Second { + t.Fatalf("bad: %#v", config) + } + if config.DNSConfig.ServiceTTL["web"] != 30*time.Second { + t.Fatalf("bad: %#v", config) + } } func TestDecodeConfig_Service(t *testing.T) {