diff --git a/agent/cache-types/connect_ca.go b/agent/cache-types/connect_ca.go index 5b72a47a72..22549ed498 100644 --- a/agent/cache-types/connect_ca.go +++ b/agent/cache-types/connect_ca.go @@ -2,15 +2,27 @@ package cachetype import ( "fmt" + "sync" + "sync/atomic" + "time" "github.com/hashicorp/consul/agent/cache" + "github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/structs" + + // NOTE(mitcehllh): This is temporary while certs are stubbed out. + "github.com/mitchellh/go-testing-interface" ) -// Recommended name for registration for ConnectCARoot -const ConnectCARootName = "connect-ca" +// Recommended name for registration. +const ( + ConnectCARootName = "connect-ca-root" + ConnectCALeafName = "connect-ca-leaf" +) -// ConnectCARoot supports fetching the Connect CA roots. +// ConnectCARoot supports fetching the Connect CA roots. This is a +// straightforward cache type since it only has to block on the given +// index and return the data. type ConnectCARoot struct { RPC RPC } @@ -39,3 +51,167 @@ func (c *ConnectCARoot) Fetch(opts cache.FetchOptions, req cache.Request) (cache result.Index = reply.QueryMeta.Index return result, nil } + +// ConnectCALeaf supports fetching and generating Connect leaf +// certificates. +type ConnectCALeaf struct { + caIndex uint64 // Current index for CA roots + + issuedCertsLock sync.RWMutex + issuedCerts map[string]*structs.IssuedCert + + RPC RPC // RPC client for remote requests + Cache *cache.Cache // Cache that has CA root certs via ConnectCARoot +} + +func (c *ConnectCALeaf) Fetch(opts cache.FetchOptions, req cache.Request) (cache.FetchResult, error) { + var result cache.FetchResult + + // Get the correct type + reqReal, ok := req.(*ConnectCALeafRequest) + if !ok { + return result, fmt.Errorf( + "Internal cache failure: request wrong type: %T", req) + } + + // This channel watches our overall timeout. The other goroutines + // launched in this function should end all around the same time so + // they clean themselves up. + timeoutCh := time.After(opts.Timeout) + + // Kick off the goroutine that waits for new CA roots. The channel buffer + // is so that the goroutine doesn't block forever if we return for other + // reasons. + newRootCACh := make(chan error, 1) + go c.waitNewRootCA(newRootCACh, opts.Timeout) + + // Get our prior cert (if we had one) and use that to determine our + // expiration time. If no cert exists, we expire immediately since we + // need to generate. + c.issuedCertsLock.RLock() + lastCert := c.issuedCerts[reqReal.Service] + c.issuedCertsLock.RUnlock() + + var leafExpiryCh <-chan time.Time + if lastCert != nil { + // Determine how long we wait until triggering. If we've already + // expired, we trigger immediately. + if expiryDur := lastCert.ValidBefore.Sub(time.Now()); expiryDur > 0 { + leafExpiryCh = time.After(expiryDur - 1*time.Hour) + // TODO(mitchellh): 1 hour buffer is hardcoded above + } + } + + if leafExpiryCh == nil { + // If the channel is still nil then it means we need to generate + // a cert no matter what: we either don't have an existing one or + // it is expired. + leafExpiryCh = time.After(0) + } + + // Block on the events that wake us up. + select { + case <-timeoutCh: + // TODO: what is the right error for a timeout? + return result, fmt.Errorf("timeout") + + case err := <-newRootCACh: + // A new root CA triggers us to refresh the leaf certificate. + // If there was an error while getting the root CA then we return. + // Otherwise, we leave the select statement and move to generation. + if err != nil { + return result, err + } + + case <-leafExpiryCh: + // The existing leaf certificate is expiring soon, so we generate a + // new cert with a healthy overlapping validity period (determined + // by the above channel). + } + + // Create a CSR. + // TODO(mitchellh): This is obviously not production ready! + csr, pk := connect.TestCSR(&testing.RuntimeT{}, &connect.SpiffeIDService{ + Host: "1234.consul", + Namespace: "default", + Datacenter: reqReal.Datacenter, + Service: reqReal.Service, + }) + + // Request signing + var reply structs.IssuedCert + args := structs.CASignRequest{CSR: csr} + if err := c.RPC.RPC("ConnectCA.Sign", &args, &reply); err != nil { + return result, err + } + reply.PrivateKeyPEM = pk + + // Lock the issued certs map so we can insert it. We only insert if + // we didn't happen to get a newer one. This should never happen since + // the Cache should ensure only one Fetch per service, but we sanity + // check just in case. + c.issuedCertsLock.Lock() + defer c.issuedCertsLock.Unlock() + lastCert = c.issuedCerts[reqReal.Service] + if lastCert == nil || lastCert.ModifyIndex < reply.ModifyIndex { + if c.issuedCerts == nil { + c.issuedCerts = make(map[string]*structs.IssuedCert) + } + c.issuedCerts[reqReal.Service] = &reply + lastCert = &reply + } + + result.Value = lastCert + result.Index = lastCert.ModifyIndex + return result, nil +} + +// waitNewRootCA blocks until a new root CA is available or the timeout is +// reached (on timeout ErrTimeout is returned on the channel). +func (c *ConnectCALeaf) waitNewRootCA(ch chan<- error, timeout time.Duration) { + // Fetch some new roots. This will block until our MinQueryIndex is + // matched or the timeout is reached. + rawRoots, err := c.Cache.Get(ConnectCARootName, &structs.DCSpecificRequest{ + Datacenter: "", + QueryOptions: structs.QueryOptions{ + MinQueryIndex: atomic.LoadUint64(&c.caIndex), + MaxQueryTime: timeout, + }, + }) + if err != nil { + ch <- err + return + } + + roots, ok := rawRoots.(*structs.IndexedCARoots) + if !ok { + // This should never happen but we don't want to even risk a panic + ch <- fmt.Errorf( + "internal error: CA root cache returned bad type: %T", rawRoots) + return + } + + // Set the new index + atomic.StoreUint64(&c.caIndex, roots.QueryMeta.Index) + + // Trigger the channel since we updated. + ch <- nil +} + +// ConnectCALeafRequest is the cache.Request implementation for the +// COnnectCALeaf cache type. This is implemented here and not in structs +// since this is only used for cache-related requests and not forwarded +// directly to any Consul servers. +type ConnectCALeafRequest struct { + Datacenter string + Service string // Service name, not ID + MinQueryIndex uint64 +} + +func (r *ConnectCALeafRequest) CacheInfo() cache.RequestInfo { + return cache.RequestInfo{ + Key: r.Service, + Datacenter: r.Datacenter, + MinIndex: r.MinQueryIndex, + } +} diff --git a/agent/cache-types/connect_ca_test.go b/agent/cache-types/connect_ca_test.go index 24c37f3139..43953e7f8b 100644 --- a/agent/cache-types/connect_ca_test.go +++ b/agent/cache-types/connect_ca_test.go @@ -1,6 +1,8 @@ package cachetype import ( + "fmt" + "sync/atomic" "testing" "time" @@ -55,3 +57,196 @@ func TestConnectCARoot_badReqType(t *testing.T) { require.Contains(err.Error(), "wrong type") } + +// Test that after an initial signing, new CA roots (new ID) will +// trigger a blocking query to execute. +func TestConnectCALeaf_changingRoots(t *testing.T) { + t.Parallel() + + require := require.New(t) + rpc := TestRPC(t) + defer rpc.AssertExpectations(t) + + typ, rootsCh := testCALeafType(t, rpc) + defer close(rootsCh) + rootsCh <- structs.IndexedCARoots{ + ActiveRootID: "1", + QueryMeta: structs.QueryMeta{Index: 1}, + } + + // Instrument ConnectCA.Sign to + var resp *structs.IssuedCert + var idx uint64 + rpc.On("RPC", "ConnectCA.Sign", mock.Anything, mock.Anything).Return(nil). + Run(func(args mock.Arguments) { + reply := args.Get(2).(*structs.IssuedCert) + reply.ValidBefore = time.Now().Add(12 * time.Hour) + reply.CreateIndex = atomic.AddUint64(&idx, 1) + reply.ModifyIndex = reply.CreateIndex + resp = reply + }) + + // We'll reuse the fetch options and request + opts := cache.FetchOptions{MinIndex: 0, Timeout: 10 * time.Second} + req := &ConnectCALeafRequest{Datacenter: "dc1", Service: "web"} + + // First fetch should return immediately + fetchCh := TestFetchCh(t, typ, opts, req) + select { + case <-time.After(100 * time.Millisecond): + t.Fatal("shouldn't block waiting for fetch") + case result := <-fetchCh: + require.Equal(cache.FetchResult{ + Value: resp, + Index: 1, + }, result) + } + + // Second fetch should block with set index + fetchCh = TestFetchCh(t, typ, opts, req) + select { + case result := <-fetchCh: + t.Fatalf("should not return: %#v", result) + case <-time.After(100 * time.Millisecond): + } + + // Let's send in new roots, which should trigger the sign req + rootsCh <- structs.IndexedCARoots{ + ActiveRootID: "2", + QueryMeta: structs.QueryMeta{Index: 2}, + } + select { + case <-time.After(100 * time.Millisecond): + t.Fatal("shouldn't block waiting for fetch") + case result := <-fetchCh: + require.Equal(cache.FetchResult{ + Value: resp, + Index: 2, + }, result) + } + + // Third fetch should block + fetchCh = TestFetchCh(t, typ, opts, req) + select { + case result := <-fetchCh: + t.Fatalf("should not return: %#v", result) + case <-time.After(100 * time.Millisecond): + } +} + +// Test that after an initial signing, an expiringLeaf will trigger a +// blocking query to resign. +func TestConnectCALeaf_expiringLeaf(t *testing.T) { + t.Parallel() + + require := require.New(t) + rpc := TestRPC(t) + defer rpc.AssertExpectations(t) + + typ, rootsCh := testCALeafType(t, rpc) + defer close(rootsCh) + rootsCh <- structs.IndexedCARoots{ + ActiveRootID: "1", + QueryMeta: structs.QueryMeta{Index: 1}, + } + + // Instrument ConnectCA.Sign to + var resp *structs.IssuedCert + var idx uint64 + rpc.On("RPC", "ConnectCA.Sign", mock.Anything, mock.Anything).Return(nil). + Run(func(args mock.Arguments) { + reply := args.Get(2).(*structs.IssuedCert) + reply.CreateIndex = atomic.AddUint64(&idx, 1) + reply.ModifyIndex = reply.CreateIndex + + // This sets the validity to 0 on the first call, and + // 12 hours+ on subsequent calls. This means that our first + // cert expires immediately. + reply.ValidBefore = time.Now().Add((12 * time.Hour) * + time.Duration(reply.CreateIndex-1)) + + resp = reply + }) + + // We'll reuse the fetch options and request + opts := cache.FetchOptions{MinIndex: 0, Timeout: 10 * time.Second} + req := &ConnectCALeafRequest{Datacenter: "dc1", Service: "web"} + + // First fetch should return immediately + fetchCh := TestFetchCh(t, typ, opts, req) + select { + case <-time.After(100 * time.Millisecond): + t.Fatal("shouldn't block waiting for fetch") + case result := <-fetchCh: + require.Equal(cache.FetchResult{ + Value: resp, + Index: 1, + }, result) + } + + // Second fetch should return immediately despite there being + // no updated CA roots, because we issued an expired cert. + fetchCh = TestFetchCh(t, typ, opts, req) + select { + case <-time.After(100 * time.Millisecond): + t.Fatal("shouldn't block waiting for fetch") + case result := <-fetchCh: + require.Equal(cache.FetchResult{ + Value: resp, + Index: 2, + }, result) + } + + // Third fetch should block since the cert is not expiring and + // we also didn't update CA certs. + fetchCh = TestFetchCh(t, typ, opts, req) + select { + case result := <-fetchCh: + t.Fatalf("should not return: %#v", result) + case <-time.After(100 * time.Millisecond): + } +} + +// testCALeafType returns a *ConnectCALeaf that is pre-configured to +// use the given RPC implementation for "ConnectCA.Sign" operations. +func testCALeafType(t *testing.T, rpc RPC) (*ConnectCALeaf, chan structs.IndexedCARoots) { + // This creates an RPC implementation that will block until the + // value is sent on the channel. This lets us control when the + // next values show up. + rootsCh := make(chan structs.IndexedCARoots, 10) + rootsRPC := &testGatedRootsRPC{ValueCh: rootsCh} + + // Create a cache + c := cache.TestCache(t) + c.RegisterType(ConnectCARootName, &ConnectCARoot{RPC: rootsRPC}, &cache.RegisterOptions{ + // Disable refresh so that the gated channel controls the + // request directly. Otherwise, we get background refreshes and + // it screws up the ordering of the channel reads of the + // testGatedRootsRPC implementation. + Refresh: false, + }) + + // Create the leaf type + return &ConnectCALeaf{RPC: rpc, Cache: c}, rootsCh +} + +// testGatedRootsRPC will send each subsequent value on the channel as the +// RPC response, blocking if it is waiting for a value on the channel. This +// can be used to control when background fetches are returned and what they +// return. +// +// This should be used with Refresh = false for the registration options so +// automatic refreshes don't mess up the channel read ordering. +type testGatedRootsRPC struct { + ValueCh chan structs.IndexedCARoots +} + +func (r *testGatedRootsRPC) RPC(method string, args interface{}, reply interface{}) error { + if method != "ConnectCA.Roots" { + return fmt.Errorf("invalid RPC method: %s", method) + } + + replyReal := reply.(*structs.IndexedCARoots) + *replyReal = <-r.ValueCh + return nil +} diff --git a/agent/cache-types/testing.go b/agent/cache-types/testing.go index bf68ec4787..fcffe45a95 100644 --- a/agent/cache-types/testing.go +++ b/agent/cache-types/testing.go @@ -1,6 +1,10 @@ package cachetype import ( + "reflect" + "time" + + "github.com/hashicorp/consul/agent/cache" "github.com/mitchellh/go-testing-interface" ) @@ -10,3 +14,47 @@ func TestRPC(t testing.T) *MockRPC { // perform some initialization later. return &MockRPC{} } + +// TestFetchCh returns a channel that returns the result of the Fetch call. +// This is useful for testing timing and concurrency with Fetch calls. +// Errors will show up as an error type on the resulting channel so a +// type switch should be used. +func TestFetchCh( + t testing.T, + typ cache.Type, + opts cache.FetchOptions, + req cache.Request) <-chan interface{} { + resultCh := make(chan interface{}) + go func() { + result, err := typ.Fetch(opts, req) + if err != nil { + resultCh <- err + return + } + + resultCh <- result + }() + + return resultCh +} + +// TestFetchChResult tests that the result from TestFetchCh matches +// within a reasonable period of time (it expects it to be "immediate" but +// waits some milliseconds). +func TestFetchChResult(t testing.T, ch <-chan interface{}, expected interface{}) { + t.Helper() + + select { + case result := <-ch: + if err, ok := result.(error); ok { + t.Fatalf("Result was error: %s", err) + return + } + + if !reflect.DeepEqual(result, expected) { + t.Fatalf("Result doesn't match!\n\n%#v\n\n%#v", result, expected) + } + + case <-time.After(50 * time.Millisecond): + } +}