Merge branch 'master' into vault-ca-renew-token

pull/8560/head
Kyle Havlovitz 2020-09-15 14:39:04 -07:00
commit b1b21139ca
761 changed files with 32016 additions and 9497 deletions

3
.changelog/8458.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
connect: Add support for http2 and grpc to ingress gateways
```

3
.changelog/8537.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
api: Fixed a panic caused by an api request with Connect=null
```

3
.changelog/8545.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:feature
agent: expose the list of supported envoy versions on /v1/agent/self
```

3
.changelog/8547.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
agent: ensure that we normalize bootstrapped config entries
```

3
.changelog/8552.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:feature
cache: Config parameters for cache throttling are now reloaded automatically on agent reload. Restarting the agent is not needed anymore.
```

3
.changelog/8569.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:feature
xds: use envoy's rbac filter to handle intentions entirely within envoy
```

11
.changelog/8575.txt Normal file
View File

@ -0,0 +1,11 @@
```release-note:improvement
api: Added constants for common tag keys and values in the `Tags` field of the `AgentMember` struct.
```
```release-note:improvement
api: Added `IsConsulServer` method to the `AgentMember` type to easily determine whether the agent is a server.
```
```release-note:improvement
api: Added `ACLMode` method to the `AgentMember` type to determine what ACL mode the agent is operating in.
```

3
.changelog/8585.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
connect: add support for specifying load balancing policy in service-resolver
```

3
.changelog/8588.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
connect: fix renewing secondary intermediate certificates
```

3
.changelog/8596.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:feature
connect: all config entries pick up a meta field
```

3
.changelog/8601.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
connect: fix bug in preventing some namespaced config entry modifications
```

3
.changelog/8602.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
api: Allow for the client to use TLS over a Unix domain socket.
```

3
.changelog/8603.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:feature
telemetry: track node and service counts and emit them as metrics
```

3
.changelog/8606.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
connect: `connect envoy` command now respects the `-ca-path` flag
```

3
.changelog/8646.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
connect: fix Vault provider not respecting IntermediateCertTTL
```

3
.changelog/_8621.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
snapshot agent: Deregister critical snapshotting TTL check if leadership is transferred.
```

View File

@ -19,7 +19,7 @@ references:
EMAIL: noreply@hashicorp.com
GIT_AUTHOR_NAME: circleci-consul
GIT_COMMITTER_NAME: circleci-consul
S3_ARTIFACT_BUCKET: consul-dev-artifacts
S3_ARTIFACT_BUCKET: consul-dev-artifacts-v2
BASH_ENV: .circleci/bash_env.sh
VAULT_BINARY_VERSION: 1.2.2
@ -33,6 +33,27 @@ steps:
curl -sSL "${url}/v${GOTESTSUM_RELEASE}/gotestsum_${GOTESTSUM_RELEASE}_linux_amd64.tar.gz" | \
sudo tar -xz --overwrite -C /usr/local/bin gotestsum
get-aws-cli: &get-aws-cli
run:
name: download and install AWS CLI
command: |
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
echo -e "${AWS_CLI_GPG_KEY}" | gpg --import
curl -o awscliv2.sig https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip.sig
gpg --verify awscliv2.sig awscliv2.zip
unzip awscliv2.zip
sudo ./aws/install
aws-assume-role: &aws-assume-role
run:
name: assume-role aws creds
command: |
# assume role has duration of 15 min (the minimum allowed)
CREDENTIALS="$(aws sts assume-role --duration-seconds 900 --role-arn ${ROLE_ARN} --role-session-name build-${CIRCLE_SHA1} | jq '.Credentials')"
echo "export AWS_ACCESS_KEY_ID=$(echo $CREDENTIALS | jq -r '.AccessKeyId')" >> $BASH_ENV
echo "export AWS_SECRET_ACCESS_KEY=$(echo $CREDENTIALS | jq -r '.SecretAccessKey')" >> $BASH_ENV
echo "export AWS_SESSION_TOKEN=$(echo $CREDENTIALS | jq -r '.SessionToken')" >> $BASH_ENV
# This step MUST be at the end of any set of steps due to the 'when' condition
notify-slack-failure: &notify-slack-failure
name: notify-slack-failure
@ -389,13 +410,13 @@ jobs:
# upload development build to s3
dev-upload-s3:
docker:
- image: circleci/python:stretch
- image: *GOLANG_IMAGE
environment:
<<: *ENVIRONMENT
steps:
- run:
name: Install awscli
command: sudo pip install awscli
- checkout
- *get-aws-cli
- *aws-assume-role
# get consul binary
- attach_workspace:
at: bin/

View File

@ -124,7 +124,7 @@ The underlying script dumps the full Consul log output to `test.log` in
the directory of the target package. In the example above it would be
located at `consul/connect/proxy/test.log`.
Historically, the defaults for `FLAKE_CPUS` (30) and `FLAKE_N` (0.15) have been
Historically, the defaults for `FLAKE_CPUS` (0.15) and `FLAKE_N` (30) have been
sufficient to surface a flaky test. If a test is run in this environment and
it does not fail after 30 iterations, it should be sufficiently stable.

View File

@ -1,5 +1,33 @@
## UNRELEASED
## 1.8.4 (September 11, 2020)
FEATURES:
* agent: expose the list of supported envoy versions on /v1/agent/self [[GH-8545](https://github.com/hashicorp/consul/issues/8545)]
* cache: Config parameters for cache throttling are now reloaded automatically on agent reload. Restarting the agent is not needed anymore. [[GH-8552](https://github.com/hashicorp/consul/issues/8552)]
* connect: all config entries pick up a meta field [[GH-8596](https://github.com/hashicorp/consul/issues/8596)]
IMPROVEMENTS:
* api: Added `ACLMode` method to the `AgentMember` type to determine what ACL mode the agent is operating in. [[GH-8575](https://github.com/hashicorp/consul/issues/8575)]
* api: Added `IsConsulServer` method to the `AgentMember` type to easily determine whether the agent is a server. [[GH-8575](https://github.com/hashicorp/consul/issues/8575)]
* api: Added constants for common tag keys and values in the `Tags` field of the `AgentMember` struct. [[GH-8575](https://github.com/hashicorp/consul/issues/8575)]
* api: Allow for the client to use TLS over a Unix domain socket. [[GH-8602](https://github.com/hashicorp/consul/issues/8602)]
* api: `GET v1/operator/keyring` also lists primary keys. [[GH-8522](https://github.com/hashicorp/consul/issues/8522)]
* connect: Add support for http2 and grpc to ingress gateways [[GH-8458](https://github.com/hashicorp/consul/issues/8458)]
* serf: update to `v0.9.4` which supports primary keys in the ListKeys operation. [[GH-8522](https://github.com/hashicorp/consul/issues/8522)]
BUGFIXES:
* connect: use stronger validation that ingress gateways have compatible protocols defined for their upstreams [[GH-8494](https://github.com/hashicorp/consul/issues/8494)]
* agent: ensure that we normalize bootstrapped config entries [[GH-8547](https://github.com/hashicorp/consul/issues/8547)]
* api: Fixed a panic caused by an api request with Connect=null [[GH-8537](https://github.com/hashicorp/consul/issues/8537)]
* connect: `connect envoy` command now respects the `-ca-path` flag [[GH-8606](https://github.com/hashicorp/consul/issues/8606)]
* connect: fix bug in preventing some namespaced config entry modifications [[GH-8601](https://github.com/hashicorp/consul/issues/8601)]
* connect: fix renewing secondary intermediate certificates [[GH-8588](https://github.com/hashicorp/consul/issues/8588)]
* ui: fixed a bug related to in-folder KV creation [[GH-8613](https://github.com/hashicorp/consul/issues/8613)](https://github.com/hashicorp/consul/pull/8613)
## 1.8.3 (August 12, 2020)
BUGFIXES:
@ -116,6 +144,17 @@ BUGFIXES:
* ui: Miscellaneous amends for Safari and Firefox [[GH-7904](https://github.com/hashicorp/consul/issues/7904)] [[GH-7907](https://github.com/hashicorp/consul/pull/7907)]
* ui: Ensure a value is always passed to CONSUL_SSO_ENABLED [[GH-7913](https://github.com/hashicorp/consul/pull/7913)]
## 1.7.8 (September 11, 2020)
FEATURES:
* agent: expose the list of supported envoy versions on /v1/agent/self [[GH-8545](https://github.com/hashicorp/consul/issues/8545)]
BUG FIXES:
* connect: fix bug in preventing some namespaced config entry modifications [[GH-8601](https://github.com/hashicorp/consul/issues/8601)]
* api: fixed a panic caused by an api request with Connect=null [[GH-8537](https://github.com/hashicorp/consul/pull/8537)]
## 1.7.7 (August 12, 2020)
BUGFIXES:
@ -127,7 +166,7 @@ BUGFIXES:
BUG FIXES:
* [backport/1.7.x] xds: revert setting set_node_on_first_message_only to true when generating envoy bootstrap config [[GH-8441](https://github.com/hashicorp/consul/issues/8441)]
* xds: revert setting set_node_on_first_message_only to true when generating envoy bootstrap config [[GH-8441](https://github.com/hashicorp/consul/issues/8441)]
## 1.7.5 (July 30, 2020)
@ -340,6 +379,12 @@ BUGFIXES:
* ui: Discovery-Chain: Improve parsing of redirects [[GH-7174](https://github.com/hashicorp/consul/pull/7174)]
* ui: Fix styling of duplicate intention error message [[GH6936]](https://github.com/hashicorp/consul/pull/6936)
## 1.6.9 (September 11, 2020)
BUG FIXES:
* api: fixed a panic caused by an api request with Connect=null [[GH-8537](https://github.com/hashicorp/consul/pull/8537)]
## 1.6.8 (August 12, 2020)
BUG FIXES:

View File

@ -1,7 +1,7 @@
# Consul [![CircleCI](https://circleci.com/gh/hashicorp/consul/tree/master.svg?style=svg)](https://circleci.com/gh/hashicorp/consul/tree/master) [![Discuss](https://img.shields.io/badge/discuss-consul-ca2171.svg?style=flat)](https://discuss.hashicorp.com/c/consul)
* Website: https://www.consul.io
* Tutorials: [https://learn.hashicorp.com](https://learn.hashicorp.com/consul)
* Tutorials: [HashiCorp Learn](https://learn.hashicorp.com/consul)
* Forum: [Discuss](https://discuss.hashicorp.com/c/consul)
Consul is a distributed, highly available, and data center aware solution to connect and configure applications across dynamic, distributed infrastructure.
@ -10,12 +10,12 @@ Consul provides several key features:
* **Multi-Datacenter** - Consul is built to be datacenter aware, and can
support any number of regions without complex configuration.
* **Service Mesh/Service Segmentation** - Consul Connect enables secure service-to-service
communication with automatic TLS encryption and identity-based authorization. Applications
can use sidecar proxies in a service mesh configuration to establish TLS
connections for inbound and outbound connections without being aware of Connect at all.
communication with automatic TLS encryption and identity-based authorization. Applications
can use sidecar proxies in a service mesh configuration to establish TLS
connections for inbound and outbound connections without being aware of Connect at all.
* **Service Discovery** - Consul makes it simple for services to register
themselves and to discover other services via a DNS or HTTP interface.
External services such as SaaS providers can be registered as well.
@ -41,9 +41,10 @@ contacting us at security@hashicorp.com.
A few quick start guides are available on the Consul website:
* **Standalone binary install:** https://learn.hashicorp.com/consul/getting-started/install
* **Minikube install:** https://learn.hashicorp.com/consul/kubernetes/minikube
* **Kubernetes install:** https://learn.hashicorp.com/consul/kubernetes/kubernetes-deployment-guide
* **Standalone binary install:** https://learn.hashicorp.com/tutorials/consul/get-started-install
* **Minikube install:** https://learn.hashicorp.com/tutorials/consul/kubernetes-minikube
* **Kind install:** https://learn.hashicorp.com/tutorials/consul/kubernetes-kind
* **Kubernetes install:** https://learn.hashicorp.com/tutorials/consul/kubernetes-deployment-guide
## Documentation

View File

@ -184,7 +184,9 @@ func TestACL_AgentMasterToken(t *testing.T) {
t.Parallel()
a := NewTestACLAgent(t, t.Name(), TestACLConfig(), nil, nil)
a.loadTokens(a.config)
err := a.tokens.Load(a.config.ACLTokens, a.logger)
require.NoError(t, err)
authz, err := a.resolveToken("towel")
require.NotNil(t, authz)
require.Nil(t, err)

View File

@ -18,6 +18,8 @@ import (
"time"
"github.com/hashicorp/consul/agent/dns"
"github.com/hashicorp/consul/agent/router"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/go-connlimit"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb"
@ -30,7 +32,6 @@ import (
autoconf "github.com/hashicorp/consul/agent/auto-config"
"github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types"
certmon "github.com/hashicorp/consul/agent/cert-monitor"
"github.com/hashicorp/consul/agent/checks"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/consul"
@ -39,7 +40,6 @@ import (
"github.com/hashicorp/consul/agent/proxycfg"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/systemd"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/agent/xds"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/api/watch"
@ -67,9 +67,6 @@ const (
checksDir = "checks"
checkStateDir = "checks/state"
// Name of the file tokens will be persisted within
tokensPath = "acl-tokens.json"
// Default reasons for node/service maintenance mode
defaultNodeMaintReason = "Maintenance mode is enabled for this node, " +
"but no reason was provided. This is a default message."
@ -161,8 +158,6 @@ type notifier interface {
type Agent struct {
autoConf *autoconf.AutoConfig
certMonitor *certmon.CertMonitor
// config is the agent configuration.
config *config.RuntimeConfig
@ -261,10 +256,12 @@ type Agent struct {
// dnsServer provides the DNS API
dnsServers []*DNSServer
// httpServers provides the HTTP API on various endpoints
httpServers []*HTTPServer
// apiServers listening for connections. If any of these server goroutines
// fail, the agent will be shutdown.
apiServers *apiServers
// wgServers is the wait group for all HTTP and DNS servers
// TODO: remove once dnsServers are handled by apiServers
wgServers sync.WaitGroup
// watchPlans tracks all the currently-running watch plans for the
@ -294,11 +291,6 @@ type Agent struct {
// based on the current consul configuration.
tlsConfigurator *tlsutil.Configurator
// persistedTokensLock is used to synchronize access to the persisted token
// store within the data directory. This will prevent loading while writing as
// well as multiple concurrent writes.
persistedTokensLock sync.RWMutex
// httpConnLimiter is used to limit connections to the HTTP server by client
// IP.
httpConnLimiter connlimit.Limiter
@ -306,6 +298,9 @@ type Agent struct {
// Connection Pool
connPool *pool.ConnPool
// Shared RPC Router
router *router.Router
// enterpriseAgent embeds fields that we only access in consul-enterprise builds
enterpriseAgent
}
@ -351,6 +346,7 @@ func New(bd BaseDeps) (*Agent, error) {
MemSink: bd.MetricsHandler,
connPool: bd.ConnPool,
autoConf: bd.AutoConfig,
router: bd.Router,
}
a.serviceManager = NewServiceManager(&a)
@ -368,6 +364,12 @@ func New(bd BaseDeps) (*Agent, error) {
// pass the agent itself so its safe to move here.
a.registerCache()
// TODO: why do we ignore failure to load persisted tokens?
_ = a.tokens.Load(bd.RuntimeConfig.ACLTokens, a.logger)
// TODO: pass in a fully populated apiServers into Agent.New
a.apiServers = NewAPIServers(a.logger)
return &a, nil
}
@ -421,11 +423,6 @@ func (a *Agent) Start(ctx context.Context) error {
return fmt.Errorf("Failed to load TLS configurations after applying auto-config settings: %w", err)
}
// TODO: move to newBaseDeps
// TODO: handle error
a.loadTokens(a.config)
a.loadEnterpriseTokens(a.config)
// create the local state
a.State = local.NewState(LocalConfig(c), a.logger, a.tokens)
@ -462,6 +459,7 @@ func (a *Agent) Start(ctx context.Context) error {
consul.WithTokenStore(a.tokens),
consul.WithTLSConfigurator(a.tlsConfigurator),
consul.WithConnectionPool(a.connPool),
consul.WithRouter(a.router),
}
// Setup either the client or the server.
@ -489,43 +487,6 @@ func (a *Agent) Start(ctx context.Context) error {
a.State.Delegate = a.delegate
a.State.TriggerSyncChanges = a.sync.SyncChanges.Trigger
if a.config.AutoEncryptTLS && !a.config.ServerMode {
reply, err := a.autoEncryptInitialCertificate(ctx)
if err != nil {
return fmt.Errorf("AutoEncrypt failed: %s", err)
}
cmConfig := new(certmon.Config).
WithCache(a.cache).
WithLogger(a.logger.Named(logging.AutoEncrypt)).
WithTLSConfigurator(a.tlsConfigurator).
WithTokens(a.tokens).
WithFallback(a.autoEncryptInitialCertificate).
WithDNSSANs(a.config.AutoEncryptDNSSAN).
WithIPSANs(a.config.AutoEncryptIPSAN).
WithDatacenter(a.config.Datacenter).
WithNodeName(a.config.NodeName)
monitor, err := certmon.New(cmConfig)
if err != nil {
return fmt.Errorf("AutoEncrypt failed to setup certificate monitor: %w", err)
}
if err := monitor.Update(reply); err != nil {
return fmt.Errorf("AutoEncrypt failed to setup certificate monitor: %w", err)
}
a.certMonitor = monitor
// we don't need to worry about ever calling Stop as we have tied the go routines
// to the agents lifetime by using the StopCh. Also the agent itself doesn't have
// a need of ensuring that the go routine was stopped before performing any action
// so we can ignore the chan in the return.
if _, err := a.certMonitor.Start(&lib.StopChannelContext{StopCh: a.shutdownCh}); err != nil {
return fmt.Errorf("AutoEncrypt failed to start certificate monitor: %w", err)
}
a.logger.Info("automatically upgraded to TLS")
}
if err := a.autoConf.Start(&lib.StopChannelContext{StopCh: a.shutdownCh}); err != nil {
return fmt.Errorf("AutoConf failed to start certificate monitor: %w", err)
}
@ -542,6 +503,16 @@ func (a *Agent) Start(ctx context.Context) error {
return err
}
var intentionDefaultAllow bool
switch a.config.ACLDefaultPolicy {
case "allow":
intentionDefaultAllow = true
case "deny":
intentionDefaultAllow = false
default:
return fmt.Errorf("unexpected ACL default policy value of %q", a.config.ACLDefaultPolicy)
}
// Start the proxy config manager.
a.proxyConfig, err = proxycfg.NewManager(proxycfg.ManagerConfig{
Cache: a.cache,
@ -556,7 +527,8 @@ func (a *Agent) Start(ctx context.Context) error {
Domain: a.config.DNSDomain,
AltDomain: a.config.DNSAltDomain,
},
TLSConfigurator: a.tlsConfigurator,
TLSConfigurator: a.tlsConfigurator,
IntentionDefaultAllow: intentionDefaultAllow,
})
if err != nil {
return err
@ -603,10 +575,7 @@ func (a *Agent) Start(ctx context.Context) error {
// Start HTTP and HTTPS servers.
for _, srv := range servers {
if err := a.serveHTTP(srv); err != nil {
return err
}
a.httpServers = append(a.httpServers, srv)
a.apiServers.Start(srv)
}
// Start gRPC server.
@ -628,17 +597,10 @@ func (a *Agent) Start(ctx context.Context) error {
return nil
}
func (a *Agent) autoEncryptInitialCertificate(ctx context.Context) (*structs.SignedResponse, error) {
client := a.delegate.(*consul.Client)
addrs := a.config.StartJoinAddrsLAN
disco, err := newDiscover()
if err != nil && len(addrs) == 0 {
return nil, err
}
addrs = append(addrs, retryJoinAddrs(disco, retryJoinSerfVariant, "LAN", a.config.RetryJoinLAN, a.logger)...)
return client.RequestAutoEncryptCerts(ctx, addrs, a.config.ServerPort, a.tokens.AgentToken(), a.config.AutoEncryptDNSSAN, a.config.AutoEncryptIPSAN)
// Failed returns a channel which is closed when the first server goroutine exits
// with a non-nil error.
func (a *Agent) Failed() <-chan struct{} {
return a.apiServers.failed
}
func (a *Agent) listenAndServeGRPC() error {
@ -649,7 +611,6 @@ func (a *Agent) listenAndServeGRPC() error {
xdsServer := &xds.Server{
Logger: a.logger,
CfgMgr: a.proxyConfig,
Authz: a,
ResolveToken: a.resolveToken,
CheckFetcher: a,
CfgFetcher: a,
@ -774,14 +735,16 @@ func (a *Agent) startListeners(addrs []net.Addr) ([]net.Listener, error) {
//
// This approach should ultimately be refactored to the point where we just
// start the server and any error should trigger a proper shutdown of the agent.
func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
func (a *Agent) listenHTTP() ([]apiServer, error) {
var ln []net.Listener
var servers []*HTTPServer
var servers []apiServer
start := func(proto string, addrs []net.Addr) error {
listeners, err := a.startListeners(addrs)
if err != nil {
return err
}
ln = append(ln, listeners...)
for _, l := range listeners {
var tlscfg *tls.Config
@ -791,18 +754,15 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
l = tls.NewListener(l, tlscfg)
}
srv := &HTTPServer{
agent: a,
denylist: NewDenylist(a.config.HTTPBlockEndpoints),
}
httpServer := &http.Server{
Addr: l.Addr().String(),
TLSConfig: tlscfg,
Handler: srv.handler(a.config.EnableDebug),
}
srv := &HTTPServer{
Server: httpServer,
ln: l,
agent: a,
denylist: NewDenylist(a.config.HTTPBlockEndpoints),
proto: proto,
}
httpServer.Handler = srv.handler(a.config.EnableDebug)
// Load the connlimit helper into the server
connLimitFn := a.httpConnLimiter.HTTPConnStateFuncWithDefault429Handler(10 * time.Millisecond)
@ -815,27 +775,39 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
httpServer.ConnState = connLimitFn
}
ln = append(ln, l)
servers = append(servers, srv)
servers = append(servers, apiServer{
Protocol: proto,
Addr: l.Addr(),
Shutdown: httpServer.Shutdown,
Run: func() error {
err := httpServer.Serve(l)
if err == nil || err == http.ErrServerClosed {
return nil
}
return fmt.Errorf("%s server %s failed: %w", proto, l.Addr(), err)
},
})
}
return nil
}
if err := start("http", a.config.HTTPAddrs); err != nil {
for _, l := range ln {
l.Close()
}
closeListeners(ln)
return nil, err
}
if err := start("https", a.config.HTTPSAddrs); err != nil {
for _, l := range ln {
l.Close()
}
closeListeners(ln)
return nil, err
}
return servers, nil
}
func closeListeners(lns []net.Listener) {
for _, l := range lns {
l.Close()
}
}
// setupHTTPS adds HTTP/2 support, ConnState, and a connection handshake timeout
// to the http.Server.
func setupHTTPS(server *http.Server, connState func(net.Conn, http.ConnState), timeout time.Duration) error {
@ -897,43 +869,6 @@ func (a *Agent) listenSocket(path string) (net.Listener, error) {
return l, nil
}
func (a *Agent) serveHTTP(srv *HTTPServer) error {
// https://github.com/golang/go/issues/20239
//
// In go.8.1 there is a race between Serve and Shutdown. If
// Shutdown is called before the Serve go routine was scheduled then
// the Serve go routine never returns. This deadlocks the agent
// shutdown for some tests since it will wait forever.
notif := make(chan net.Addr)
a.wgServers.Add(1)
go func() {
defer a.wgServers.Done()
notif <- srv.ln.Addr()
err := srv.Server.Serve(srv.ln)
if err != nil && err != http.ErrServerClosed {
a.logger.Error("error closing server", "error", err)
}
}()
select {
case addr := <-notif:
if srv.proto == "https" {
a.logger.Info("Started HTTPS server",
"address", addr.String(),
"network", addr.Network(),
)
} else {
a.logger.Info("Started HTTP server",
"address", addr.String(),
"network", addr.Network(),
)
}
return nil
case <-time.After(time.Second):
return fmt.Errorf("agent: timeout starting HTTP servers")
}
}
// stopAllWatches stops all the currently running watches
func (a *Agent) stopAllWatches() {
for _, wp := range a.watchPlans {
@ -1364,12 +1299,6 @@ func (a *Agent) ShutdownAgent() error {
// this should help them to be stopped more quickly
a.autoConf.Stop()
if a.certMonitor != nil {
// this would be cancelled anyways (by the closing of the shutdown ch)
// but this should help them to be stopped more quickly
a.certMonitor.Stop()
}
// Stop the service manager (must happen before we take the stateLock to avoid deadlock)
if a.serviceManager != nil {
a.serviceManager.Stop()
@ -1438,13 +1367,12 @@ func (a *Agent) ShutdownAgent() error {
// ShutdownEndpoints terminates the HTTP and DNS servers. Should be
// preceded by ShutdownAgent.
// TODO: remove this method, move to ShutdownAgent
func (a *Agent) ShutdownEndpoints() {
a.shutdownLock.Lock()
defer a.shutdownLock.Unlock()
if len(a.dnsServers) == 0 && len(a.httpServers) == 0 {
return
}
ctx := context.TODO()
for _, srv := range a.dnsServers {
if srv.Server != nil {
@ -1458,27 +1386,11 @@ func (a *Agent) ShutdownEndpoints() {
}
a.dnsServers = nil
for _, srv := range a.httpServers {
a.logger.Info("Stopping server",
"protocol", strings.ToUpper(srv.proto),
"address", srv.ln.Addr().String(),
"network", srv.ln.Addr().Network(),
)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
srv.Server.Shutdown(ctx)
if ctx.Err() == context.DeadlineExceeded {
a.logger.Warn("Timeout stopping server",
"protocol", strings.ToUpper(srv.proto),
"address", srv.ln.Addr().String(),
"network", srv.ln.Addr().Network(),
)
}
}
a.httpServers = nil
a.apiServers.Shutdown(ctx)
a.logger.Info("Waiting for endpoints to shut down")
a.wgServers.Wait()
if err := a.apiServers.WaitForShutdown(); err != nil {
a.logger.Error(err.Error())
}
a.logger.Info("Endpoints down")
}
@ -3430,90 +3342,6 @@ func (a *Agent) unloadChecks() error {
return nil
}
type persistedTokens struct {
Replication string `json:"replication,omitempty"`
AgentMaster string `json:"agent_master,omitempty"`
Default string `json:"default,omitempty"`
Agent string `json:"agent,omitempty"`
}
func (a *Agent) getPersistedTokens() (*persistedTokens, error) {
persistedTokens := &persistedTokens{}
if !a.config.ACLEnableTokenPersistence {
return persistedTokens, nil
}
a.persistedTokensLock.RLock()
defer a.persistedTokensLock.RUnlock()
tokensFullPath := filepath.Join(a.config.DataDir, tokensPath)
buf, err := ioutil.ReadFile(tokensFullPath)
if err != nil {
if os.IsNotExist(err) {
// non-existence is not an error we care about
return persistedTokens, nil
}
return persistedTokens, fmt.Errorf("failed reading tokens file %q: %s", tokensFullPath, err)
}
if err := json.Unmarshal(buf, persistedTokens); err != nil {
return persistedTokens, fmt.Errorf("failed to decode tokens file %q: %s", tokensFullPath, err)
}
return persistedTokens, nil
}
func (a *Agent) loadTokens(conf *config.RuntimeConfig) error {
persistedTokens, persistenceErr := a.getPersistedTokens()
if persistenceErr != nil {
a.logger.Warn("unable to load persisted tokens", "error", persistenceErr)
}
if persistedTokens.Default != "" {
a.tokens.UpdateUserToken(persistedTokens.Default, token.TokenSourceAPI)
if conf.ACLToken != "" {
a.logger.Warn("\"default\" token present in both the configuration and persisted token store, using the persisted token")
}
} else {
a.tokens.UpdateUserToken(conf.ACLToken, token.TokenSourceConfig)
}
if persistedTokens.Agent != "" {
a.tokens.UpdateAgentToken(persistedTokens.Agent, token.TokenSourceAPI)
if conf.ACLAgentToken != "" {
a.logger.Warn("\"agent\" token present in both the configuration and persisted token store, using the persisted token")
}
} else {
a.tokens.UpdateAgentToken(conf.ACLAgentToken, token.TokenSourceConfig)
}
if persistedTokens.AgentMaster != "" {
a.tokens.UpdateAgentMasterToken(persistedTokens.AgentMaster, token.TokenSourceAPI)
if conf.ACLAgentMasterToken != "" {
a.logger.Warn("\"agent_master\" token present in both the configuration and persisted token store, using the persisted token")
}
} else {
a.tokens.UpdateAgentMasterToken(conf.ACLAgentMasterToken, token.TokenSourceConfig)
}
if persistedTokens.Replication != "" {
a.tokens.UpdateReplicationToken(persistedTokens.Replication, token.TokenSourceAPI)
if conf.ACLReplicationToken != "" {
a.logger.Warn("\"replication\" token present in both the configuration and persisted token store, using the persisted token")
}
} else {
a.tokens.UpdateReplicationToken(conf.ACLReplicationToken, token.TokenSourceConfig)
}
return persistenceErr
}
// snapshotCheckState is used to snapshot the current state of the health
// checks. This is done before we reload our checks, so that we can properly
// restore into the same state.
@ -3693,8 +3521,7 @@ func (a *Agent) reloadConfigInternal(newCfg *config.RuntimeConfig) error {
// Reload tokens - should be done before all the other loading
// to ensure the correct tokens are available for attaching to
// the checks and service registrations.
a.loadTokens(newCfg)
a.loadEnterpriseTokens(newCfg)
a.tokens.Load(newCfg.ACLTokens, a.logger)
if err := a.tlsConfigurator.Update(newCfg.ToTLSUtilConfig()); err != nil {
return fmt.Errorf("Failed reloading tls configuration: %s", err)
@ -3748,6 +3575,12 @@ func (a *Agent) reloadConfigInternal(newCfg *config.RuntimeConfig) error {
return err
}
if a.cache.ReloadOptions(newCfg.Cache) {
a.logger.Info("Cache options have been updated")
} else {
a.logger.Debug("Cache options have not been modified")
}
// Update filtered metrics
metrics.UpdateFilter(newCfg.Telemetry.AllowedPrefixes,
newCfg.Telemetry.BlockedPrefixes)

View File

@ -1,10 +1,8 @@
package agent
import (
"encoding/json"
"fmt"
"net/http"
"path/filepath"
"strconv"
"strings"
@ -17,10 +15,10 @@ import (
"github.com/hashicorp/consul/agent/debug"
"github.com/hashicorp/consul/agent/structs"
token_store "github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/agent/xds/proxysupport"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/lib/file"
"github.com/hashicorp/consul/logging"
"github.com/hashicorp/consul/logging/monitor"
"github.com/hashicorp/consul/types"
@ -38,6 +36,11 @@ type Self struct {
Member serf.Member
Stats map[string]map[string]string
Meta map[string]string
XDS *xdsSelf `json:"xDS,omitempty"`
}
type xdsSelf struct {
SupportedProxies map[string][]string
}
func (s *HTTPServer) AgentSelf(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
@ -60,6 +63,15 @@ func (s *HTTPServer) AgentSelf(resp http.ResponseWriter, req *http.Request) (int
}
}
var xds *xdsSelf
if s.agent.grpcServer != nil {
xds = &xdsSelf{
SupportedProxies: map[string][]string{
"envoy": proxysupport.EnvoyVersions,
},
}
}
config := struct {
Datacenter string
NodeName string
@ -82,6 +94,7 @@ func (s *HTTPServer) AgentSelf(resp http.ResponseWriter, req *http.Request) (int
Member: s.agent.LocalMember(),
Stats: s.agent.Stats(),
Meta: s.agent.State.Metadata(),
XDS: xds,
}, nil
}
@ -1217,79 +1230,42 @@ func (s *HTTPServer) AgentToken(resp http.ResponseWriter, req *http.Request) (in
return nil, nil
}
if s.agent.config.ACLEnableTokenPersistence {
// we hold the lock around updating the internal token store
// as well as persisting the tokens because we don't want to write
// into the store to have something else wipe it out before we can
// persist everything (like an agent config reload). The token store
// lock is only held for those operations so other go routines that
// just need to read some token out of the store will not be impacted
// any more than they would be without token persistence.
s.agent.persistedTokensLock.Lock()
defer s.agent.persistedTokensLock.Unlock()
}
// Figure out the target token.
target := strings.TrimPrefix(req.URL.Path, "/v1/agent/token/")
triggerAntiEntropySync := false
switch target {
case "acl_token", "default":
changed := s.agent.tokens.UpdateUserToken(args.Token, token_store.TokenSourceAPI)
if changed {
triggerAntiEntropySync = true
err = s.agent.tokens.WithPersistenceLock(func() error {
triggerAntiEntropySync := false
switch target {
case "acl_token", "default":
changed := s.agent.tokens.UpdateUserToken(args.Token, token_store.TokenSourceAPI)
if changed {
triggerAntiEntropySync = true
}
case "acl_agent_token", "agent":
changed := s.agent.tokens.UpdateAgentToken(args.Token, token_store.TokenSourceAPI)
if changed {
triggerAntiEntropySync = true
}
case "acl_agent_master_token", "agent_master":
s.agent.tokens.UpdateAgentMasterToken(args.Token, token_store.TokenSourceAPI)
case "acl_replication_token", "replication":
s.agent.tokens.UpdateReplicationToken(args.Token, token_store.TokenSourceAPI)
default:
return NotFoundError{Reason: fmt.Sprintf("Token %q is unknown", target)}
}
case "acl_agent_token", "agent":
changed := s.agent.tokens.UpdateAgentToken(args.Token, token_store.TokenSourceAPI)
if changed {
triggerAntiEntropySync = true
}
case "acl_agent_master_token", "agent_master":
s.agent.tokens.UpdateAgentMasterToken(args.Token, token_store.TokenSourceAPI)
case "acl_replication_token", "replication":
s.agent.tokens.UpdateReplicationToken(args.Token, token_store.TokenSourceAPI)
default:
resp.WriteHeader(http.StatusNotFound)
fmt.Fprintf(resp, "Token %q is unknown", target)
return nil, nil
}
if triggerAntiEntropySync {
s.agent.sync.SyncFull.Trigger()
}
if s.agent.config.ACLEnableTokenPersistence {
tokens := persistedTokens{}
if tok, source := s.agent.tokens.UserTokenAndSource(); tok != "" && source == token_store.TokenSourceAPI {
tokens.Default = tok
}
if tok, source := s.agent.tokens.AgentTokenAndSource(); tok != "" && source == token_store.TokenSourceAPI {
tokens.Agent = tok
}
if tok, source := s.agent.tokens.AgentMasterTokenAndSource(); tok != "" && source == token_store.TokenSourceAPI {
tokens.AgentMaster = tok
}
if tok, source := s.agent.tokens.ReplicationTokenAndSource(); tok != "" && source == token_store.TokenSourceAPI {
tokens.Replication = tok
}
data, err := json.Marshal(tokens)
if err != nil {
s.agent.logger.Warn("failed to persist tokens", "error", err)
return nil, fmt.Errorf("Failed to marshal tokens for persistence: %v", err)
}
if err := file.WriteAtomicWithPerms(filepath.Join(s.agent.config.DataDir, tokensPath), data, 0700, 0600); err != nil {
s.agent.logger.Warn("failed to persist tokens", "error", err)
return nil, fmt.Errorf("Failed to persist tokens - %v", err)
// TODO: is it safe to move this out of WithPersistenceLock?
if triggerAntiEntropySync {
s.agent.sync.SyncFull.Trigger()
}
return nil
})
if err != nil {
return nil, err
}
s.agent.logger.Info("Updated agent's ACL token", "token", target)

View File

@ -13,7 +13,6 @@ import (
"net/http/httptest"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"testing"
@ -27,6 +26,7 @@ import (
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/token"
tokenStore "github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/agent/xds/proxysupport"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/sdk/testutil"
@ -1209,39 +1209,65 @@ func TestAgent_Checks_ACLFilter(t *testing.T) {
func TestAgent_Self(t *testing.T) {
t.Parallel()
a := NewTestAgent(t, `
node_meta {
somekey = "somevalue"
}
`)
defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
req, _ := http.NewRequest("GET", "/v1/agent/self", nil)
obj, err := a.srv.AgentSelf(nil, req)
if err != nil {
t.Fatalf("err: %v", err)
cases := map[string]struct {
hcl string
expectXDS bool
}{
"normal": {
hcl: `
node_meta {
somekey = "somevalue"
}
`,
expectXDS: true,
},
"no grpc": {
hcl: `
node_meta {
somekey = "somevalue"
}
ports = {
grpc = -1
}
`,
expectXDS: false,
},
}
val := obj.(Self)
if int(val.Member.Port) != a.Config.SerfPortLAN {
t.Fatalf("incorrect port: %v", obj)
}
for name, tc := range cases {
tc := tc
t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, tc.hcl)
defer a.Shutdown()
if val.DebugConfig["SerfPortLAN"].(int) != a.Config.SerfPortLAN {
t.Fatalf("incorrect port: %v", obj)
}
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
req, _ := http.NewRequest("GET", "/v1/agent/self", nil)
obj, err := a.srv.AgentSelf(nil, req)
require.NoError(t, err)
cs, err := a.GetLANCoordinate()
if err != nil {
t.Fatalf("err: %v", err)
}
if c := cs[a.config.SegmentName]; !reflect.DeepEqual(c, val.Coord) {
t.Fatalf("coordinates are not equal: %v != %v", c, val.Coord)
}
delete(val.Meta, structs.MetaSegmentKey) // Added later, not in config.
if !reflect.DeepEqual(a.config.NodeMeta, val.Meta) {
t.Fatalf("meta fields are not equal: %v != %v", a.config.NodeMeta, val.Meta)
val := obj.(Self)
require.Equal(t, a.Config.SerfPortLAN, int(val.Member.Port))
require.Equal(t, a.Config.SerfPortLAN, val.DebugConfig["SerfPortLAN"].(int))
cs, err := a.GetLANCoordinate()
require.NoError(t, err)
require.Equal(t, cs[a.config.SegmentName], val.Coord)
delete(val.Meta, structs.MetaSegmentKey) // Added later, not in config.
require.Equal(t, a.config.NodeMeta, val.Meta)
if tc.expectXDS {
require.NotNil(t, val.XDS, "xds component missing when gRPC is enabled")
require.Equal(t,
map[string][]string{"envoy": proxysupport.EnvoyVersions},
val.XDS.SupportedProxies,
)
} else {
require.Nil(t, val.XDS, "xds component should be missing when gRPC is disabled")
}
})
}
}
@ -4748,13 +4774,14 @@ func TestAgent_Token(t *testing.T) {
init tokens
raw tokens
effective tokens
expectedErr error
}{
{
name: "bad token name",
method: "PUT",
url: "nope?token=root",
body: body("X"),
code: http.StatusNotFound,
name: "bad token name",
method: "PUT",
url: "nope?token=root",
body: body("X"),
expectedErr: NotFoundError{Reason: `Token "nope" is unknown`},
},
{
name: "bad JSON",
@ -4916,7 +4943,12 @@ func TestAgent_Token(t *testing.T) {
url := fmt.Sprintf("/v1/agent/token/%s", tt.url)
resp := httptest.NewRecorder()
req, _ := http.NewRequest(tt.method, url, tt.body)
_, err := a.srv.AgentToken(resp, req)
if tt.expectedErr != nil {
require.Equal(t, tt.expectedErr, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.code, resp.Code)
require.Equal(t, tt.effective.user, a.tokens.UserToken())

View File

@ -23,10 +23,6 @@ func (a *Agent) initEnterprise(consulCfg *consul.Config) error {
return nil
}
// loadEnterpriseTokens is a noop stub for the func defined agent_ent.go
func (a *Agent) loadEnterpriseTokens(conf *config.RuntimeConfig) {
}
// reloadEnterprise is a noop stub for the func defined agent_ent.go
func (a *Agent) reloadEnterprise(conf *config.RuntimeConfig) error {
return nil

View File

@ -43,6 +43,7 @@ import (
"github.com/hashicorp/serf/serf"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
"gopkg.in/square/go-jose.v2/jwt"
)
@ -765,10 +766,18 @@ func TestCacheRateLimit(test *testing.T) {
test.Run(fmt.Sprintf("rate_limit_at_%v", currentTest.rateLimit), func(t *testing.T) {
tt := currentTest
t.Parallel()
a := NewTestAgent(t, fmt.Sprintf("cache = { entry_fetch_rate = %v, entry_fetch_max_burst = 1 }", tt.rateLimit))
a := NewTestAgent(t, "cache = { entry_fetch_rate = 1, entry_fetch_max_burst = 100 }")
defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
cfg := a.config
require.Equal(t, rate.Limit(1), a.config.Cache.EntryFetchRate)
require.Equal(t, 100, a.config.Cache.EntryFetchMaxBurst)
cfg.Cache.EntryFetchRate = rate.Limit(tt.rateLimit)
cfg.Cache.EntryFetchMaxBurst = 1
a.reloadConfigInternal(cfg)
require.Equal(t, rate.Limit(tt.rateLimit), a.config.Cache.EntryFetchRate)
require.Equal(t, 1, a.config.Cache.EntryFetchMaxBurst)
var wg sync.WaitGroup
stillProcessing := true
@ -1908,7 +1917,7 @@ func TestAgent_HTTPCheck_EnableAgentTLSForChecks(t *testing.T) {
Status: api.HealthCritical,
}
url := fmt.Sprintf("https://%s/v1/agent/self", a.srv.ln.Addr().String())
url := fmt.Sprintf("https://%s/v1/agent/self", a.HTTPAddr())
chk := &structs.CheckType{
HTTP: url,
Interval: 20 * time.Millisecond,
@ -3336,163 +3345,6 @@ func TestAgent_reloadWatchesHTTPS(t *testing.T) {
}
}
func TestAgent_loadTokens(t *testing.T) {
t.Parallel()
a := NewTestAgent(t, `
acl = {
enabled = true
tokens = {
agent = "alfa"
agent_master = "bravo",
default = "charlie"
replication = "delta"
}
}
`)
defer a.Shutdown()
require := require.New(t)
tokensFullPath := filepath.Join(a.config.DataDir, tokensPath)
t.Run("original-configuration", func(t *testing.T) {
require.Equal("alfa", a.tokens.AgentToken())
require.Equal("bravo", a.tokens.AgentMasterToken())
require.Equal("charlie", a.tokens.UserToken())
require.Equal("delta", a.tokens.ReplicationToken())
})
t.Run("updated-configuration", func(t *testing.T) {
cfg := &config.RuntimeConfig{
ACLToken: "echo",
ACLAgentToken: "foxtrot",
ACLAgentMasterToken: "golf",
ACLReplicationToken: "hotel",
}
// ensures no error for missing persisted tokens file
require.NoError(a.loadTokens(cfg))
require.Equal("echo", a.tokens.UserToken())
require.Equal("foxtrot", a.tokens.AgentToken())
require.Equal("golf", a.tokens.AgentMasterToken())
require.Equal("hotel", a.tokens.ReplicationToken())
})
t.Run("persisted-tokens", func(t *testing.T) {
cfg := &config.RuntimeConfig{
ACLToken: "echo",
ACLAgentToken: "foxtrot",
ACLAgentMasterToken: "golf",
ACLReplicationToken: "hotel",
}
tokens := `{
"agent" : "india",
"agent_master" : "juliett",
"default": "kilo",
"replication" : "lima"
}`
require.NoError(ioutil.WriteFile(tokensFullPath, []byte(tokens), 0600))
require.NoError(a.loadTokens(cfg))
// no updates since token persistence is not enabled
require.Equal("echo", a.tokens.UserToken())
require.Equal("foxtrot", a.tokens.AgentToken())
require.Equal("golf", a.tokens.AgentMasterToken())
require.Equal("hotel", a.tokens.ReplicationToken())
a.config.ACLEnableTokenPersistence = true
require.NoError(a.loadTokens(cfg))
require.Equal("india", a.tokens.AgentToken())
require.Equal("juliett", a.tokens.AgentMasterToken())
require.Equal("kilo", a.tokens.UserToken())
require.Equal("lima", a.tokens.ReplicationToken())
})
t.Run("persisted-tokens-override", func(t *testing.T) {
tokens := `{
"agent" : "mike",
"agent_master" : "november",
"default": "oscar",
"replication" : "papa"
}`
cfg := &config.RuntimeConfig{
ACLToken: "quebec",
ACLAgentToken: "romeo",
ACLAgentMasterToken: "sierra",
ACLReplicationToken: "tango",
}
require.NoError(ioutil.WriteFile(tokensFullPath, []byte(tokens), 0600))
require.NoError(a.loadTokens(cfg))
require.Equal("mike", a.tokens.AgentToken())
require.Equal("november", a.tokens.AgentMasterToken())
require.Equal("oscar", a.tokens.UserToken())
require.Equal("papa", a.tokens.ReplicationToken())
})
t.Run("partial-persisted", func(t *testing.T) {
tokens := `{
"agent" : "uniform",
"agent_master" : "victor"
}`
cfg := &config.RuntimeConfig{
ACLToken: "whiskey",
ACLAgentToken: "xray",
ACLAgentMasterToken: "yankee",
ACLReplicationToken: "zulu",
}
require.NoError(ioutil.WriteFile(tokensFullPath, []byte(tokens), 0600))
require.NoError(a.loadTokens(cfg))
require.Equal("uniform", a.tokens.AgentToken())
require.Equal("victor", a.tokens.AgentMasterToken())
require.Equal("whiskey", a.tokens.UserToken())
require.Equal("zulu", a.tokens.ReplicationToken())
})
t.Run("persistence-error-not-json", func(t *testing.T) {
cfg := &config.RuntimeConfig{
ACLToken: "one",
ACLAgentToken: "two",
ACLAgentMasterToken: "three",
ACLReplicationToken: "four",
}
require.NoError(ioutil.WriteFile(tokensFullPath, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, 0600))
err := a.loadTokens(cfg)
require.Error(err)
require.Equal("one", a.tokens.UserToken())
require.Equal("two", a.tokens.AgentToken())
require.Equal("three", a.tokens.AgentMasterToken())
require.Equal("four", a.tokens.ReplicationToken())
})
t.Run("persistence-error-wrong-top-level", func(t *testing.T) {
cfg := &config.RuntimeConfig{
ACLToken: "alfa",
ACLAgentToken: "bravo",
ACLAgentMasterToken: "charlie",
ACLReplicationToken: "foxtrot",
}
require.NoError(ioutil.WriteFile(tokensFullPath, []byte("[1,2,3]"), 0600))
err := a.loadTokens(cfg)
require.Error(err)
require.Equal("alfa", a.tokens.UserToken())
require.Equal("bravo", a.tokens.AgentToken())
require.Equal("charlie", a.tokens.AgentMasterToken())
require.Equal("foxtrot", a.tokens.ReplicationToken())
})
}
func TestAgent_SecurityChecks(t *testing.T) {
t.Parallel()
hcl := `
@ -4741,3 +4593,33 @@ func TestAgent_AutoEncrypt(t *testing.T) {
require.Len(t, x509Cert.URIs, 1)
require.Equal(t, id.URI(), x509Cert.URIs[0])
}
func TestSharedRPCRouter(t *testing.T) {
// this test runs both a server and client and ensures that the shared
// router is being used. It would be possible for the Client and Server
// types to create and use their own routers and for RPCs such as the
// ones used in WaitForTestAgent to succeed. However accessing the
// router stored on the agent ensures that Serf information from the
// Client/Server types are being set in the same shared rpc router.
srv := NewTestAgent(t, "")
defer srv.Shutdown()
testrpc.WaitForTestAgent(t, srv.RPC, "dc1")
mgr, server := srv.Agent.router.FindLANRoute()
require.NotNil(t, mgr)
require.NotNil(t, server)
client := NewTestAgent(t, `
server = false
bootstrap = false
retry_join = ["`+srv.Config.SerfBindAddrLAN.String()+`"]
`)
testrpc.WaitForTestAgent(t, client.RPC, "dc1")
mgr, server = client.Agent.router.FindLANRoute()
require.NotNil(t, mgr)
require.NotNil(t, server)
}

94
agent/apiserver.go Normal file
View File

@ -0,0 +1,94 @@
package agent
import (
"context"
"net"
"sync"
"time"
"github.com/hashicorp/go-hclog"
"golang.org/x/sync/errgroup"
)
// apiServers is a wrapper around errgroup.Group for managing go routines for
// long running agent components (ex: http server, dns server). If any of the
// servers fail, the failed channel will be closed, which will cause the agent
// to be shutdown instead of running in a degraded state.
//
// This struct exists as a shim for using errgroup.Group without making major
// changes to Agent. In the future it may be removed and replaced with more
// direct usage of errgroup.Group.
type apiServers struct {
logger hclog.Logger
group *errgroup.Group
servers []apiServer
// failed channel is closed when the first server goroutines exit with a
// non-nil error.
failed <-chan struct{}
}
type apiServer struct {
// Protocol supported by this server. One of: dns, http, https
Protocol string
// Addr the server is listening on
Addr net.Addr
// Run will be called in a goroutine to run the server. When any Run exits
// with a non-nil error, the failed channel will be closed.
Run func() error
// Shutdown function used to stop the server
Shutdown func(context.Context) error
}
// NewAPIServers returns an empty apiServers that is ready to Start servers.
func NewAPIServers(logger hclog.Logger) *apiServers {
group, ctx := errgroup.WithContext(context.TODO())
return &apiServers{
logger: logger,
group: group,
failed: ctx.Done(),
}
}
func (s *apiServers) Start(srv apiServer) {
srv.logger(s.logger).Info("Starting server")
s.servers = append(s.servers, srv)
s.group.Go(srv.Run)
}
func (s apiServer) logger(base hclog.Logger) hclog.Logger {
return base.With(
"protocol", s.Protocol,
"address", s.Addr.String(),
"network", s.Addr.Network())
}
// Shutdown all the servers and log any errors as warning. Each server is given
// 1 second, or until ctx is cancelled, to shutdown gracefully.
func (s *apiServers) Shutdown(ctx context.Context) {
shutdownGroup := new(sync.WaitGroup)
for i := range s.servers {
server := s.servers[i]
shutdownGroup.Add(1)
go func() {
defer shutdownGroup.Done()
logger := server.logger(s.logger)
logger.Info("Stopping server")
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
logger.Warn("Failed to stop server")
}
}()
}
s.servers = nil
shutdownGroup.Wait()
}
// WaitForShutdown waits until all server goroutines have exited. Shutdown
// must be called before WaitForShutdown, otherwise it will block forever.
func (s *apiServers) WaitForShutdown() error {
return s.group.Wait()
}

65
agent/apiserver_test.go Normal file
View File

@ -0,0 +1,65 @@
package agent
import (
"context"
"fmt"
"net"
"testing"
"time"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/require"
)
func TestAPIServers_WithServiceRunError(t *testing.T) {
servers := NewAPIServers(hclog.New(nil))
server1, chErr1 := newAPIServerStub()
server2, _ := newAPIServerStub()
t.Run("Start", func(t *testing.T) {
servers.Start(server1)
servers.Start(server2)
select {
case <-servers.failed:
t.Fatalf("expected servers to still be running")
case <-time.After(5 * time.Millisecond):
}
})
err := fmt.Errorf("oops, I broke")
t.Run("server exit non-nil error", func(t *testing.T) {
chErr1 <- err
select {
case <-servers.failed:
case <-time.After(time.Second):
t.Fatalf("expected failed channel to be closed")
}
})
t.Run("shutdown remaining services", func(t *testing.T) {
servers.Shutdown(context.Background())
require.Equal(t, err, servers.WaitForShutdown())
})
}
func newAPIServerStub() (apiServer, chan error) {
chErr := make(chan error)
return apiServer{
Protocol: "http",
Addr: &net.TCPAddr{
IP: net.ParseIP("127.0.0.11"),
Port: 5505,
},
Run: func() error {
return <-chErr
},
Shutdown: func(ctx context.Context) error {
close(chErr)
return nil
},
}, chErr
}

View File

@ -4,62 +4,54 @@ import (
"context"
"fmt"
"io/ioutil"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/logging"
"github.com/hashicorp/consul/proto/pbautoconf"
"github.com/hashicorp/go-discover"
discoverk8s "github.com/hashicorp/go-discover/provider/k8s"
"github.com/hashicorp/go-hclog"
"github.com/golang/protobuf/jsonpb"
)
const (
// autoConfigFileName is the name of the file that the agent auto-config settings are
// stored in within the data directory
autoConfigFileName = "auto-config.json"
dummyTrustDomain = "dummytrustdomain"
)
var (
pbMarshaler = &jsonpb.Marshaler{
OrigName: false,
EnumsAsInts: false,
Indent: " ",
EmitDefaults: true,
}
pbUnmarshaler = &jsonpb.Unmarshaler{
AllowUnknownFields: false,
}
)
// AutoConfig is all the state necessary for being able to parse a configuration
// as well as perform the necessary RPCs to perform Agent Auto Configuration.
//
// NOTE: This struct and methods on it are not currently thread/goroutine safe.
// However it doesn't spawn any of its own go routines yet and is used in a
// synchronous fashion. In the future if either of those two conditions change
// then we will need to add some locking here. I am deferring that for now
// to help ease the review of this already large PR.
type AutoConfig struct {
sync.Mutex
acConfig Config
logger hclog.Logger
certMonitor CertMonitor
cache Cache
waiter *lib.RetryWaiter
config *config.RuntimeConfig
autoConfigResponse *pbautoconf.AutoConfigResponse
autoConfigSource config.Source
running bool
done chan struct{}
// cancel is used to cancel the entire AutoConfig
// go routine. This is the main field protected
// by the mutex as it being non-nil indicates that
// the go routine has been started and is stoppable.
// note that it doesn't indcate that the go routine
// is currently running.
cancel context.CancelFunc
// cancelWatches is used to cancel the existing
// cache watches regarding the agents certificate. This is
// mainly only necessary when the Agent token changes.
cancelWatches context.CancelFunc
// cacheUpdates is the chan used to have the cache
// send us back events
cacheUpdates chan cache.UpdateEvent
// tokenUpdates is the struct used to receive
// events from the token store when the Agent
// token is updated.
tokenUpdates token.Notifier
}
// New creates a new AutoConfig object for providing automatic Consul configuration.
@ -69,6 +61,19 @@ func New(config Config) (*AutoConfig, error) {
return nil, fmt.Errorf("must provide a config loader")
case config.DirectRPC == nil:
return nil, fmt.Errorf("must provide a direct RPC delegate")
case config.Cache == nil:
return nil, fmt.Errorf("must provide a cache")
case config.TLSConfigurator == nil:
return nil, fmt.Errorf("must provide a TLS configurator")
case config.Tokens == nil:
return nil, fmt.Errorf("must provide a token store")
}
if config.FallbackLeeway == 0 {
config.FallbackLeeway = 10 * time.Second
}
if config.FallbackRetry == 0 {
config.FallbackRetry = time.Minute
}
logger := config.Logger
@ -83,15 +88,16 @@ func New(config Config) (*AutoConfig, error) {
}
return &AutoConfig{
acConfig: config,
logger: logger,
certMonitor: config.CertMonitor,
acConfig: config,
logger: logger,
}, nil
}
// ReadConfig will parse the current configuration and inject any
// auto-config sources if present into the correct place in the parsing chain.
func (ac *AutoConfig) ReadConfig() (*config.RuntimeConfig, error) {
ac.Lock()
defer ac.Unlock()
cfg, warnings, err := ac.acConfig.Loader(ac.autoConfigSource)
if err != nil {
return cfg, err
@ -105,46 +111,6 @@ func (ac *AutoConfig) ReadConfig() (*config.RuntimeConfig, error) {
return cfg, nil
}
// restorePersistedAutoConfig will attempt to load the persisted auto-config
// settings from the data directory. It returns true either when there was an
// unrecoverable error or when the configuration was successfully loaded from
// disk. Recoverable errors, such as "file not found" are suppressed and this
// method will return false for the first boolean.
func (ac *AutoConfig) restorePersistedAutoConfig() (bool, error) {
if ac.config.DataDir == "" {
// no data directory means we don't have anything to potentially load
return false, nil
}
path := filepath.Join(ac.config.DataDir, autoConfigFileName)
ac.logger.Debug("attempting to restore any persisted configuration", "path", path)
content, err := ioutil.ReadFile(path)
if err == nil {
rdr := strings.NewReader(string(content))
var resp pbautoconf.AutoConfigResponse
if err := pbUnmarshaler.Unmarshal(rdr, &resp); err != nil {
return false, fmt.Errorf("failed to decode persisted auto-config data: %w", err)
}
if err := ac.update(&resp); err != nil {
return false, fmt.Errorf("error restoring persisted auto-config response: %w", err)
}
ac.logger.Info("restored persisted configuration", "path", path)
return true, nil
}
if !os.IsNotExist(err) {
return true, fmt.Errorf("failed to load %s: %w", path, err)
}
// ignore non-existence errors as that is an indicator that we haven't
// performed the auto configuration before
return false, nil
}
// InitialConfiguration will perform a one-time RPC request to the configured servers
// to retrieve various cluster wide configurations. See the proto/pbautoconf/auto_config.proto
// file for a complete reference of what configurations can be applied in this manner.
@ -164,30 +130,49 @@ func (ac *AutoConfig) InitialConfiguration(ctx context.Context) (*config.Runtime
ac.config = config
}
if !ac.config.AutoConfig.Enabled {
return ac.config, nil
}
ready, err := ac.restorePersistedAutoConfig()
if err != nil {
return nil, err
}
if !ready {
ac.logger.Info("retrieving initial agent auto configuration remotely")
if err := ac.getInitialConfiguration(ctx); err != nil {
switch {
case ac.config.AutoConfig.Enabled:
resp, err := ac.readPersistedAutoConfig()
if err != nil {
return nil, err
}
}
// re-read the configuration now that we have our initial auto-config
config, err := ac.ReadConfig()
if err != nil {
return nil, err
}
if resp == nil {
ac.logger.Info("retrieving initial agent auto configuration remotely")
resp, err = ac.getInitialConfiguration(ctx)
if err != nil {
return nil, err
}
}
ac.config = config
return ac.config, nil
ac.logger.Debug("updating auto-config settings")
if err = ac.recordInitialConfiguration(resp); err != nil {
return nil, err
}
// re-read the configuration now that we have our initial auto-config
config, err := ac.ReadConfig()
if err != nil {
return nil, err
}
ac.config = config
return ac.config, nil
case ac.config.AutoEncryptTLS:
certs, err := ac.autoEncryptInitialCerts(ctx)
if err != nil {
return nil, err
}
if err := ac.setInitialTLSCertificates(certs); err != nil {
return nil, err
}
ac.logger.Info("automatically upgraded to TLS")
return ac.config, nil
default:
return ac.config, nil
}
}
// introToken is responsible for determining the correct intro token to use
@ -217,118 +202,45 @@ func (ac *AutoConfig) introToken() (string, error) {
return token, nil
}
// serverHosts is responsible for taking the list of server addresses and
// resolving any go-discover provider invocations. It will then return a list
// of hosts. These might be hostnames and is expected that DNS resolution may
// be performed after this function runs. Additionally these may contain ports
// so SplitHostPort could also be necessary.
func (ac *AutoConfig) serverHosts() ([]string, error) {
servers := ac.config.AutoConfig.ServerAddresses
// recordInitialConfiguration is responsible for recording the AutoConfigResponse from
// the AutoConfig.InitialConfiguration RPC. It is an all-in-one function to do the following
// * update the Agent token in the token store
func (ac *AutoConfig) recordInitialConfiguration(resp *pbautoconf.AutoConfigResponse) error {
ac.autoConfigResponse = resp
providers := make(map[string]discover.Provider)
for k, v := range discover.Providers {
providers[k] = v
ac.autoConfigSource = config.LiteralSource{
Name: autoConfigFileName,
Config: translateConfig(resp.Config),
}
providers["k8s"] = &discoverk8s.Provider{}
disco, err := discover.New(
discover.WithUserAgent(lib.UserAgent()),
discover.WithProviders(providers),
)
// we need to re-read the configuration to determine what the correct ACL
// token to push into the token store is. Any user provided token will override
// any AutoConfig generated token.
config, err := ac.ReadConfig()
if err != nil {
return nil, fmt.Errorf("Failed to create go-discover resolver: %w", err)
return fmt.Errorf("failed to fully resolve configuration: %w", err)
}
var addrs []string
for _, addr := range servers {
switch {
case strings.Contains(addr, "provider="):
resolved, err := disco.Addrs(addr, ac.logger.StandardLogger(&hclog.StandardLoggerOptions{InferLevels: true}))
if err != nil {
ac.logger.Error("failed to resolve go-discover auto-config servers", "configuration", addr, "err", err)
continue
}
// ignoring the return value which would indicate a change in the token
_ = ac.acConfig.Tokens.UpdateAgentToken(config.ACLTokens.ACLAgentToken, token.TokenSourceConfig)
addrs = append(addrs, resolved...)
ac.logger.Debug("discovered auto-config servers", "servers", resolved)
default:
addrs = append(addrs, addr)
}
}
if len(addrs) == 0 {
return nil, fmt.Errorf("no auto-config server addresses available for use")
}
return addrs, nil
}
// resolveHost will take a single host string and convert it to a list of TCPAddrs
// This will process any port in the input as well as looking up the hostname using
// normal DNS resolution.
func (ac *AutoConfig) resolveHost(hostPort string) []net.TCPAddr {
port := ac.config.ServerPort
host, portStr, err := net.SplitHostPort(hostPort)
// extra a structs.SignedResponse from the AutoConfigResponse for use in cache prepopulation
signed, err := extractSignedResponse(resp)
if err != nil {
if strings.Contains(err.Error(), "missing port in address") {
host = hostPort
} else {
ac.logger.Warn("error splitting host address into IP and port", "address", hostPort, "error", err)
return nil
}
} else {
port, err = strconv.Atoi(portStr)
if err != nil {
ac.logger.Warn("Parsed port is not an integer", "port", portStr, "error", err)
return nil
}
return fmt.Errorf("failed to extract certificates from the auto-config response: %w", err)
}
// resolve the host to a list of IPs
ips, err := net.LookupIP(host)
if err != nil {
ac.logger.Warn("IP resolution failed", "host", host, "error", err)
return nil
// prepopulate the cache
if err = ac.populateCertificateCache(signed); err != nil {
return fmt.Errorf("failed to populate the cache with certificate responses: %w", err)
}
var addrs []net.TCPAddr
for _, ip := range ips {
addrs = append(addrs, net.TCPAddr{IP: ip, Port: port})
}
return addrs
}
// recordResponse takes an AutoConfig RPC response records it with the agent
// This will persist the configuration to disk (unless in dev mode running without
// a data dir) and will reload the configuration.
func (ac *AutoConfig) recordResponse(resp *pbautoconf.AutoConfigResponse) error {
serialized, err := pbMarshaler.MarshalToString(resp)
if err != nil {
return fmt.Errorf("failed to encode auto-config response as JSON: %w", err)
}
if err := ac.update(resp); err != nil {
// update the TLS configurator with the latest certificates
if err := ac.updateTLSFromResponse(resp); err != nil {
return err
}
// now that we know the configuration is generally fine including TLS certs go ahead and persist it to disk.
if ac.config.DataDir == "" {
ac.logger.Debug("not persisting auto-config settings because there is no data directory")
return nil
}
path := filepath.Join(ac.config.DataDir, autoConfigFileName)
err = ioutil.WriteFile(path, []byte(serialized), 0660)
if err != nil {
return fmt.Errorf("failed to write auto-config configurations: %w", err)
}
ac.logger.Debug("auto-config settings were persisted to disk")
return nil
return ac.persistAutoConfig(resp)
}
// getInitialConfigurationOnce will perform full server to TCPAddr resolution and
@ -352,7 +264,7 @@ func (ac *AutoConfig) getInitialConfigurationOnce(ctx context.Context, csr strin
var resp pbautoconf.AutoConfigResponse
servers, err := ac.serverHosts()
servers, err := ac.autoConfigHosts()
if err != nil {
return nil, err
}
@ -369,6 +281,7 @@ func (ac *AutoConfig) getInitialConfigurationOnce(ctx context.Context, csr strin
ac.logger.Error("AutoConfig.InitialConfiguration RPC failed", "addr", addr.String(), "error", err)
continue
}
ac.logger.Debug("AutoConfig.InitialConfiguration RPC was successful")
// update the Certificate with the private key we generated locally
if resp.Certificate != nil {
@ -379,17 +292,17 @@ func (ac *AutoConfig) getInitialConfigurationOnce(ctx context.Context, csr strin
}
}
return nil, ctx.Err()
return nil, fmt.Errorf("No server successfully responded to the auto-config request")
}
// getInitialConfiguration implements a loop to retry calls to getInitialConfigurationOnce.
// It uses the RetryWaiter on the AutoConfig object to control how often to attempt
// the initial configuration process. It is also canceallable by cancelling the provided context.
func (ac *AutoConfig) getInitialConfiguration(ctx context.Context) error {
func (ac *AutoConfig) getInitialConfiguration(ctx context.Context) (*pbautoconf.AutoConfigResponse, error) {
// generate a CSR
csr, key, err := ac.generateCSR()
if err != nil {
return err
return nil, err
}
// this resets the failures so that we will perform immediate request
@ -397,183 +310,95 @@ func (ac *AutoConfig) getInitialConfiguration(ctx context.Context) error {
for {
select {
case <-wait:
resp, err := ac.getInitialConfigurationOnce(ctx, csr, key)
if resp != nil {
return ac.recordResponse(resp)
if resp, err := ac.getInitialConfigurationOnce(ctx, csr, key); err == nil && resp != nil {
return resp, nil
} else if err != nil {
ac.logger.Error(err.Error())
} else {
ac.logger.Error("No error returned when fetching the initial auto-configuration but no response was either")
ac.logger.Error("No error returned when fetching configuration from the servers but no response was either")
}
wait = ac.acConfig.Waiter.Failed()
case <-ctx.Done():
ac.logger.Info("interrupted during initial auto configuration", "err", ctx.Err())
return ctx.Err()
return nil, ctx.Err()
}
}
}
// generateCSR will generate a CSR for an Agent certificate. This should
// be sent along with the AutoConfig.InitialConfiguration RPC. The generated
// CSR does NOT have a real trust domain as when generating this we do
// not yet have the CA roots. The server will update the trust domain
// for us though.
func (ac *AutoConfig) generateCSR() (csr string, key string, err error) {
// We don't provide the correct host here, because we don't know any
// better at this point. Apart from the domain, we would need the
// ClusterID, which we don't have. This is why we go with
// dummyTrustDomain the first time. Subsequent CSRs will have the
// correct TrustDomain.
id := &connect.SpiffeIDAgent{
// will be replaced
Host: dummyTrustDomain,
Datacenter: ac.config.Datacenter,
Agent: ac.config.NodeName,
}
caConfig, err := ac.config.ConnectCAConfiguration()
if err != nil {
return "", "", fmt.Errorf("Cannot generate CSR: %w", err)
}
conf, err := caConfig.GetCommonConfig()
if err != nil {
return "", "", fmt.Errorf("Failed to load common CA configuration: %w", err)
}
if conf.PrivateKeyType == "" {
conf.PrivateKeyType = connect.DefaultPrivateKeyType
}
if conf.PrivateKeyBits == 0 {
conf.PrivateKeyBits = connect.DefaultPrivateKeyBits
}
// Create a new private key
pk, pkPEM, err := connect.GeneratePrivateKeyWithConfig(conf.PrivateKeyType, conf.PrivateKeyBits)
if err != nil {
return "", "", fmt.Errorf("Failed to generate private key: %w", err)
}
dnsNames := append([]string{"localhost"}, ac.config.AutoConfig.DNSSANs...)
ipAddresses := append([]net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::")}, ac.config.AutoConfig.IPSANs...)
// Create a CSR.
//
// The Common Name includes the dummy trust domain for now but Server will
// override this when it is signed anyway so it's OK.
cn := connect.AgentCN(ac.config.NodeName, dummyTrustDomain)
csr, err = connect.CreateCSR(id, cn, pk, dnsNames, ipAddresses)
if err != nil {
return "", "", err
}
return csr, pkPEM, nil
}
// update will take an AutoConfigResponse and do all things necessary
// to restore those settings. This currently involves updating the
// config data to be used during a call to ReadConfig, updating the
// tls Configurator and prepopulating the cache.
func (ac *AutoConfig) update(resp *pbautoconf.AutoConfigResponse) error {
ac.autoConfigResponse = resp
ac.autoConfigSource = config.LiteralSource{
Name: autoConfigFileName,
Config: translateConfig(resp.Config),
}
if err := ac.updateTLSFromResponse(resp); err != nil {
return err
}
return nil
}
// updateTLSFromResponse will update the TLS certificate and roots in the shared
// TLS configurator.
func (ac *AutoConfig) updateTLSFromResponse(resp *pbautoconf.AutoConfigResponse) error {
if ac.certMonitor == nil {
return nil
}
roots, err := translateCARootsToStructs(resp.CARoots)
if err != nil {
return err
}
cert, err := translateIssuedCertToStructs(resp.Certificate)
if err != nil {
return err
}
update := &structs.SignedResponse{
IssuedCert: *cert,
ConnectCARoots: *roots,
ManualCARoots: resp.ExtraCACertificates,
}
if resp.Config != nil && resp.Config.TLS != nil {
update.VerifyServerHostname = resp.Config.TLS.VerifyServerHostname
}
if err := ac.certMonitor.Update(update); err != nil {
return fmt.Errorf("failed to update the certificate monitor: %w", err)
}
return nil
}
func (ac *AutoConfig) Start(ctx context.Context) error {
if ac.certMonitor == nil {
ac.Lock()
defer ac.Unlock()
if !ac.config.AutoConfig.Enabled && !ac.config.AutoEncryptTLS {
return nil
}
if !ac.config.AutoConfig.Enabled {
return nil
if ac.running || ac.cancel != nil {
return fmt.Errorf("AutoConfig is already running")
}
_, err := ac.certMonitor.Start(ctx)
return err
// create the top level context to control the go
// routine executing the `run` method
ctx, cancel := context.WithCancel(ctx)
// create the channel to get cache update events through
// really we should only ever get 10 updates
ac.cacheUpdates = make(chan cache.UpdateEvent, 10)
// setup the cache watches
cancelCertWatches, err := ac.setupCertificateCacheWatches(ctx)
if err != nil {
cancel()
return fmt.Errorf("error setting up cache watches: %w", err)
}
// start the token update notifier
ac.tokenUpdates = ac.acConfig.Tokens.Notify(token.TokenKindAgent)
// store the cancel funcs
ac.cancel = cancel
ac.cancelWatches = cancelCertWatches
ac.running = true
ac.done = make(chan struct{})
go ac.run(ctx, ac.done)
ac.logger.Info("auto-config started")
return nil
}
func (ac *AutoConfig) Done() <-chan struct{} {
ac.Lock()
defer ac.Unlock()
if ac.done != nil {
return ac.done
}
// return a closed channel to indicate that we are already done
done := make(chan struct{})
close(done)
return done
}
func (ac *AutoConfig) IsRunning() bool {
ac.Lock()
defer ac.Unlock()
return ac.running
}
func (ac *AutoConfig) Stop() bool {
if ac.certMonitor == nil {
ac.Lock()
defer ac.Unlock()
if !ac.running {
return false
}
if !ac.config.AutoConfig.Enabled {
return false
if ac.cancel != nil {
ac.cancel()
}
return ac.certMonitor.Stop()
}
func (ac *AutoConfig) FallbackTLS(ctx context.Context) (*structs.SignedResponse, error) {
// generate a CSR
csr, key, err := ac.generateCSR()
if err != nil {
return nil, err
}
resp, err := ac.getInitialConfigurationOnce(ctx, csr, key)
if err != nil {
return nil, err
}
return extractSignedResponse(resp)
}
func (ac *AutoConfig) RecordUpdatedCerts(resp *structs.SignedResponse) error {
var err error
ac.autoConfigResponse.ExtraCACertificates = resp.ManualCARoots
ac.autoConfigResponse.CARoots, err = translateCARootsToProtobuf(&resp.ConnectCARoots)
if err != nil {
return err
}
ac.autoConfigResponse.Certificate, err = translateIssuedCertToProtobuf(&resp.IssuedCert)
if err != nil {
return err
}
return ac.recordResponse(ac.autoConfigResponse)
return true
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,111 @@
package autoconf
import (
"context"
"fmt"
"net"
"strings"
"github.com/hashicorp/consul/agent/structs"
)
func (ac *AutoConfig) autoEncryptInitialCerts(ctx context.Context) (*structs.SignedResponse, error) {
// generate a CSR
csr, key, err := ac.generateCSR()
if err != nil {
return nil, err
}
// this resets the failures so that we will perform immediate request
wait := ac.acConfig.Waiter.Success()
for {
select {
case <-wait:
if resp, err := ac.autoEncryptInitialCertsOnce(ctx, csr, key); err == nil && resp != nil {
return resp, nil
} else if err != nil {
ac.logger.Error(err.Error())
} else {
ac.logger.Error("No error returned when fetching certificates from the servers but no response was either")
}
wait = ac.acConfig.Waiter.Failed()
case <-ctx.Done():
ac.logger.Info("interrupted during retrieval of auto-encrypt certificates", "err", ctx.Err())
return nil, ctx.Err()
}
}
}
func (ac *AutoConfig) autoEncryptInitialCertsOnce(ctx context.Context, csr, key string) (*structs.SignedResponse, error) {
request := structs.CASignRequest{
WriteRequest: structs.WriteRequest{Token: ac.acConfig.Tokens.AgentToken()},
Datacenter: ac.config.Datacenter,
CSR: csr,
}
var resp structs.SignedResponse
servers, err := ac.autoEncryptHosts()
if err != nil {
return nil, err
}
for _, s := range servers {
// try each IP to see if we can successfully make the request
for _, addr := range ac.resolveHost(s) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
ac.logger.Debug("making AutoEncrypt.Sign RPC", "addr", addr.String())
err = ac.acConfig.DirectRPC.RPC(ac.config.Datacenter, ac.config.NodeName, &addr, "AutoEncrypt.Sign", &request, &resp)
if err != nil {
ac.logger.Error("AutoEncrypt.Sign RPC failed", "addr", addr.String(), "error", err)
continue
}
resp.IssuedCert.PrivateKeyPEM = key
return &resp, nil
}
}
return nil, fmt.Errorf("No servers successfully responded to the auto-encrypt request")
}
func (ac *AutoConfig) autoEncryptHosts() ([]string, error) {
// use servers known to gossip if there are any
if ac.acConfig.ServerProvider != nil {
if srv := ac.acConfig.ServerProvider.FindLANServer(); srv != nil {
return []string{srv.Addr.String()}, nil
}
}
hosts, err := ac.discoverServers(ac.config.RetryJoinLAN)
if err != nil {
return nil, err
}
var addrs []string
// The addresses we use for auto-encrypt are the retry join and start join
// addresses. These are for joining serf and therefore we cannot rely on the
// ports for these. This loop strips any port that may have been specified and
// will let subsequent resolveAddr calls add on the default RPC port.
for _, addr := range append(ac.config.StartJoinAddrsLAN, hosts...) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
if strings.Contains(err.Error(), "missing port in address") {
host = addr
} else {
ac.logger.Warn("error splitting host address into IP and port", "address", addr, "error", err)
continue
}
}
addrs = append(addrs, host)
}
if len(addrs) == 0 {
return nil, fmt.Errorf("no auto-encrypt server addresses available for use")
}
return addrs, nil
}

View File

@ -0,0 +1,562 @@
package autoconf
import (
"context"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"fmt"
"net"
"net/url"
"testing"
"time"
"github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestAutoEncrypt_generateCSR(t *testing.T) {
type testCase struct {
conf *config.RuntimeConfig
// to validate the csr
expectedSubject pkix.Name
expectedSigAlg x509.SignatureAlgorithm
expectedPubAlg x509.PublicKeyAlgorithm
expectedDNSNames []string
expectedIPs []net.IP
expectedURIs []*url.URL
}
cases := map[string]testCase{
"ip-sans": {
conf: &config.RuntimeConfig{
Datacenter: "dc1",
NodeName: "test-node",
AutoEncryptTLS: true,
AutoEncryptIPSAN: []net.IP{net.IPv4(198, 18, 0, 1), net.IPv4(198, 18, 0, 2)},
},
expectedSubject: pkix.Name{
CommonName: connect.AgentCN("test-node", unknownTrustDomain),
Names: []pkix.AttributeTypeAndValue{
{
// 2,5,4,3 is the CommonName type ASN1 identifier
Type: asn1.ObjectIdentifier{2, 5, 4, 3},
Value: "testnode.agnt.unknown.consul",
},
},
},
expectedSigAlg: x509.ECDSAWithSHA256,
expectedPubAlg: x509.ECDSA,
expectedDNSNames: defaultDNSSANs,
expectedIPs: append(defaultIPSANs,
net.IP{198, 18, 0, 1},
net.IP{198, 18, 0, 2},
),
expectedURIs: []*url.URL{
{
Scheme: "spiffe",
Host: unknownTrustDomain,
Path: "/agent/client/dc/dc1/id/test-node",
},
},
},
"dns-sans": {
conf: &config.RuntimeConfig{
Datacenter: "dc1",
NodeName: "test-node",
AutoEncryptTLS: true,
AutoEncryptDNSSAN: []string{"foo.local", "bar.local"},
},
expectedSubject: pkix.Name{
CommonName: connect.AgentCN("test-node", unknownTrustDomain),
Names: []pkix.AttributeTypeAndValue{
{
// 2,5,4,3 is the CommonName type ASN1 identifier
Type: asn1.ObjectIdentifier{2, 5, 4, 3},
Value: "testnode.agnt.unknown.consul",
},
},
},
expectedSigAlg: x509.ECDSAWithSHA256,
expectedPubAlg: x509.ECDSA,
expectedDNSNames: append(defaultDNSSANs, "foo.local", "bar.local"),
expectedIPs: defaultIPSANs,
expectedURIs: []*url.URL{
{
Scheme: "spiffe",
Host: unknownTrustDomain,
Path: "/agent/client/dc/dc1/id/test-node",
},
},
},
}
for name, tcase := range cases {
t.Run(name, func(t *testing.T) {
ac := AutoConfig{config: tcase.conf}
csr, _, err := ac.generateCSR()
require.NoError(t, err)
request, err := connect.ParseCSR(csr)
require.NoError(t, err)
require.NotNil(t, request)
require.Equal(t, tcase.expectedSubject, request.Subject)
require.Equal(t, tcase.expectedSigAlg, request.SignatureAlgorithm)
require.Equal(t, tcase.expectedPubAlg, request.PublicKeyAlgorithm)
require.Equal(t, tcase.expectedDNSNames, request.DNSNames)
require.Equal(t, tcase.expectedIPs, request.IPAddresses)
require.Equal(t, tcase.expectedURIs, request.URIs)
})
}
}
func TestAutoEncrypt_hosts(t *testing.T) {
type testCase struct {
serverProvider ServerProvider
config *config.RuntimeConfig
hosts []string
err string
}
providerNone := newMockServerProvider(t)
providerNone.On("FindLANServer").Return(nil).Times(0)
providerWithServer := newMockServerProvider(t)
providerWithServer.On("FindLANServer").Return(&metadata.Server{Addr: &net.TCPAddr{IP: net.IPv4(198, 18, 0, 1), Port: 1234}}).Times(0)
cases := map[string]testCase{
"router-override": {
serverProvider: providerWithServer,
config: &config.RuntimeConfig{
RetryJoinLAN: []string{"127.0.0.1:9876"},
StartJoinAddrsLAN: []string{"192.168.1.2:4321"},
},
hosts: []string{"198.18.0.1:1234"},
},
"various-addresses": {
serverProvider: providerNone,
config: &config.RuntimeConfig{
RetryJoinLAN: []string{"198.18.0.1", "foo.com", "[2001:db8::1234]:1234", "abc.local:9876"},
StartJoinAddrsLAN: []string{"192.168.1.1:5432", "start.local", "[::ffff:172.16.5.4]", "main.dev:6789"},
},
hosts: []string{
"192.168.1.1",
"start.local",
"[::ffff:172.16.5.4]",
"main.dev",
"198.18.0.1",
"foo.com",
"2001:db8::1234",
"abc.local",
},
},
"split-host-port-error": {
serverProvider: providerNone,
config: &config.RuntimeConfig{
StartJoinAddrsLAN: []string{"this-is-not:a:ip:and_port"},
},
err: "no auto-encrypt server addresses available for use",
},
}
for name, tcase := range cases {
t.Run(name, func(t *testing.T) {
ac := AutoConfig{
config: tcase.config,
logger: testutil.Logger(t),
acConfig: Config{
ServerProvider: tcase.serverProvider,
},
}
hosts, err := ac.autoEncryptHosts()
if tcase.err != "" {
testutil.RequireErrorContains(t, err, tcase.err)
} else {
require.NoError(t, err)
require.Equal(t, tcase.hosts, hosts)
}
})
}
}
func TestAutoEncrypt_InitialCerts(t *testing.T) {
token := "1a148388-3dd7-4db4-9eea-520424b4a86a"
datacenter := "foo"
nodeName := "bar"
mcfg := newMockedConfig(t)
_, indexedRoots, cert := testCerts(t, nodeName, datacenter)
// The following are called once for each round through the auto-encrypt initial certs outer loop
// (not the per-host direct rpc attempts but the one involving the RetryWaiter)
mcfg.tokens.On("AgentToken").Return(token).Times(2)
mcfg.serverProvider.On("FindLANServer").Return(nil).Times(2)
request := structs.CASignRequest{
WriteRequest: structs.WriteRequest{Token: token},
Datacenter: datacenter,
// this gets removed by the mock code as its non-deterministic what it will be
CSR: "",
}
// first failure
mcfg.directRPC.On("RPC",
datacenter,
nodeName,
&net.TCPAddr{IP: net.IPv4(198, 18, 0, 1), Port: 8300},
"AutoEncrypt.Sign",
&request,
&structs.SignedResponse{},
).Once().Return(fmt.Errorf("injected error"))
// second failure
mcfg.directRPC.On("RPC",
datacenter,
nodeName,
&net.TCPAddr{IP: net.IPv4(198, 18, 0, 2), Port: 8300},
"AutoEncrypt.Sign",
&request,
&structs.SignedResponse{},
).Once().Return(fmt.Errorf("injected error"))
// third times is successfuly (second attempt to first server)
mcfg.directRPC.On("RPC",
datacenter,
nodeName,
&net.TCPAddr{IP: net.IPv4(198, 18, 0, 1), Port: 8300},
"AutoEncrypt.Sign",
&request,
&structs.SignedResponse{},
).Once().Return(nil).Run(func(args mock.Arguments) {
resp, ok := args.Get(5).(*structs.SignedResponse)
require.True(t, ok)
resp.ConnectCARoots = *indexedRoots
resp.IssuedCert = *cert
resp.VerifyServerHostname = true
})
mcfg.Config.Waiter = lib.NewRetryWaiter(2, 0, 1*time.Millisecond, nil)
ac := AutoConfig{
config: &config.RuntimeConfig{
Datacenter: datacenter,
NodeName: nodeName,
RetryJoinLAN: []string{"198.18.0.1:1234", "198.18.0.2:3456"},
ServerPort: 8300,
},
acConfig: mcfg.Config,
logger: testutil.Logger(t),
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
resp, err := ac.autoEncryptInitialCerts(ctx)
require.NoError(t, err)
require.NotNil(t, resp)
require.True(t, resp.VerifyServerHostname)
require.NotEmpty(t, resp.IssuedCert.PrivateKeyPEM)
resp.IssuedCert.PrivateKeyPEM = ""
cert.PrivateKeyPEM = ""
require.Equal(t, cert, &resp.IssuedCert)
require.Equal(t, indexedRoots, &resp.ConnectCARoots)
require.Empty(t, resp.ManualCARoots)
}
func TestAutoEncrypt_InitialConfiguration(t *testing.T) {
token := "010494ae-ee45-4433-903c-a58c91297714"
nodeName := "auto-encrypt"
datacenter := "dc1"
mcfg := newMockedConfig(t)
loader := setupRuntimeConfig(t)
loader.addConfigHCL(`
auto_encrypt {
tls = true
}
`)
loader.opts.Config.NodeName = &nodeName
mcfg.Config.Loader = loader.Load
indexedRoots, cert, extraCerts := mcfg.setupInitialTLS(t, nodeName, datacenter, token)
// prepopulation is going to grab the token to populate the correct cache key
mcfg.tokens.On("AgentToken").Return(token).Times(0)
// no server provider
mcfg.serverProvider.On("FindLANServer").Return(&metadata.Server{Addr: &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8300}}).Times(1)
populateResponse := func(args mock.Arguments) {
resp, ok := args.Get(5).(*structs.SignedResponse)
require.True(t, ok)
*resp = structs.SignedResponse{
VerifyServerHostname: true,
ConnectCARoots: *indexedRoots,
IssuedCert: *cert,
ManualCARoots: extraCerts,
}
}
expectedRequest := structs.CASignRequest{
WriteRequest: structs.WriteRequest{Token: token},
Datacenter: datacenter,
// TODO (autoconf) Maybe in the future we should populate a CSR
// and do some manual parsing/verification of the contents. The
// bits not having to do with the signing key such as the requested
// SANs and CN. For now though the mockDirectRPC type will empty
// the CSR so we have to pass in an empty string to the expectation.
CSR: "",
}
mcfg.directRPC.On(
"RPC",
datacenter,
nodeName,
&net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8300},
"AutoEncrypt.Sign",
&expectedRequest,
&structs.SignedResponse{}).Return(nil).Run(populateResponse)
ac, err := New(mcfg.Config)
require.NoError(t, err)
require.NotNil(t, ac)
cfg, err := ac.InitialConfiguration(context.Background())
require.NoError(t, err)
require.NotNil(t, cfg)
}
func TestAutoEncrypt_TokenUpdate(t *testing.T) {
testAC := startedAutoConfig(t, true)
newToken := "1a4cc445-86ed-46b4-a355-bbf5a11dddb0"
rootsCtx, rootsCancel := context.WithCancel(context.Background())
testAC.mcfg.cache.On("Notify",
mock.Anything,
cachetype.ConnectCARootName,
&structs.DCSpecificRequest{Datacenter: testAC.ac.config.Datacenter},
rootsWatchID,
mock.Anything,
).Return(nil).Once().Run(func(args mock.Arguments) {
rootsCancel()
})
leafCtx, leafCancel := context.WithCancel(context.Background())
testAC.mcfg.cache.On("Notify",
mock.Anything,
cachetype.ConnectCALeafName,
&cachetype.ConnectCALeafRequest{
Datacenter: "dc1",
Agent: "autoconf",
Token: newToken,
DNSSAN: defaultDNSSANs,
IPSAN: defaultIPSANs,
},
leafWatchID,
mock.Anything,
).Return(nil).Once().Run(func(args mock.Arguments) {
leafCancel()
})
// this will be retrieved once when resetting the leaf cert watch
testAC.mcfg.tokens.On("AgentToken").Return(newToken).Once()
// send the notification about the token update
testAC.tokenUpdates <- struct{}{}
// wait for the leaf cert watches
require.True(t, waitForChans(100*time.Millisecond, leafCtx.Done(), rootsCtx.Done()), "New cache watches were not started within 100ms")
}
func TestAutoEncrypt_RootsUpdate(t *testing.T) {
testAC := startedAutoConfig(t, true)
secondCA := connect.TestCA(t, testAC.initialRoots.Roots[0])
secondRoots := structs.IndexedCARoots{
ActiveRootID: secondCA.ID,
TrustDomain: connect.TestClusterID,
Roots: []*structs.CARoot{
secondCA,
testAC.initialRoots.Roots[0],
},
QueryMeta: structs.QueryMeta{
Index: 99,
},
}
updatedCtx, cancel := context.WithCancel(context.Background())
testAC.mcfg.tlsCfg.On("UpdateAutoTLSCA",
[]string{secondCA.RootCert, testAC.initialRoots.Roots[0].RootCert},
).Return(nil).Once().Run(func(args mock.Arguments) {
cancel()
})
// when a cache event comes in we end up recalculating the fallback timer which requires this call
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Once()
req := structs.DCSpecificRequest{Datacenter: "dc1"}
require.True(t, testAC.mcfg.cache.sendNotification(context.Background(), req.CacheInfo().Key, cache.UpdateEvent{
CorrelationID: rootsWatchID,
Result: &secondRoots,
Meta: cache.ResultMeta{
Index: secondRoots.Index,
},
}))
require.True(t, waitForChans(100*time.Millisecond, updatedCtx.Done()), "TLS certificates were not updated within the alotted time")
}
func TestAutoEncrypt_CertUpdate(t *testing.T) {
testAC := startedAutoConfig(t, true)
secondCert := newLeaf(t, "autoconf", "dc1", testAC.initialRoots.Roots[0], 99, 10*time.Minute)
updatedCtx, cancel := context.WithCancel(context.Background())
testAC.mcfg.tlsCfg.On("UpdateAutoTLSCert",
secondCert.CertPEM,
"redacted",
).Return(nil).Once().Run(func(args mock.Arguments) {
cancel()
})
// when a cache event comes in we end up recalculating the fallback timer which requires this call
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(secondCert.ValidBefore).Once()
req := cachetype.ConnectCALeafRequest{
Datacenter: "dc1",
Agent: "autoconf",
Token: testAC.originalToken,
DNSSAN: defaultDNSSANs,
IPSAN: defaultIPSANs,
}
require.True(t, testAC.mcfg.cache.sendNotification(context.Background(), req.CacheInfo().Key, cache.UpdateEvent{
CorrelationID: leafWatchID,
Result: secondCert,
Meta: cache.ResultMeta{
Index: secondCert.ModifyIndex,
},
}))
require.True(t, waitForChans(100*time.Millisecond, updatedCtx.Done()), "TLS certificates were not updated within the alotted time")
}
func TestAutoEncrypt_Fallback(t *testing.T) {
testAC := startedAutoConfig(t, true)
// at this point everything is operating normally and we are just
// waiting for events. We are going to send a new cert that is basically
// already expired and then allow the fallback routine to kick in.
secondCert := newLeaf(t, "autoconf", "dc1", testAC.initialRoots.Roots[0], 100, time.Nanosecond)
secondCA := connect.TestCA(t, testAC.initialRoots.Roots[0])
secondRoots := structs.IndexedCARoots{
ActiveRootID: secondCA.ID,
TrustDomain: connect.TestClusterID,
Roots: []*structs.CARoot{
secondCA,
testAC.initialRoots.Roots[0],
},
QueryMeta: structs.QueryMeta{
Index: 101,
},
}
thirdCert := newLeaf(t, "autoconf", "dc1", secondCA, 102, 10*time.Minute)
// setup the expectation for when the certs get updated initially
updatedCtx, updateCancel := context.WithCancel(context.Background())
testAC.mcfg.tlsCfg.On("UpdateAutoTLSCert",
secondCert.CertPEM,
"redacted",
).Return(nil).Once().Run(func(args mock.Arguments) {
updateCancel()
})
// when a cache event comes in we end up recalculating the fallback timer which requires this call
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(secondCert.ValidBefore).Once()
testAC.mcfg.tlsCfg.On("AutoEncryptCertExpired").Return(true).Once()
fallbackCtx, fallbackCancel := context.WithCancel(context.Background())
// also testing here that we can change server IPs for ongoing operations
testAC.mcfg.serverProvider.On("FindLANServer").Once().Return(&metadata.Server{
Addr: &net.TCPAddr{IP: net.IPv4(198, 18, 23, 2), Port: 8300},
})
// after sending the notification for the cert update another InitialConfiguration RPC
// will be made to pull down the latest configuration. So we need to set up the response
// for the second RPC
populateResponse := func(args mock.Arguments) {
resp, ok := args.Get(5).(*structs.SignedResponse)
require.True(t, ok)
*resp = structs.SignedResponse{
VerifyServerHostname: true,
ConnectCARoots: secondRoots,
IssuedCert: *thirdCert,
ManualCARoots: testAC.extraCerts,
}
fallbackCancel()
}
expectedRequest := structs.CASignRequest{
WriteRequest: structs.WriteRequest{Token: testAC.originalToken},
Datacenter: "dc1",
// TODO (autoconf) Maybe in the future we should populate a CSR
// and do some manual parsing/verification of the contents. The
// bits not having to do with the signing key such as the requested
// SANs and CN. For now though the mockDirectRPC type will empty
// the CSR so we have to pass in an empty string to the expectation.
CSR: "",
}
// the fallback routine to perform auto-encrypt again will need to grab this
testAC.mcfg.tokens.On("AgentToken").Return(testAC.originalToken).Once()
testAC.mcfg.directRPC.On(
"RPC",
"dc1",
"autoconf",
&net.TCPAddr{IP: net.IPv4(198, 18, 23, 2), Port: 8300},
"AutoEncrypt.Sign",
&expectedRequest,
&structs.SignedResponse{}).Return(nil).Run(populateResponse).Once()
testAC.mcfg.expectInitialTLS(t, "autoconf", "dc1", testAC.originalToken, secondCA, &secondRoots, thirdCert, testAC.extraCerts)
// after the second RPC we now will use the new certs validity period in the next run loop iteration
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Once()
// now that all the mocks are set up we can trigger the whole thing by sending the second expired cert
// as a cache update event.
req := cachetype.ConnectCALeafRequest{
Datacenter: "dc1",
Agent: "autoconf",
Token: testAC.originalToken,
DNSSAN: defaultDNSSANs,
IPSAN: defaultIPSANs,
}
require.True(t, testAC.mcfg.cache.sendNotification(context.Background(), req.CacheInfo().Key, cache.UpdateEvent{
CorrelationID: leafWatchID,
Result: secondCert,
Meta: cache.ResultMeta{
Index: secondCert.ModifyIndex,
},
}))
// wait for the TLS certificates to get updated
require.True(t, waitForChans(100*time.Millisecond, updatedCtx.Done()), "TLS certificates were not updated within the alotted time")
// now wait for the fallback routine to be invoked
require.True(t, waitForChans(100*time.Millisecond, fallbackCtx.Done()), "fallback routines did not get invoked within the alotted time")
}

View File

@ -3,9 +3,12 @@ package autoconf
import (
"context"
"net"
"time"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/go-hclog"
)
@ -18,12 +21,35 @@ type DirectRPC interface {
RPC(dc string, node string, addr net.Addr, method string, args interface{}, reply interface{}) error
}
// CertMonitor is the interface that needs to be satisfied for AutoConfig to be able to
// setup monitoring of the Connect TLS certificate after we first get it.
type CertMonitor interface {
Update(*structs.SignedResponse) error
Start(context.Context) (<-chan struct{}, error)
Stop() bool
// Cache is an interface to represent the methods of the
// agent/cache.Cache struct that we care about
type Cache interface {
Notify(ctx context.Context, t string, r cache.Request, correlationID string, ch chan<- cache.UpdateEvent) error
Prepopulate(t string, result cache.FetchResult, dc string, token string, key string) error
}
// ServerProvider is an interface that can be used to find one server in the local DC known to
// the agent via Gossip
type ServerProvider interface {
FindLANServer() *metadata.Server
}
// TLSConfigurator is an interface of the methods on the tlsutil.Configurator that we will require at
// runtime.
type TLSConfigurator interface {
UpdateAutoTLS(manualCAPEMs, connectCAPEMs []string, pub, priv string, verifyServerHostname bool) error
UpdateAutoTLSCA([]string) error
UpdateAutoTLSCert(pub, priv string) error
AutoEncryptCertNotAfter() time.Time
AutoEncryptCertExpired() bool
}
// TokenStore is an interface of the methods we will need to use from the token.Store.
type TokenStore interface {
AgentToken() string
UpdateAgentToken(secret string, source token.TokenSource) bool
Notify(kind token.TokenKind) token.Notifier
StopNotify(notifier token.Notifier)
}
// Config contains all the tunables for AutoConfig
@ -37,6 +63,10 @@ type Config struct {
// configuration. Setting this field is required.
DirectRPC DirectRPC
// ServerProvider is the interfaced to be used by AutoConfig to find any
// known servers during fallback operations.
ServerProvider ServerProvider
// Waiter is a RetryWaiter to be used during retrieval of the
// initial configuration. When a round of requests fails we will
// wait and eventually make another round of requests (1 round
@ -49,14 +79,28 @@ type Config struct {
// having the test take minutes/hours to complete.
Waiter *lib.RetryWaiter
// CertMonitor is the Connect TLS Certificate Monitor to be used for ongoing
// certificate renewals and connect CA roots updates. This field is not
// strictly required but if not provided the TLS certificates retrieved
// through by the AutoConfig.InitialConfiguration RPC will not be used
// or renewed.
CertMonitor CertMonitor
// Loader merges source with the existing FileSources and returns the complete
// RuntimeConfig.
Loader func(source config.Source) (cfg *config.RuntimeConfig, warnings []string, err error)
// TLSConfigurator is the shared TLS Configurator. AutoConfig will update the
// auto encrypt/auto config certs as they are renewed.
TLSConfigurator TLSConfigurator
// Cache is an object implementing our Cache interface. The Cache
// used at runtime must be able to handle Roots and Leaf Cert watches
Cache Cache
// FallbackLeeway is the amount of time after certificate expiration before
// invoking the fallback routine. If not set this will default to 10s.
FallbackLeeway time.Duration
// FallbackRetry is the duration between Fallback invocations when the configured
// fallback routine returns an error. If not set this will default to 1m.
FallbackRetry time.Duration
// Tokens is the shared token store. It is used to retrieve the current
// agent token as well as getting notifications when that token is updated.
// This field is required.
Tokens TokenStore
}

View File

@ -22,9 +22,9 @@ import (
// package cannot import the agent/config package without running into import cycles.
func translateConfig(c *pbconfig.Config) config.Config {
result := config.Config{
Datacenter: &c.Datacenter,
PrimaryDatacenter: &c.PrimaryDatacenter,
NodeName: &c.NodeName,
Datacenter: stringPtrOrNil(c.Datacenter),
PrimaryDatacenter: stringPtrOrNil(c.PrimaryDatacenter),
NodeName: stringPtrOrNil(c.NodeName),
// only output the SegmentName in the configuration if its non-empty
// this will avoid a warning later when parsing the persisted configuration
SegmentName: stringPtrOrNil(c.SegmentName),
@ -42,13 +42,13 @@ func translateConfig(c *pbconfig.Config) config.Config {
if a := c.ACL; a != nil {
result.ACL = config.ACL{
Enabled: &a.Enabled,
PolicyTTL: &a.PolicyTTL,
RoleTTL: &a.RoleTTL,
TokenTTL: &a.TokenTTL,
DownPolicy: &a.DownPolicy,
DefaultPolicy: &a.DefaultPolicy,
PolicyTTL: stringPtrOrNil(a.PolicyTTL),
RoleTTL: stringPtrOrNil(a.RoleTTL),
TokenTTL: stringPtrOrNil(a.TokenTTL),
DownPolicy: stringPtrOrNil(a.DownPolicy),
DefaultPolicy: stringPtrOrNil(a.DefaultPolicy),
EnableKeyListPolicy: &a.EnableKeyListPolicy,
DisabledTTL: &a.DisabledTTL,
DisabledTTL: stringPtrOrNil(a.DisabledTTL),
EnableTokenPersistence: &a.EnableTokenPersistence,
}
@ -76,7 +76,7 @@ func translateConfig(c *pbconfig.Config) config.Config {
result.RetryJoinLAN = g.RetryJoinLAN
if e := c.Gossip.Encryption; e != nil {
result.EncryptKey = &e.Key
result.EncryptKey = stringPtrOrNil(e.Key)
result.EncryptVerifyIncoming = &e.VerifyIncoming
result.EncryptVerifyOutgoing = &e.VerifyOutgoing
}

View File

@ -1,10 +1,13 @@
package autoconf
import (
"fmt"
"testing"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/structs"
pbconfig "github.com/hashicorp/consul/proto/pbconfig"
"github.com/hashicorp/consul/proto/pbconnect"
"github.com/stretchr/testify/require"
)
@ -16,6 +19,38 @@ func boolPointer(b bool) *bool {
return &b
}
func translateCARootToProtobuf(in *structs.CARoot) (*pbconnect.CARoot, error) {
var out pbconnect.CARoot
if err := mapstructureTranslateToProtobuf(in, &out); err != nil {
return nil, fmt.Errorf("Failed to re-encode CA Roots: %w", err)
}
return &out, nil
}
func mustTranslateCARootToProtobuf(t *testing.T, in *structs.CARoot) *pbconnect.CARoot {
out, err := translateCARootToProtobuf(in)
require.NoError(t, err)
return out
}
func mustTranslateCARootsToStructs(t *testing.T, in *pbconnect.CARoots) *structs.IndexedCARoots {
out, err := translateCARootsToStructs(in)
require.NoError(t, err)
return out
}
func mustTranslateCARootsToProtobuf(t *testing.T, in *structs.IndexedCARoots) *pbconnect.CARoots {
out, err := translateCARootsToProtobuf(in)
require.NoError(t, err)
return out
}
func mustTranslateIssuedCertToProtobuf(t *testing.T, in *structs.IssuedCert) *pbconnect.IssuedCert {
out, err := translateIssuedCertToProtobuf(in)
require.NoError(t, err)
return out
}
func TestTranslateConfig(t *testing.T) {
original := pbconfig.Config{
Datacenter: "abc",
@ -119,3 +154,9 @@ func TestTranslateConfig(t *testing.T) {
translated := translateConfig(&original)
require.Equal(t, expected, translated)
}
func TestCArootsTranslation(t *testing.T) {
_, indexedRoots, _ := testCerts(t, "autoconf", "dc1")
protoRoots := mustTranslateCARootsToProtobuf(t, indexedRoots)
require.Equal(t, indexedRoots, mustTranslateCARootsToStructs(t, protoRoots))
}

View File

@ -0,0 +1,337 @@
package autoconf
import (
"context"
"net"
"sync"
"testing"
"time"
"github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/proto/pbautoconf"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/stretchr/testify/mock"
)
type mockDirectRPC struct {
mock.Mock
}
func newMockDirectRPC(t *testing.T) *mockDirectRPC {
m := mockDirectRPC{}
m.Test(t)
return &m
}
func (m *mockDirectRPC) RPC(dc string, node string, addr net.Addr, method string, args interface{}, reply interface{}) error {
var retValues mock.Arguments
if method == "AutoConfig.InitialConfiguration" {
req := args.(*pbautoconf.AutoConfigRequest)
csr := req.CSR
req.CSR = ""
retValues = m.Called(dc, node, addr, method, args, reply)
req.CSR = csr
} else if method == "AutoEncrypt.Sign" {
req := args.(*structs.CASignRequest)
csr := req.CSR
req.CSR = ""
retValues = m.Called(dc, node, addr, method, args, reply)
req.CSR = csr
} else {
retValues = m.Called(dc, node, addr, method, args, reply)
}
return retValues.Error(0)
}
type mockTLSConfigurator struct {
mock.Mock
}
func newMockTLSConfigurator(t *testing.T) *mockTLSConfigurator {
m := mockTLSConfigurator{}
m.Test(t)
return &m
}
func (m *mockTLSConfigurator) UpdateAutoTLS(manualCAPEMs, connectCAPEMs []string, pub, priv string, verifyServerHostname bool) error {
if priv != "" {
priv = "redacted"
}
ret := m.Called(manualCAPEMs, connectCAPEMs, pub, priv, verifyServerHostname)
return ret.Error(0)
}
func (m *mockTLSConfigurator) UpdateAutoTLSCA(pems []string) error {
ret := m.Called(pems)
return ret.Error(0)
}
func (m *mockTLSConfigurator) UpdateAutoTLSCert(pub, priv string) error {
if priv != "" {
priv = "redacted"
}
ret := m.Called(pub, priv)
return ret.Error(0)
}
func (m *mockTLSConfigurator) AutoEncryptCertNotAfter() time.Time {
ret := m.Called()
ts, _ := ret.Get(0).(time.Time)
return ts
}
func (m *mockTLSConfigurator) AutoEncryptCertExpired() bool {
ret := m.Called()
return ret.Bool(0)
}
type mockServerProvider struct {
mock.Mock
}
func newMockServerProvider(t *testing.T) *mockServerProvider {
m := mockServerProvider{}
m.Test(t)
return &m
}
func (m *mockServerProvider) FindLANServer() *metadata.Server {
ret := m.Called()
srv, _ := ret.Get(0).(*metadata.Server)
return srv
}
type mockWatcher struct {
ch chan<- cache.UpdateEvent
done <-chan struct{}
}
type mockCache struct {
mock.Mock
lock sync.Mutex
watchers map[string][]mockWatcher
}
func newMockCache(t *testing.T) *mockCache {
m := mockCache{
watchers: make(map[string][]mockWatcher),
}
m.Test(t)
return &m
}
func (m *mockCache) Notify(ctx context.Context, t string, r cache.Request, correlationID string, ch chan<- cache.UpdateEvent) error {
ret := m.Called(ctx, t, r, correlationID, ch)
err := ret.Error(0)
if err == nil {
m.lock.Lock()
key := r.CacheInfo().Key
m.watchers[key] = append(m.watchers[key], mockWatcher{ch: ch, done: ctx.Done()})
m.lock.Unlock()
}
return err
}
func (m *mockCache) Prepopulate(t string, result cache.FetchResult, dc string, token string, key string) error {
var restore string
cert, ok := result.Value.(*structs.IssuedCert)
if ok {
// we cannot know what the private key is prior to it being injected into the cache.
// therefore redact it here and all mock expectations should take that into account
restore = cert.PrivateKeyPEM
cert.PrivateKeyPEM = "redacted"
}
ret := m.Called(t, result, dc, token, key)
if ok && restore != "" {
cert.PrivateKeyPEM = restore
}
return ret.Error(0)
}
func (m *mockCache) sendNotification(ctx context.Context, key string, u cache.UpdateEvent) bool {
m.lock.Lock()
defer m.lock.Unlock()
watchers, ok := m.watchers[key]
if !ok || len(m.watchers) < 1 {
return false
}
var newWatchers []mockWatcher
for _, watcher := range watchers {
select {
case watcher.ch <- u:
newWatchers = append(newWatchers, watcher)
case <-watcher.done:
// do nothing, this watcher will be removed from the list
case <-ctx.Done():
// return doesn't matter here really, the test is being cancelled
return true
}
}
// this removes any already cancelled watches from being sent to
m.watchers[key] = newWatchers
return true
}
type mockTokenStore struct {
mock.Mock
}
func newMockTokenStore(t *testing.T) *mockTokenStore {
m := mockTokenStore{}
m.Test(t)
return &m
}
func (m *mockTokenStore) AgentToken() string {
ret := m.Called()
return ret.String(0)
}
func (m *mockTokenStore) UpdateAgentToken(secret string, source token.TokenSource) bool {
return m.Called(secret, source).Bool(0)
}
func (m *mockTokenStore) Notify(kind token.TokenKind) token.Notifier {
ret := m.Called(kind)
n, _ := ret.Get(0).(token.Notifier)
return n
}
func (m *mockTokenStore) StopNotify(notifier token.Notifier) {
m.Called(notifier)
}
type mockedConfig struct {
Config
directRPC *mockDirectRPC
serverProvider *mockServerProvider
cache *mockCache
tokens *mockTokenStore
tlsCfg *mockTLSConfigurator
}
func newMockedConfig(t *testing.T) *mockedConfig {
directRPC := newMockDirectRPC(t)
serverProvider := newMockServerProvider(t)
mcache := newMockCache(t)
tokens := newMockTokenStore(t)
tlsCfg := newMockTLSConfigurator(t)
// I am not sure it is well defined behavior but in testing it
// out it does appear like Cleanup functions can fail tests
// Adding in the mock expectations assertions here saves us
// a bunch of code in the other test functions.
t.Cleanup(func() {
if !t.Failed() {
directRPC.AssertExpectations(t)
serverProvider.AssertExpectations(t)
mcache.AssertExpectations(t)
tokens.AssertExpectations(t)
tlsCfg.AssertExpectations(t)
}
})
return &mockedConfig{
Config: Config{
DirectRPC: directRPC,
ServerProvider: serverProvider,
Cache: mcache,
Tokens: tokens,
TLSConfigurator: tlsCfg,
Logger: testutil.Logger(t),
},
directRPC: directRPC,
serverProvider: serverProvider,
cache: mcache,
tokens: tokens,
tlsCfg: tlsCfg,
}
}
func (m *mockedConfig) expectInitialTLS(t *testing.T, agentName, datacenter, token string, ca *structs.CARoot, indexedRoots *structs.IndexedCARoots, cert *structs.IssuedCert, extraCerts []string) {
var pems []string
for _, root := range indexedRoots.Roots {
pems = append(pems, root.RootCert)
}
// we should update the TLS configurator with the proper certs
m.tlsCfg.On("UpdateAutoTLS",
extraCerts,
pems,
cert.CertPEM,
// auto-config handles the CSR and Key so our tests don't have
// a way to know that the key is correct or not. We do replace
// a non empty PEM with "redacted" so we can ensure that some
// certificate is being sent
"redacted",
true,
).Return(nil).Once()
rootRes := cache.FetchResult{Value: indexedRoots, Index: indexedRoots.QueryMeta.Index}
rootsReq := structs.DCSpecificRequest{Datacenter: datacenter}
// we should prepopulate the cache with the CA roots
m.cache.On("Prepopulate",
cachetype.ConnectCARootName,
rootRes,
datacenter,
"",
rootsReq.CacheInfo().Key,
).Return(nil).Once()
leafReq := cachetype.ConnectCALeafRequest{
Token: token,
Agent: agentName,
Datacenter: datacenter,
}
// copy the cert and redact the private key for the mock expectation
// the actual private key will not correspond to the cert but thats
// because AutoConfig is generated a key/csr internally and sending that
// on up with the request.
copy := *cert
copy.PrivateKeyPEM = "redacted"
leafRes := cache.FetchResult{
Value: &copy,
Index: copy.RaftIndex.ModifyIndex,
State: cachetype.ConnectCALeafSuccess(ca.SigningKeyID),
}
// we should prepopulate the cache with the agents cert
m.cache.On("Prepopulate",
cachetype.ConnectCALeafName,
leafRes,
datacenter,
token,
leafReq.Key(),
).Return(nil).Once()
// when prepopulating the cert in the cache we grab the token so
// we should expec that here
m.tokens.On("AgentToken").Return(token).Once()
}
func (m *mockedConfig) setupInitialTLS(t *testing.T, agentName, datacenter, token string) (*structs.IndexedCARoots, *structs.IssuedCert, []string) {
ca, indexedRoots, cert := testCerts(t, agentName, datacenter)
ca2 := connect.TestCA(t, nil)
extraCerts := []string{ca2.RootCert}
m.expectInitialTLS(t, agentName, datacenter, token, ca, indexedRoots, cert, extraCerts)
return indexedRoots, cert, extraCerts
}

View File

@ -0,0 +1,86 @@
package autoconf
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
"github.com/golang/protobuf/jsonpb"
"github.com/hashicorp/consul/proto/pbautoconf"
)
const (
// autoConfigFileName is the name of the file that the agent auto-config settings are
// stored in within the data directory
autoConfigFileName = "auto-config.json"
)
var (
pbMarshaler = &jsonpb.Marshaler{
OrigName: false,
EnumsAsInts: false,
Indent: " ",
EmitDefaults: true,
}
pbUnmarshaler = &jsonpb.Unmarshaler{
AllowUnknownFields: false,
}
)
func (ac *AutoConfig) readPersistedAutoConfig() (*pbautoconf.AutoConfigResponse, error) {
if ac.config.DataDir == "" {
// no data directory means we don't have anything to potentially load
return nil, nil
}
path := filepath.Join(ac.config.DataDir, autoConfigFileName)
ac.logger.Debug("attempting to restore any persisted configuration", "path", path)
content, err := ioutil.ReadFile(path)
if err == nil {
rdr := strings.NewReader(string(content))
var resp pbautoconf.AutoConfigResponse
if err := pbUnmarshaler.Unmarshal(rdr, &resp); err != nil {
return nil, fmt.Errorf("failed to decode persisted auto-config data: %w", err)
}
ac.logger.Info("read persisted configuration", "path", path)
return &resp, nil
}
if !os.IsNotExist(err) {
return nil, fmt.Errorf("failed to load %s: %w", path, err)
}
// ignore non-existence errors as that is an indicator that we haven't
// performed the auto configuration before
return nil, nil
}
func (ac *AutoConfig) persistAutoConfig(resp *pbautoconf.AutoConfigResponse) error {
// now that we know the configuration is generally fine including TLS certs go ahead and persist it to disk.
if ac.config.DataDir == "" {
ac.logger.Debug("not persisting auto-config settings because there is no data directory")
return nil
}
serialized, err := pbMarshaler.MarshalToString(resp)
if err != nil {
return fmt.Errorf("failed to encode auto-config response as JSON: %w", err)
}
path := filepath.Join(ac.config.DataDir, autoConfigFileName)
err = ioutil.WriteFile(path, []byte(serialized), 0660)
if err != nil {
return fmt.Errorf("failed to write auto-config configurations: %w", err)
}
ac.logger.Debug("auto-config settings were persisted to disk")
return nil
}

192
agent/auto-config/run.go Normal file
View File

@ -0,0 +1,192 @@
package autoconf
import (
"context"
"fmt"
"time"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
)
// handleCacheEvent is used to handle event notifications from the cache for the roots
// or leaf cert watches.
func (ac *AutoConfig) handleCacheEvent(u cache.UpdateEvent) error {
switch u.CorrelationID {
case rootsWatchID:
ac.logger.Debug("roots watch fired - updating CA certificates")
if u.Err != nil {
return fmt.Errorf("root watch returned an error: %w", u.Err)
}
roots, ok := u.Result.(*structs.IndexedCARoots)
if !ok {
return fmt.Errorf("invalid type for roots watch response: %T", u.Result)
}
return ac.updateCARoots(roots)
case leafWatchID:
ac.logger.Debug("leaf certificate watch fired - updating TLS certificate")
if u.Err != nil {
return fmt.Errorf("leaf watch returned an error: %w", u.Err)
}
leaf, ok := u.Result.(*structs.IssuedCert)
if !ok {
return fmt.Errorf("invalid type for agent leaf cert watch response: %T", u.Result)
}
return ac.updateLeafCert(leaf)
}
return nil
}
// handleTokenUpdate is used when a notification about the agent token being updated
// is received and various watches need cancelling/restarting to use the new token.
func (ac *AutoConfig) handleTokenUpdate(ctx context.Context) error {
ac.logger.Debug("Agent token updated - resetting watches")
// TODO (autoencrypt) Prepopulate the cache with the new token with
// the existing cache entry with the old token. The certificate doesn't
// need to change just because the token has. However there isn't a
// good way to make that happen and this behavior is benign enough
// that I am going to push off implementing it.
// the agent token has been updated so we must update our leaf cert watch.
// this cancels the current watches before setting up new ones
ac.cancelWatches()
// recreate the chan for cache updates. This is a precautionary measure to ensure
// that we don't accidentally get notified for the new watches being setup before
// a blocking query in the cache returns and sends data to the old chan. In theory
// the code in agent/cache/watch.go should prevent this where we specifically check
// for context cancellation prior to sending the event. However we could cancel
// it after that check and finish setting up the new watches before getting the old
// events. Both the go routine scheduler and the OS thread scheduler would have to
// be acting up for this to happen. Regardless the way to ensure we don't get events
// for the old watches is to simply replace the chan we are expecting them from.
close(ac.cacheUpdates)
ac.cacheUpdates = make(chan cache.UpdateEvent, 10)
// restart watches - this will be done with the correct token
cancelWatches, err := ac.setupCertificateCacheWatches(ctx)
if err != nil {
return fmt.Errorf("failed to restart watches after agent token update: %w", err)
}
ac.cancelWatches = cancelWatches
return nil
}
// handleFallback is used when the current TLS certificate has expired and the normal
// updating mechanisms have failed to renew it quickly enough. This function will
// use the configured fallback mechanism to retrieve a new cert and start monitoring
// that one.
func (ac *AutoConfig) handleFallback(ctx context.Context) error {
ac.logger.Warn("agent's client certificate has expired")
// Background because the context is mainly useful when the agent is first starting up.
switch {
case ac.config.AutoConfig.Enabled:
resp, err := ac.getInitialConfiguration(ctx)
if err != nil {
return fmt.Errorf("error while retrieving new agent certificates via auto-config: %w", err)
}
return ac.recordInitialConfiguration(resp)
case ac.config.AutoEncryptTLS:
reply, err := ac.autoEncryptInitialCerts(ctx)
if err != nil {
return fmt.Errorf("error while retrieving new agent certificate via auto-encrypt: %w", err)
}
return ac.setInitialTLSCertificates(reply)
default:
return fmt.Errorf("logic error: either auto-encrypt or auto-config must be enabled")
}
}
// run is the private method to be spawn by the Start method for
// executing the main monitoring loop.
func (ac *AutoConfig) run(ctx context.Context, exit chan struct{}) {
// The fallbackTimer is used to notify AFTER the agents
// leaf certificate has expired and where we need
// to fall back to the less secure RPC endpoint just like
// if the agent was starting up new.
//
// Check 10sec (fallback leeway duration) after cert
// expires. The agent cache should be handling the expiration
// and renew it before then.
//
// If there is no cert, AutoEncryptCertNotAfter returns
// a value in the past which immediately triggers the
// renew, but this case shouldn't happen because at
// this point, auto_encrypt was just being setup
// successfully.
calcFallbackInterval := func() time.Duration {
certExpiry := ac.acConfig.TLSConfigurator.AutoEncryptCertNotAfter()
return certExpiry.Add(ac.acConfig.FallbackLeeway).Sub(time.Now())
}
fallbackTimer := time.NewTimer(calcFallbackInterval())
// cleanup for once we are stopped
defer func() {
// cancel the go routines performing the cache watches
ac.cancelWatches()
// ensure we don't leak the timers go routine
fallbackTimer.Stop()
// stop receiving notifications for token updates
ac.acConfig.Tokens.StopNotify(ac.tokenUpdates)
ac.logger.Debug("auto-config has been stopped")
ac.Lock()
ac.cancel = nil
ac.running = false
// this should be the final cleanup task as its what notifies
// the rest of the world that this go routine has exited.
close(exit)
ac.Unlock()
}()
for {
select {
case <-ctx.Done():
ac.logger.Debug("stopping auto-config")
return
case <-ac.tokenUpdates.Ch:
ac.logger.Debug("handling a token update event")
if err := ac.handleTokenUpdate(ctx); err != nil {
ac.logger.Error("error in handling token update event", "error", err)
}
case u := <-ac.cacheUpdates:
ac.logger.Debug("handling a cache update event", "correlation_id", u.CorrelationID)
if err := ac.handleCacheEvent(u); err != nil {
ac.logger.Error("error in handling cache update event", "error", err)
}
// reset the fallback timer as the certificate may have been updated
fallbackTimer.Stop()
fallbackTimer = time.NewTimer(calcFallbackInterval())
case <-fallbackTimer.C:
// This is a safety net in case the cert doesn't get renewed
// in time. The agent would be stuck in that case because the watches
// never use the AutoEncrypt.Sign endpoint.
// check auto encrypt client cert expiration
if ac.acConfig.TLSConfigurator.AutoEncryptCertExpired() {
if err := ac.handleFallback(ctx); err != nil {
ac.logger.Error("error when handling a certificate expiry event", "error", err)
fallbackTimer = time.NewTimer(ac.acConfig.FallbackRetry)
} else {
fallbackTimer = time.NewTimer(calcFallbackInterval())
}
} else {
// this shouldn't be possible. We calculate the timer duration to be the certificate
// expiration time + some leeway (10s default). So whenever we get here the certificate
// should be expired. Regardless its probably worth resetting the timer.
fallbackTimer = time.NewTimer(calcFallbackInterval())
}
}
}
}

View File

@ -0,0 +1,111 @@
package autoconf
import (
"fmt"
"net"
"strconv"
"strings"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/go-discover"
discoverk8s "github.com/hashicorp/go-discover/provider/k8s"
"github.com/hashicorp/go-hclog"
)
func (ac *AutoConfig) discoverServers(servers []string) ([]string, error) {
providers := make(map[string]discover.Provider)
for k, v := range discover.Providers {
providers[k] = v
}
providers["k8s"] = &discoverk8s.Provider{}
disco, err := discover.New(
discover.WithUserAgent(lib.UserAgent()),
discover.WithProviders(providers),
)
if err != nil {
return nil, fmt.Errorf("Failed to create go-discover resolver: %w", err)
}
var addrs []string
for _, addr := range servers {
switch {
case strings.Contains(addr, "provider="):
resolved, err := disco.Addrs(addr, ac.logger.StandardLogger(&hclog.StandardLoggerOptions{InferLevels: true}))
if err != nil {
ac.logger.Error("failed to resolve go-discover auto-config servers", "configuration", addr, "err", err)
continue
}
addrs = append(addrs, resolved...)
ac.logger.Debug("discovered auto-config servers", "servers", resolved)
default:
addrs = append(addrs, addr)
}
}
return addrs, nil
}
// autoConfigHosts is responsible for taking the list of server addresses
// and resolving any go-discover provider invocations. It will then return
// a list of hosts. These might be hostnames and is expected that DNS resolution
// may be performed after this function runs. Additionally these may contain
// ports so SplitHostPort could also be necessary.
func (ac *AutoConfig) autoConfigHosts() ([]string, error) {
// use servers known to gossip if there are any
if ac.acConfig.ServerProvider != nil {
if srv := ac.acConfig.ServerProvider.FindLANServer(); srv != nil {
return []string{srv.Addr.String()}, nil
}
}
addrs, err := ac.discoverServers(ac.config.AutoConfig.ServerAddresses)
if err != nil {
return nil, err
}
if len(addrs) == 0 {
return nil, fmt.Errorf("no auto-config server addresses available for use")
}
return addrs, nil
}
// resolveHost will take a single host string and convert it to a list of TCPAddrs
// This will process any port in the input as well as looking up the hostname using
// normal DNS resolution.
func (ac *AutoConfig) resolveHost(hostPort string) []net.TCPAddr {
port := ac.config.ServerPort
host, portStr, err := net.SplitHostPort(hostPort)
if err != nil {
if strings.Contains(err.Error(), "missing port in address") {
host = hostPort
} else {
ac.logger.Warn("error splitting host address into IP and port", "address", hostPort, "error", err)
return nil
}
} else {
port, err = strconv.Atoi(portStr)
if err != nil {
ac.logger.Warn("Parsed port is not an integer", "port", portStr, "error", err)
return nil
}
}
// resolve the host to a list of IPs
ips, err := net.LookupIP(host)
if err != nil {
ac.logger.Warn("IP resolution failed", "host", host, "error", err)
return nil
}
var addrs []net.TCPAddr
for _, ip := range ips {
addrs = append(addrs, net.TCPAddr{IP: ip, Port: port})
}
return addrs
}

280
agent/auto-config/tls.go Normal file
View File

@ -0,0 +1,280 @@
package autoconf
import (
"context"
"fmt"
"net"
"github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbautoconf"
)
const (
// ID of the roots watch
rootsWatchID = "roots"
// ID of the leaf watch
leafWatchID = "leaf"
unknownTrustDomain = "unknown"
)
var (
defaultDNSSANs = []string{"localhost"}
defaultIPSANs = []net.IP{{127, 0, 0, 1}, net.ParseIP("::1")}
)
func extractPEMs(roots *structs.IndexedCARoots) []string {
var pems []string
for _, root := range roots.Roots {
pems = append(pems, root.RootCert)
}
return pems
}
// updateTLSFromResponse will update the TLS certificate and roots in the shared
// TLS configurator.
func (ac *AutoConfig) updateTLSFromResponse(resp *pbautoconf.AutoConfigResponse) error {
var pems []string
for _, root := range resp.GetCARoots().GetRoots() {
pems = append(pems, root.RootCert)
}
err := ac.acConfig.TLSConfigurator.UpdateAutoTLS(
resp.ExtraCACertificates,
pems,
resp.Certificate.GetCertPEM(),
resp.Certificate.GetPrivateKeyPEM(),
resp.Config.GetTLS().GetVerifyServerHostname(),
)
if err != nil {
return fmt.Errorf("Failed to update the TLS configurator with new certificates: %w", err)
}
return nil
}
func (ac *AutoConfig) setInitialTLSCertificates(certs *structs.SignedResponse) error {
if certs == nil {
return nil
}
if err := ac.populateCertificateCache(certs); err != nil {
return fmt.Errorf("error populating cache with certificates: %w", err)
}
connectCAPems := extractPEMs(&certs.ConnectCARoots)
err := ac.acConfig.TLSConfigurator.UpdateAutoTLS(
certs.ManualCARoots,
connectCAPems,
certs.IssuedCert.CertPEM,
certs.IssuedCert.PrivateKeyPEM,
certs.VerifyServerHostname,
)
if err != nil {
return fmt.Errorf("error updating TLS configurator with certificates: %w", err)
}
return nil
}
func (ac *AutoConfig) populateCertificateCache(certs *structs.SignedResponse) error {
cert, err := connect.ParseCert(certs.IssuedCert.CertPEM)
if err != nil {
return fmt.Errorf("Failed to parse certificate: %w", err)
}
// prepolutate roots cache
rootRes := cache.FetchResult{Value: &certs.ConnectCARoots, Index: certs.ConnectCARoots.QueryMeta.Index}
rootsReq := ac.caRootsRequest()
// getting the roots doesn't require a token so in order to potentially share the cache with another
if err := ac.acConfig.Cache.Prepopulate(cachetype.ConnectCARootName, rootRes, ac.config.Datacenter, "", rootsReq.CacheInfo().Key); err != nil {
return err
}
leafReq := ac.leafCertRequest()
// prepolutate leaf cache
certRes := cache.FetchResult{
Value: &certs.IssuedCert,
Index: certs.IssuedCert.RaftIndex.ModifyIndex,
State: cachetype.ConnectCALeafSuccess(connect.EncodeSigningKeyID(cert.AuthorityKeyId)),
}
if err := ac.acConfig.Cache.Prepopulate(cachetype.ConnectCALeafName, certRes, leafReq.Datacenter, leafReq.Token, leafReq.Key()); err != nil {
return err
}
return nil
}
func (ac *AutoConfig) setupCertificateCacheWatches(ctx context.Context) (context.CancelFunc, error) {
notificationCtx, cancel := context.WithCancel(ctx)
rootsReq := ac.caRootsRequest()
err := ac.acConfig.Cache.Notify(notificationCtx, cachetype.ConnectCARootName, &rootsReq, rootsWatchID, ac.cacheUpdates)
if err != nil {
cancel()
return nil, err
}
leafReq := ac.leafCertRequest()
err = ac.acConfig.Cache.Notify(notificationCtx, cachetype.ConnectCALeafName, &leafReq, leafWatchID, ac.cacheUpdates)
if err != nil {
cancel()
return nil, err
}
return cancel, nil
}
func (ac *AutoConfig) updateCARoots(roots *structs.IndexedCARoots) error {
switch {
case ac.config.AutoConfig.Enabled:
ac.Lock()
defer ac.Unlock()
var err error
ac.autoConfigResponse.CARoots, err = translateCARootsToProtobuf(roots)
if err != nil {
return err
}
if err := ac.updateTLSFromResponse(ac.autoConfigResponse); err != nil {
return err
}
return ac.persistAutoConfig(ac.autoConfigResponse)
case ac.config.AutoEncryptTLS:
pems := extractPEMs(roots)
if err := ac.acConfig.TLSConfigurator.UpdateAutoTLSCA(pems); err != nil {
return fmt.Errorf("failed to update Connect CA certificates: %w", err)
}
return nil
default:
return nil
}
}
func (ac *AutoConfig) updateLeafCert(cert *structs.IssuedCert) error {
switch {
case ac.config.AutoConfig.Enabled:
ac.Lock()
defer ac.Unlock()
var err error
ac.autoConfigResponse.Certificate, err = translateIssuedCertToProtobuf(cert)
if err != nil {
return err
}
if err := ac.updateTLSFromResponse(ac.autoConfigResponse); err != nil {
return err
}
return ac.persistAutoConfig(ac.autoConfigResponse)
case ac.config.AutoEncryptTLS:
if err := ac.acConfig.TLSConfigurator.UpdateAutoTLSCert(cert.CertPEM, cert.PrivateKeyPEM); err != nil {
return fmt.Errorf("failed to update the agent leaf cert: %w", err)
}
return nil
default:
return nil
}
}
func (ac *AutoConfig) caRootsRequest() structs.DCSpecificRequest {
return structs.DCSpecificRequest{Datacenter: ac.config.Datacenter}
}
func (ac *AutoConfig) leafCertRequest() cachetype.ConnectCALeafRequest {
return cachetype.ConnectCALeafRequest{
Datacenter: ac.config.Datacenter,
Agent: ac.config.NodeName,
DNSSAN: ac.getDNSSANs(),
IPSAN: ac.getIPSANs(),
Token: ac.acConfig.Tokens.AgentToken(),
}
}
// generateCSR will generate a CSR for an Agent certificate. This should
// be sent along with the AutoConfig.InitialConfiguration RPC or the
// AutoEncrypt.Sign RPC. The generated CSR does NOT have a real trust domain
// as when generating this we do not yet have the CA roots. The server will
// update the trust domain for us though.
func (ac *AutoConfig) generateCSR() (csr string, key string, err error) {
// We don't provide the correct host here, because we don't know any
// better at this point. Apart from the domain, we would need the
// ClusterID, which we don't have. This is why we go with
// unknownTrustDomain the first time. Subsequent CSRs will have the
// correct TrustDomain.
id := &connect.SpiffeIDAgent{
// will be replaced
Host: unknownTrustDomain,
Datacenter: ac.config.Datacenter,
Agent: ac.config.NodeName,
}
caConfig, err := ac.config.ConnectCAConfiguration()
if err != nil {
return "", "", fmt.Errorf("Cannot generate CSR: %w", err)
}
conf, err := caConfig.GetCommonConfig()
if err != nil {
return "", "", fmt.Errorf("Failed to load common CA configuration: %w", err)
}
if conf.PrivateKeyType == "" {
conf.PrivateKeyType = connect.DefaultPrivateKeyType
}
if conf.PrivateKeyBits == 0 {
conf.PrivateKeyBits = connect.DefaultPrivateKeyBits
}
// Create a new private key
pk, pkPEM, err := connect.GeneratePrivateKeyWithConfig(conf.PrivateKeyType, conf.PrivateKeyBits)
if err != nil {
return "", "", fmt.Errorf("Failed to generate private key: %w", err)
}
dnsNames := ac.getDNSSANs()
ipAddresses := ac.getIPSANs()
// Create a CSR.
//
// The Common Name includes the dummy trust domain for now but Server will
// override this when it is signed anyway so it's OK.
cn := connect.AgentCN(ac.config.NodeName, unknownTrustDomain)
csr, err = connect.CreateCSR(id, cn, pk, dnsNames, ipAddresses)
if err != nil {
return "", "", err
}
return csr, pkPEM, nil
}
func (ac *AutoConfig) getDNSSANs() []string {
sans := defaultDNSSANs
switch {
case ac.config.AutoConfig.Enabled:
sans = append(sans, ac.config.AutoConfig.DNSSANs...)
case ac.config.AutoEncryptTLS:
sans = append(sans, ac.config.AutoEncryptDNSSAN...)
}
return sans
}
func (ac *AutoConfig) getIPSANs() []net.IP {
sans := defaultIPSANs
switch {
case ac.config.AutoConfig.Enabled:
sans = append(sans, ac.config.AutoConfig.IPSANs...)
case ac.config.AutoEncryptTLS:
sans = append(sans, ac.config.AutoEncryptIPSAN...)
}
return sans
}

View File

@ -0,0 +1,56 @@
package autoconf
import (
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/require"
)
func newLeaf(t *testing.T, agentName, datacenter string, ca *structs.CARoot, idx uint64, expiration time.Duration) *structs.IssuedCert {
t.Helper()
pub, priv, err := connect.TestAgentLeaf(t, agentName, datacenter, ca, expiration)
require.NoError(t, err)
cert, err := connect.ParseCert(pub)
require.NoError(t, err)
spiffeID, err := connect.ParseCertURI(cert.URIs[0])
require.NoError(t, err)
agentID, ok := spiffeID.(*connect.SpiffeIDAgent)
require.True(t, ok, "certificate doesn't have an agent leaf cert URI")
return &structs.IssuedCert{
SerialNumber: cert.SerialNumber.String(),
CertPEM: pub,
PrivateKeyPEM: priv,
ValidAfter: cert.NotBefore,
ValidBefore: cert.NotAfter,
Agent: agentID.Agent,
AgentURI: agentID.URI().String(),
EnterpriseMeta: *structs.DefaultEnterpriseMeta(),
RaftIndex: structs.RaftIndex{
CreateIndex: idx,
ModifyIndex: idx,
},
}
}
func testCerts(t *testing.T, agentName, datacenter string) (*structs.CARoot, *structs.IndexedCARoots, *structs.IssuedCert) {
ca := connect.TestCA(t, nil)
ca.IntermediateCerts = make([]string, 0)
cert := newLeaf(t, agentName, datacenter, ca, 1, 10*time.Minute)
indexedRoots := structs.IndexedCARoots{
ActiveRootID: ca.ID,
TrustDomain: connect.TestClusterID,
Roots: []*structs.CARoot{
ca,
},
QueryMeta: structs.QueryMeta{Index: 1},
}
return ca, &indexedRoots, cert
}

File diff suppressed because one or more lines are too long

38
agent/cache/cache.go vendored
View File

@ -144,16 +144,26 @@ type Options struct {
EntryFetchRate rate.Limit
}
// New creates a new cache with the given RPC client and reasonable defaults.
// Further settings can be tweaked on the returned value.
func New(options Options) *Cache {
// Equal return true if both options are equivalent
func (o Options) Equal(other Options) bool {
return o.EntryFetchMaxBurst == other.EntryFetchMaxBurst && o.EntryFetchRate == other.EntryFetchRate
}
// applyDefaultValuesOnOptions set default values on options and returned updated value
func applyDefaultValuesOnOptions(options Options) Options {
if options.EntryFetchRate == 0.0 {
options.EntryFetchRate = DefaultEntryFetchRate
}
if options.EntryFetchMaxBurst == 0 {
options.EntryFetchMaxBurst = DefaultEntryFetchMaxBurst
}
return options
}
// New creates a new cache with the given RPC client and reasonable defaults.
// Further settings can be tweaked on the returned value.
func New(options Options) *Cache {
options = applyDefaultValuesOnOptions(options)
// Initialize the heap. The buffer of 1 is really important because
// its possible for the expiry loop to trigger the heap to update
// itself and it'd block forever otherwise.
@ -234,6 +244,28 @@ func (c *Cache) RegisterType(n string, typ Type) {
c.types[n] = typeEntry{Name: n, Type: typ, Opts: &opts}
}
// ReloadOptions updates the cache with the new options
// return true if Cache is updated, false if already up to date
func (c *Cache) ReloadOptions(options Options) bool {
options = applyDefaultValuesOnOptions(options)
modified := !options.Equal(c.options)
if modified {
c.entriesLock.RLock()
defer c.entriesLock.RUnlock()
for _, entry := range c.entries {
if c.options.EntryFetchRate != options.EntryFetchRate {
entry.FetchRateLimiter.SetLimit(options.EntryFetchRate)
}
if c.options.EntryFetchMaxBurst != options.EntryFetchMaxBurst {
entry.FetchRateLimiter.SetBurst(options.EntryFetchMaxBurst)
}
}
c.options.EntryFetchRate = options.EntryFetchRate
c.options.EntryFetchMaxBurst = options.EntryFetchMaxBurst
}
return modified
}
// Get loads the data for the given type and request. If data satisfying the
// minimum index is present in the cache, it is returned immediately. Otherwise,
// this will block until the data is available or the request timeout is

View File

@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)
// Test a basic Get with no indexes (and therefore no blocking queries).
@ -1220,6 +1221,64 @@ func TestCacheGet_nonBlockingType(t *testing.T) {
typ.AssertExpectations(t)
}
// Test a get with an index set will wait until an index that is higher
// is set in the cache.
func TestCacheReload(t *testing.T) {
t.Parallel()
typ1 := TestType(t)
defer typ1.AssertExpectations(t)
c := New(Options{EntryFetchRate: rate.Limit(1), EntryFetchMaxBurst: 1})
c.RegisterType("t1", typ1)
typ1.Mock.On("Fetch", mock.Anything, mock.Anything).Return(FetchResult{Value: 42, Index: 42}, nil).Maybe()
require.False(t, c.ReloadOptions(Options{EntryFetchRate: rate.Limit(1), EntryFetchMaxBurst: 1}), "Value should not be reloaded")
_, meta, err := c.Get(context.Background(), "t1", TestRequest(t, RequestInfo{Key: "hello1", MinIndex: uint64(1)}))
require.NoError(t, err)
require.Equal(t, meta.Index, uint64(42))
testEntry := func(t *testing.T, doTest func(t *testing.T, entry cacheEntry)) {
c.entriesLock.Lock()
tEntry, ok := c.types["t1"]
require.True(t, ok)
keyName := makeEntryKey("t1", "", "", "hello1")
ok, entryValid, entry := c.getEntryLocked(tEntry, keyName, RequestInfo{})
require.True(t, ok)
require.True(t, entryValid)
doTest(t, entry)
c.entriesLock.Unlock()
}
testEntry(t, func(t *testing.T, entry cacheEntry) {
require.Equal(t, entry.FetchRateLimiter.Limit(), rate.Limit(1))
require.Equal(t, entry.FetchRateLimiter.Burst(), 1)
})
// Modify only rateLimit
require.True(t, c.ReloadOptions(Options{EntryFetchRate: rate.Limit(100), EntryFetchMaxBurst: 1}))
testEntry(t, func(t *testing.T, entry cacheEntry) {
require.Equal(t, entry.FetchRateLimiter.Limit(), rate.Limit(100))
require.Equal(t, entry.FetchRateLimiter.Burst(), 1)
})
// Modify only Burst
require.True(t, c.ReloadOptions(Options{EntryFetchRate: rate.Limit(100), EntryFetchMaxBurst: 5}))
testEntry(t, func(t *testing.T, entry cacheEntry) {
require.Equal(t, entry.FetchRateLimiter.Limit(), rate.Limit(100))
require.Equal(t, entry.FetchRateLimiter.Burst(), 5)
})
// Modify only Burst and Limit at the same time
require.True(t, c.ReloadOptions(Options{EntryFetchRate: rate.Limit(1000), EntryFetchMaxBurst: 42}))
testEntry(t, func(t *testing.T, entry cacheEntry) {
require.Equal(t, entry.FetchRateLimiter.Limit(), rate.Limit(1000))
require.Equal(t, entry.FetchRateLimiter.Burst(), 42)
})
}
// TestCacheThrottle checks the assumptions for the cache throttling. It sets
// up a cache with Options{EntryFetchRate: 10.0, EntryFetchMaxBurst: 1}, which
// allows for 10req/s, or one request every 100ms.

View File

@ -60,7 +60,7 @@ func TestCacheNotifyChResult(t testing.T, ch <-chan UpdateEvent, expected ...Upd
}
got := make([]UpdateEvent, 0, expectLen)
timeoutCh := time.After(50 * time.Millisecond)
timeoutCh := time.After(75 * time.Millisecond)
OUT:
for {
@ -74,7 +74,7 @@ OUT:
}
case <-timeoutCh:
t.Fatalf("got %d results on chan in 50ms, want %d", len(got), expectLen)
t.Fatalf("timeout while waiting for result: got %d results on chan, want %d", len(got), expectLen)
}
}

View File

@ -258,7 +258,7 @@ func TestCacheNotifyPolling(t *testing.T) {
}
require.Equal(events[0].Result, 42)
require.Equal(events[0].Meta.Hit, false)
require.Equal(events[0].Meta.Hit && events[1].Meta.Hit, false)
require.Equal(events[0].Meta.Index, uint64(1))
require.True(events[0].Meta.Age < 50*time.Millisecond)
require.NoError(events[0].Err)

View File

@ -1,505 +0,0 @@
package certmon
import (
"context"
"fmt"
"io/ioutil"
"sync"
"time"
"github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/go-hclog"
)
const (
// ID of the roots watch
rootsWatchID = "roots"
// ID of the leaf watch
leafWatchID = "leaf"
)
// Cache is an interface to represent the methods of the
// agent/cache.Cache struct that we care about
type Cache interface {
Notify(ctx context.Context, t string, r cache.Request, correlationID string, ch chan<- cache.UpdateEvent) error
Prepopulate(t string, result cache.FetchResult, dc string, token string, key string) error
}
// CertMonitor will setup the proper watches to ensure that
// the Agent's Connect TLS certificate remains up to date
type CertMonitor struct {
logger hclog.Logger
cache Cache
tlsConfigurator *tlsutil.Configurator
tokens *token.Store
leafReq cachetype.ConnectCALeafRequest
rootsReq structs.DCSpecificRequest
persist PersistFunc
fallback FallbackFunc
fallbackLeeway time.Duration
fallbackRetry time.Duration
l sync.Mutex
running bool
// cancel is used to cancel the entire CertMonitor
// go routine. This is the main field protected
// by the mutex as it being non-nil indicates that
// the go routine has been started and is stoppable.
// note that it doesn't indcate that the go routine
// is currently running.
cancel context.CancelFunc
// cancelWatches is used to cancel the existing
// cache watches. This is mainly only necessary
// when the Agent token changes
cancelWatches context.CancelFunc
// cacheUpdates is the chan used to have the cache
// send us back events
cacheUpdates chan cache.UpdateEvent
// tokenUpdates is the struct used to receive
// events from the token store when the Agent
// token is updated.
tokenUpdates token.Notifier
// this is used to keep a local copy of the certs
// keys and ca certs. It will be used to persist
// all of the local state at once.
certs structs.SignedResponse
}
// New creates a new CertMonitor for automatically rotating
// an Agent's Connect Certificate
func New(config *Config) (*CertMonitor, error) {
logger := config.Logger
if logger == nil {
logger = hclog.New(&hclog.LoggerOptions{
Level: 0,
Output: ioutil.Discard,
})
}
if config.FallbackLeeway == 0 {
config.FallbackLeeway = 10 * time.Second
}
if config.FallbackRetry == 0 {
config.FallbackRetry = time.Minute
}
if config.Cache == nil {
return nil, fmt.Errorf("CertMonitor creation requires a Cache")
}
if config.TLSConfigurator == nil {
return nil, fmt.Errorf("CertMonitor creation requires a TLS Configurator")
}
if config.Fallback == nil {
return nil, fmt.Errorf("CertMonitor creation requires specifying a FallbackFunc")
}
if config.Datacenter == "" {
return nil, fmt.Errorf("CertMonitor creation requires specifying the datacenter")
}
if config.NodeName == "" {
return nil, fmt.Errorf("CertMonitor creation requires specifying the agent's node name")
}
if config.Tokens == nil {
return nil, fmt.Errorf("CertMonitor creation requires specifying a token store")
}
return &CertMonitor{
logger: logger,
cache: config.Cache,
tokens: config.Tokens,
tlsConfigurator: config.TLSConfigurator,
persist: config.Persist,
fallback: config.Fallback,
fallbackLeeway: config.FallbackLeeway,
fallbackRetry: config.FallbackRetry,
rootsReq: structs.DCSpecificRequest{Datacenter: config.Datacenter},
leafReq: cachetype.ConnectCALeafRequest{
Datacenter: config.Datacenter,
Agent: config.NodeName,
DNSSAN: config.DNSSANs,
IPSAN: config.IPSANs,
},
}, nil
}
// Update is responsible for priming the cache with the certificates
// as well as injecting them into the TLS configurator
func (m *CertMonitor) Update(certs *structs.SignedResponse) error {
if certs == nil {
return nil
}
m.certs = *certs
if err := m.populateCache(certs); err != nil {
return fmt.Errorf("error populating cache with certificates: %w", err)
}
connectCAPems := []string{}
for _, ca := range certs.ConnectCARoots.Roots {
connectCAPems = append(connectCAPems, ca.RootCert)
}
// Note that its expected that the private key be within the IssuedCert in the
// SignedResponse. This isn't how a server would send back the response and requires
// that the recipient of the response who also has access to the private key will
// have filled it in. The Cache definitely does this but auto-encrypt/auto-config
// will need to ensure the original response is setup this way too.
err := m.tlsConfigurator.UpdateAutoTLS(
certs.ManualCARoots,
connectCAPems,
certs.IssuedCert.CertPEM,
certs.IssuedCert.PrivateKeyPEM,
certs.VerifyServerHostname)
if err != nil {
return fmt.Errorf("error updating TLS configurator with certificates: %w", err)
}
return nil
}
// populateCache is responsible for inserting the certificates into the cache
func (m *CertMonitor) populateCache(resp *structs.SignedResponse) error {
cert, err := connect.ParseCert(resp.IssuedCert.CertPEM)
if err != nil {
return fmt.Errorf("Failed to parse certificate: %w", err)
}
// prepolutate roots cache
rootRes := cache.FetchResult{Value: &resp.ConnectCARoots, Index: resp.ConnectCARoots.QueryMeta.Index}
// getting the roots doesn't require a token so in order to potentially share the cache with another
if err := m.cache.Prepopulate(cachetype.ConnectCARootName, rootRes, m.rootsReq.Datacenter, "", m.rootsReq.CacheInfo().Key); err != nil {
return err
}
// copy the template and update the token
leafReq := m.leafReq
leafReq.Token = m.tokens.AgentToken()
// prepolutate leaf cache
certRes := cache.FetchResult{
Value: &resp.IssuedCert,
Index: resp.ConnectCARoots.QueryMeta.Index,
State: cachetype.ConnectCALeafSuccess(connect.EncodeSigningKeyID(cert.AuthorityKeyId)),
}
if err := m.cache.Prepopulate(cachetype.ConnectCALeafName, certRes, leafReq.Datacenter, leafReq.Token, leafReq.Key()); err != nil {
return err
}
return nil
}
// Start spawns the go routine to monitor the certificate and ensure it is
// rotated/renewed as necessary. The chan will indicate once the started
// go routine has exited
func (m *CertMonitor) Start(ctx context.Context) (<-chan struct{}, error) {
m.l.Lock()
defer m.l.Unlock()
if m.running || m.cancel != nil {
return nil, fmt.Errorf("the CertMonitor is already running")
}
// create the top level context to control the go
// routine executing the `run` method
ctx, cancel := context.WithCancel(ctx)
// create the channel to get cache update events through
// really we should only ever get 10 updates
m.cacheUpdates = make(chan cache.UpdateEvent, 10)
// setup the cache watches
cancelWatches, err := m.setupCacheWatches(ctx)
if err != nil {
cancel()
return nil, fmt.Errorf("error setting up cache watches: %w", err)
}
// start the token update notifier
m.tokenUpdates = m.tokens.Notify(token.TokenKindAgent)
// store the cancel funcs
m.cancel = cancel
m.cancelWatches = cancelWatches
m.running = true
exit := make(chan struct{})
go m.run(ctx, exit)
m.logger.Info("certificate monitor started")
return exit, nil
}
// Stop manually stops the go routine spawned by Start and
// returns whether the go routine was still running before
// cancelling.
//
// Note that cancelling the context passed into Start will
// also cause the go routine to stop
func (m *CertMonitor) Stop() bool {
m.l.Lock()
defer m.l.Unlock()
if !m.running {
return false
}
if m.cancel != nil {
m.cancel()
}
return true
}
// IsRunning returns whether the go routine to perform certificate monitoring
// is already running.
func (m *CertMonitor) IsRunning() bool {
m.l.Lock()
defer m.l.Unlock()
return m.running
}
// setupCacheWatches will start both the roots and leaf cert watch with a new child
// context and an up to date ACL token. The watches are started with a new child context
// whose CancelFunc is also returned.
func (m *CertMonitor) setupCacheWatches(ctx context.Context) (context.CancelFunc, error) {
notificationCtx, cancel := context.WithCancel(ctx)
// copy the request
rootsReq := m.rootsReq
err := m.cache.Notify(notificationCtx, cachetype.ConnectCARootName, &rootsReq, rootsWatchID, m.cacheUpdates)
if err != nil {
cancel()
return nil, err
}
// copy the request
leafReq := m.leafReq
leafReq.Token = m.tokens.AgentToken()
err = m.cache.Notify(notificationCtx, cachetype.ConnectCALeafName, &leafReq, leafWatchID, m.cacheUpdates)
if err != nil {
cancel()
return nil, err
}
return cancel, nil
}
// handleCacheEvent is used to handle event notifications from the cache for the roots
// or leaf cert watches.
func (m *CertMonitor) handleCacheEvent(u cache.UpdateEvent) error {
switch u.CorrelationID {
case rootsWatchID:
m.logger.Debug("roots watch fired - updating CA certificates")
if u.Err != nil {
return fmt.Errorf("root watch returned an error: %w", u.Err)
}
roots, ok := u.Result.(*structs.IndexedCARoots)
if !ok {
return fmt.Errorf("invalid type for roots watch response: %T", u.Result)
}
m.certs.ConnectCARoots = *roots
var pems []string
for _, root := range roots.Roots {
pems = append(pems, root.RootCert)
}
if err := m.tlsConfigurator.UpdateAutoTLSCA(pems); err != nil {
return fmt.Errorf("failed to update Connect CA certificates: %w", err)
}
if m.persist != nil {
copy := m.certs
if err := m.persist(&copy); err != nil {
return fmt.Errorf("failed to persist certificate package: %w", err)
}
}
case leafWatchID:
m.logger.Debug("leaf certificate watch fired - updating TLS certificate")
if u.Err != nil {
return fmt.Errorf("leaf watch returned an error: %w", u.Err)
}
leaf, ok := u.Result.(*structs.IssuedCert)
if !ok {
return fmt.Errorf("invalid type for agent leaf cert watch response: %T", u.Result)
}
m.certs.IssuedCert = *leaf
if err := m.tlsConfigurator.UpdateAutoTLSCert(leaf.CertPEM, leaf.PrivateKeyPEM); err != nil {
return fmt.Errorf("failed to update the agent leaf cert: %w", err)
}
if m.persist != nil {
copy := m.certs
if err := m.persist(&copy); err != nil {
return fmt.Errorf("failed to persist certificate package: %w", err)
}
}
}
return nil
}
// handleTokenUpdate is used when a notification about the agent token being updated
// is received and various watches need cancelling/restarting to use the new token.
func (m *CertMonitor) handleTokenUpdate(ctx context.Context) error {
m.logger.Debug("Agent token updated - resetting watches")
// TODO (autoencrypt) Prepopulate the cache with the new token with
// the existing cache entry with the old token. The certificate doesn't
// need to change just because the token has. However there isn't a
// good way to make that happen and this behavior is benign enough
// that I am going to push off implementing it.
// the agent token has been updated so we must update our leaf cert watch.
// this cancels the current watches before setting up new ones
m.cancelWatches()
// recreate the chan for cache updates. This is a precautionary measure to ensure
// that we don't accidentally get notified for the new watches being setup before
// a blocking query in the cache returns and sends data to the old chan. In theory
// the code in agent/cache/watch.go should prevent this where we specifically check
// for context cancellation prior to sending the event. However we could cancel
// it after that check and finish setting up the new watches before getting the old
// events. Both the go routine scheduler and the OS thread scheduler would have to
// be acting up for this to happen. Regardless the way to ensure we don't get events
// for the old watches is to simply replace the chan we are expecting them from.
close(m.cacheUpdates)
m.cacheUpdates = make(chan cache.UpdateEvent, 10)
// restart watches - this will be done with the correct token
cancelWatches, err := m.setupCacheWatches(ctx)
if err != nil {
return fmt.Errorf("failed to restart watches after agent token update: %w", err)
}
m.cancelWatches = cancelWatches
return nil
}
// handleFallback is used when the current TLS certificate has expired and the normal
// updating mechanisms have failed to renew it quickly enough. This function will
// use the configured fallback mechanism to retrieve a new cert and start monitoring
// that one.
func (m *CertMonitor) handleFallback(ctx context.Context) error {
m.logger.Warn("agent's client certificate has expired")
// Background because the context is mainly useful when the agent is first starting up.
reply, err := m.fallback(ctx)
if err != nil {
return fmt.Errorf("error when getting new agent certificate: %w", err)
}
if m.persist != nil {
if err := m.persist(reply); err != nil {
return fmt.Errorf("failed to persist certificate package: %w", err)
}
}
return m.Update(reply)
}
// run is the private method to be spawn by the Start method for
// executing the main monitoring loop.
func (m *CertMonitor) run(ctx context.Context, exit chan struct{}) {
// The fallbackTimer is used to notify AFTER the agents
// leaf certificate has expired and where we need
// to fall back to the less secure RPC endpoint just like
// if the agent was starting up new.
//
// Check 10sec (fallback leeway duration) after cert
// expires. The agent cache should be handling the expiration
// and renew it before then.
//
// If there is no cert, AutoEncryptCertNotAfter returns
// a value in the past which immediately triggers the
// renew, but this case shouldn't happen because at
// this point, auto_encrypt was just being setup
// successfully.
calcFallbackInterval := func() time.Duration {
certExpiry := m.tlsConfigurator.AutoEncryptCertNotAfter()
return certExpiry.Add(m.fallbackLeeway).Sub(time.Now())
}
fallbackTimer := time.NewTimer(calcFallbackInterval())
// cleanup for once we are stopped
defer func() {
// cancel the go routines performing the cache watches
m.cancelWatches()
// ensure we don't leak the timers go routine
fallbackTimer.Stop()
// stop receiving notifications for token updates
m.tokens.StopNotify(m.tokenUpdates)
m.logger.Debug("certificate monitor has been stopped")
m.l.Lock()
m.cancel = nil
m.running = false
m.l.Unlock()
// this should be the final cleanup task as its what notifies
// the rest of the world that this go routine has exited.
close(exit)
}()
for {
select {
case <-ctx.Done():
m.logger.Debug("stopping the certificate monitor")
return
case <-m.tokenUpdates.Ch:
m.logger.Debug("handling a token update event")
if err := m.handleTokenUpdate(ctx); err != nil {
m.logger.Error("error in handling token update event", "error", err)
}
case u := <-m.cacheUpdates:
m.logger.Debug("handling a cache update event", "correlation_id", u.CorrelationID)
if err := m.handleCacheEvent(u); err != nil {
m.logger.Error("error in handling cache update event", "error", err)
}
// reset the fallback timer as the certificate may have been updated
fallbackTimer.Stop()
fallbackTimer = time.NewTimer(calcFallbackInterval())
case <-fallbackTimer.C:
// This is a safety net in case the auto_encrypt cert doesn't get renewed
// in time. The agent would be stuck in that case because the watches
// never use the AutoEncrypt.Sign endpoint.
// check auto encrypt client cert expiration
if m.tlsConfigurator.AutoEncryptCertExpired() {
if err := m.handleFallback(ctx); err != nil {
m.logger.Error("error when handling a certificate expiry event", "error", err)
fallbackTimer = time.NewTimer(m.fallbackRetry)
} else {
fallbackTimer = time.NewTimer(calcFallbackInterval())
}
} else {
// this shouldn't be possible. We calculate the timer duration to be the certificate
// expiration time + some leeway (10s default). So whenever we get here the certificate
// should be expired. Regardless its probably worth resetting the timer.
fallbackTimer = time.NewTimer(calcFallbackInterval())
}
}
}
}

View File

@ -1,731 +0,0 @@
package certmon
import (
"context"
"crypto/tls"
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/go-uuid"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type mockFallback struct {
mock.Mock
}
func (m *mockFallback) fallback(ctx context.Context) (*structs.SignedResponse, error) {
ret := m.Called()
resp, _ := ret.Get(0).(*structs.SignedResponse)
return resp, ret.Error(1)
}
type mockPersist struct {
mock.Mock
}
func (m *mockPersist) persist(resp *structs.SignedResponse) error {
return m.Called(resp).Error(0)
}
type mockWatcher struct {
ch chan<- cache.UpdateEvent
done <-chan struct{}
}
type mockCache struct {
mock.Mock
lock sync.Mutex
watchers map[string][]mockWatcher
}
func (m *mockCache) Notify(ctx context.Context, t string, r cache.Request, correlationID string, ch chan<- cache.UpdateEvent) error {
m.lock.Lock()
key := r.CacheInfo().Key
m.watchers[key] = append(m.watchers[key], mockWatcher{ch: ch, done: ctx.Done()})
m.lock.Unlock()
ret := m.Called(t, r, correlationID)
return ret.Error(0)
}
func (m *mockCache) Prepopulate(t string, result cache.FetchResult, dc string, token string, key string) error {
ret := m.Called(t, result, dc, token, key)
return ret.Error(0)
}
func (m *mockCache) sendNotification(ctx context.Context, key string, u cache.UpdateEvent) bool {
m.lock.Lock()
defer m.lock.Unlock()
watchers, ok := m.watchers[key]
if !ok || len(m.watchers) < 1 {
return false
}
var newWatchers []mockWatcher
for _, watcher := range watchers {
select {
case watcher.ch <- u:
newWatchers = append(newWatchers, watcher)
case <-watcher.done:
// do nothing, this watcher will be removed from the list
case <-ctx.Done():
// return doesn't matter here really, the test is being cancelled
return true
}
}
// this removes any already cancelled watches from being sent to
m.watchers[key] = newWatchers
return true
}
func newMockCache(t *testing.T) *mockCache {
mcache := mockCache{watchers: make(map[string][]mockWatcher)}
mcache.Test(t)
return &mcache
}
func waitForChan(timer *time.Timer, ch <-chan struct{}) bool {
select {
case <-timer.C:
return false
case <-ch:
return true
}
}
func waitForChans(timeout time.Duration, chans ...<-chan struct{}) bool {
timer := time.NewTimer(timeout)
defer timer.Stop()
for _, ch := range chans {
if !waitForChan(timer, ch) {
return false
}
}
return true
}
func testTLSConfigurator(t *testing.T) *tlsutil.Configurator {
t.Helper()
logger := testutil.Logger(t)
cfg, err := tlsutil.NewConfigurator(tlsutil.Config{AutoTLS: true}, logger)
require.NoError(t, err)
return cfg
}
func newLeaf(t *testing.T, ca *structs.CARoot, idx uint64, expiration time.Duration) *structs.IssuedCert {
t.Helper()
pub, priv, err := connect.TestAgentLeaf(t, "node", "foo", ca, expiration)
require.NoError(t, err)
cert, err := connect.ParseCert(pub)
require.NoError(t, err)
spiffeID, err := connect.ParseCertURI(cert.URIs[0])
require.NoError(t, err)
agentID, ok := spiffeID.(*connect.SpiffeIDAgent)
require.True(t, ok, "certificate doesn't have an agent leaf cert URI")
return &structs.IssuedCert{
SerialNumber: cert.SerialNumber.String(),
CertPEM: pub,
PrivateKeyPEM: priv,
ValidAfter: cert.NotBefore,
ValidBefore: cert.NotAfter,
Agent: agentID.Agent,
AgentURI: agentID.URI().String(),
EnterpriseMeta: *structs.DefaultEnterpriseMeta(),
RaftIndex: structs.RaftIndex{
CreateIndex: idx,
ModifyIndex: idx,
},
}
}
type testCertMonitor struct {
monitor *CertMonitor
mcache *mockCache
tls *tlsutil.Configurator
tokens *token.Store
fallback *mockFallback
persist *mockPersist
extraCACerts []string
initialCert *structs.IssuedCert
initialRoots *structs.IndexedCARoots
// these are some variables that the CertMonitor was created with
datacenter string
nodeName string
dns []string
ips []net.IP
verifyServerHostname bool
}
func newTestCertMonitor(t *testing.T) testCertMonitor {
t.Helper()
tlsConfigurator := testTLSConfigurator(t)
tokens := new(token.Store)
id, err := uuid.GenerateUUID()
require.NoError(t, err)
tokens.UpdateAgentToken(id, token.TokenSourceConfig)
ca := connect.TestCA(t, nil)
manualCA := connect.TestCA(t, nil)
// this cert is setup to not expire quickly. this will prevent
// the test from accidentally running the fallback routine
// before we want to force that to happen.
issued := newLeaf(t, ca, 1, 10*time.Minute)
indexedRoots := structs.IndexedCARoots{
ActiveRootID: ca.ID,
TrustDomain: connect.TestClusterID,
Roots: []*structs.CARoot{
ca,
},
QueryMeta: structs.QueryMeta{
Index: 1,
},
}
initialCerts := &structs.SignedResponse{
ConnectCARoots: indexedRoots,
IssuedCert: *issued,
ManualCARoots: []string{manualCA.RootCert},
VerifyServerHostname: true,
}
dnsSANs := []string{"test.dev"}
ipSANs := []net.IP{net.IPv4(198, 18, 0, 1)}
fallback := &mockFallback{}
fallback.Test(t)
persist := &mockPersist{}
persist.Test(t)
mcache := newMockCache(t)
rootRes := cache.FetchResult{Value: &indexedRoots, Index: 1}
rootsReq := structs.DCSpecificRequest{Datacenter: "foo"}
mcache.On("Prepopulate", cachetype.ConnectCARootName, rootRes, "foo", "", rootsReq.CacheInfo().Key).Return(nil).Once()
leafReq := cachetype.ConnectCALeafRequest{
Token: tokens.AgentToken(),
Agent: "node",
Datacenter: "foo",
DNSSAN: dnsSANs,
IPSAN: ipSANs,
}
leafRes := cache.FetchResult{
Value: issued,
Index: 1,
State: cachetype.ConnectCALeafSuccess(ca.SigningKeyID),
}
mcache.On("Prepopulate", cachetype.ConnectCALeafName, leafRes, "foo", tokens.AgentToken(), leafReq.Key()).Return(nil).Once()
// we can assert more later but this should always be done.
defer mcache.AssertExpectations(t)
cfg := new(Config).
WithCache(mcache).
WithLogger(testutil.Logger(t)).
WithTLSConfigurator(tlsConfigurator).
WithTokens(tokens).
WithFallback(fallback.fallback).
WithDNSSANs(dnsSANs).
WithIPSANs(ipSANs).
WithDatacenter("foo").
WithNodeName("node").
WithFallbackLeeway(time.Nanosecond).
WithFallbackRetry(time.Millisecond).
WithPersistence(persist.persist)
monitor, err := New(cfg)
require.NoError(t, err)
require.NotNil(t, monitor)
require.NoError(t, monitor.Update(initialCerts))
return testCertMonitor{
monitor: monitor,
tls: tlsConfigurator,
tokens: tokens,
mcache: mcache,
persist: persist,
fallback: fallback,
extraCACerts: []string{manualCA.RootCert},
initialCert: issued,
initialRoots: &indexedRoots,
datacenter: "foo",
nodeName: "node",
dns: dnsSANs,
ips: ipSANs,
verifyServerHostname: true,
}
}
func tlsCertificateFromIssued(t *testing.T, issued *structs.IssuedCert) *tls.Certificate {
t.Helper()
cert, err := tls.X509KeyPair([]byte(issued.CertPEM), []byte(issued.PrivateKeyPEM))
require.NoError(t, err)
return &cert
}
// convenience method to get a TLS Certificate from the intial issued certificate and priv key
func (cm *testCertMonitor) initialTLSCertificate(t *testing.T) *tls.Certificate {
t.Helper()
return tlsCertificateFromIssued(t, cm.initialCert)
}
// just a convenience method to get a list of all the CA pems that we set up regardless
// of manual vs connect.
func (cm *testCertMonitor) initialCACerts() []string {
pems := cm.extraCACerts
for _, root := range cm.initialRoots.Roots {
pems = append(pems, root.RootCert)
}
return pems
}
func (cm *testCertMonitor) assertExpectations(t *testing.T) {
cm.mcache.AssertExpectations(t)
cm.fallback.AssertExpectations(t)
cm.persist.AssertExpectations(t)
}
func TestCertMonitor_InitialCerts(t *testing.T) {
// this also ensures that the cache was prepopulated properly
cm := newTestCertMonitor(t)
// verify that the certificate was injected into the TLS configurator correctly
require.Equal(t, cm.initialTLSCertificate(t), cm.tls.Cert())
// verify that the CA certs (both Connect and manual ones) were injected correctly
require.ElementsMatch(t, cm.initialCACerts(), cm.tls.CAPems())
// verify that the auto-tls verify server hostname setting was injected correctly
require.Equal(t, cm.verifyServerHostname, cm.tls.VerifyServerHostname())
}
func TestCertMonitor_GoRoutineManagement(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cm := newTestCertMonitor(t)
// ensure that the monitor is not running
require.False(t, cm.monitor.IsRunning())
// ensure that nothing bad happens and that it reports as stopped
require.False(t, cm.monitor.Stop())
// we will never send notifications so these just ignore everything
cm.mcache.On("Notify", cachetype.ConnectCARootName, &structs.DCSpecificRequest{Datacenter: cm.datacenter}, rootsWatchID).Return(nil).Times(2)
cm.mcache.On("Notify", cachetype.ConnectCALeafName,
&cachetype.ConnectCALeafRequest{
Token: cm.tokens.AgentToken(),
Datacenter: cm.datacenter,
Agent: cm.nodeName,
DNSSAN: cm.dns,
IPSAN: cm.ips,
},
leafWatchID,
).Return(nil).Times(2)
done, err := cm.monitor.Start(ctx)
require.NoError(t, err)
require.True(t, cm.monitor.IsRunning())
_, err = cm.monitor.Start(ctx)
testutil.RequireErrorContains(t, err, "the CertMonitor is already running")
require.True(t, cm.monitor.Stop())
require.True(t, waitForChans(100*time.Millisecond, done), "monitor didn't shut down")
require.False(t, cm.monitor.IsRunning())
done, err = cm.monitor.Start(ctx)
require.NoError(t, err)
// ensure that context cancellation causes us to stop as well
cancel()
require.True(t, waitForChans(100*time.Millisecond, done))
cm.assertExpectations(t)
}
func startedCertMonitor(t *testing.T) (context.Context, testCertMonitor) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
cm := newTestCertMonitor(t)
rootsCtx, rootsCancel := context.WithCancel(ctx)
defer rootsCancel()
leafCtx, leafCancel := context.WithCancel(ctx)
defer leafCancel()
// initial roots watch
cm.mcache.On("Notify", cachetype.ConnectCARootName,
&structs.DCSpecificRequest{
Datacenter: cm.datacenter,
},
rootsWatchID).
Return(nil).
Once().
Run(func(_ mock.Arguments) {
rootsCancel()
})
// the initial watch after starting the monitor
cm.mcache.On("Notify", cachetype.ConnectCALeafName,
&cachetype.ConnectCALeafRequest{
Token: cm.tokens.AgentToken(),
Datacenter: cm.datacenter,
Agent: cm.nodeName,
DNSSAN: cm.dns,
IPSAN: cm.ips,
},
leafWatchID).
Return(nil).
Once().
Run(func(_ mock.Arguments) {
leafCancel()
})
done, err := cm.monitor.Start(ctx)
require.NoError(t, err)
// this prevents logs after the test finishes
t.Cleanup(func() {
cm.monitor.Stop()
<-done
})
require.True(t,
waitForChans(100*time.Millisecond, rootsCtx.Done(), leafCtx.Done()),
"not all watches were started within the alotted time")
return ctx, cm
}
// This test ensures that the cache watches are restarted with the updated
// token after receiving a token update
func TestCertMonitor_TokenUpdate(t *testing.T) {
ctx, cm := startedCertMonitor(t)
rootsCtx, rootsCancel := context.WithCancel(ctx)
defer rootsCancel()
leafCtx, leafCancel := context.WithCancel(ctx)
defer leafCancel()
newToken := "8e4fe8db-162d-42d8-81ca-710fb2280ad0"
// we expect a new roots watch because when the leaf cert watch is restarted so is the root cert watch
cm.mcache.On("Notify", cachetype.ConnectCARootName,
&structs.DCSpecificRequest{
Datacenter: cm.datacenter,
},
rootsWatchID).
Return(nil).
Once().
Run(func(_ mock.Arguments) {
rootsCancel()
})
secondWatch := &cachetype.ConnectCALeafRequest{
Token: newToken,
Datacenter: cm.datacenter,
Agent: cm.nodeName,
DNSSAN: cm.dns,
IPSAN: cm.ips,
}
// the new watch after updating the token
cm.mcache.On("Notify", cachetype.ConnectCALeafName, secondWatch, leafWatchID).
Return(nil).
Once().
Run(func(args mock.Arguments) {
leafCancel()
})
cm.tokens.UpdateAgentToken(newToken, token.TokenSourceAPI)
require.True(t,
waitForChans(100*time.Millisecond, rootsCtx.Done(), leafCtx.Done()),
"not all watches were restarted within the alotted time")
cm.assertExpectations(t)
}
func TestCertMonitor_RootsUpdate(t *testing.T) {
ctx, cm := startedCertMonitor(t)
secondCA := connect.TestCA(t, cm.initialRoots.Roots[0])
secondRoots := structs.IndexedCARoots{
ActiveRootID: secondCA.ID,
TrustDomain: connect.TestClusterID,
Roots: []*structs.CARoot{
secondCA,
cm.initialRoots.Roots[0],
},
QueryMeta: structs.QueryMeta{
Index: 99,
},
}
cm.persist.On("persist", &structs.SignedResponse{
IssuedCert: *cm.initialCert,
ManualCARoots: cm.extraCACerts,
ConnectCARoots: secondRoots,
VerifyServerHostname: cm.verifyServerHostname,
}).Return(nil).Once()
// assert value of the CA certs prior to updating
require.ElementsMatch(t, cm.initialCACerts(), cm.tls.CAPems())
req := structs.DCSpecificRequest{Datacenter: cm.datacenter}
require.True(t, cm.mcache.sendNotification(ctx, req.CacheInfo().Key, cache.UpdateEvent{
CorrelationID: rootsWatchID,
Result: &secondRoots,
Meta: cache.ResultMeta{
Index: secondRoots.Index,
},
}))
expectedCAs := append(cm.extraCACerts, secondCA.RootCert, cm.initialRoots.Roots[0].RootCert)
// this will wait up to 200ms (8 x 25 ms waits between the 9 requests)
retry.RunWith(&retry.Counter{Count: 9, Wait: 25 * time.Millisecond}, t, func(r *retry.R) {
require.ElementsMatch(r, expectedCAs, cm.tls.CAPems())
})
cm.assertExpectations(t)
}
func TestCertMonitor_CertUpdate(t *testing.T) {
ctx, cm := startedCertMonitor(t)
secondCert := newLeaf(t, cm.initialRoots.Roots[0], 100, 10*time.Minute)
cm.persist.On("persist", &structs.SignedResponse{
IssuedCert: *secondCert,
ManualCARoots: cm.extraCACerts,
ConnectCARoots: *cm.initialRoots,
VerifyServerHostname: cm.verifyServerHostname,
}).Return(nil).Once()
// assert value of cert prior to updating the leaf
require.Equal(t, cm.initialTLSCertificate(t), cm.tls.Cert())
key := cm.monitor.leafReq.CacheInfo().Key
// send the new certificate - this notifies only the watchers utilizing
// the new ACL token
require.True(t, cm.mcache.sendNotification(ctx, key, cache.UpdateEvent{
CorrelationID: leafWatchID,
Result: secondCert,
Meta: cache.ResultMeta{
Index: secondCert.ModifyIndex,
},
}))
tlsCert := tlsCertificateFromIssued(t, secondCert)
// this will wait up to 200ms (8 x 25 ms waits between the 9 requests)
retry.RunWith(&retry.Counter{Count: 9, Wait: 25 * time.Millisecond}, t, func(r *retry.R) {
require.Equal(r, tlsCert, cm.tls.Cert())
})
cm.assertExpectations(t)
}
func TestCertMonitor_Fallback(t *testing.T) {
ctx, cm := startedCertMonitor(t)
// at this point everything is operating normally and the monitor is just
// waiting for events. We are going to send a new cert that is basically
// already expired and then allow the fallback routine to kick in.
secondCert := newLeaf(t, cm.initialRoots.Roots[0], 100, time.Nanosecond)
secondCA := connect.TestCA(t, cm.initialRoots.Roots[0])
secondRoots := structs.IndexedCARoots{
ActiveRootID: secondCA.ID,
TrustDomain: connect.TestClusterID,
Roots: []*structs.CARoot{
secondCA,
cm.initialRoots.Roots[0],
},
QueryMeta: structs.QueryMeta{
Index: 101,
},
}
thirdCert := newLeaf(t, secondCA, 102, 10*time.Minute)
// inject a fallback routine error to check that we rerun it quickly
cm.fallback.On("fallback").Return(nil, fmt.Errorf("induced error")).Once()
fallbackResp := &structs.SignedResponse{
ConnectCARoots: secondRoots,
IssuedCert: *thirdCert,
ManualCARoots: cm.extraCACerts,
VerifyServerHostname: true,
}
// expect the fallback routine to be executed and setup the return
cm.fallback.On("fallback").Return(fallbackResp, nil).Once()
cm.persist.On("persist", &structs.SignedResponse{
IssuedCert: *secondCert,
ConnectCARoots: *cm.initialRoots,
ManualCARoots: cm.extraCACerts,
VerifyServerHostname: cm.verifyServerHostname,
}).Return(nil).Once()
cm.persist.On("persist", fallbackResp).Return(nil).Once()
// Add another roots cache prepopulation expectation which should happen
// in response to executing the fallback mechanism
rootRes := cache.FetchResult{Value: &secondRoots, Index: 101}
rootsReq := structs.DCSpecificRequest{Datacenter: cm.datacenter}
cm.mcache.On("Prepopulate", cachetype.ConnectCARootName, rootRes, cm.datacenter, "", rootsReq.CacheInfo().Key).Return(nil).Once()
// add another leaf cert cache prepopulation expectation which should happen
// in response to executing the fallback mechanism
leafReq := cachetype.ConnectCALeafRequest{
Token: cm.tokens.AgentToken(),
Agent: cm.nodeName,
Datacenter: cm.datacenter,
DNSSAN: cm.dns,
IPSAN: cm.ips,
}
leafRes := cache.FetchResult{
Value: thirdCert,
Index: 101,
State: cachetype.ConnectCALeafSuccess(secondCA.SigningKeyID),
}
cm.mcache.On("Prepopulate", cachetype.ConnectCALeafName, leafRes, leafReq.Datacenter, leafReq.Token, leafReq.Key()).Return(nil).Once()
// nothing in the monitor should be looking at this as its only done
// in response to sending token updates, no need to synchronize
key := cm.monitor.leafReq.CacheInfo().Key
// send the new certificate - this notifies only the watchers utilizing
// the new ACL token
require.True(t, cm.mcache.sendNotification(ctx, key, cache.UpdateEvent{
CorrelationID: leafWatchID,
Result: secondCert,
Meta: cache.ResultMeta{
Index: secondCert.ModifyIndex,
},
}))
// if all went well we would have updated the first certificate which was pretty much expired
// causing the fallback handler to be invoked almost immediately. The fallback routine will
// return the response containing the third cert and second CA roots so now we should wait
// a little while and ensure they were applied to the TLS Configurator
tlsCert := tlsCertificateFromIssued(t, thirdCert)
expectedCAs := append(cm.extraCACerts, secondCA.RootCert, cm.initialRoots.Roots[0].RootCert)
// this will wait up to 200ms (8 x 25 ms waits between the 9 requests)
retry.RunWith(&retry.Counter{Count: 9, Wait: 25 * time.Millisecond}, t, func(r *retry.R) {
require.Equal(r, tlsCert, cm.tls.Cert())
require.ElementsMatch(r, expectedCAs, cm.tls.CAPems())
})
cm.assertExpectations(t)
}
func TestCertMonitor_New_Errors(t *testing.T) {
type testCase struct {
cfg Config
err string
}
fallback := func(_ context.Context) (*structs.SignedResponse, error) {
return nil, fmt.Errorf("Unimplemented")
}
tokens := new(token.Store)
cases := map[string]testCase{
"no-cache": {
cfg: Config{
TLSConfigurator: testTLSConfigurator(t),
Fallback: fallback,
Tokens: tokens,
Datacenter: "foo",
NodeName: "bar",
},
err: "CertMonitor creation requires a Cache",
},
"no-tls-configurator": {
cfg: Config{
Cache: cache.New(cache.Options{}),
Fallback: fallback,
Tokens: tokens,
Datacenter: "foo",
NodeName: "bar",
},
err: "CertMonitor creation requires a TLS Configurator",
},
"no-fallback": {
cfg: Config{
Cache: cache.New(cache.Options{}),
TLSConfigurator: testTLSConfigurator(t),
Tokens: tokens,
Datacenter: "foo",
NodeName: "bar",
},
err: "CertMonitor creation requires specifying a FallbackFunc",
},
"no-tokens": {
cfg: Config{
Cache: cache.New(cache.Options{}),
TLSConfigurator: testTLSConfigurator(t),
Fallback: fallback,
Datacenter: "foo",
NodeName: "bar",
},
err: "CertMonitor creation requires specifying a token store",
},
"no-datacenter": {
cfg: Config{
Cache: cache.New(cache.Options{}),
TLSConfigurator: testTLSConfigurator(t),
Fallback: fallback,
Tokens: tokens,
NodeName: "bar",
},
err: "CertMonitor creation requires specifying the datacenter",
},
"no-node-name": {
cfg: Config{
Cache: cache.New(cache.Options{}),
TLSConfigurator: testTLSConfigurator(t),
Fallback: fallback,
Tokens: tokens,
Datacenter: "foo",
},
err: "CertMonitor creation requires specifying the agent's node name",
},
}
for name, tcase := range cases {
t.Run(name, func(t *testing.T) {
monitor, err := New(&tcase.cfg)
testutil.RequireErrorContains(t, err, tcase.err)
require.Nil(t, monitor)
})
}
}

View File

@ -1,150 +0,0 @@
package certmon
import (
"context"
"net"
"time"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/go-hclog"
)
// FallbackFunc is used when the normal cache watch based Certificate
// updating fails to update the Certificate in time and a different
// method of updating the certificate is required.
type FallbackFunc func(context.Context) (*structs.SignedResponse, error)
// PersistFunc is used to persist the data from a signed response
type PersistFunc func(*structs.SignedResponse) error
type Config struct {
// Logger is the logger to be used while running. If not set
// then no logging will be performed.
Logger hclog.Logger
// TLSConfigurator is where the certificates and roots are set when
// they are updated. This field is required.
TLSConfigurator *tlsutil.Configurator
// Cache is an object implementing our Cache interface. The Cache
// used at runtime must be able to handle Roots and Leaf Cert watches
Cache Cache
// Tokens is the shared token store. It is used to retrieve the current
// agent token as well as getting notifications when that token is updated.
// This field is required.
Tokens *token.Store
// Persist is a function to run when there are new certs or keys
Persist PersistFunc
// Fallback is a function to run when the normal cache updating of the
// agent's certificates has failed to work for one reason or another.
// This field is required.
Fallback FallbackFunc
// FallbackLeeway is the amount of time after certificate expiration before
// invoking the fallback routine. If not set this will default to 10s.
FallbackLeeway time.Duration
// FallbackRetry is the duration between Fallback invocations when the configured
// fallback routine returns an error. If not set this will default to 1m.
FallbackRetry time.Duration
// DNSSANs is a list of DNS SANs that certificate requests should include. This
// field is optional and no extra DNS SANs will be requested if unset. 'localhost'
// is unconditionally requested by the cache implementation.
DNSSANs []string
// IPSANs is a list of IP SANs to include in the certificate signing request. This
// field is optional and no extra IP SANs will be requested if unset. Both '127.0.0.1'
// and '::1' IP SANs are unconditionally requested by the cache implementation.
IPSANs []net.IP
// Datacenter is the datacenter to request certificates within. This filed is required
Datacenter string
// NodeName is the agent's node name to use when requesting certificates. This field
// is required.
NodeName string
}
// WithCache will cause the created CertMonitor type to use the provided Cache
func (cfg *Config) WithCache(cache Cache) *Config {
cfg.Cache = cache
return cfg
}
// WithLogger will cause the created CertMonitor type to use the provided logger
func (cfg *Config) WithLogger(logger hclog.Logger) *Config {
cfg.Logger = logger
return cfg
}
// WithTLSConfigurator will cause the created CertMonitor type to use the provided configurator
func (cfg *Config) WithTLSConfigurator(tlsConfigurator *tlsutil.Configurator) *Config {
cfg.TLSConfigurator = tlsConfigurator
return cfg
}
// WithTokens will cause the created CertMonitor type to use the provided token store
func (cfg *Config) WithTokens(tokens *token.Store) *Config {
cfg.Tokens = tokens
return cfg
}
// WithFallback configures a fallback function to use if the normal update mechanisms
// fail to renew the certificate in time.
func (cfg *Config) WithFallback(fallback FallbackFunc) *Config {
cfg.Fallback = fallback
return cfg
}
// WithDNSSANs configures the CertMonitor to request these DNS SANs when requesting a new
// certificate
func (cfg *Config) WithDNSSANs(sans []string) *Config {
cfg.DNSSANs = sans
return cfg
}
// WithIPSANs configures the CertMonitor to request these IP SANs when requesting a new
// certificate
func (cfg *Config) WithIPSANs(sans []net.IP) *Config {
cfg.IPSANs = sans
return cfg
}
// WithDatacenter configures the CertMonitor to request Certificates in this DC
func (cfg *Config) WithDatacenter(dc string) *Config {
cfg.Datacenter = dc
return cfg
}
// WithNodeName configures the CertMonitor to request Certificates with this agent name
func (cfg *Config) WithNodeName(name string) *Config {
cfg.NodeName = name
return cfg
}
// WithFallbackLeeway configures how long after a certificate expires before attempting to
// generarte a new certificate using the fallback mechanism. The default is 10s.
func (cfg *Config) WithFallbackLeeway(leeway time.Duration) *Config {
cfg.FallbackLeeway = leeway
return cfg
}
// WithFallbackRetry controls how quickly we will make subsequent invocations of
// the fallback func in the case of it erroring out.
func (cfg *Config) WithFallbackRetry(after time.Duration) *Config {
cfg.FallbackRetry = after
return cfg
}
// WithPersistence will configure the CertMonitor to use this callback for persisting
// a new TLS configuration.
func (cfg *Config) WithPersistence(persist PersistFunc) *Config {
cfg.Persist = persist
return cfg
}

View File

@ -22,6 +22,7 @@ import (
"github.com/hashicorp/consul/agent/consul/authmethod/ssoauth"
"github.com/hashicorp/consul/agent/dns"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/lib"
libtempl "github.com/hashicorp/consul/lib/template"
@ -777,6 +778,9 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
if err != nil {
return RuntimeConfig{}, fmt.Errorf("config_entries.bootstrap[%d]: %s", i, err)
}
if err := entry.Normalize(); err != nil {
return RuntimeConfig{}, fmt.Errorf("config_entries.bootstrap[%d]: %s", i, err)
}
if err := entry.Validate(); err != nil {
return RuntimeConfig{}, fmt.Errorf("config_entries.bootstrap[%d]: %s", i, err)
}
@ -796,6 +800,7 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
// ----------------------------------------------------------------
// build runtime config
//
dataDir := b.stringVal(c.DataDir)
rt = RuntimeConfig{
// non-user configurable values
ACLDisabledTTL: b.durationVal("acl.disabled_ttl", c.ACL.DisabledTTL),
@ -834,21 +839,25 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
GossipWANRetransmitMult: b.intVal(c.GossipWAN.RetransmitMult),
// ACL
ACLsEnabled: aclsEnabled,
ACLAgentMasterToken: b.stringValWithDefault(c.ACL.Tokens.AgentMaster, b.stringVal(c.ACLAgentMasterToken)),
ACLAgentToken: b.stringValWithDefault(c.ACL.Tokens.Agent, b.stringVal(c.ACLAgentToken)),
ACLDatacenter: primaryDatacenter,
ACLDefaultPolicy: b.stringValWithDefault(c.ACL.DefaultPolicy, b.stringVal(c.ACLDefaultPolicy)),
ACLDownPolicy: b.stringValWithDefault(c.ACL.DownPolicy, b.stringVal(c.ACLDownPolicy)),
ACLEnableKeyListPolicy: b.boolValWithDefault(c.ACL.EnableKeyListPolicy, b.boolVal(c.ACLEnableKeyListPolicy)),
ACLMasterToken: b.stringValWithDefault(c.ACL.Tokens.Master, b.stringVal(c.ACLMasterToken)),
ACLReplicationToken: b.stringValWithDefault(c.ACL.Tokens.Replication, b.stringVal(c.ACLReplicationToken)),
ACLTokenTTL: b.durationValWithDefault("acl.token_ttl", c.ACL.TokenTTL, b.durationVal("acl_ttl", c.ACLTTL)),
ACLPolicyTTL: b.durationVal("acl.policy_ttl", c.ACL.PolicyTTL),
ACLRoleTTL: b.durationVal("acl.role_ttl", c.ACL.RoleTTL),
ACLToken: b.stringValWithDefault(c.ACL.Tokens.Default, b.stringVal(c.ACLToken)),
ACLTokenReplication: b.boolValWithDefault(c.ACL.TokenReplication, b.boolValWithDefault(c.EnableACLReplication, enableTokenReplication)),
ACLEnableTokenPersistence: b.boolValWithDefault(c.ACL.EnableTokenPersistence, false),
ACLsEnabled: aclsEnabled,
ACLDatacenter: primaryDatacenter,
ACLDefaultPolicy: b.stringValWithDefault(c.ACL.DefaultPolicy, b.stringVal(c.ACLDefaultPolicy)),
ACLDownPolicy: b.stringValWithDefault(c.ACL.DownPolicy, b.stringVal(c.ACLDownPolicy)),
ACLEnableKeyListPolicy: b.boolValWithDefault(c.ACL.EnableKeyListPolicy, b.boolVal(c.ACLEnableKeyListPolicy)),
ACLMasterToken: b.stringValWithDefault(c.ACL.Tokens.Master, b.stringVal(c.ACLMasterToken)),
ACLTokenTTL: b.durationValWithDefault("acl.token_ttl", c.ACL.TokenTTL, b.durationVal("acl_ttl", c.ACLTTL)),
ACLPolicyTTL: b.durationVal("acl.policy_ttl", c.ACL.PolicyTTL),
ACLRoleTTL: b.durationVal("acl.role_ttl", c.ACL.RoleTTL),
ACLTokenReplication: b.boolValWithDefault(c.ACL.TokenReplication, b.boolValWithDefault(c.EnableACLReplication, enableTokenReplication)),
ACLTokens: token.Config{
DataDir: dataDir,
EnablePersistence: b.boolValWithDefault(c.ACL.EnableTokenPersistence, false),
ACLDefaultToken: b.stringValWithDefault(c.ACL.Tokens.Default, b.stringVal(c.ACLToken)),
ACLAgentToken: b.stringValWithDefault(c.ACL.Tokens.Agent, b.stringVal(c.ACLAgentToken)),
ACLAgentMasterToken: b.stringValWithDefault(c.ACL.Tokens.AgentMaster, b.stringVal(c.ACLAgentMasterToken)),
ACLReplicationToken: b.stringValWithDefault(c.ACL.Tokens.Replication, b.stringVal(c.ACLReplicationToken)),
},
// Autopilot
AutopilotCleanupDeadServers: b.boolVal(c.Autopilot.CleanupDeadServers),
@ -954,7 +963,7 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
ConnectTestCALeafRootChangeSpread: b.durationVal("connect.test_ca_leaf_root_change_spread", c.Connect.TestCALeafRootChangeSpread),
ExposeMinPort: exposeMinPort,
ExposeMaxPort: exposeMaxPort,
DataDir: b.stringVal(c.DataDir),
DataDir: dataDir,
Datacenter: datacenter,
DefaultQueryTime: b.durationVal("default_query_time", c.DefaultQueryTime),
DevMode: b.boolVal(b.devMode),
@ -1069,10 +1078,8 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
return RuntimeConfig{}, fmt.Errorf("cache.entry_fetch_rate must be strictly positive, was: %v", rt.Cache.EntryFetchRate)
}
if entCfg, err := b.BuildEnterpriseRuntimeConfig(&c); err != nil {
return RuntimeConfig{}, err
} else {
rt.EnterpriseRuntimeConfig = entCfg
if err := b.BuildEnterpriseRuntimeConfig(&rt, &c); err != nil {
return rt, err
}
if rt.BootstrapExpect == 1 {
@ -1360,7 +1367,8 @@ func (b *Builder) Validate(rt RuntimeConfig) error {
b.warn(err.Error())
}
return nil
err := b.validateEnterpriseConfig(rt)
return err
}
// addrUnique checks if the given address is already in use for another

View File

@ -51,8 +51,12 @@ func (e enterpriseConfigKeyError) Error() string {
return fmt.Sprintf("%q is a Consul Enterprise configuration and will have no effect", e.key)
}
func (_ *Builder) BuildEnterpriseRuntimeConfig(_ *Config) (EnterpriseRuntimeConfig, error) {
return EnterpriseRuntimeConfig{}, nil
func (*Builder) BuildEnterpriseRuntimeConfig(_ *RuntimeConfig, _ *Config) error {
return nil
}
func (*Builder) validateEnterpriseConfig(_ RuntimeConfig) error {
return nil
}
// validateEnterpriseConfig is a function to validate the enterprise specific

View File

@ -9,6 +9,7 @@ import (
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/logging"
@ -63,19 +64,7 @@ type RuntimeConfig struct {
// hcl: acl.enabled = boolean
ACLsEnabled bool
// ACLAgentMasterToken is a special token that has full read and write
// privileges for this agent, and can be used to call agent endpoints
// when no servers are available.
//
// hcl: acl.tokens.agent_master = string
ACLAgentMasterToken string
// ACLAgentToken is the default token used to make requests for the agent
// itself, such as for registering itself with the catalog. If not
// configured, the 'acl_token' will be used.
//
// hcl: acl.tokens.agent = string
ACLAgentToken string
ACLTokens token.Config
// ACLDatacenter is the central datacenter that holds authoritative
// ACL records. This must be the same for the entire cluster.
@ -123,16 +112,6 @@ type RuntimeConfig struct {
// hcl: acl.tokens.master = string
ACLMasterToken string
// ACLReplicationToken is used to replicate data locally from the
// PrimaryDatacenter. Replication is only available on servers in
// datacenters other than the PrimaryDatacenter
//
// DEPRECATED (ACL-Legacy-Compat): Setting this to a non-empty value
// also enables legacy ACL replication if ACLs are enabled and in legacy mode.
//
// hcl: acl.tokens.replication = string
ACLReplicationToken string
// ACLtokenReplication is used to indicate that both tokens and policies
// should be replicated instead of just policies
//
@ -157,16 +136,6 @@ type RuntimeConfig struct {
// hcl: acl.role_ttl = "duration"
ACLRoleTTL time.Duration
// ACLToken is the default token used to make requests if a per-request
// token is not provided. If not configured the 'anonymous' token is used.
//
// hcl: acl.tokens.default = string
ACLToken string
// ACLEnableTokenPersistence determines whether or not tokens set via the agent HTTP API
// should be persisted to disk and reloaded when an agent restarts.
ACLEnableTokenPersistence bool
// AutopilotCleanupDeadServers enables the automatic cleanup of dead servers when new ones
// are added to the peer list. Defaults to true.
//

View File

@ -6,11 +6,9 @@ var entMetaJSON = `{}`
var entRuntimeConfigSanitize = `{}`
var entFullDNSJSONConfig = ``
var entTokenConfigSanitize = `"EnterpriseConfig": {},`
var entFullDNSHCLConfig = ``
var entFullRuntimeConfig = EnterpriseRuntimeConfig{}
func entFullRuntimeConfig(rt *RuntimeConfig) {}
var enterpriseNonVotingServerWarnings []string = []string{enterpriseConfigKeyError{key: "non_voting_server"}.Error()}

View File

@ -21,6 +21,7 @@ import (
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/checks"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/logging"
"github.com/hashicorp/consul/sdk/testutil"
@ -52,6 +53,8 @@ type configTest struct {
func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
dataDir := testutil.TempDir(t, "consul")
defaultEntMeta := structs.DefaultEnterpriseMeta()
tests := []configTest{
// ------------------------------------------------------------
// cmd line flags
@ -1611,7 +1614,7 @@ func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
json: []string{`{ "acl_replication_token": "a" }`},
hcl: []string{`acl_replication_token = "a"`},
patch: func(rt *RuntimeConfig) {
rt.ACLReplicationToken = "a"
rt.ACLTokens.ACLReplicationToken = "a"
rt.ACLTokenReplication = true
rt.DataDir = dataDir
},
@ -3286,17 +3289,15 @@ func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
err: "config_entries.bootstrap[0]: invalid config entry kind: foo",
},
{
desc: "ConfigEntry bootstrap invalid",
desc: "ConfigEntry bootstrap invalid service-defaults",
args: []string{`-data-dir=` + dataDir},
json: []string{`{
"config_entries": {
"bootstrap": [
{
"kind": "proxy-defaults",
"name": "invalid-name",
"config": {
"foo": "bar"
}
"kind": "service-defaults",
"name": "web",
"made_up_key": "blah"
}
]
}
@ -3304,14 +3305,12 @@ func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
hcl: []string{`
config_entries {
bootstrap {
kind = "proxy-defaults"
name = "invalid-name"
config {
foo = "bar"
}
kind = "service-defaults"
name = "web"
made_up_key = "blah"
}
}`},
err: "config_entries.bootstrap[0]: invalid name (\"invalid-name\"), only \"global\" is supported",
err: "config_entries.bootstrap[0]: 1 error occurred:\n\t* invalid config key \"made_up_key\"\n\n",
},
{
desc: "ConfigEntry bootstrap proxy-defaults (snake-case)",
@ -3355,8 +3354,9 @@ func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
rt.DataDir = dataDir
rt.ConfigEntryBootstrap = []structs.ConfigEntry{
&structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal,
Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal,
EnterpriseMeta: *defaultEntMeta,
Config: map[string]interface{}{
"bar": "abc",
"moreconfig": map[string]interface{}{
@ -3412,8 +3412,9 @@ func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
rt.DataDir = dataDir
rt.ConfigEntryBootstrap = []structs.ConfigEntry{
&structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal,
Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal,
EnterpriseMeta: *defaultEntMeta,
Config: map[string]interface{}{
"bar": "abc",
"moreconfig": map[string]interface{}{
@ -3436,6 +3437,10 @@ func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
{
"kind": "service-defaults",
"name": "web",
"meta" : {
"foo": "bar",
"gir": "zim"
},
"protocol": "http",
"external_sni": "abc-123",
"mesh_gateway": {
@ -3450,6 +3455,10 @@ func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
bootstrap {
kind = "service-defaults"
name = "web"
meta {
"foo" = "bar"
"gir" = "zim"
}
protocol = "http"
external_sni = "abc-123"
mesh_gateway {
@ -3461,10 +3470,15 @@ func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
rt.DataDir = dataDir
rt.ConfigEntryBootstrap = []structs.ConfigEntry{
&structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "web",
Protocol: "http",
ExternalSNI: "abc-123",
Kind: structs.ServiceDefaults,
Name: "web",
Meta: map[string]string{
"foo": "bar",
"gir": "zim",
},
EnterpriseMeta: *defaultEntMeta,
Protocol: "http",
ExternalSNI: "abc-123",
MeshGateway: structs.MeshGatewayConfig{
Mode: structs.MeshGatewayModeRemote,
},
@ -3481,6 +3495,10 @@ func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
{
"Kind": "service-defaults",
"Name": "web",
"Meta" : {
"foo": "bar",
"gir": "zim"
},
"Protocol": "http",
"ExternalSNI": "abc-123",
"MeshGateway": {
@ -3495,6 +3513,10 @@ func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
bootstrap {
Kind = "service-defaults"
Name = "web"
Meta {
"foo" = "bar"
"gir" = "zim"
}
Protocol = "http"
ExternalSNI = "abc-123"
MeshGateway {
@ -3506,10 +3528,15 @@ func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
rt.DataDir = dataDir
rt.ConfigEntryBootstrap = []structs.ConfigEntry{
&structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "web",
Protocol: "http",
ExternalSNI: "abc-123",
Kind: structs.ServiceDefaults,
Name: "web",
Meta: map[string]string{
"foo": "bar",
"gir": "zim",
},
EnterpriseMeta: *defaultEntMeta,
Protocol: "http",
ExternalSNI: "abc-123",
MeshGateway: structs.MeshGatewayConfig{
Mode: structs.MeshGatewayModeRemote,
},
@ -3526,6 +3553,10 @@ func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
{
"kind": "service-router",
"name": "main",
"meta" : {
"foo": "bar",
"gir": "zim"
},
"routes": [
{
"match": {
@ -3610,6 +3641,10 @@ func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
bootstrap {
kind = "service-router"
name = "main"
meta {
"foo" = "bar"
"gir" = "zim"
}
routes = [
{
match {
@ -3693,6 +3728,11 @@ func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
&structs.ServiceRouterConfigEntry{
Kind: structs.ServiceRouter,
Name: "main",
Meta: map[string]string{
"foo": "bar",
"gir": "zim",
},
EnterpriseMeta: *defaultEntMeta,
Routes: []structs.ServiceRoute{
{
Match: &structs.ServiceRouteMatch{
@ -3772,6 +3812,8 @@ func TestBuilder_BuildAndValide_ConfigFlagsAndEdgecases(t *testing.T) {
}
},
},
// TODO(rb): add in missing tests for ingress-gateway (snake + camel)
// TODO(rb): add in missing tests for terminating-gateway (snake + camel)
///////////////////////////////////
// Defaults sanity checks
@ -4345,6 +4387,13 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
if tt.patch != nil {
tt.patch(&expected)
}
// both DataDir fields should always be the same, so test for the
// invariant, and than updated the expected, so that every test
// case does not need to set this field.
require.Equal(t, actual.DataDir, actual.ACLTokens.DataDir)
expected.ACLTokens.DataDir = actual.ACLTokens.DataDir
require.Equal(t, expected, actual)
})
}
@ -4380,6 +4429,8 @@ func TestFullConfig(t *testing.T) {
return n
}
defaultEntMeta := structs.DefaultEnterpriseMeta()
flagSrc := []string{`-dev`}
src := map[string]string{
"json": `{
@ -5836,20 +5887,24 @@ func TestFullConfig(t *testing.T) {
// user configurable values
ACLAgentMasterToken: "64fd0e08",
ACLAgentToken: "bed2377c",
ACLTokens: token.Config{
EnablePersistence: true,
DataDir: dataDir,
ACLDefaultToken: "418fdff1",
ACLAgentToken: "bed2377c",
ACLAgentMasterToken: "64fd0e08",
ACLReplicationToken: "5795983a",
},
ACLsEnabled: true,
ACLDatacenter: "ejtmd43d",
ACLDefaultPolicy: "72c2e7a0",
ACLDownPolicy: "03eb2aee",
ACLEnableKeyListPolicy: true,
ACLEnableTokenPersistence: true,
ACLMasterToken: "8a19ac27",
ACLReplicationToken: "5795983a",
ACLTokenTTL: 3321 * time.Second,
ACLPolicyTTL: 1123 * time.Second,
ACLRoleTTL: 9876 * time.Second,
ACLToken: "418fdff1",
ACLTokenReplication: true,
AdvertiseAddrLAN: ipAddr("17.99.29.16"),
AdvertiseAddrWAN: ipAddr("78.63.37.19"),
@ -5953,8 +6008,9 @@ func TestFullConfig(t *testing.T) {
ClientAddrs: []*net.IPAddr{ipAddr("93.83.18.19")},
ConfigEntryBootstrap: []structs.ConfigEntry{
&structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal,
Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal,
EnterpriseMeta: *defaultEntMeta,
Config: map[string]interface{}{
"foo": "bar",
// has to be a float due to being a map[string]interface
@ -6477,9 +6533,10 @@ func TestFullConfig(t *testing.T) {
"args": []interface{}{"dltjDJ2a", "flEa7C2d"},
},
},
EnterpriseRuntimeConfig: entFullRuntimeConfig,
}
entFullRuntimeConfig(&want)
warns := []string{
`The 'acl_datacenter' field is deprecated. Use the 'primary_datacenter' field instead.`,
`bootstrap_expect > 0: expecting 53 servers`,
@ -6796,21 +6853,25 @@ func TestSanitize(t *testing.T) {
}
rtJSON := `{
"ACLAgentMasterToken": "hidden",
"ACLAgentToken": "hidden",
"ACLTokens": {
` + entTokenConfigSanitize + `
"ACLAgentMasterToken": "hidden",
"ACLAgentToken": "hidden",
"ACLDefaultToken": "hidden",
"ACLReplicationToken": "hidden",
"DataDir": "",
"EnablePersistence": false
},
"ACLDatacenter": "",
"ACLDefaultPolicy": "",
"ACLDisabledTTL": "0s",
"ACLDownPolicy": "",
"ACLEnableKeyListPolicy": false,
"ACLEnableTokenPersistence": false,
"ACLMasterToken": "hidden",
"ACLPolicyTTL": "0s",
"ACLReplicationToken": "hidden",
"ACLRoleTTL": "0s",
"ACLTokenReplication": false,
"ACLTokenTTL": "0s",
"ACLToken": "hidden",
"ACLsEnabled": false,
"AEInterval": "0s",
"AdvertiseAddrLAN": "",

View File

@ -44,13 +44,13 @@ func ParseConsulCAConfig(raw map[string]interface{}) (*structs.ConsulCAProviderC
func defaultConsulCAProviderConfig() structs.ConsulCAProviderConfig {
return structs.ConsulCAProviderConfig{
CommonCAProviderConfig: defaultCommonConfig(),
IntermediateCertTTL: 24 * 365 * time.Hour,
}
}
func defaultCommonConfig() structs.CommonCAProviderConfig {
return structs.CommonCAProviderConfig{
LeafCertTTL: 3 * 24 * time.Hour,
PrivateKeyType: connect.DefaultPrivateKeyType,
PrivateKeyBits: connect.DefaultPrivateKeyBits,
LeafCertTTL: 3 * 24 * time.Hour,
IntermediateCertTTL: 24 * 365 * time.Hour,
PrivateKeyType: connect.DefaultPrivateKeyType,
PrivateKeyBits: connect.DefaultPrivateKeyBits,
}
}

View File

@ -26,12 +26,13 @@ func TestStructs_CAConfiguration_MsgpackEncodeDecode(t *testing.T) {
"PrivateKeyBits": int64(4096),
}
expectCommonBase := &structs.CommonCAProviderConfig{
LeafCertTTL: 30 * time.Hour,
SkipValidate: true,
CSRMaxPerSecond: 5.25,
CSRMaxConcurrent: 55,
PrivateKeyType: "rsa",
PrivateKeyBits: 4096,
LeafCertTTL: 30 * time.Hour,
IntermediateCertTTL: 90 * time.Hour,
SkipValidate: true,
CSRMaxPerSecond: 5.25,
CSRMaxConcurrent: 55,
PrivateKeyType: "rsa",
PrivateKeyBits: 4096,
}
cases := map[string]testcase{
@ -60,7 +61,6 @@ func TestStructs_CAConfiguration_MsgpackEncodeDecode(t *testing.T) {
PrivateKey: "key",
RootCert: "cert",
RotationPeriod: 5 * time.Minute,
IntermediateCertTTL: 90 * time.Hour,
DisableCrossSigning: true,
},
parseFunc: func(t *testing.T, raw map[string]interface{}) interface{} {
@ -86,6 +86,7 @@ func TestStructs_CAConfiguration_MsgpackEncodeDecode(t *testing.T) {
"Token": "token",
"RootPKIPath": "root-pki/",
"IntermediatePKIPath": "im-pki/",
"IntermediateCertTTL": "90h",
"CAFile": "ca-file",
"CAPath": "ca-path",
"CertFile": "cert-file",
@ -126,8 +127,9 @@ func TestStructs_CAConfiguration_MsgpackEncodeDecode(t *testing.T) {
ModifyIndex: 99,
},
Config: map[string]interface{}{
"ExistingARN": "arn://foo",
"DeleteOnExit": true,
"ExistingARN": "arn://foo",
"DeleteOnExit": true,
"IntermediateCertTTL": "90h",
},
},
expectConfig: &structs.AWSCAProviderConfig{

View File

@ -231,7 +231,7 @@ func (v *VaultProvider) setupIntermediatePKIPath() error {
Type: "pki",
Description: "intermediate CA backend for Consul Connect",
Config: vaultapi.MountConfigInput{
MaxLeaseTTL: "2160h",
MaxLeaseTTL: v.config.IntermediateCertTTL.String(),
},
})

View File

@ -38,9 +38,10 @@ var badParams = []KeyConfig{
func makeConfig(kc KeyConfig) structs.CommonCAProviderConfig {
return structs.CommonCAProviderConfig{
LeafCertTTL: 3 * 24 * time.Hour,
PrivateKeyType: kc.keyType,
PrivateKeyBits: kc.keyBits,
LeafCertTTL: 3 * 24 * time.Hour,
IntermediateCertTTL: 365 * 24 * time.Hour,
PrivateKeyType: kc.keyType,
PrivateKeyBits: kc.keyBits,
}
}

View File

@ -1639,8 +1639,8 @@ func TestACLResolver_Client(t *testing.T) {
// effectively disable caching - so the only way we end up with 1 token read is if they were
// being resolved concurrently
config.Config.ACLTokenTTL = 0 * time.Second
config.Config.ACLPolicyTTL = 30 * time.Millisecond
config.Config.ACLRoleTTL = 30 * time.Millisecond
config.Config.ACLPolicyTTL = 30 * time.Second
config.Config.ACLRoleTTL = 30 * time.Second
config.Config.ACLDownPolicy = "extend-cache"
})

View File

@ -1,239 +0,0 @@
package consul
import (
"context"
"fmt"
"net"
"strings"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
)
const (
dummyTrustDomain = "dummy.trustdomain"
retryJitterWindow = 30 * time.Second
)
func (c *Client) autoEncryptCSR(extraDNSSANs []string, extraIPSANs []net.IP) (string, string, error) {
// We don't provide the correct host here, because we don't know any
// better at this point. Apart from the domain, we would need the
// ClusterID, which we don't have. This is why we go with
// dummyTrustDomain the first time. Subsequent CSRs will have the
// correct TrustDomain.
id := &connect.SpiffeIDAgent{
Host: dummyTrustDomain,
Datacenter: c.config.Datacenter,
Agent: c.config.NodeName,
}
conf, err := c.config.CAConfig.GetCommonConfig()
if err != nil {
return "", "", err
}
if conf.PrivateKeyType == "" {
conf.PrivateKeyType = connect.DefaultPrivateKeyType
}
if conf.PrivateKeyBits == 0 {
conf.PrivateKeyBits = connect.DefaultPrivateKeyBits
}
// Create a new private key
pk, pkPEM, err := connect.GeneratePrivateKeyWithConfig(conf.PrivateKeyType, conf.PrivateKeyBits)
if err != nil {
return "", "", err
}
dnsNames := append([]string{"localhost"}, extraDNSSANs...)
ipAddresses := append([]net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}, extraIPSANs...)
// Create a CSR.
//
// The Common Name includes the dummy trust domain for now but Server will
// override this when it is signed anyway so it's OK.
cn := connect.AgentCN(c.config.NodeName, dummyTrustDomain)
csr, err := connect.CreateCSR(id, cn, pk, dnsNames, ipAddresses)
if err != nil {
return "", "", err
}
return pkPEM, csr, nil
}
func (c *Client) RequestAutoEncryptCerts(ctx context.Context, servers []string, port int, token string, extraDNSSANs []string, extraIPSANs []net.IP) (*structs.SignedResponse, error) {
errFn := func(err error) (*structs.SignedResponse, error) {
return nil, err
}
// Check if we know about a server already through gossip. Depending on
// how the agent joined, there might already be one. Also in case this
// gets called because the cert expired.
server := c.routers.FindServer()
if server != nil {
servers = []string{server.Addr.String()}
}
if len(servers) == 0 {
return errFn(fmt.Errorf("No servers to request AutoEncrypt.Sign"))
}
pkPEM, csr, err := c.autoEncryptCSR(extraDNSSANs, extraIPSANs)
if err != nil {
return errFn(err)
}
// Prepare request and response so that it can be passed to
// RPCInsecure.
args := structs.CASignRequest{
WriteRequest: structs.WriteRequest{Token: token},
Datacenter: c.config.Datacenter,
CSR: csr,
}
var reply structs.SignedResponse
// Retry implementation modeled after https://github.com/hashicorp/consul/pull/5228.
// TLDR; there is a 30s window from which a random time is picked.
// Repeat until the call is successful.
attempts := 0
for {
select {
case <-ctx.Done():
return errFn(fmt.Errorf("aborting AutoEncrypt because interrupted: %w", ctx.Err()))
default:
}
// Translate host to net.TCPAddr to make life easier for
// RPCInsecure.
for _, s := range servers {
ips, err := resolveAddr(s, c.logger)
if err != nil {
c.logger.Warn("AutoEncrypt resolveAddr failed", "error", err)
continue
}
for _, ip := range ips {
addr := net.TCPAddr{IP: ip, Port: port}
if err = c.connPool.RPC(c.config.Datacenter, c.config.NodeName, &addr, "AutoEncrypt.Sign", &args, &reply); err == nil {
reply.IssuedCert.PrivateKeyPEM = pkPEM
return &reply, nil
} else {
c.logger.Warn("AutoEncrypt failed", "error", err)
}
}
}
attempts++
delay := lib.RandomStagger(retryJitterWindow)
interval := (time.Duration(attempts) * delay) + delay
c.logger.Warn("retrying AutoEncrypt", "retry_interval", interval)
select {
case <-time.After(interval):
continue
case <-ctx.Done():
return errFn(fmt.Errorf("aborting AutoEncrypt because interrupted: %w", ctx.Err()))
case <-c.shutdownCh:
return errFn(fmt.Errorf("aborting AutoEncrypt because shutting down"))
}
}
}
func missingPortError(host string, err error) bool {
return err != nil && err.Error() == fmt.Sprintf("address %s: missing port in address", host)
}
// resolveAddr is used to resolve the host into IPs and error.
func resolveAddr(rawHost string, logger hclog.Logger) ([]net.IP, error) {
host, _, err := net.SplitHostPort(rawHost)
if err != nil {
// In case we encounter this error, we proceed with the
// rawHost. This is fine since -start-join and -retry-join
// take only hosts anyways and this is an expected case.
if missingPortError(rawHost, err) {
host = rawHost
} else {
return nil, err
}
}
if ip := net.ParseIP(host); ip != nil {
return []net.IP{ip}, nil
}
// First try TCP so we have the best chance for the largest list of
// hosts to join. If this fails it's not fatal since this isn't a standard
// way to query DNS, and we have a fallback below.
if ips, err := tcpLookupIP(host, logger); err != nil {
logger.Debug("TCP-first lookup failed for host, falling back to UDP", "host", host, "error", err)
} else if len(ips) > 0 {
return ips, nil
}
// If TCP didn't yield anything then use the normal Go resolver which
// will try UDP, then might possibly try TCP again if the UDP response
// indicates it was truncated.
ips, err := net.LookupIP(host)
if err != nil {
return nil, err
}
return ips, nil
}
// tcpLookupIP is a helper to initiate a TCP-based DNS lookup for the given host.
// The built-in Go resolver will do a UDP lookup first, and will only use TCP if
// the response has the truncate bit set, which isn't common on DNS servers like
// Consul's. By doing the TCP lookup directly, we get the best chance for the
// largest list of hosts to join. Since joins are relatively rare events, it's ok
// to do this rather expensive operation.
func tcpLookupIP(host string, logger hclog.Logger) ([]net.IP, error) {
// Don't attempt any TCP lookups against non-fully qualified domain
// names, since those will likely come from the resolv.conf file.
if !strings.Contains(host, ".") {
return nil, nil
}
// Make sure the domain name is terminated with a dot (we know there's
// at least one character at this point).
dn := host
if dn[len(dn)-1] != '.' {
dn = dn + "."
}
// See if we can find a server to try.
cc, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
return nil, err
}
if len(cc.Servers) > 0 {
// Do the lookup.
c := new(dns.Client)
c.Net = "tcp"
msg := new(dns.Msg)
msg.SetQuestion(dn, dns.TypeANY)
in, _, err := c.Exchange(msg, cc.Servers[0])
if err != nil {
return nil, err
}
// Handle any IPs we get back that we can attempt to join.
var ips []net.IP
for _, r := range in.Answer {
switch rr := r.(type) {
case (*dns.A):
ips = append(ips, rr.A)
case (*dns.AAAA):
ips = append(ips, rr.AAAA)
case (*dns.CNAME):
logger.Debug("Ignoring CNAME RR in TCP-first answer for host", "host", host)
}
}
return ips, nil
}
return nil, nil
}

View File

@ -1,205 +0,0 @@
package consul
import (
"context"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"net"
"net/url"
"os"
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/require"
)
func TestAutoEncrypt_resolveAddr(t *testing.T) {
type args struct {
rawHost string
logger hclog.Logger
}
logger := testutil.Logger(t)
tests := []struct {
name string
args args
ips []net.IP
wantErr bool
}{
{
name: "host without port",
args: args{
"127.0.0.1",
logger,
},
ips: []net.IP{net.IPv4(127, 0, 0, 1)},
wantErr: false,
},
{
name: "host with port",
args: args{
"127.0.0.1:1234",
logger,
},
ips: []net.IP{net.IPv4(127, 0, 0, 1)},
wantErr: false,
},
{
name: "host with broken port",
args: args{
"127.0.0.1:xyz",
logger,
},
ips: []net.IP{net.IPv4(127, 0, 0, 1)},
wantErr: false,
},
{
name: "not an address",
args: args{
"abc",
logger,
},
ips: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ips, err := resolveAddr(tt.args.rawHost, tt.args.logger)
if (err != nil) != tt.wantErr {
t.Errorf("resolveAddr error: %v, wantErr: %v", err, tt.wantErr)
return
}
require.Equal(t, tt.ips, ips)
})
}
}
func TestAutoEncrypt_missingPortError(t *testing.T) {
host := "127.0.0.1"
_, _, err := net.SplitHostPort(host)
require.True(t, missingPortError(host, err))
host = "127.0.0.1:1234"
_, _, err = net.SplitHostPort(host)
require.False(t, missingPortError(host, err))
}
func TestAutoEncrypt_RequestAutoEncryptCerts(t *testing.T) {
dir1, c1 := testClient(t)
defer os.RemoveAll(dir1)
defer c1.Shutdown()
servers := []string{"localhost"}
port := 8301
token := ""
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(75*time.Millisecond))
defer cancel()
doneCh := make(chan struct{})
var err error
go func() {
_, err = c1.RequestAutoEncryptCerts(ctx, servers, port, token, nil, nil)
close(doneCh)
}()
select {
case <-doneCh:
// since there are no servers at this port, we shouldn't be
// done and this should be an error of some sorts that happened
// in the setup phase before entering the for loop in
// RequestAutoEncryptCerts.
require.NoError(t, err)
case <-ctx.Done():
// this is the happy case since auto encrypt is in its loop to
// try to request certs.
}
}
func TestAutoEncrypt_autoEncryptCSR(t *testing.T) {
type testCase struct {
conf *Config
extraDNSSANs []string
extraIPSANs []net.IP
err string
// to validate the csr
expectedSubject pkix.Name
expectedSigAlg x509.SignatureAlgorithm
expectedPubAlg x509.PublicKeyAlgorithm
expectedDNSNames []string
expectedIPs []net.IP
expectedURIs []*url.URL
}
cases := map[string]testCase{
"sans": {
conf: &Config{
Datacenter: "dc1",
NodeName: "test-node",
CAConfig: &structs.CAConfiguration{},
},
extraDNSSANs: []string{"foo.local", "bar.local"},
extraIPSANs: []net.IP{net.IPv4(198, 18, 0, 1), net.IPv4(198, 18, 0, 2)},
expectedSubject: pkix.Name{
CommonName: connect.AgentCN("test-node", dummyTrustDomain),
Names: []pkix.AttributeTypeAndValue{
{
// 2,5,4,3 is the CommonName type ASN1 identifier
Type: asn1.ObjectIdentifier{2, 5, 4, 3},
Value: "testnode.agnt.dummy.tr.consul",
},
},
},
expectedSigAlg: x509.ECDSAWithSHA256,
expectedPubAlg: x509.ECDSA,
expectedDNSNames: []string{
"localhost",
"foo.local",
"bar.local",
},
expectedIPs: []net.IP{
{127, 0, 0, 1},
net.ParseIP("::1"),
{198, 18, 0, 1},
{198, 18, 0, 2},
},
expectedURIs: []*url.URL{
{
Scheme: "spiffe",
Host: dummyTrustDomain,
Path: "/agent/client/dc/dc1/id/test-node",
},
},
},
}
for name, tcase := range cases {
t.Run(name, func(t *testing.T) {
client := Client{config: tcase.conf}
_, csr, err := client.autoEncryptCSR(tcase.extraDNSSANs, tcase.extraIPSANs)
if tcase.err == "" {
require.NoError(t, err)
request, err := connect.ParseCSR(csr)
require.NoError(t, err)
require.NotNil(t, request)
require.Equal(t, tcase.expectedSubject, request.Subject)
require.Equal(t, tcase.expectedSigAlg, request.SignatureAlgorithm)
require.Equal(t, tcase.expectedPubAlg, request.PublicKeyAlgorithm)
require.Equal(t, tcase.expectedDNSNames, request.DNSNames)
require.Equal(t, tcase.expectedIPs, request.IPAddresses)
require.Equal(t, tcase.expectedURIs, request.URIs)
} else {
require.Error(t, err)
require.Empty(t, csr)
}
})
}
}

View File

@ -15,6 +15,7 @@ import (
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/logging"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/consul/types"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/serf/serf"
"golang.org/x/time/rate"
@ -59,9 +60,9 @@ type Client struct {
// Connection pool to consul servers
connPool *pool.ConnPool
// routers is responsible for the selection and maintenance of
// router is responsible for the selection and maintenance of
// Consul servers this agent uses for RPC requests
routers *router.Manager
router *router.Router
// rpcLimiter is used to rate limit the total number of RPCs initiated
// from an agent.
@ -120,12 +121,14 @@ func NewClient(config *Config, options ...ConsulOption) (*Client, error) {
}
}
logger := flat.logger.NamedIntercept(logging.ConsulClient)
// Create client
c := &Client{
config: config,
connPool: connPool,
eventCh: make(chan serf.Event, serfEventBacklog),
logger: flat.logger.NamedIntercept(logging.ConsulClient),
logger: logger,
shutdownCh: make(chan struct{}),
tlsConfigurator: tlsConfigurator,
}
@ -160,15 +163,22 @@ func NewClient(config *Config, options ...ConsulOption) (*Client, error) {
return nil, fmt.Errorf("Failed to start lan serf: %v", err)
}
// Start maintenance task for servers
c.routers = router.New(c.logger, c.shutdownCh, c.serf, c.connPool, "")
go c.routers.Start()
rpcRouter := flat.router
if rpcRouter == nil {
rpcRouter = router.NewRouter(logger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter))
}
if err := rpcRouter.AddArea(types.AreaLAN, c.serf, c.connPool); err != nil {
c.Shutdown()
return nil, fmt.Errorf("Failed to add LAN area to the RPC router: %w", err)
}
c.router = rpcRouter
// Start LAN event handlers after the router is complete since the event
// handlers depend on the router and the router depends on Serf.
go c.lanEventHandler()
// This needs to happen after initializing c.routers to prevent a race
// This needs to happen after initializing c.router to prevent a race
// condition where the router manager is used when the pointer is nil
if c.acls.ACLsEnabled() {
go c.monitorACLMode()
@ -276,7 +286,7 @@ func (c *Client) RPC(method string, args interface{}, reply interface{}) error {
firstCheck := time.Now()
TRY:
server := c.routers.FindServer()
manager, server := c.router.FindLANRoute()
if server == nil {
return structs.ErrNoServers
}
@ -301,7 +311,7 @@ TRY:
"error", rpcErr,
)
metrics.IncrCounterWithLabels([]string{"client", "rpc", "failed"}, 1, []metrics.Label{{Name: "server", Value: server.Name}})
c.routers.NotifyFailedServer(server)
manager.NotifyFailedServer(server)
if retry := canRetry(args, rpcErr); !retry {
return rpcErr
}
@ -323,7 +333,7 @@ TRY:
// operation.
func (c *Client) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer,
replyFn structs.SnapshotReplyFn) error {
server := c.routers.FindServer()
manager, server := c.router.FindLANRoute()
if server == nil {
return structs.ErrNoServers
}
@ -339,6 +349,7 @@ func (c *Client) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io
var reply structs.SnapshotResponse
snap, err := SnapshotRPC(c.connPool, c.config.Datacenter, server.ShortName, server.Addr, args, in, &reply)
if err != nil {
manager.NotifyFailedServer(server)
return err
}
defer func() {
@ -367,7 +378,7 @@ func (c *Client) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io
// Stats is used to return statistics for debugging and insight
// for various sub-systems
func (c *Client) Stats() map[string]map[string]string {
numServers := c.routers.NumServers()
numServers := c.router.GetLANManager().NumServers()
toString := func(v uint64) string {
return strconv.FormatUint(v, 10)

View File

@ -9,6 +9,7 @@ import (
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/logging"
"github.com/hashicorp/consul/types"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/serf/serf"
)
@ -115,7 +116,7 @@ func (c *Client) nodeJoin(me serf.MemberEvent) {
continue
}
c.logger.Info("adding server", "server", parts)
c.routers.AddServer(parts)
c.router.AddServer(types.AreaLAN, parts)
// Trigger the callback
if c.config.ServerUp != nil {
@ -139,7 +140,7 @@ func (c *Client) nodeUpdate(me serf.MemberEvent) {
continue
}
c.logger.Info("updating server", "server", parts.String())
c.routers.AddServer(parts)
c.router.AddServer(types.AreaLAN, parts)
}
}
@ -151,7 +152,7 @@ func (c *Client) nodeFail(me serf.MemberEvent) {
continue
}
c.logger.Info("removing server", "server", parts.String())
c.routers.RemoveServer(parts)
c.router.RemoveServer(types.AreaLAN, parts)
}
}

View File

@ -112,7 +112,7 @@ func TestClient_JoinLAN(t *testing.T) {
joinLAN(t, c1, s1)
testrpc.WaitForTestAgent(t, c1.RPC, "dc1")
retry.Run(t, func(r *retry.R) {
if got, want := c1.routers.NumServers(), 1; got != want {
if got, want := c1.router.GetLANManager().NumServers(), 1; got != want {
r.Fatalf("got %d servers want %d", got, want)
}
if got, want := len(s1.LANMembers()), 2; got != want {
@ -150,7 +150,7 @@ func TestClient_LANReap(t *testing.T) {
// Check the router has both
retry.Run(t, func(r *retry.R) {
server := c1.routers.FindServer()
server := c1.router.FindLANServer()
require.NotNil(t, server)
require.Equal(t, s1.config.NodeName, server.Name)
})
@ -160,7 +160,7 @@ func TestClient_LANReap(t *testing.T) {
retry.Run(t, func(r *retry.R) {
require.Len(r, c1.LANMembers(), 1)
server := c1.routers.FindServer()
server := c1.router.FindLANServer()
require.Nil(t, server)
})
}
@ -390,7 +390,7 @@ func TestClient_RPC_ConsulServerPing(t *testing.T) {
}
// Sleep to allow Serf to sync, shuffle, and let the shuffle complete
c.routers.ResetRebalanceTimer()
c.router.GetLANManager().ResetRebalanceTimer()
time.Sleep(time.Second)
if len(c.LANMembers()) != numServers+numClients {
@ -406,7 +406,7 @@ func TestClient_RPC_ConsulServerPing(t *testing.T) {
var pingCount int
for range servers {
time.Sleep(200 * time.Millisecond)
s := c.routers.FindServer()
m, s := c.router.FindLANRoute()
ok, err := c.connPool.Ping(s.Datacenter, s.ShortName, s.Addr)
if !ok {
t.Errorf("Unable to ping server %v: %s", s.String(), err)
@ -415,7 +415,7 @@ func TestClient_RPC_ConsulServerPing(t *testing.T) {
// Artificially fail the server in order to rotate the server
// list
c.routers.NotifyFailedServer(s)
m.NotifyFailedServer(s)
}
if pingCount != numServers {
@ -524,7 +524,7 @@ func TestClient_SnapshotRPC(t *testing.T) {
// Wait until we've got a healthy server.
retry.Run(t, func(r *retry.R) {
if got, want := c1.routers.NumServers(), 1; got != want {
if got, want := c1.router.GetLANManager().NumServers(), 1; got != want {
r.Fatalf("got %d servers want %d", got, want)
}
})
@ -559,7 +559,7 @@ func TestClient_SnapshotRPC_RateLimit(t *testing.T) {
joinLAN(t, c1, s1)
retry.Run(t, func(r *retry.R) {
if got, want := c1.routers.NumServers(), 1; got != want {
if got, want := c1.router.GetLANManager().NumServers(), 1; got != want {
r.Fatalf("got %d servers want %d", got, want)
}
})
@ -607,7 +607,7 @@ func TestClient_SnapshotRPC_TLS(t *testing.T) {
}
// Wait until we've got a healthy server.
if got, want := c1.routers.NumServers(), 1; got != want {
if got, want := c1.router.GetLANManager().NumServers(), 1; got != want {
r.Fatalf("got %d servers want %d", got, want)
}
})

View File

@ -443,6 +443,10 @@ type Config struct {
// dead servers.
AutopilotInterval time.Duration
// MetricsReportingInterval is the frequency with which the server will
// report usage metrics to the configured go-metrics Sinks.
MetricsReportingInterval time.Duration
// ConnectEnabled is whether to enable Connect features such as the CA.
ConnectEnabled bool
@ -466,6 +470,9 @@ type Config struct {
// AutoEncrypt.Sign requests.
AutoEncryptAllowTLS bool
// TODO: godoc, set this value from Agent
EnableGRPCServer bool
// Embedded Consul Enterprise specific configuration
*EnterpriseConfig
}
@ -589,11 +596,16 @@ func DefaultConfig() *Config {
},
},
ServerHealthInterval: 2 * time.Second,
AutopilotInterval: 10 * time.Second,
DefaultQueryTime: 300 * time.Second,
MaxQueryTime: 600 * time.Second,
EnterpriseConfig: DefaultEnterpriseConfig(),
// Stay under the 10 second aggregation interval of
// go-metrics. This ensures we always report the
// usage metrics in each cycle.
MetricsReportingInterval: 9 * time.Second,
ServerHealthInterval: 2 * time.Second,
AutopilotInterval: 10 * time.Second,
DefaultQueryTime: 300 * time.Second,
MaxQueryTime: 600 * time.Second,
EnterpriseConfig: DefaultEnterpriseConfig(),
}
// Increase our reap interval to 3 days instead of 24h.

View File

@ -10,6 +10,7 @@ import (
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/logging"
"github.com/hashicorp/consul/types"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb"
)
@ -161,23 +162,32 @@ func (c *Coordinate) Update(args *structs.CoordinateUpdateRequest, reply *struct
// ListDatacenters returns the list of datacenters and their respective nodes
// and the raw coordinates of those nodes (if no coordinates are available for
// any of the nodes, the node list may be empty).
// any of the nodes, the node list may be empty). This endpoint will not return
// information about the LAN network area.
func (c *Coordinate) ListDatacenters(args *struct{}, reply *[]structs.DatacenterMap) error {
maps, err := c.srv.router.GetDatacenterMaps()
if err != nil {
return err
}
var out []structs.DatacenterMap
// Strip the datacenter suffixes from all the node names.
for i := range maps {
suffix := fmt.Sprintf(".%s", maps[i].Datacenter)
for j := range maps[i].Coordinates {
node := maps[i].Coordinates[j].Node
maps[i].Coordinates[j].Node = strings.TrimSuffix(node, suffix)
for _, dcMap := range maps {
if dcMap.AreaID == types.AreaLAN {
continue
}
suffix := fmt.Sprintf(".%s", dcMap.Datacenter)
for j := range dcMap.Coordinates {
node := dcMap.Coordinates[j].Node
dcMap.Coordinates[j].Node = strings.TrimSuffix(node, suffix)
}
out = append(out, dcMap)
}
*reply = maps
*reply = out
return nil
}

View File

@ -707,6 +707,7 @@ func (c *compiler) getSplitterNode(sid structs.ServiceID) (*structs.DiscoveryGra
// sanely if there is some sort of graph loop below.
c.recordNode(splitNode)
var hasLB bool
for _, split := range splitter.Splits {
compiledSplit := &structs.DiscoverySplit{
Weight: split.Weight,
@ -739,6 +740,17 @@ func (c *compiler) getSplitterNode(sid structs.ServiceID) (*structs.DiscoveryGra
return nil, err
}
compiledSplit.NextNode = node.MapKey()
// There exists the possibility that a splitter may split between two distinct service names
// with distinct hash-based load balancer configs specified in their service resolvers.
// We cannot apply multiple hash policies to a splitter node's route action.
// Therefore, we attach the first hash-based load balancer config we encounter.
if !hasLB {
if lb := node.LoadBalancer; lb != nil && lb.IsHashBased() {
splitNode.LoadBalancer = node.LoadBalancer
hasLB = true
}
}
}
c.usesAdvancedRoutingFeatures = true
@ -851,6 +863,7 @@ RESOLVE_AGAIN:
Target: target.ID,
ConnectTimeout: connectTimeout,
},
LoadBalancer: resolver.LoadBalancer,
}
target.Subset = resolver.Subsets[target.ServiceSubset]
@ -1009,10 +1022,5 @@ func defaultIfEmpty(val, defaultVal string) string {
}
func enableAdvancedRoutingForProtocol(protocol string) bool {
switch protocol {
case "http", "http2", "grpc":
return true
default:
return false
}
return structs.IsProtocolHTTPLike(protocol)
}

View File

@ -51,6 +51,8 @@ func TestCompile(t *testing.T) {
"default resolver with external sni": testcase_DefaultResolver_ExternalSNI(),
"resolver with no entries and inferring defaults": testcase_DefaultResolver(),
"default resolver with proxy defaults": testcase_DefaultResolver_WithProxyDefaults(),
"loadbalancer splitter and resolver": testcase_LBSplitterAndResolver(),
"loadbalancer resolver": testcase_LBResolver(),
"service redirect to service with default resolver is not a default chain": testcase_RedirectToDefaultResolverIsNotDefaultChain(),
"all the bells and whistles": testcase_AllBellsAndWhistles(),
@ -1760,6 +1762,17 @@ func testcase_AllBellsAndWhistles() compileTestCase {
"prod": {Filter: "ServiceMeta.env == prod"},
"qa": {Filter: "ServiceMeta.env == qa"},
},
LoadBalancer: &structs.LoadBalancer{
Policy: "ring_hash",
RingHashConfig: &structs.RingHashConfig{
MaximumRingSize: 100,
},
HashPolicies: []structs.HashPolicy{
{
SourceIP: true,
},
},
},
},
&structs.ServiceResolverConfigEntry{
Kind: "service-resolver",
@ -1821,6 +1834,17 @@ func testcase_AllBellsAndWhistles() compileTestCase {
NextNode: "resolver:v3.main.default.dc1",
},
},
LoadBalancer: &structs.LoadBalancer{
Policy: "ring_hash",
RingHashConfig: &structs.RingHashConfig{
MaximumRingSize: 100,
},
HashPolicies: []structs.HashPolicy{
{
SourceIP: true,
},
},
},
},
"resolver:prod.redirected.default.dc1": {
Type: structs.DiscoveryGraphNodeTypeResolver,
@ -1829,6 +1853,17 @@ func testcase_AllBellsAndWhistles() compileTestCase {
ConnectTimeout: 5 * time.Second,
Target: "prod.redirected.default.dc1",
},
LoadBalancer: &structs.LoadBalancer{
Policy: "ring_hash",
RingHashConfig: &structs.RingHashConfig{
MaximumRingSize: 100,
},
HashPolicies: []structs.HashPolicy{
{
SourceIP: true,
},
},
},
},
"resolver:v1.main.default.dc1": {
Type: structs.DiscoveryGraphNodeTypeResolver,
@ -2219,6 +2254,231 @@ func testcase_CircularSplit() compileTestCase {
}
}
func testcase_LBSplitterAndResolver() compileTestCase {
entries := newEntries()
setServiceProtocol(entries, "foo", "http")
setServiceProtocol(entries, "bar", "http")
setServiceProtocol(entries, "baz", "http")
entries.AddSplitters(
&structs.ServiceSplitterConfigEntry{
Kind: "service-splitter",
Name: "main",
Splits: []structs.ServiceSplit{
{Weight: 60, Service: "foo"},
{Weight: 20, Service: "bar"},
{Weight: 20, Service: "baz"},
},
},
)
entries.AddResolvers(
&structs.ServiceResolverConfigEntry{
Kind: "service-resolver",
Name: "foo",
LoadBalancer: &structs.LoadBalancer{
Policy: "least_request",
LeastRequestConfig: &structs.LeastRequestConfig{
ChoiceCount: 3,
},
},
},
&structs.ServiceResolverConfigEntry{
Kind: "service-resolver",
Name: "bar",
LoadBalancer: &structs.LoadBalancer{
Policy: "ring_hash",
RingHashConfig: &structs.RingHashConfig{
MaximumRingSize: 101,
},
HashPolicies: []structs.HashPolicy{
{
SourceIP: true,
},
},
},
},
&structs.ServiceResolverConfigEntry{
Kind: "service-resolver",
Name: "baz",
LoadBalancer: &structs.LoadBalancer{
Policy: "maglev",
HashPolicies: []structs.HashPolicy{
{
Field: "cookie",
FieldValue: "chocolate-chip",
CookieConfig: &structs.CookieConfig{
TTL: 2 * time.Minute,
Path: "/bowl",
},
Terminal: true,
},
},
},
},
)
expect := &structs.CompiledDiscoveryChain{
Protocol: "http",
StartNode: "splitter:main.default",
Nodes: map[string]*structs.DiscoveryGraphNode{
"splitter:main.default": {
Type: structs.DiscoveryGraphNodeTypeSplitter,
Name: "main.default",
Splits: []*structs.DiscoverySplit{
{
Weight: 60,
NextNode: "resolver:foo.default.dc1",
},
{
Weight: 20,
NextNode: "resolver:bar.default.dc1",
},
{
Weight: 20,
NextNode: "resolver:baz.default.dc1",
},
},
// The LB config from bar is attached because splitters only care about hash-based policies,
// and it's the config from bar not baz because we pick the first one we encounter in the Splits.
LoadBalancer: &structs.LoadBalancer{
Policy: "ring_hash",
RingHashConfig: &structs.RingHashConfig{
MaximumRingSize: 101,
},
HashPolicies: []structs.HashPolicy{
{
SourceIP: true,
},
},
},
},
// Each service's LB config is passed down from the service-resolver to the resolver node
"resolver:foo.default.dc1": {
Type: structs.DiscoveryGraphNodeTypeResolver,
Name: "foo.default.dc1",
Resolver: &structs.DiscoveryResolver{
Default: false,
ConnectTimeout: 5 * time.Second,
Target: "foo.default.dc1",
},
LoadBalancer: &structs.LoadBalancer{
Policy: "least_request",
LeastRequestConfig: &structs.LeastRequestConfig{
ChoiceCount: 3,
},
},
},
"resolver:bar.default.dc1": {
Type: structs.DiscoveryGraphNodeTypeResolver,
Name: "bar.default.dc1",
Resolver: &structs.DiscoveryResolver{
Default: false,
ConnectTimeout: 5 * time.Second,
Target: "bar.default.dc1",
},
LoadBalancer: &structs.LoadBalancer{
Policy: "ring_hash",
RingHashConfig: &structs.RingHashConfig{
MaximumRingSize: 101,
},
HashPolicies: []structs.HashPolicy{
{
SourceIP: true,
},
},
},
},
"resolver:baz.default.dc1": {
Type: structs.DiscoveryGraphNodeTypeResolver,
Name: "baz.default.dc1",
Resolver: &structs.DiscoveryResolver{
Default: false,
ConnectTimeout: 5 * time.Second,
Target: "baz.default.dc1",
},
LoadBalancer: &structs.LoadBalancer{
Policy: "maglev",
HashPolicies: []structs.HashPolicy{
{
Field: "cookie",
FieldValue: "chocolate-chip",
CookieConfig: &structs.CookieConfig{
TTL: 2 * time.Minute,
Path: "/bowl",
},
Terminal: true,
},
},
},
},
},
Targets: map[string]*structs.DiscoveryTarget{
"foo.default.dc1": newTarget("foo", "", "default", "dc1", nil),
"bar.default.dc1": newTarget("bar", "", "default", "dc1", nil),
"baz.default.dc1": newTarget("baz", "", "default", "dc1", nil),
},
}
return compileTestCase{entries: entries, expect: expect}
}
// ensure chain with LB cfg in resolver isn't a default chain (!IsDefault)
func testcase_LBResolver() compileTestCase {
entries := newEntries()
setServiceProtocol(entries, "main", "http")
entries.AddResolvers(
&structs.ServiceResolverConfigEntry{
Kind: "service-resolver",
Name: "main",
LoadBalancer: &structs.LoadBalancer{
Policy: "ring_hash",
RingHashConfig: &structs.RingHashConfig{
MaximumRingSize: 101,
},
HashPolicies: []structs.HashPolicy{
{
SourceIP: true,
},
},
},
},
)
expect := &structs.CompiledDiscoveryChain{
Protocol: "http",
StartNode: "resolver:main.default.dc1",
Nodes: map[string]*structs.DiscoveryGraphNode{
"resolver:main.default.dc1": {
Type: structs.DiscoveryGraphNodeTypeResolver,
Name: "main.default.dc1",
Resolver: &structs.DiscoveryResolver{
Default: false,
ConnectTimeout: 5 * time.Second,
Target: "main.default.dc1",
},
LoadBalancer: &structs.LoadBalancer{
Policy: "ring_hash",
RingHashConfig: &structs.RingHashConfig{
MaximumRingSize: 101,
},
HashPolicies: []structs.HashPolicy{
{
SourceIP: true,
},
},
},
},
},
Targets: map[string]*structs.DiscoveryTarget{
"main.default.dc1": newTarget("main", "", "default", "dc1", nil),
},
}
return compileTestCase{entries: entries, expect: expect}
}
func newSimpleRoute(name string, muts ...func(*structs.ServiceRoute)) structs.ServiceRoute {
r := structs.ServiceRoute{
Match: &structs.ServiceRouteMatch{

View File

@ -654,6 +654,12 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
require.NoError(t, err)
require.Equal(t, fedState2, fedStateLoaded2)
// Verify usage data is correctly updated
idx, nodeCount, err := fsm2.state.NodeCount()
require.NoError(t, err)
require.Equal(t, len(nodes), nodeCount)
require.NotZero(t, idx)
// Snapshot
snap, err = fsm2.Snapshot()
require.NoError(t, err)

View File

@ -663,7 +663,7 @@ func (s *Server) secondaryIntermediateCertRenewalWatch(ctx context.Context) erro
case <-ctx.Done():
return nil
case <-time.After(structs.IntermediateCertRenewInterval):
retryLoopBackoff(ctx, func() error {
retryLoopBackoffAbortOnSuccess(ctx, func() error {
s.caProviderReconfigurationLock.Lock()
defer s.caProviderReconfigurationLock.Unlock()
@ -845,6 +845,14 @@ func (s *Server) replicateIntentions(ctx context.Context) error {
// retryLoopBackoff loops a given function indefinitely, backing off exponentially
// upon errors up to a maximum of maxRetryBackoff seconds.
func retryLoopBackoff(ctx context.Context, loopFn func() error, errFn func(error)) {
retryLoopBackoffHandleSuccess(ctx, loopFn, errFn, false)
}
func retryLoopBackoffAbortOnSuccess(ctx context.Context, loopFn func() error, errFn func(error)) {
retryLoopBackoffHandleSuccess(ctx, loopFn, errFn, true)
}
func retryLoopBackoffHandleSuccess(ctx context.Context, loopFn func() error, errFn func(error), abortOnSuccess bool) {
var failedAttempts uint
limiter := rate.NewLimiter(loopRateLimit, retryBucketSize)
for {
@ -871,6 +879,8 @@ func retryLoopBackoff(ctx context.Context, loopFn func() error, errFn func(error
case <-timer.C:
continue
}
} else if abortOnSuccess {
return
}
// Reset the failed attempts after a successful run.

View File

@ -1,6 +1,7 @@
package consul
import (
"context"
"crypto/x509"
"fmt"
"io/ioutil"
@ -1442,3 +1443,43 @@ func TestLeader_lessThanHalfTimePassed(t *testing.T) {
require.True(t, lessThanHalfTimePassed(now, now.Add(-10*time.Second), now.Add(20*time.Second)))
}
func TestLeader_retryLoopBackoffHandleSuccess(t *testing.T) {
type test struct {
desc string
loopFn func() error
abort bool
timedOut bool
}
success := func() error {
return nil
}
failure := func() error {
return fmt.Errorf("test error")
}
tests := []test{
{"loop without error and no abortOnSuccess keeps running", success, false, true},
{"loop with error and no abortOnSuccess keeps running", failure, false, true},
{"loop without error and abortOnSuccess is stopped", success, true, false},
{"loop with error and abortOnSuccess keeps running", failure, true, true},
}
for _, tc := range tests {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
retryLoopBackoffHandleSuccess(ctx, tc.loopFn, func(_ error) {}, tc.abort)
select {
case <-ctx.Done():
if !tc.timedOut {
t.Fatal("should not have timed out")
}
default:
if tc.timedOut {
t.Fatal("should have timed out")
}
}
})
}
}

View File

@ -2,6 +2,7 @@ package consul
import (
"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/agent/router"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/go-hclog"
@ -12,6 +13,7 @@ type consulOptions struct {
tlsConfigurator *tlsutil.Configurator
connPool *pool.ConnPool
tokens *token.Store
router *router.Router
}
type ConsulOption func(*consulOptions)
@ -40,6 +42,12 @@ func WithTokenStore(tokens *token.Store) ConsulOption {
}
}
func WithRouter(router *router.Router) ConsulOption {
return func(opt *consulOptions) {
opt.router = router
}
}
func flattenConsulOptions(options []ConsulOption) consulOptions {
var flat consulOptions
for _, opt := range options {

View File

@ -188,6 +188,9 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) {
conn = tls.Server(conn, s.tlsConfigurator.IncomingInsecureRPCConfig())
s.handleInsecureConn(conn)
case pool.RPCGRPC:
s.grpcHandler.Handle(conn)
default:
if !s.handleEnterpriseRPCConn(typ, conn, isTLS) {
s.rpcLogger().Error("unrecognized RPC byte",
@ -254,6 +257,9 @@ func (s *Server) handleNativeTLS(conn net.Conn) {
case pool.ALPN_RPCSnapshot:
s.handleSnapshotConn(tlsConn)
case pool.ALPN_RPCGRPC:
s.grpcHandler.Handle(conn)
case pool.ALPN_WANGossipPacket:
if err := s.handleALPN_WANGossipPacketStream(tlsConn); err != nil && err != io.EOF {
s.rpcLogger().Error(

View File

@ -25,6 +25,8 @@ import (
"github.com/hashicorp/consul/agent/consul/autopilot"
"github.com/hashicorp/consul/agent/consul/fsm"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/usagemetrics"
"github.com/hashicorp/consul/agent/grpc"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/agent/router"
@ -238,8 +240,9 @@ type Server struct {
rpcConnLimiter connlimit.Limiter
// Listener is used to listen for incoming connections
Listener net.Listener
rpcServer *rpc.Server
Listener net.Listener
grpcHandler connHandler
rpcServer *rpc.Server
// insecureRPCServer is a RPC server that is configure with
// IncomingInsecureRPCConfig to allow clients to call AutoEncrypt.Sign
@ -313,6 +316,12 @@ type Server struct {
EnterpriseServer
}
type connHandler interface {
Run() error
Handle(conn net.Conn)
Shutdown() error
}
// NewServer is used to construct a new Consul server from the configuration
// and extra options, potentially returning an error.
func NewServer(config *Config, options ...ConsulOption) (*Server, error) {
@ -322,6 +331,7 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) {
tokens := flat.tokens
tlsConfigurator := flat.tlsConfigurator
connPool := flat.connPool
rpcRouter := flat.router
if err := config.CheckProtocolVersion(); err != nil {
return nil, err
@ -377,6 +387,11 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) {
serverLogger := logger.NamedIntercept(logging.ConsulServer)
loggers := newLoggerStore(serverLogger)
if rpcRouter == nil {
rpcRouter = router.NewRouter(serverLogger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter))
}
// Create server.
s := &Server{
config: config,
@ -388,7 +403,7 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) {
loggers: loggers,
leaveCh: make(chan struct{}),
reconcileCh: make(chan serf.Member, reconcileChSize),
router: router.NewRouter(serverLogger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter)),
router: rpcRouter,
rpcServer: rpc.NewServer(),
insecureRPCServer: rpc.NewServer(),
tlsConfigurator: tlsConfigurator,
@ -545,6 +560,11 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) {
s.Shutdown()
return nil, fmt.Errorf("Failed to start LAN Serf: %v", err)
}
if err := s.router.AddArea(types.AreaLAN, s.serfLAN, s.connPool); err != nil {
s.Shutdown()
return nil, fmt.Errorf("Failed to add LAN serf route: %w", err)
}
go s.lanEventHandler()
// Start the flooders after the LAN event handler is wired up.
@ -578,6 +598,21 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) {
return nil, err
}
reporter, err := usagemetrics.NewUsageMetricsReporter(
new(usagemetrics.Config).
WithStateProvider(s.fsm).
WithLogger(s.logger).
WithDatacenter(s.config.Datacenter).
WithReportingInterval(s.config.MetricsReportingInterval),
)
if err != nil {
s.Shutdown()
return nil, fmt.Errorf("Failed to start usage metrics reporter: %v", err)
}
go reporter.Run(&lib.StopChannelContext{StopCh: s.shutdownCh})
s.grpcHandler = newGRPCHandlerFromConfig(logger, config)
// Initialize Autopilot. This must happen before starting leadership monitoring
// as establishing leadership could attempt to use autopilot and cause a panic.
s.initAutopilot(config)
@ -587,6 +622,11 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) {
go s.monitorLeadership()
// Start listening for RPC requests.
go func() {
if err := s.grpcHandler.Run(); err != nil {
s.logger.Error("gRPC server failed", "error", err)
}
}()
go s.listen(s.Listener)
// Start listeners for any segments with separate RPC listeners.
@ -600,6 +640,14 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) {
return s, nil
}
func newGRPCHandlerFromConfig(logger hclog.Logger, config *Config) connHandler {
if !config.EnableGRPCServer {
return grpc.NoOpHandler{Logger: logger}
}
return grpc.NewHandler(config.RPCAddr)
}
func (s *Server) connectCARootsMonitor(ctx context.Context) {
for {
ws := memdb.NewWatchSet()
@ -924,6 +972,12 @@ func (s *Server) Shutdown() error {
s.Listener.Close()
}
if s.grpcHandler != nil {
if err := s.grpcHandler.Shutdown(); err != nil {
s.logger.Warn("failed to stop gRPC server", "error", err)
}
}
// Close the connection pool
if s.connPool != nil {
s.connPool.Shutdown()

View File

@ -0,0 +1,475 @@
package state
import (
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs"
memdb "github.com/hashicorp/go-memdb"
)
type changeOp int
const (
OpDelete changeOp = iota
OpCreate
OpUpdate
)
type eventPayload struct {
Op changeOp
Obj interface{}
}
// serviceHealthSnapshot returns a stream.SnapshotFunc that provides a snapshot
// of stream.Events that describe the current state of a service health query.
//
// TODO: no tests for this yet
func serviceHealthSnapshot(s *Store, topic topic) stream.SnapshotFunc {
return func(req stream.SubscribeRequest, buf stream.SnapshotAppender) (index uint64, err error) {
tx := s.db.Txn(false)
defer tx.Abort()
connect := topic == TopicServiceHealthConnect
// TODO(namespace-streaming): plumb entMeta through from SubscribeRequest
idx, nodes, err := checkServiceNodesTxn(tx, nil, req.Key, connect, nil)
if err != nil {
return 0, err
}
for _, n := range nodes {
event := stream.Event{
Index: idx,
Topic: topic,
Payload: eventPayload{
Op: OpCreate,
Obj: &n,
},
}
if n.Service != nil {
event.Key = n.Service.Service
}
// append each event as a separate item so that they can be serialized
// separately, to prevent the encoding of one massive message.
buf.Append([]stream.Event{event})
}
return idx, err
}
}
type nodeServiceTuple struct {
Node string
ServiceID string
EntMeta structs.EnterpriseMeta
}
func newNodeServiceTupleFromServiceNode(sn *structs.ServiceNode) nodeServiceTuple {
return nodeServiceTuple{
Node: sn.Node,
ServiceID: sn.ServiceID,
EntMeta: sn.EnterpriseMeta,
}
}
func newNodeServiceTupleFromServiceHealthCheck(hc *structs.HealthCheck) nodeServiceTuple {
return nodeServiceTuple{
Node: hc.Node,
ServiceID: hc.ServiceID,
EntMeta: hc.EnterpriseMeta,
}
}
type serviceChange struct {
changeType changeType
change memdb.Change
}
var serviceChangeIndirect = serviceChange{changeType: changeIndirect}
// ServiceHealthEventsFromChanges returns all the service and Connect health
// events that should be emitted given a set of changes to the state store.
func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event, error) {
var events []stream.Event
var nodeChanges map[string]changeType
var serviceChanges map[nodeServiceTuple]serviceChange
markNode := func(node string, typ changeType) {
if nodeChanges == nil {
nodeChanges = make(map[string]changeType)
}
// If the caller has an actual node mutation ensure we store it even if the
// node is already marked. If the caller is just marking the node dirty
// without a node change, don't overwrite any existing node change we know
// about.
if nodeChanges[node] == changeIndirect {
nodeChanges[node] = typ
}
}
markService := func(key nodeServiceTuple, svcChange serviceChange) {
if serviceChanges == nil {
serviceChanges = make(map[nodeServiceTuple]serviceChange)
}
// If the caller has an actual service mutation ensure we store it even if
// the service is already marked. If the caller is just marking the service
// dirty without a service change, don't overwrite any existing service change we
// know about.
if serviceChanges[key].changeType == changeIndirect {
serviceChanges[key] = svcChange
}
}
for _, change := range changes.Changes {
switch change.Table {
case "nodes":
// Node changed in some way, if it's not a delete, we'll need to
// re-deliver CheckServiceNode results for all services on that node but
// we mark it anyway because if it _is_ a delete then we need to know that
// later to avoid trying to deliver events when node level checks mark the
// node as "changed".
n := changeObject(change).(*structs.Node)
markNode(n.Node, changeTypeFromChange(change))
case "services":
sn := changeObject(change).(*structs.ServiceNode)
srvChange := serviceChange{changeType: changeTypeFromChange(change), change: change}
markService(newNodeServiceTupleFromServiceNode(sn), srvChange)
case "checks":
// For health we only care about the scope for now to know if it's just
// affecting a single service or every service on a node. There is a
// subtle edge case where the check with same ID changes from being node
// scoped to service scoped or vice versa, in either case we need to treat
// it as affecting all services on the node.
switch {
case change.Updated():
before := change.Before.(*structs.HealthCheck)
after := change.After.(*structs.HealthCheck)
if after.ServiceID == "" || before.ServiceID == "" {
// check before and/or after is node-scoped
markNode(after.Node, changeIndirect)
} else {
// Check changed which means we just need to emit for the linked
// service.
markService(newNodeServiceTupleFromServiceHealthCheck(after), serviceChangeIndirect)
// Edge case - if the check with same ID was updated to link to a
// different service ID but the old service with old ID still exists,
// then the old service instance needs updating too as it has one
// fewer checks now.
if before.ServiceID != after.ServiceID {
markService(newNodeServiceTupleFromServiceHealthCheck(before), serviceChangeIndirect)
}
}
case change.Deleted(), change.Created():
obj := changeObject(change).(*structs.HealthCheck)
if obj.ServiceID == "" {
// Node level check
markNode(obj.Node, changeIndirect)
} else {
markService(newNodeServiceTupleFromServiceHealthCheck(obj), serviceChangeIndirect)
}
}
}
}
// Now act on those marked nodes/services
for node, changeType := range nodeChanges {
if changeType == changeDelete {
// Node deletions are a no-op here since the state store transaction will
// have also removed all the service instances which will be handled in
// the loop below.
continue
}
// Rebuild events for all services on this node
es, err := newServiceHealthEventsForNode(tx, changes.Index, node)
if err != nil {
return nil, err
}
events = append(events, es...)
}
for tuple, srvChange := range serviceChanges {
// change may be nil if there was a change that _affected_ the service
// like a change to checks but it didn't actually change the service
// record itself.
if srvChange.changeType == changeDelete {
sn := srvChange.change.Before.(*structs.ServiceNode)
e := newServiceHealthEventDeregister(changes.Index, sn)
events = append(events, e)
continue
}
// Check if this was a service mutation that changed it's name which
// requires special handling even if node changed and new events were
// already published.
if srvChange.changeType == changeUpdate {
before := srvChange.change.Before.(*structs.ServiceNode)
after := srvChange.change.After.(*structs.ServiceNode)
if before.ServiceName != after.ServiceName {
// Service was renamed, the code below will ensure the new registrations
// go out to subscribers to the new service name topic key, but we need
// to fix up subscribers that were watching the old name by sending
// deregistrations.
e := newServiceHealthEventDeregister(changes.Index, before)
events = append(events, e)
}
if e, ok := isConnectProxyDestinationServiceChange(changes.Index, before, after); ok {
events = append(events, e)
}
}
if _, ok := nodeChanges[tuple.Node]; ok {
// We already rebuilt events for everything on this node, no need to send
// a duplicate.
continue
}
// Build service event and append it
e, err := newServiceHealthEventForService(tx, changes.Index, tuple)
if err != nil {
return nil, err
}
events = append(events, e)
}
// Duplicate any events that affected connect-enabled instances (proxies or
// native apps) to the relevant Connect topic.
events = append(events, serviceHealthToConnectEvents(events...)...)
return events, nil
}
// isConnectProxyDestinationServiceChange handles the case where a Connect proxy changed
// the service it is proxying. We need to issue a de-registration for the old
// service on the Connect topic. We don't actually need to deregister this sidecar
// service though as it still exists and didn't change its name.
func isConnectProxyDestinationServiceChange(idx uint64, before, after *structs.ServiceNode) (stream.Event, bool) {
if before.ServiceKind != structs.ServiceKindConnectProxy ||
before.ServiceProxy.DestinationServiceName == after.ServiceProxy.DestinationServiceName {
return stream.Event{}, false
}
e := newServiceHealthEventDeregister(idx, before)
e.Topic = TopicServiceHealthConnect
e.Key = getPayloadCheckServiceNode(e.Payload).Service.Proxy.DestinationServiceName
return e, true
}
type changeType uint8
const (
// changeIndirect indicates some other object changed which has implications
// for the target object.
changeIndirect changeType = iota
changeDelete
changeCreate
changeUpdate
)
func changeTypeFromChange(change memdb.Change) changeType {
switch {
case change.Deleted():
return changeDelete
case change.Created():
return changeCreate
default:
return changeUpdate
}
}
// serviceHealthToConnectEvents converts already formatted service health
// registration events into the ones needed to publish to the Connect topic.
// This essentially means filtering out any instances that are not Connect
// enabled and so of no interest to those subscribers but also involves
// switching connection details to be the proxy instead of the actual instance
// in case of a sidecar.
func serviceHealthToConnectEvents(events ...stream.Event) []stream.Event {
var result []stream.Event
for _, event := range events {
if event.Topic != TopicServiceHealth {
// Skip non-health or any events already emitted to Connect topic
continue
}
node := getPayloadCheckServiceNode(event.Payload)
if node.Service == nil {
continue
}
connectEvent := event
connectEvent.Topic = TopicServiceHealthConnect
switch {
case node.Service.Connect.Native:
result = append(result, connectEvent)
case node.Service.Kind == structs.ServiceKindConnectProxy:
connectEvent.Key = node.Service.Proxy.DestinationServiceName
result = append(result, connectEvent)
default:
// ServiceKindTerminatingGateway changes are handled separately.
// All other cases are not relevant to the connect topic
}
}
return result
}
func getPayloadCheckServiceNode(payload interface{}) *structs.CheckServiceNode {
ep, ok := payload.(eventPayload)
if !ok {
return nil
}
csn, ok := ep.Obj.(*structs.CheckServiceNode)
if !ok {
return nil
}
return csn
}
// newServiceHealthEventsForNode returns health events for all services on the
// given node. This mirrors some of the the logic in the oddly-named
// parseCheckServiceNodes but is more efficient since we know they are all on
// the same node.
func newServiceHealthEventsForNode(tx ReadTxn, idx uint64, node string) ([]stream.Event, error) {
// TODO(namespace-streaming): figure out the right EntMeta and mystery arg.
services, err := catalogServiceListByNode(tx, node, nil, false)
if err != nil {
return nil, err
}
n, checksFunc, err := getNodeAndChecks(tx, node)
if err != nil {
return nil, err
}
var events []stream.Event
for service := services.Next(); service != nil; service = services.Next() {
sn := service.(*structs.ServiceNode)
event := newServiceHealthEventRegister(idx, n, sn, checksFunc(sn.ServiceID))
events = append(events, event)
}
return events, nil
}
// getNodeAndNodeChecks returns a the node structure and a function that returns
// the full list of checks for a specific service on that node.
func getNodeAndChecks(tx ReadTxn, node string) (*structs.Node, serviceChecksFunc, error) {
// Fetch the node
nodeRaw, err := tx.First("nodes", "id", node)
if err != nil {
return nil, nil, err
}
if nodeRaw == nil {
return nil, nil, ErrMissingNode
}
n := nodeRaw.(*structs.Node)
// TODO(namespace-streaming): work out what EntMeta is needed here, wildcard?
iter, err := catalogListChecksByNode(tx, node, nil)
if err != nil {
return nil, nil, err
}
var nodeChecks structs.HealthChecks
var svcChecks map[string]structs.HealthChecks
for check := iter.Next(); check != nil; check = iter.Next() {
check := check.(*structs.HealthCheck)
if check.ServiceID == "" {
nodeChecks = append(nodeChecks, check)
} else {
if svcChecks == nil {
svcChecks = make(map[string]structs.HealthChecks)
}
svcChecks[check.ServiceID] = append(svcChecks[check.ServiceID], check)
}
}
serviceChecks := func(serviceID string) structs.HealthChecks {
// Create a new slice so that append does not modify the array backing nodeChecks.
result := make(structs.HealthChecks, 0, len(nodeChecks))
result = append(result, nodeChecks...)
for _, check := range svcChecks[serviceID] {
result = append(result, check)
}
return result
}
return n, serviceChecks, nil
}
type serviceChecksFunc func(serviceID string) structs.HealthChecks
func newServiceHealthEventForService(tx ReadTxn, idx uint64, tuple nodeServiceTuple) (stream.Event, error) {
n, checksFunc, err := getNodeAndChecks(tx, tuple.Node)
if err != nil {
return stream.Event{}, err
}
svc, err := getCompoundWithTxn(tx, "services", "id", &tuple.EntMeta, tuple.Node, tuple.ServiceID)
if err != nil {
return stream.Event{}, err
}
raw := svc.Next()
if raw == nil {
return stream.Event{}, ErrMissingService
}
sn := raw.(*structs.ServiceNode)
return newServiceHealthEventRegister(idx, n, sn, checksFunc(sn.ServiceID)), nil
}
func newServiceHealthEventRegister(
idx uint64,
node *structs.Node,
sn *structs.ServiceNode,
checks structs.HealthChecks,
) stream.Event {
csn := &structs.CheckServiceNode{
Node: node,
Service: sn.ToNodeService(),
Checks: checks,
}
return stream.Event{
Topic: TopicServiceHealth,
Key: sn.ServiceName,
Index: idx,
Payload: eventPayload{
Op: OpCreate,
Obj: csn,
},
}
}
func newServiceHealthEventDeregister(idx uint64, sn *structs.ServiceNode) stream.Event {
// We actually only need the node name populated in the node part as it's only
// used as a key to know which service was deregistered so don't bother looking
// up the node in the DB. Note that while the ServiceNode does have NodeID
// etc. fields, they are never populated in memdb per the comment on that
// struct and only filled in when we return copies of the result to users.
// This is also important because if the service was deleted as part of a
// whole node deregistering then the node record won't actually exist now
// anyway and we'd have to plumb it through from the changeset above.
csn := &structs.CheckServiceNode{
Node: &structs.Node{
Node: sn.Node,
},
Service: sn.ToNodeService(),
}
return stream.Event{
Topic: TopicServiceHealth,
Key: sn.ServiceName,
Index: idx,
Payload: eventPayload{
Op: OpDelete,
Obj: csn,
},
}
}

File diff suppressed because it is too large Load Diff

View File

@ -467,7 +467,7 @@ func validateProposedConfigEntryInServiceGraph(
}
overrides := map[structs.ConfigEntryKindName]structs.ConfigEntry{
{Kind: kind, Name: name}: next,
structs.NewConfigEntryKindName(kind, name, entMeta): next,
}
var (
@ -909,9 +909,8 @@ func configEntryWithOverridesTxn(
entMeta *structs.EnterpriseMeta,
) (uint64, structs.ConfigEntry, error) {
if len(overrides) > 0 {
entry, ok := overrides[structs.ConfigEntryKindName{
Kind: kind, Name: name,
}]
kn := structs.NewConfigEntryKindName(kind, name, entMeta)
entry, ok := overrides[kn]
if ok {
return 0, entry, nil // a nil entry implies it should act like it is erased
}

View File

@ -880,10 +880,10 @@ func TestStore_ReadDiscoveryChainConfigEntries_Overrides(t *testing.T) {
},
},
expectBefore: []structs.ConfigEntryKindName{
{Kind: structs.ServiceDefaults, Name: "main"},
structs.NewConfigEntryKindName(structs.ServiceDefaults, "main", nil),
},
overrides: map[structs.ConfigEntryKindName]structs.ConfigEntry{
{Kind: structs.ServiceDefaults, Name: "main"}: nil,
structs.NewConfigEntryKindName(structs.ServiceDefaults, "main", nil): nil,
},
expectAfter: []structs.ConfigEntryKindName{
// nothing
@ -899,17 +899,17 @@ func TestStore_ReadDiscoveryChainConfigEntries_Overrides(t *testing.T) {
},
},
expectBefore: []structs.ConfigEntryKindName{
{Kind: structs.ServiceDefaults, Name: "main"},
structs.NewConfigEntryKindName(structs.ServiceDefaults, "main", nil),
},
overrides: map[structs.ConfigEntryKindName]structs.ConfigEntry{
{Kind: structs.ServiceDefaults, Name: "main"}: &structs.ServiceConfigEntry{
structs.NewConfigEntryKindName(structs.ServiceDefaults, "main", nil): &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "main",
Protocol: "grpc",
},
},
expectAfter: []structs.ConfigEntryKindName{
{Kind: structs.ServiceDefaults, Name: "main"},
structs.NewConfigEntryKindName(structs.ServiceDefaults, "main", nil),
},
checkAfter: func(t *testing.T, entrySet *structs.DiscoveryChainConfigEntries) {
defaults := entrySet.GetService(structs.NewServiceID("main", nil))
@ -932,14 +932,14 @@ func TestStore_ReadDiscoveryChainConfigEntries_Overrides(t *testing.T) {
},
},
expectBefore: []structs.ConfigEntryKindName{
{Kind: structs.ServiceDefaults, Name: "main"},
{Kind: structs.ServiceRouter, Name: "main"},
structs.NewConfigEntryKindName(structs.ServiceDefaults, "main", nil),
structs.NewConfigEntryKindName(structs.ServiceRouter, "main", nil),
},
overrides: map[structs.ConfigEntryKindName]structs.ConfigEntry{
{Kind: structs.ServiceRouter, Name: "main"}: nil,
structs.NewConfigEntryKindName(structs.ServiceRouter, "main", nil): nil,
},
expectAfter: []structs.ConfigEntryKindName{
{Kind: structs.ServiceDefaults, Name: "main"},
structs.NewConfigEntryKindName(structs.ServiceDefaults, "main", nil),
},
},
{
@ -977,12 +977,12 @@ func TestStore_ReadDiscoveryChainConfigEntries_Overrides(t *testing.T) {
},
},
expectBefore: []structs.ConfigEntryKindName{
{Kind: structs.ServiceDefaults, Name: "main"},
{Kind: structs.ServiceResolver, Name: "main"},
{Kind: structs.ServiceRouter, Name: "main"},
structs.NewConfigEntryKindName(structs.ServiceDefaults, "main", nil),
structs.NewConfigEntryKindName(structs.ServiceResolver, "main", nil),
structs.NewConfigEntryKindName(structs.ServiceRouter, "main", nil),
},
overrides: map[structs.ConfigEntryKindName]structs.ConfigEntry{
{Kind: structs.ServiceRouter, Name: "main"}: &structs.ServiceRouterConfigEntry{
structs.NewConfigEntryKindName(structs.ServiceRouter, "main", nil): &structs.ServiceRouterConfigEntry{
Kind: structs.ServiceRouter,
Name: "main",
Routes: []structs.ServiceRoute{
@ -1000,9 +1000,9 @@ func TestStore_ReadDiscoveryChainConfigEntries_Overrides(t *testing.T) {
},
},
expectAfter: []structs.ConfigEntryKindName{
{Kind: structs.ServiceDefaults, Name: "main"},
{Kind: structs.ServiceResolver, Name: "main"},
{Kind: structs.ServiceRouter, Name: "main"},
structs.NewConfigEntryKindName(structs.ServiceDefaults, "main", nil),
structs.NewConfigEntryKindName(structs.ServiceResolver, "main", nil),
structs.NewConfigEntryKindName(structs.ServiceRouter, "main", nil),
},
checkAfter: func(t *testing.T, entrySet *structs.DiscoveryChainConfigEntries) {
router := entrySet.GetRouter(structs.NewServiceID("main", nil))
@ -1040,14 +1040,14 @@ func TestStore_ReadDiscoveryChainConfigEntries_Overrides(t *testing.T) {
},
},
expectBefore: []structs.ConfigEntryKindName{
{Kind: structs.ServiceDefaults, Name: "main"},
{Kind: structs.ServiceSplitter, Name: "main"},
structs.NewConfigEntryKindName(structs.ServiceDefaults, "main", nil),
structs.NewConfigEntryKindName(structs.ServiceSplitter, "main", nil),
},
overrides: map[structs.ConfigEntryKindName]structs.ConfigEntry{
{Kind: structs.ServiceSplitter, Name: "main"}: nil,
structs.NewConfigEntryKindName(structs.ServiceSplitter, "main", nil): nil,
},
expectAfter: []structs.ConfigEntryKindName{
{Kind: structs.ServiceDefaults, Name: "main"},
structs.NewConfigEntryKindName(structs.ServiceDefaults, "main", nil),
},
},
{
@ -1067,11 +1067,11 @@ func TestStore_ReadDiscoveryChainConfigEntries_Overrides(t *testing.T) {
},
},
expectBefore: []structs.ConfigEntryKindName{
{Kind: structs.ServiceDefaults, Name: "main"},
{Kind: structs.ServiceSplitter, Name: "main"},
structs.NewConfigEntryKindName(structs.ServiceDefaults, "main", nil),
structs.NewConfigEntryKindName(structs.ServiceSplitter, "main", nil),
},
overrides: map[structs.ConfigEntryKindName]structs.ConfigEntry{
{Kind: structs.ServiceSplitter, Name: "main"}: &structs.ServiceSplitterConfigEntry{
structs.NewConfigEntryKindName(structs.ServiceSplitter, "main", nil): &structs.ServiceSplitterConfigEntry{
Kind: structs.ServiceSplitter,
Name: "main",
Splits: []structs.ServiceSplit{
@ -1081,8 +1081,8 @@ func TestStore_ReadDiscoveryChainConfigEntries_Overrides(t *testing.T) {
},
},
expectAfter: []structs.ConfigEntryKindName{
{Kind: structs.ServiceDefaults, Name: "main"},
{Kind: structs.ServiceSplitter, Name: "main"},
structs.NewConfigEntryKindName(structs.ServiceDefaults, "main", nil),
structs.NewConfigEntryKindName(structs.ServiceSplitter, "main", nil),
},
checkAfter: func(t *testing.T, entrySet *structs.DiscoveryChainConfigEntries) {
splitter := entrySet.GetSplitter(structs.NewServiceID("main", nil))
@ -1106,10 +1106,10 @@ func TestStore_ReadDiscoveryChainConfigEntries_Overrides(t *testing.T) {
},
},
expectBefore: []structs.ConfigEntryKindName{
{Kind: structs.ServiceResolver, Name: "main"},
structs.NewConfigEntryKindName(structs.ServiceResolver, "main", nil),
},
overrides: map[structs.ConfigEntryKindName]structs.ConfigEntry{
{Kind: structs.ServiceResolver, Name: "main"}: nil,
structs.NewConfigEntryKindName(structs.ServiceResolver, "main", nil): nil,
},
expectAfter: []structs.ConfigEntryKindName{
// nothing
@ -1124,17 +1124,17 @@ func TestStore_ReadDiscoveryChainConfigEntries_Overrides(t *testing.T) {
},
},
expectBefore: []structs.ConfigEntryKindName{
{Kind: structs.ServiceResolver, Name: "main"},
structs.NewConfigEntryKindName(structs.ServiceResolver, "main", nil),
},
overrides: map[structs.ConfigEntryKindName]structs.ConfigEntry{
{Kind: structs.ServiceResolver, Name: "main"}: &structs.ServiceResolverConfigEntry{
structs.NewConfigEntryKindName(structs.ServiceResolver, "main", nil): &structs.ServiceResolverConfigEntry{
Kind: structs.ServiceResolver,
Name: "main",
ConnectTimeout: 33 * time.Second,
},
},
expectAfter: []structs.ConfigEntryKindName{
{Kind: structs.ServiceResolver, Name: "main"},
structs.NewConfigEntryKindName(structs.ServiceResolver, "main", nil),
},
checkAfter: func(t *testing.T, entrySet *structs.DiscoveryChainConfigEntries) {
resolver := entrySet.GetResolver(structs.NewServiceID("main", nil))
@ -1181,28 +1181,32 @@ func TestStore_ReadDiscoveryChainConfigEntries_Overrides(t *testing.T) {
func entrySetToKindNames(entrySet *structs.DiscoveryChainConfigEntries) []structs.ConfigEntryKindName {
var out []structs.ConfigEntryKindName
for _, entry := range entrySet.Routers {
out = append(out, structs.ConfigEntryKindName{
Kind: entry.Kind,
Name: entry.Name,
})
out = append(out, structs.NewConfigEntryKindName(
entry.Kind,
entry.Name,
&entry.EnterpriseMeta,
))
}
for _, entry := range entrySet.Splitters {
out = append(out, structs.ConfigEntryKindName{
Kind: entry.Kind,
Name: entry.Name,
})
out = append(out, structs.NewConfigEntryKindName(
entry.Kind,
entry.Name,
&entry.EnterpriseMeta,
))
}
for _, entry := range entrySet.Resolvers {
out = append(out, structs.ConfigEntryKindName{
Kind: entry.Kind,
Name: entry.Name,
})
out = append(out, structs.NewConfigEntryKindName(
entry.Kind,
entry.Name,
&entry.EnterpriseMeta,
))
}
for _, entry := range entrySet.Services {
out = append(out, structs.ConfigEntryKindName{
Kind: entry.Kind,
Name: entry.Name,
})
out = append(out, structs.NewConfigEntryKindName(
entry.Kind,
entry.Name,
&entry.EnterpriseMeta,
))
}
return out
}

View File

@ -15,6 +15,13 @@ type ReadTxn interface {
Abort()
}
// WriteTxn is implemented by memdb.Txn to perform write operations.
type WriteTxn interface {
ReadTxn
Insert(table string, obj interface{}) error
Commit() error
}
// Changes wraps a memdb.Changes to include the index at which these changes
// were made.
type Changes struct {
@ -24,8 +31,9 @@ type Changes struct {
}
// changeTrackerDB is a thin wrapper around memdb.DB which enables TrackChanges on
// all write transactions. When the transaction is committed the changes are
// sent to the eventPublisher which will create and emit change events.
// all write transactions. When the transaction is committed the changes are:
// 1. Used to update our internal usage tracking
// 2. Sent to the eventPublisher which will create and emit change events
type changeTrackerDB struct {
db *memdb.MemDB
publisher eventPublisher
@ -77,11 +85,8 @@ func (c *changeTrackerDB) WriteTxn(idx uint64) *txn {
return t
}
func (c *changeTrackerDB) publish(changes Changes) error {
readOnlyTx := c.db.Txn(false)
defer readOnlyTx.Abort()
events, err := c.processChanges(readOnlyTx, changes)
func (c *changeTrackerDB) publish(tx ReadTxn, changes Changes) error {
events, err := c.processChanges(tx, changes)
if err != nil {
return fmt.Errorf("failed generating events from changes: %v", err)
}
@ -89,17 +94,21 @@ func (c *changeTrackerDB) publish(changes Changes) error {
return nil
}
// WriteTxnRestore returns a wrapped RW transaction that does NOT have change
// tracking enabled. This should only be used in Restore where we need to
// replace the entire contents of the Store without a need to track the changes.
// WriteTxnRestore uses a zero index since the whole restore doesn't really occur
// at one index - the effect is to write many values that were previously
// written across many indexes.
// WriteTxnRestore returns a wrapped RW transaction that should only be used in
// Restore where we need to replace the entire contents of the Store.
// WriteTxnRestore uses a zero index since the whole restore doesn't really
// occur at one index - the effect is to write many values that were previously
// written across many indexes. WriteTxnRestore also does not publish any
// change events to subscribers.
func (c *changeTrackerDB) WriteTxnRestore() *txn {
return &txn{
t := &txn{
Txn: c.db.Txn(true),
Index: 0,
}
// We enable change tracking so that usage data is correctly populated.
t.Txn.TrackChanges()
return t
}
// txn wraps a memdb.Txn to capture changes and send them to the EventPublisher.
@ -115,7 +124,7 @@ type txn struct {
// Index is stored so that it may be passed along to any subscribers as part
// of a change event.
Index uint64
publish func(changes Changes) error
publish func(tx ReadTxn, changes Changes) error
}
// Commit first pushes changes to EventPublisher, then calls Commit on the
@ -125,15 +134,22 @@ type txn struct {
// by the caller. A non-nil error indicates that a commit failed and was not
// applied.
func (tx *txn) Commit() error {
changes := Changes{
Index: tx.Index,
Changes: tx.Txn.Changes(),
}
if len(changes.Changes) > 0 {
if err := updateUsage(tx, changes); err != nil {
return err
}
}
// publish may be nil if this is a read-only or WriteTxnRestore transaction.
// In those cases changes should also be empty, and there will be nothing
// to publish.
if tx.publish != nil {
changes := Changes{
Index: tx.Index,
Changes: tx.Txn.Changes(),
}
if err := tx.publish(changes); err != nil {
if err := tx.publish(tx.Txn, changes); err != nil {
return err
}
}
@ -149,11 +165,33 @@ func (t topic) String() string {
return string(t)
}
var (
// TopicServiceHealth contains events for all registered service instances.
TopicServiceHealth topic = "topic-service-health"
// TopicServiceHealthConnect contains events for connect-enabled service instances.
TopicServiceHealthConnect topic = "topic-service-health-connect"
)
func processDBChanges(tx ReadTxn, changes Changes) ([]stream.Event, error) {
// TODO: add other table handlers here.
return aclChangeUnsubscribeEvent(tx, changes)
var events []stream.Event
fns := []func(tx ReadTxn, changes Changes) ([]stream.Event, error){
aclChangeUnsubscribeEvent,
ServiceHealthEventsFromChanges,
// TODO: add other table handlers here.
}
for _, fn := range fns {
e, err := fn(tx, changes)
if err != nil {
return nil, err
}
events = append(events, e...)
}
return events, nil
}
func newSnapshotHandlers() stream.SnapshotHandlers {
return stream.SnapshotHandlers{}
func newSnapshotHandlers(s *Store) stream.SnapshotHandlers {
return stream.SnapshotHandlers{
TopicServiceHealth: serviceHealthSnapshot(s, TopicServiceHealth),
TopicServiceHealthConnect: serviceHealthSnapshot(s, TopicServiceHealthConnect),
}
}

View File

@ -7,30 +7,30 @@ import (
"github.com/hashicorp/go-memdb"
)
func firstWithTxn(tx *txn,
func firstWithTxn(tx ReadTxn,
table, index, idxVal string, entMeta *structs.EnterpriseMeta) (interface{}, error) {
return tx.First(table, index, idxVal)
}
func firstWatchWithTxn(tx *txn,
func firstWatchWithTxn(tx ReadTxn,
table, index, idxVal string, entMeta *structs.EnterpriseMeta) (<-chan struct{}, interface{}, error) {
return tx.FirstWatch(table, index, idxVal)
}
func firstWatchCompoundWithTxn(tx *txn,
func firstWatchCompoundWithTxn(tx ReadTxn,
table, index string, _ *structs.EnterpriseMeta, idxVals ...interface{}) (<-chan struct{}, interface{}, error) {
return tx.FirstWatch(table, index, idxVals...)
}
func getWithTxn(tx *txn,
func getWithTxn(tx ReadTxn,
table, index, idxVal string, entMeta *structs.EnterpriseMeta) (memdb.ResultIterator, error) {
return tx.Get(table, index, idxVal)
}
func getCompoundWithTxn(tx *txn, table, index string,
func getCompoundWithTxn(tx ReadTxn, table, index string,
_ *structs.EnterpriseMeta, idxVals ...interface{}) (memdb.ResultIterator, error) {
return tx.Get(table, index, idxVals...)

View File

@ -162,17 +162,17 @@ func NewStateStore(gc *TombstoneGC) (*Store, error) {
ctx, cancel := context.WithCancel(context.TODO())
s := &Store{
schema: schema,
abandonCh: make(chan struct{}),
kvsGraveyard: NewGraveyard(gc),
lockDelay: NewDelay(),
db: &changeTrackerDB{
db: db,
publisher: stream.NewEventPublisher(ctx, newSnapshotHandlers(), 10*time.Second),
processChanges: processDBChanges,
},
schema: schema,
abandonCh: make(chan struct{}),
kvsGraveyard: NewGraveyard(gc),
lockDelay: NewDelay(),
stopEventPublisher: cancel,
}
s.db = &changeTrackerDB{
db: db,
publisher: stream.NewEventPublisher(ctx, newSnapshotHandlers(s), 10*time.Second),
processChanges: processDBChanges,
}
return s, nil
}

View File

@ -376,7 +376,7 @@ var topicService stream.Topic = topic("test-topic-service")
func newTestSnapshotHandlers(s *Store) stream.SnapshotHandlers {
return stream.SnapshotHandlers{
topicService: func(req *stream.SubscribeRequest, snap stream.SnapshotAppender) (uint64, error) {
topicService: func(req stream.SubscribeRequest, snap stream.SnapshotAppender) (uint64, error) {
idx, nodes, err := s.ServiceNodes(nil, req.Key, nil)
if err != nil {
return idx, err

258
agent/consul/state/usage.go Normal file
View File

@ -0,0 +1,258 @@
package state
import (
"fmt"
"github.com/hashicorp/consul/agent/structs"
memdb "github.com/hashicorp/go-memdb"
)
const (
serviceNamesUsageTable = "service-names"
)
// usageTableSchema returns a new table schema used for tracking various indexes
// for the Raft log.
func usageTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: "usage",
Indexes: map[string]*memdb.IndexSchema{
"id": {
Name: "id",
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "ID",
Lowercase: true,
},
},
},
}
}
func init() {
registerSchema(usageTableSchema)
}
// UsageEntry represents a count of some arbitrary identifier within the
// state store, along with the last seen index.
type UsageEntry struct {
ID string
Index uint64
Count int
}
// ServiceUsage contains all of the usage data related to services
type ServiceUsage struct {
Services int
ServiceInstances int
EnterpriseServiceUsage
}
type uniqueServiceState int
const (
NoChange uniqueServiceState = 0
Deleted uniqueServiceState = 1
Created uniqueServiceState = 2
)
// updateUsage takes a set of memdb changes and computes a delta for specific
// usage metrics that we track.
func updateUsage(tx WriteTxn, changes Changes) error {
usageDeltas := make(map[string]int)
for _, change := range changes.Changes {
var delta int
if change.Created() {
delta = 1
} else if change.Deleted() {
delta = -1
}
switch change.Table {
case "nodes":
usageDeltas[change.Table] += delta
case "services":
svc := changeObject(change).(*structs.ServiceNode)
usageDeltas[change.Table] += delta
serviceIter, err := getWithTxn(tx, servicesTableName, "service", svc.ServiceName, &svc.EnterpriseMeta)
if err != nil {
return err
}
var serviceState uniqueServiceState
if serviceIter.Next() == nil {
// If no services exist, we know we deleted the last service
// instance.
serviceState = Deleted
usageDeltas[serviceNamesUsageTable] -= 1
} else if serviceIter.Next() == nil {
// If a second call to Next() returns nil, we know only a single
// instance exists. If, in addition, a new service name has been
// registered, either via creating a new service instance or via
// renaming an existing service, than we update our service count.
//
// We only care about two cases here:
// 1. A new service instance has been created with a unique name
// 2. An existing service instance has been updated with a new unique name
//
// These are the only ways a new unique service can be created. The
// other valid cases here: an update that does not change the service
// name, and a deletion, both do not impact the count of unique service
// names in the system.
if change.Created() {
// Given a single existing service instance of the service: If a
// service has just been created, then we know this is a new unique
// service.
serviceState = Created
usageDeltas[serviceNamesUsageTable] += 1
} else if serviceNameChanged(change) {
// Given a single existing service instance of the service: If a
// service has been updated with a new service name, then we know
// this is a new unique service.
serviceState = Created
usageDeltas[serviceNamesUsageTable] += 1
// Check whether the previous name was deleted in this rename, this
// is a special case of renaming a service which does not result in
// changing the count of unique service names.
before := change.Before.(*structs.ServiceNode)
beforeSvc, err := firstWithTxn(tx, servicesTableName, "service", before.ServiceName, &before.EnterpriseMeta)
if err != nil {
return err
}
if beforeSvc == nil {
usageDeltas[serviceNamesUsageTable] -= 1
// set serviceState to NoChange since we have both gained and lost a
// service, cancelling each other out
serviceState = NoChange
}
}
}
addEnterpriseServiceUsage(usageDeltas, change, serviceState)
}
}
idx := changes.Index
// This will happen when restoring from a snapshot, just take the max index
// of the tables we are tracking.
if idx == 0 {
idx = maxIndexTxn(tx, "nodes", servicesTableName)
}
return writeUsageDeltas(tx, idx, usageDeltas)
}
// serviceNameChanged returns a boolean that indicates whether the
// provided change resulted in an update to the service's service name.
func serviceNameChanged(change memdb.Change) bool {
if change.Updated() {
before := change.Before.(*structs.ServiceNode)
after := change.After.(*structs.ServiceNode)
return before.ServiceName != after.ServiceName
}
return false
}
// writeUsageDeltas will take in a map of IDs to deltas and update each
// entry accordingly, checking for integer underflow. The index that is
// passed in will be recorded on the entry as well.
func writeUsageDeltas(tx WriteTxn, idx uint64, usageDeltas map[string]int) error {
for id, delta := range usageDeltas {
u, err := tx.First("usage", "id", id)
if err != nil {
return fmt.Errorf("failed to retrieve existing usage entry: %s", err)
}
if u == nil {
if delta < 0 {
return fmt.Errorf("failed to insert usage entry for %q: delta will cause a negative count", id)
}
err := tx.Insert("usage", &UsageEntry{
ID: id,
Count: delta,
Index: idx,
})
if err != nil {
return fmt.Errorf("failed to update usage entry: %s", err)
}
} else if cur, ok := u.(*UsageEntry); ok {
if cur.Count+delta < 0 {
return fmt.Errorf("failed to insert usage entry for %q: delta will cause a negative count", id)
}
err := tx.Insert("usage", &UsageEntry{
ID: id,
Count: cur.Count + delta,
Index: idx,
})
if err != nil {
return fmt.Errorf("failed to update usage entry: %s", err)
}
}
}
return nil
}
// NodeCount returns the latest seen Raft index, a count of the number of nodes
// registered, and any errors.
func (s *Store) NodeCount() (uint64, int, error) {
tx := s.db.ReadTxn()
defer tx.Abort()
nodeUsage, err := firstUsageEntry(tx, "nodes")
if err != nil {
return 0, 0, fmt.Errorf("failed nodes lookup: %s", err)
}
return nodeUsage.Index, nodeUsage.Count, nil
}
// ServiceUsage returns the latest seen Raft index, a compiled set of service
// usage data, and any errors.
func (s *Store) ServiceUsage() (uint64, ServiceUsage, error) {
tx := s.db.ReadTxn()
defer tx.Abort()
serviceInstances, err := firstUsageEntry(tx, servicesTableName)
if err != nil {
return 0, ServiceUsage{}, fmt.Errorf("failed services lookup: %s", err)
}
services, err := firstUsageEntry(tx, serviceNamesUsageTable)
if err != nil {
return 0, ServiceUsage{}, fmt.Errorf("failed services lookup: %s", err)
}
usage := ServiceUsage{
ServiceInstances: serviceInstances.Count,
Services: services.Count,
}
results, err := compileEnterpriseUsage(tx, usage)
if err != nil {
return 0, ServiceUsage{}, fmt.Errorf("failed services lookup: %s", err)
}
return serviceInstances.Index, results, nil
}
func firstUsageEntry(tx ReadTxn, id string) (*UsageEntry, error) {
usage, err := tx.First("usage", "id", id)
if err != nil {
return nil, err
}
// If no elements have been inserted, the usage entry will not exist. We
// return a valid value so that can be certain the return value is not nil
// when no error has occurred.
if usage == nil {
return &UsageEntry{ID: id, Count: 0}, nil
}
realUsage, ok := usage.(*UsageEntry)
if !ok {
return nil, fmt.Errorf("failed usage lookup: type %T is not *UsageEntry", usage)
}
return realUsage, nil
}

View File

@ -0,0 +1,15 @@
// +build !consulent
package state
import (
memdb "github.com/hashicorp/go-memdb"
)
type EnterpriseServiceUsage struct{}
func addEnterpriseServiceUsage(map[string]int, memdb.Change, uniqueServiceState) {}
func compileEnterpriseUsage(tx ReadTxn, usage ServiceUsage) (ServiceUsage, error) {
return usage, nil
}

View File

@ -0,0 +1,25 @@
// +build !consulent
package state
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestStateStore_Usage_ServiceUsage(t *testing.T) {
s := testStateStore(t)
testRegisterNode(t, s, 0, "node1")
testRegisterNode(t, s, 1, "node2")
testRegisterService(t, s, 8, "node1", "service1")
testRegisterService(t, s, 9, "node2", "service1")
testRegisterService(t, s, 10, "node2", "service2")
idx, usage, err := s.ServiceUsage()
require.NoError(t, err)
require.Equal(t, idx, uint64(10))
require.Equal(t, 2, usage.Services)
require.Equal(t, 3, usage.ServiceInstances)
}

View File

@ -0,0 +1,194 @@
package state
import (
"testing"
"github.com/hashicorp/consul/agent/structs"
memdb "github.com/hashicorp/go-memdb"
"github.com/stretchr/testify/require"
)
func TestStateStore_Usage_NodeCount(t *testing.T) {
s := testStateStore(t)
// No nodes have been registered, and thus no usage entry exists
idx, count, err := s.NodeCount()
require.NoError(t, err)
require.Equal(t, idx, uint64(0))
require.Equal(t, count, 0)
testRegisterNode(t, s, 0, "node1")
testRegisterNode(t, s, 1, "node2")
idx, count, err = s.NodeCount()
require.NoError(t, err)
require.Equal(t, idx, uint64(1))
require.Equal(t, count, 2)
}
func TestStateStore_Usage_NodeCount_Delete(t *testing.T) {
s := testStateStore(t)
testRegisterNode(t, s, 0, "node1")
testRegisterNode(t, s, 1, "node2")
idx, count, err := s.NodeCount()
require.NoError(t, err)
require.Equal(t, idx, uint64(1))
require.Equal(t, count, 2)
require.NoError(t, s.DeleteNode(2, "node2"))
idx, count, err = s.NodeCount()
require.NoError(t, err)
require.Equal(t, idx, uint64(2))
require.Equal(t, count, 1)
}
func TestStateStore_Usage_ServiceUsageEmpty(t *testing.T) {
s := testStateStore(t)
// No services have been registered, and thus no usage entry exists
idx, usage, err := s.ServiceUsage()
require.NoError(t, err)
require.Equal(t, idx, uint64(0))
require.Equal(t, usage.Services, 0)
require.Equal(t, usage.ServiceInstances, 0)
}
func TestStateStore_Usage_Restore(t *testing.T) {
s := testStateStore(t)
restore := s.Restore()
restore.Registration(9, &structs.RegisterRequest{
Node: "test-node",
Service: &structs.NodeService{
ID: "mysql",
Service: "mysql",
Port: 8080,
Address: "198.18.0.2",
},
})
require.NoError(t, restore.Commit())
idx, count, err := s.NodeCount()
require.NoError(t, err)
require.Equal(t, idx, uint64(9))
require.Equal(t, count, 1)
}
func TestStateStore_Usage_updateUsage_Underflow(t *testing.T) {
s := testStateStore(t)
txn := s.db.WriteTxn(1)
// A single delete change will cause a negative count
changes := Changes{
Index: 1,
Changes: memdb.Changes{
{
Table: "nodes",
Before: &structs.Node{},
After: nil,
},
},
}
err := updateUsage(txn, changes)
require.Error(t, err)
require.Contains(t, err.Error(), "negative count")
// A insert a change to create a usage entry
changes = Changes{
Index: 1,
Changes: memdb.Changes{
{
Table: "nodes",
Before: nil,
After: &structs.Node{},
},
},
}
err = updateUsage(txn, changes)
require.NoError(t, err)
// Two deletes will cause a negative count now
changes = Changes{
Index: 1,
Changes: memdb.Changes{
{
Table: "nodes",
Before: &structs.Node{},
After: nil,
},
{
Table: "nodes",
Before: &structs.Node{},
After: nil,
},
},
}
err = updateUsage(txn, changes)
require.Error(t, err)
require.Contains(t, err.Error(), "negative count")
}
func TestStateStore_Usage_ServiceUsage_updatingServiceName(t *testing.T) {
s := testStateStore(t)
testRegisterNode(t, s, 1, "node1")
testRegisterService(t, s, 1, "node1", "service1")
t.Run("rename service with a single instance", func(t *testing.T) {
svc := &structs.NodeService{
ID: "service1",
Service: "after",
Address: "1.1.1.1",
Port: 1111,
}
require.NoError(t, s.EnsureService(2, "node1", svc))
// We renamed a service with a single instance, so we maintain 1 service.
idx, usage, err := s.ServiceUsage()
require.NoError(t, err)
require.Equal(t, idx, uint64(2))
require.Equal(t, usage.Services, 1)
require.Equal(t, usage.ServiceInstances, 1)
})
t.Run("rename service with a multiple instances", func(t *testing.T) {
svc2 := &structs.NodeService{
ID: "service2",
Service: "before",
Address: "1.1.1.2",
Port: 1111,
}
require.NoError(t, s.EnsureService(3, "node1", svc2))
svc3 := &structs.NodeService{
ID: "service3",
Service: "before",
Address: "1.1.1.3",
Port: 1111,
}
require.NoError(t, s.EnsureService(4, "node1", svc3))
idx, usage, err := s.ServiceUsage()
require.NoError(t, err)
require.Equal(t, idx, uint64(4))
require.Equal(t, usage.Services, 2)
require.Equal(t, usage.ServiceInstances, 3)
update := &structs.NodeService{
ID: "service2",
Service: "another-name",
Address: "1.1.1.2",
Port: 1111,
}
require.NoError(t, s.EnsureService(5, "node1", update))
idx, usage, err = s.ServiceUsage()
require.NoError(t, err)
require.Equal(t, idx, uint64(5))
require.Equal(t, usage.Services, 3)
require.Equal(t, usage.ServiceInstances, 3)
})
}

View File

@ -61,7 +61,11 @@ type changeEvents struct {
// SnapshotHandlers is a mapping of Topic to a function which produces a snapshot
// of events for the SubscribeRequest. Events are appended to the snapshot using SnapshotAppender.
// The nil Topic is reserved and should not be used.
type SnapshotHandlers map[Topic]func(*SubscribeRequest, SnapshotAppender) (index uint64, err error)
type SnapshotHandlers map[Topic]SnapshotFunc
// SnapshotFunc builds a snapshot for the subscription request, and appends the
// events to the Snapshot using SnapshotAppender.
type SnapshotFunc func(SubscribeRequest, SnapshotAppender) (index uint64, err error)
// SnapshotAppender appends groups of events to create a Snapshot of state.
type SnapshotAppender interface {

View File

@ -58,7 +58,7 @@ func TestEventPublisher_PublishChangesAndSubscribe_WithSnapshot(t *testing.T) {
func newTestSnapshotHandlers() SnapshotHandlers {
return SnapshotHandlers{
testTopic: func(req *SubscribeRequest, buf SnapshotAppender) (uint64, error) {
testTopic: func(req SubscribeRequest, buf SnapshotAppender) (uint64, error) {
if req.Topic != testTopic {
return 0, fmt.Errorf("unexpected topic: %v", req.Topic)
}
@ -117,7 +117,7 @@ func TestEventPublisher_ShutdownClosesSubscriptions(t *testing.T) {
t.Cleanup(cancel)
handlers := newTestSnapshotHandlers()
fn := func(req *SubscribeRequest, buf SnapshotAppender) (uint64, error) {
fn := func(req SubscribeRequest, buf SnapshotAppender) (uint64, error) {
return 0, nil
}
handlers[intTopic(22)] = fn

View File

@ -18,8 +18,6 @@ type eventSnapshot struct {
snapBuffer *eventBuffer
}
type snapFunc func(req *SubscribeRequest, buf SnapshotAppender) (uint64, error)
// newEventSnapshot creates a snapshot buffer based on the subscription request.
// The current buffer head for the topic requested is passed so that once the
// snapshot is complete and has been delivered into the buffer, any events
@ -27,7 +25,7 @@ type snapFunc func(req *SubscribeRequest, buf SnapshotAppender) (uint64, error)
// missed. Once the snapshot is delivered the topic buffer is spliced onto the
// snapshot buffer so that subscribers will naturally follow from the snapshot
// to wait for any subsequent updates.
func newEventSnapshot(req *SubscribeRequest, topicBufferHead *bufferItem, fn snapFunc) *eventSnapshot {
func newEventSnapshot(req *SubscribeRequest, topicBufferHead *bufferItem, fn SnapshotFunc) *eventSnapshot {
buf := newEventBuffer()
s := &eventSnapshot{
Head: buf.Head(),
@ -35,7 +33,7 @@ func newEventSnapshot(req *SubscribeRequest, topicBufferHead *bufferItem, fn sna
}
go func() {
idx, err := fn(req, s.snapBuffer)
idx, err := fn(*req, s.snapBuffer)
if err != nil {
s.snapBuffer.AppendItem(&bufferItem{Err: err})
return

View File

@ -161,8 +161,8 @@ func genSequentialIDs(start, end int) []string {
return ids
}
func testHealthConsecutiveSnapshotFn(size int, index uint64) snapFunc {
return func(req *SubscribeRequest, buf SnapshotAppender) (uint64, error) {
func testHealthConsecutiveSnapshotFn(size int, index uint64) SnapshotFunc {
return func(req SubscribeRequest, buf SnapshotAppender) (uint64, error) {
for i := 0; i < size; i++ {
// Event content is arbitrary we are just using Health because it's the
// first type defined. We just want a set of things with consecutive

View File

@ -0,0 +1,135 @@
package usagemetrics
import (
"context"
"errors"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/logging"
"github.com/hashicorp/go-hclog"
)
// Config holds the settings for various parameters for the
// UsageMetricsReporter
type Config struct {
logger hclog.Logger
metricLabels []metrics.Label
stateProvider StateProvider
tickerInterval time.Duration
}
// WithDatacenter adds the datacenter as a label to all metrics emitted by the
// UsageMetricsReporter
func (c *Config) WithDatacenter(dc string) *Config {
c.metricLabels = append(c.metricLabels, metrics.Label{Name: "datacenter", Value: dc})
return c
}
// WithLogger takes a logger and creates a new, named sub-logger to use when
// running
func (c *Config) WithLogger(logger hclog.Logger) *Config {
c.logger = logger.Named(logging.UsageMetrics)
return c
}
// WithReportingInterval specifies the interval on which UsageMetricsReporter
// should emit metrics
func (c *Config) WithReportingInterval(dur time.Duration) *Config {
c.tickerInterval = dur
return c
}
func (c *Config) WithStateProvider(sp StateProvider) *Config {
c.stateProvider = sp
return c
}
// StateProvider defines an inteface for retrieving a state.Store handle. In
// non-test code, this is satisfied by the fsm.FSM struct.
type StateProvider interface {
State() *state.Store
}
// UsageMetricsReporter provides functionality for emitting usage metrics into
// the metrics stream. This makes it essentially a translation layer
// between the state store and metrics stream.
type UsageMetricsReporter struct {
logger hclog.Logger
metricLabels []metrics.Label
stateProvider StateProvider
tickerInterval time.Duration
}
func NewUsageMetricsReporter(cfg *Config) (*UsageMetricsReporter, error) {
if cfg.stateProvider == nil {
return nil, errors.New("must provide a StateProvider to usage reporter")
}
if cfg.logger == nil {
cfg.logger = hclog.NewNullLogger()
}
if cfg.tickerInterval == 0 {
// Metrics are aggregated every 10 seconds, so we default to that.
cfg.tickerInterval = 10 * time.Second
}
u := &UsageMetricsReporter{
logger: cfg.logger,
stateProvider: cfg.stateProvider,
metricLabels: cfg.metricLabels,
tickerInterval: cfg.tickerInterval,
}
return u, nil
}
// Run must be run in a goroutine, and can be stopped by closing or sending
// data to the passed in shutdownCh
func (u *UsageMetricsReporter) Run(ctx context.Context) {
ticker := time.NewTicker(u.tickerInterval)
for {
select {
case <-ctx.Done():
u.logger.Debug("usage metrics reporter shutting down")
ticker.Stop()
return
case <-ticker.C:
u.runOnce()
}
}
}
func (u *UsageMetricsReporter) runOnce() {
state := u.stateProvider.State()
_, nodes, err := state.NodeCount()
if err != nil {
u.logger.Warn("failed to retrieve nodes from state store", "error", err)
}
metrics.SetGaugeWithLabels(
[]string{"consul", "state", "nodes"},
float32(nodes),
u.metricLabels,
)
_, serviceUsage, err := state.ServiceUsage()
if err != nil {
u.logger.Warn("failed to retrieve services from state store", "error", err)
}
metrics.SetGaugeWithLabels(
[]string{"consul", "state", "services"},
float32(serviceUsage.Services),
u.metricLabels,
)
metrics.SetGaugeWithLabels(
[]string{"consul", "state", "service_instances"},
float32(serviceUsage.ServiceInstances),
u.metricLabels,
)
u.emitEnterpriseUsage(serviceUsage)
}

View File

@ -0,0 +1,7 @@
// +build !consulent
package usagemetrics
import "github.com/hashicorp/consul/agent/consul/state"
func (u *UsageMetricsReporter) emitEnterpriseUsage(state.ServiceUsage) {}

View File

@ -0,0 +1,9 @@
// +build !consulent
package usagemetrics
import "github.com/hashicorp/consul/agent/consul/state"
func newStateStore() (*state.Store, error) {
return state.NewStateStore(nil)
}

View File

@ -0,0 +1,128 @@
package usagemetrics
import (
"testing"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type mockStateProvider struct {
mock.Mock
}
func (m *mockStateProvider) State() *state.Store {
retValues := m.Called()
return retValues.Get(0).(*state.Store)
}
func TestUsageReporter_Run(t *testing.T) {
type testCase struct {
modfiyStateStore func(t *testing.T, s *state.Store)
expectedGauges map[string]metrics.GaugeValue
}
cases := map[string]testCase{
"empty-state": {
expectedGauges: map[string]metrics.GaugeValue{
"consul.usage.test.consul.state.nodes;datacenter=dc1": {
Name: "consul.usage.test.consul.state.nodes",
Value: 0,
Labels: []metrics.Label{{Name: "datacenter", Value: "dc1"}},
},
"consul.usage.test.consul.state.services;datacenter=dc1": {
Name: "consul.usage.test.consul.state.services",
Value: 0,
Labels: []metrics.Label{
{Name: "datacenter", Value: "dc1"},
},
},
"consul.usage.test.consul.state.service_instances;datacenter=dc1": {
Name: "consul.usage.test.consul.state.service_instances",
Value: 0,
Labels: []metrics.Label{
{Name: "datacenter", Value: "dc1"},
},
},
},
},
"nodes-and-services": {
modfiyStateStore: func(t *testing.T, s *state.Store) {
require.Nil(t, s.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
require.Nil(t, s.EnsureNode(2, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
require.Nil(t, s.EnsureNode(3, &structs.Node{Node: "baz", Address: "127.0.0.2"}))
// Typical services and some consul services spread across two nodes
require.Nil(t, s.EnsureService(4, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
require.Nil(t, s.EnsureService(5, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
require.Nil(t, s.EnsureService(6, "foo", &structs.NodeService{ID: "consul", Service: "consul", Tags: nil}))
require.Nil(t, s.EnsureService(7, "bar", &structs.NodeService{ID: "consul", Service: "consul", Tags: nil}))
},
expectedGauges: map[string]metrics.GaugeValue{
"consul.usage.test.consul.state.nodes;datacenter=dc1": {
Name: "consul.usage.test.consul.state.nodes",
Value: 3,
Labels: []metrics.Label{{Name: "datacenter", Value: "dc1"}},
},
"consul.usage.test.consul.state.services;datacenter=dc1": {
Name: "consul.usage.test.consul.state.services",
Value: 3,
Labels: []metrics.Label{
{Name: "datacenter", Value: "dc1"},
},
},
"consul.usage.test.consul.state.service_instances;datacenter=dc1": {
Name: "consul.usage.test.consul.state.service_instances",
Value: 4,
Labels: []metrics.Label{
{Name: "datacenter", Value: "dc1"},
},
},
},
},
}
for name, tcase := range cases {
t.Run(name, func(t *testing.T) {
// Only have a single interval for the test
sink := metrics.NewInmemSink(1*time.Minute, 1*time.Minute)
cfg := metrics.DefaultConfig("consul.usage.test")
cfg.EnableHostname = false
metrics.NewGlobal(cfg, sink)
mockStateProvider := &mockStateProvider{}
s, err := newStateStore()
require.NoError(t, err)
if tcase.modfiyStateStore != nil {
tcase.modfiyStateStore(t, s)
}
mockStateProvider.On("State").Return(s)
reporter, err := NewUsageMetricsReporter(
new(Config).
WithStateProvider(mockStateProvider).
WithLogger(testutil.Logger(t)).
WithDatacenter("dc1"),
)
require.NoError(t, err)
reporter.runOnce()
intervals := sink.Data()
require.Len(t, intervals, 1)
intv := intervals[0]
// Range over the expected values instead of just doing an Equal
// comparison on the maps because of different metrics emitted between
// OSS and Ent. The enterprise tests have a full equality comparison on
// the maps.
for key, expected := range tcase.expectedGauges {
require.Equal(t, expected, intv.Gauges[key])
}
})
}
}

View File

@ -156,7 +156,7 @@ func (c *Client) CheckServers(datacenter string, fn func(*metadata.Server) bool)
return
}
c.routers.CheckServers(fn)
c.router.CheckServers(datacenter, fn)
}
type serversACLMode struct {

109
agent/grpc/handler.go Normal file
View File

@ -0,0 +1,109 @@
/*
Package grpc provides a Handler and client for agent gRPC connections.
*/
package grpc
import (
"fmt"
"net"
"google.golang.org/grpc"
)
// NewHandler returns a gRPC server that accepts connections from Handle(conn).
func NewHandler(addr net.Addr) *Handler {
// We don't need to pass tls.Config to the server since it's multiplexed
// behind the RPC listener, which already has TLS configured.
srv := grpc.NewServer(
grpc.StatsHandler(newStatsHandler()),
grpc.StreamInterceptor((&activeStreamCounter{}).Intercept),
)
// TODO(streaming): add gRPC services to srv here
return &Handler{
srv: srv,
listener: &chanListener{addr: addr, conns: make(chan net.Conn)},
}
}
// Handler implements a handler for the rpc server listener, and the
// agent.Component interface for managing the lifecycle of the grpc.Server.
type Handler struct {
srv *grpc.Server
listener *chanListener
}
// Handle the connection by sending it to a channel for the grpc.Server to receive.
func (h *Handler) Handle(conn net.Conn) {
h.listener.conns <- conn
}
func (h *Handler) Run() error {
return h.srv.Serve(h.listener)
}
func (h *Handler) Shutdown() error {
h.srv.Stop()
return nil
}
// chanListener implements net.Listener for grpc.Server.
type chanListener struct {
conns chan net.Conn
addr net.Addr
}
// Accept blocks until a connection is received from Handle, and then returns the
// connection. Accept implements part of the net.Listener interface for grpc.Server.
func (l *chanListener) Accept() (net.Conn, error) {
return <-l.conns, nil
}
func (l *chanListener) Addr() net.Addr {
return l.addr
}
// Close does nothing. The connections are managed by the caller.
func (l *chanListener) Close() error {
return nil
}
// NoOpHandler implements the same methods as Handler, but performs no handling.
// It may be used in place of Handler to disable the grpc server.
type NoOpHandler struct {
Logger Logger
}
type Logger interface {
Error(string, ...interface{})
}
func (h NoOpHandler) Handle(conn net.Conn) {
h.Logger.Error("gRPC conn opened but gRPC RPC is disabled, closing",
"conn", logConn(conn))
_ = conn.Close()
}
func (h NoOpHandler) Run() error {
return nil
}
func (h NoOpHandler) Shutdown() error {
return nil
}
// logConn is a local copy of github.com/hashicorp/memberlist.LogConn, to avoid
// a large dependency for a minor formatting function.
// logConn is used to keep log formatting consistent.
func logConn(conn net.Conn) string {
if conn == nil {
return "from=<unknown address>"
}
addr := conn.RemoteAddr()
if addr == nil {
return "from=<unknown address>"
}
return fmt.Sprintf("from=%s", addr.String())
}

View File

@ -0,0 +1,28 @@
// Code generated by protoc-gen-go-binary. DO NOT EDIT.
// source: agent/grpc/internal/testservice/simple.proto
package testservice
import (
"github.com/golang/protobuf/proto"
)
// MarshalBinary implements encoding.BinaryMarshaler
func (msg *Req) MarshalBinary() ([]byte, error) {
return proto.Marshal(msg)
}
// UnmarshalBinary implements encoding.BinaryUnmarshaler
func (msg *Req) UnmarshalBinary(b []byte) error {
return proto.Unmarshal(b, msg)
}
// MarshalBinary implements encoding.BinaryMarshaler
func (msg *Resp) MarshalBinary() ([]byte, error) {
return proto.Marshal(msg)
}
// UnmarshalBinary implements encoding.BinaryUnmarshaler
func (msg *Resp) UnmarshalBinary(b []byte) error {
return proto.Unmarshal(b, msg)
}

View File

@ -0,0 +1,742 @@
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: agent/grpc/internal/testservice/simple.proto
package testservice
import (
context "context"
fmt "fmt"
proto "github.com/golang/protobuf/proto"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
io "io"
math "math"
math_bits "math/bits"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type Req struct {
Datacenter string `protobuf:"bytes,1,opt,name=Datacenter,proto3" json:"Datacenter,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Req) Reset() { *m = Req{} }
func (m *Req) String() string { return proto.CompactTextString(m) }
func (*Req) ProtoMessage() {}
func (*Req) Descriptor() ([]byte, []int) {
return fileDescriptor_3009a77c573f826d, []int{0}
}
func (m *Req) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *Req) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_Req.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *Req) XXX_Merge(src proto.Message) {
xxx_messageInfo_Req.Merge(m, src)
}
func (m *Req) XXX_Size() int {
return m.Size()
}
func (m *Req) XXX_DiscardUnknown() {
xxx_messageInfo_Req.DiscardUnknown(m)
}
var xxx_messageInfo_Req proto.InternalMessageInfo
func (m *Req) GetDatacenter() string {
if m != nil {
return m.Datacenter
}
return ""
}
type Resp struct {
ServerName string `protobuf:"bytes,1,opt,name=ServerName,proto3" json:"ServerName,omitempty"`
Datacenter string `protobuf:"bytes,2,opt,name=Datacenter,proto3" json:"Datacenter,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Resp) Reset() { *m = Resp{} }
func (m *Resp) String() string { return proto.CompactTextString(m) }
func (*Resp) ProtoMessage() {}
func (*Resp) Descriptor() ([]byte, []int) {
return fileDescriptor_3009a77c573f826d, []int{1}
}
func (m *Resp) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *Resp) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_Resp.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *Resp) XXX_Merge(src proto.Message) {
xxx_messageInfo_Resp.Merge(m, src)
}
func (m *Resp) XXX_Size() int {
return m.Size()
}
func (m *Resp) XXX_DiscardUnknown() {
xxx_messageInfo_Resp.DiscardUnknown(m)
}
var xxx_messageInfo_Resp proto.InternalMessageInfo
func (m *Resp) GetServerName() string {
if m != nil {
return m.ServerName
}
return ""
}
func (m *Resp) GetDatacenter() string {
if m != nil {
return m.Datacenter
}
return ""
}
func init() {
proto.RegisterType((*Req)(nil), "testservice.Req")
proto.RegisterType((*Resp)(nil), "testservice.Resp")
}
func init() {
proto.RegisterFile("agent/grpc/internal/testservice/simple.proto", fileDescriptor_3009a77c573f826d)
}
var fileDescriptor_3009a77c573f826d = []byte{
// 206 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xd2, 0x49, 0x4c, 0x4f, 0xcd,
0x2b, 0xd1, 0x4f, 0x2f, 0x2a, 0x48, 0xd6, 0xcf, 0xcc, 0x2b, 0x49, 0x2d, 0xca, 0x4b, 0xcc, 0xd1,
0x2f, 0x49, 0x2d, 0x2e, 0x29, 0x4e, 0x2d, 0x2a, 0xcb, 0x4c, 0x4e, 0xd5, 0x2f, 0xce, 0xcc, 0x2d,
0xc8, 0x49, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x46, 0x92, 0x51, 0x52, 0xe5, 0x62,
0x0e, 0x4a, 0x2d, 0x14, 0x92, 0xe3, 0xe2, 0x72, 0x49, 0x2c, 0x49, 0x4c, 0x4e, 0x05, 0xe9, 0x96,
0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x42, 0x12, 0x51, 0x72, 0xe3, 0x62, 0x09, 0x4a, 0x2d, 0x2e,
0x00, 0xa9, 0x0b, 0x4e, 0x2d, 0x2a, 0x4b, 0x2d, 0xf2, 0x4b, 0xcc, 0x4d, 0x85, 0xa9, 0x43, 0x88,
0xa0, 0x99, 0xc3, 0x84, 0x6e, 0x8e, 0x51, 0x2e, 0x17, 0x5b, 0x30, 0xd8, 0x2d, 0x42, 0x46, 0x5c,
0x9c, 0xc1, 0xf9, 0xb9, 0xa9, 0x25, 0x19, 0x99, 0x79, 0xe9, 0x42, 0x02, 0x7a, 0x48, 0x6e, 0xd2,
0x0b, 0x4a, 0x2d, 0x94, 0x12, 0x44, 0x13, 0x29, 0x2e, 0x50, 0x62, 0x10, 0xd2, 0xe7, 0x62, 0x71,
0xcb, 0xc9, 0x2f, 0x27, 0x52, 0xb9, 0x01, 0xa3, 0x93, 0xc0, 0x89, 0x47, 0x72, 0x8c, 0x17, 0x1e,
0xc9, 0x31, 0x3e, 0x78, 0x24, 0xc7, 0x38, 0xe3, 0xb1, 0x1c, 0x43, 0x12, 0x1b, 0x38, 0x0c, 0x8c,
0x01, 0x01, 0x00, 0x00, 0xff, 0xff, 0xe7, 0x4b, 0x16, 0x40, 0x33, 0x01, 0x00, 0x00,
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// SimpleClient is the client API for Simple service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type SimpleClient interface {
Something(ctx context.Context, in *Req, opts ...grpc.CallOption) (*Resp, error)
Flow(ctx context.Context, in *Req, opts ...grpc.CallOption) (Simple_FlowClient, error)
}
type simpleClient struct {
cc *grpc.ClientConn
}
func NewSimpleClient(cc *grpc.ClientConn) SimpleClient {
return &simpleClient{cc}
}
func (c *simpleClient) Something(ctx context.Context, in *Req, opts ...grpc.CallOption) (*Resp, error) {
out := new(Resp)
err := c.cc.Invoke(ctx, "/testservice.Simple/Something", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *simpleClient) Flow(ctx context.Context, in *Req, opts ...grpc.CallOption) (Simple_FlowClient, error) {
stream, err := c.cc.NewStream(ctx, &_Simple_serviceDesc.Streams[0], "/testservice.Simple/Flow", opts...)
if err != nil {
return nil, err
}
x := &simpleFlowClient{stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
type Simple_FlowClient interface {
Recv() (*Resp, error)
grpc.ClientStream
}
type simpleFlowClient struct {
grpc.ClientStream
}
func (x *simpleFlowClient) Recv() (*Resp, error) {
m := new(Resp)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// SimpleServer is the server API for Simple service.
type SimpleServer interface {
Something(context.Context, *Req) (*Resp, error)
Flow(*Req, Simple_FlowServer) error
}
// UnimplementedSimpleServer can be embedded to have forward compatible implementations.
type UnimplementedSimpleServer struct {
}
func (*UnimplementedSimpleServer) Something(ctx context.Context, req *Req) (*Resp, error) {
return nil, status.Errorf(codes.Unimplemented, "method Something not implemented")
}
func (*UnimplementedSimpleServer) Flow(req *Req, srv Simple_FlowServer) error {
return status.Errorf(codes.Unimplemented, "method Flow not implemented")
}
func RegisterSimpleServer(s *grpc.Server, srv SimpleServer) {
s.RegisterService(&_Simple_serviceDesc, srv)
}
func _Simple_Something_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Req)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(SimpleServer).Something(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/testservice.Simple/Something",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(SimpleServer).Something(ctx, req.(*Req))
}
return interceptor(ctx, in, info, handler)
}
func _Simple_Flow_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(Req)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(SimpleServer).Flow(m, &simpleFlowServer{stream})
}
type Simple_FlowServer interface {
Send(*Resp) error
grpc.ServerStream
}
type simpleFlowServer struct {
grpc.ServerStream
}
func (x *simpleFlowServer) Send(m *Resp) error {
return x.ServerStream.SendMsg(m)
}
var _Simple_serviceDesc = grpc.ServiceDesc{
ServiceName: "testservice.Simple",
HandlerType: (*SimpleServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Something",
Handler: _Simple_Something_Handler,
},
},
Streams: []grpc.StreamDesc{
{
StreamName: "Flow",
Handler: _Simple_Flow_Handler,
ServerStreams: true,
},
},
Metadata: "agent/grpc/internal/testservice/simple.proto",
}
func (m *Req) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *Req) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *Req) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if m.XXX_unrecognized != nil {
i -= len(m.XXX_unrecognized)
copy(dAtA[i:], m.XXX_unrecognized)
}
if len(m.Datacenter) > 0 {
i -= len(m.Datacenter)
copy(dAtA[i:], m.Datacenter)
i = encodeVarintSimple(dAtA, i, uint64(len(m.Datacenter)))
i--
dAtA[i] = 0xa
}
return len(dAtA) - i, nil
}
func (m *Resp) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *Resp) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *Resp) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if m.XXX_unrecognized != nil {
i -= len(m.XXX_unrecognized)
copy(dAtA[i:], m.XXX_unrecognized)
}
if len(m.Datacenter) > 0 {
i -= len(m.Datacenter)
copy(dAtA[i:], m.Datacenter)
i = encodeVarintSimple(dAtA, i, uint64(len(m.Datacenter)))
i--
dAtA[i] = 0x12
}
if len(m.ServerName) > 0 {
i -= len(m.ServerName)
copy(dAtA[i:], m.ServerName)
i = encodeVarintSimple(dAtA, i, uint64(len(m.ServerName)))
i--
dAtA[i] = 0xa
}
return len(dAtA) - i, nil
}
func encodeVarintSimple(dAtA []byte, offset int, v uint64) int {
offset -= sovSimple(v)
base := offset
for v >= 1<<7 {
dAtA[offset] = uint8(v&0x7f | 0x80)
v >>= 7
offset++
}
dAtA[offset] = uint8(v)
return base
}
func (m *Req) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
l = len(m.Datacenter)
if l > 0 {
n += 1 + l + sovSimple(uint64(l))
}
if m.XXX_unrecognized != nil {
n += len(m.XXX_unrecognized)
}
return n
}
func (m *Resp) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
l = len(m.ServerName)
if l > 0 {
n += 1 + l + sovSimple(uint64(l))
}
l = len(m.Datacenter)
if l > 0 {
n += 1 + l + sovSimple(uint64(l))
}
if m.XXX_unrecognized != nil {
n += len(m.XXX_unrecognized)
}
return n
}
func sovSimple(x uint64) (n int) {
return (math_bits.Len64(x|1) + 6) / 7
}
func sozSimple(x uint64) (n int) {
return sovSimple(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
func (m *Req) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowSimple
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: Req: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: Req: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Datacenter", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowSimple
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthSimple
}
postIndex := iNdEx + intStringLen
if postIndex < 0 {
return ErrInvalidLengthSimple
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Datacenter = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipSimple(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthSimple
}
if (iNdEx + skippy) < 0 {
return ErrInvalidLengthSimple
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...)
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func (m *Resp) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowSimple
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: Resp: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: Resp: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field ServerName", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowSimple
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthSimple
}
postIndex := iNdEx + intStringLen
if postIndex < 0 {
return ErrInvalidLengthSimple
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.ServerName = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
case 2:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Datacenter", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowSimple
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthSimple
}
postIndex := iNdEx + intStringLen
if postIndex < 0 {
return ErrInvalidLengthSimple
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Datacenter = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipSimple(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthSimple
}
if (iNdEx + skippy) < 0 {
return ErrInvalidLengthSimple
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...)
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipSimple(dAtA []byte) (n int, err error) {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowSimple
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
wireType := int(wire & 0x7)
switch wireType {
case 0:
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowSimple
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
iNdEx++
if dAtA[iNdEx-1] < 0x80 {
break
}
}
return iNdEx, nil
case 1:
iNdEx += 8
return iNdEx, nil
case 2:
var length int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowSimple
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
length |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
if length < 0 {
return 0, ErrInvalidLengthSimple
}
iNdEx += length
if iNdEx < 0 {
return 0, ErrInvalidLengthSimple
}
return iNdEx, nil
case 3:
for {
var innerWire uint64
var start int = iNdEx
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowSimple
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
innerWire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
innerWireType := int(innerWire & 0x7)
if innerWireType == 4 {
break
}
next, err := skipSimple(dAtA[start:])
if err != nil {
return 0, err
}
iNdEx = start + next
if iNdEx < 0 {
return 0, ErrInvalidLengthSimple
}
}
return iNdEx, nil
case 4:
return iNdEx, nil
case 5:
iNdEx += 4
return iNdEx, nil
default:
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
}
}
panic("unreachable")
}
var (
ErrInvalidLengthSimple = fmt.Errorf("proto: negative length found during unmarshaling")
ErrIntOverflowSimple = fmt.Errorf("proto: integer overflow")
)

View File

@ -0,0 +1,18 @@
syntax = "proto3";
package testservice;
// Simple service is used to test gRPC plumbing.
service Simple {
rpc Something(Req) returns (Resp) {}
rpc Flow(Req) returns (stream Resp) {}
}
message Req {
string Datacenter = 1;
}
message Resp {
string ServerName = 1;
string Datacenter = 2;
}

28
agent/grpc/server_test.go Normal file
View File

@ -0,0 +1,28 @@
package grpc
import (
"context"
"time"
"github.com/hashicorp/consul/agent/grpc/internal/testservice"
)
type simple struct {
name string
dc string
}
func (s *simple) Flow(_ *testservice.Req, flow testservice.Simple_FlowServer) error {
for flow.Context().Err() == nil {
resp := &testservice.Resp{ServerName: "one", Datacenter: s.dc}
if err := flow.Send(resp); err != nil {
return err
}
time.Sleep(time.Millisecond)
}
return nil
}
func (s *simple) Something(_ context.Context, _ *testservice.Req) (*testservice.Resp, error) {
return &testservice.Resp{ServerName: s.name, Datacenter: s.dc}, nil
}

Some files were not shown because too many files have changed in this diff Show More