mirror of https://github.com/hashicorp/consul
307 lines
7.5 KiB
Go
307 lines
7.5 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package connect
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"reflect"
|
|
"sort"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/hashicorp/consul/agent"
|
|
"github.com/hashicorp/consul/agent/connect"
|
|
"github.com/hashicorp/consul/api"
|
|
"github.com/hashicorp/consul/sdk/testutil/retry"
|
|
"github.com/hashicorp/consul/testrpc"
|
|
)
|
|
|
|
// Assert io.Closer implementation
|
|
var _ io.Closer = new(Service)
|
|
|
|
func TestService_Name(t *testing.T) {
|
|
ca := connect.TestCA(t, nil)
|
|
s := TestService(t, "web", ca)
|
|
assert.Equal(t, "web", s.Name())
|
|
}
|
|
|
|
func TestService_Dial(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("too slow for testing.Short")
|
|
}
|
|
|
|
ca := connect.TestCA(t, nil)
|
|
|
|
tests := []struct {
|
|
name string
|
|
accept bool
|
|
handshake bool
|
|
presentService string
|
|
wantErr string
|
|
}{
|
|
{
|
|
name: "working",
|
|
accept: true,
|
|
handshake: true,
|
|
presentService: "db",
|
|
wantErr: "",
|
|
},
|
|
{
|
|
name: "tcp connect fail",
|
|
accept: false,
|
|
handshake: false,
|
|
presentService: "db",
|
|
wantErr: "connection refused",
|
|
},
|
|
{
|
|
name: "handshake timeout",
|
|
accept: true,
|
|
handshake: false,
|
|
presentService: "db",
|
|
wantErr: "i/o timeout",
|
|
},
|
|
{
|
|
name: "bad cert",
|
|
accept: true,
|
|
handshake: true,
|
|
presentService: "web",
|
|
wantErr: "peer certificate mismatch",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
s := TestService(t, "web", ca)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(),
|
|
100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
testSvr := NewTestServer(t, tt.presentService, ca)
|
|
testSvr.TimeoutHandshake = !tt.handshake
|
|
|
|
if tt.accept {
|
|
go func() {
|
|
err := testSvr.Serve()
|
|
require.NoError(t, err)
|
|
}()
|
|
<-testSvr.Listening
|
|
defer testSvr.Close()
|
|
}
|
|
|
|
// Always expect to be connecting to a "DB"
|
|
resolver := &StaticResolver{
|
|
Addr: testSvr.Addr,
|
|
CertURI: connect.TestSpiffeIDService(t, "db"),
|
|
}
|
|
|
|
// All test runs should complete in under 500ms due to the timeout about.
|
|
// Don't wait for whole test run to get stuck.
|
|
testTimeout := 500 * time.Millisecond
|
|
testTimer := time.AfterFunc(testTimeout, func() {
|
|
panic(fmt.Sprintf("test timed out after %s", testTimeout))
|
|
})
|
|
|
|
conn, err := s.Dial(ctx, resolver)
|
|
testTimer.Stop()
|
|
|
|
if tt.wantErr == "" {
|
|
require.NoError(t, err)
|
|
require.IsType(t, &tls.Conn{}, conn)
|
|
} else {
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), tt.wantErr)
|
|
}
|
|
|
|
if err == nil {
|
|
conn.Close()
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestService_ServerTLSConfig(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("too slow for testing.Short")
|
|
}
|
|
|
|
a := agent.StartTestAgent(t, agent.TestAgent{Name: "007", Overrides: `
|
|
connect {
|
|
test_ca_leaf_root_change_spread = "1ns"
|
|
}
|
|
`})
|
|
defer a.Shutdown()
|
|
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
|
|
client := a.Client()
|
|
agent := client.Agent()
|
|
|
|
// NewTestAgent setup a CA already by default
|
|
|
|
// Register a local agent service
|
|
reg := &api.AgentServiceRegistration{
|
|
Name: "web",
|
|
Port: 8080,
|
|
}
|
|
err := agent.ServiceRegister(reg)
|
|
require.NoError(t, err)
|
|
|
|
// Now we should be able to create a service that will eventually get it's TLS
|
|
// all by itself!
|
|
service, err := NewService("web", client)
|
|
require.NoError(t, err)
|
|
|
|
// Wait for it to be ready
|
|
select {
|
|
case <-service.ReadyWait():
|
|
// continue with test case below
|
|
case <-time.After(1 * time.Second):
|
|
t.Fatalf("timeout waiting for Service.ReadyWait after 1s")
|
|
}
|
|
|
|
tlsCfg := service.ServerTLSConfig()
|
|
|
|
// Sanity check it has a leaf with the right ServiceID and that validates with
|
|
// the given roots.
|
|
require.NotNil(t, tlsCfg.GetCertificate)
|
|
leaf, err := tlsCfg.GetCertificate(&tls.ClientHelloInfo{})
|
|
require.NoError(t, err)
|
|
cert, err := x509.ParseCertificate(leaf.Certificate[0])
|
|
require.NoError(t, err)
|
|
require.Len(t, cert.URIs, 1)
|
|
require.True(t, strings.HasSuffix(cert.URIs[0].String(), "/svc/web"))
|
|
|
|
// Verify it as a client would
|
|
err = clientSideVerifier(tlsCfg, leaf.Certificate)
|
|
require.NoError(t, err)
|
|
|
|
// Now test that rotating the root updates
|
|
{
|
|
// Setup a new generated CA
|
|
connect.TestCAConfigSet(t, a, nil)
|
|
}
|
|
|
|
// After some time, both root and leaves should be different but both should
|
|
// still be correct.
|
|
oldRootSubjects := getSubjects(tlsCfg.RootCAs)
|
|
oldLeafSerial := cert.SerialNumber
|
|
oldLeafKeyID := cert.SubjectKeyId
|
|
retry.Run(t, func(r *retry.R) {
|
|
updatedCfg := service.ServerTLSConfig()
|
|
|
|
// Wait until roots are different
|
|
rootSubjects := getSubjects(updatedCfg.RootCAs)
|
|
if oldRootSubjects == rootSubjects {
|
|
r.Fatalf("root certificates should have changed, got %s",
|
|
rootSubjects)
|
|
}
|
|
|
|
leaf, err := updatedCfg.GetCertificate(&tls.ClientHelloInfo{})
|
|
r.Check(err)
|
|
cert, err := x509.ParseCertificate(leaf.Certificate[0])
|
|
r.Check(err)
|
|
|
|
if oldLeafSerial.Cmp(cert.SerialNumber) == 0 {
|
|
r.Fatalf("leaf certificate should have changed, got serial %s",
|
|
connect.EncodeSerialNumber(oldLeafSerial))
|
|
}
|
|
if bytes.Equal(oldLeafKeyID, cert.SubjectKeyId) {
|
|
r.Fatalf("leaf should have a different key, got matching SubjectKeyID = %s",
|
|
connect.HexString(oldLeafKeyID))
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestService_HTTPClient(t *testing.T) {
|
|
ca := connect.TestCA(t, nil)
|
|
|
|
s := TestService(t, "web", ca)
|
|
|
|
// Run a test HTTP server
|
|
testSvr := NewTestServer(t, "backend", ca)
|
|
defer testSvr.Close()
|
|
go func() {
|
|
err := testSvr.ServeHTTPS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("Hello, I am Backend"))
|
|
}))
|
|
require.NoError(t, err)
|
|
}()
|
|
<-testSvr.Listening
|
|
|
|
// Still get connection refused some times so retry on those
|
|
retry.Run(t, func(r *retry.R) {
|
|
// Hook the service resolver to avoid needing full agent setup.
|
|
s.httpResolverFromAddr = func(addr string) (Resolver, error) {
|
|
// Require in this goroutine seems to block causing a timeout on the Get.
|
|
//require.Equal(t,"https://backend.service.consul:443", addr)
|
|
return &StaticResolver{
|
|
Addr: testSvr.Addr,
|
|
CertURI: connect.TestSpiffeIDService(r, "backend"),
|
|
}, nil
|
|
}
|
|
|
|
client := s.HTTPClient()
|
|
client.Timeout = 1 * time.Second
|
|
|
|
resp, err := client.Get("https://backend.service.consul/foo")
|
|
r.Check(err)
|
|
defer resp.Body.Close()
|
|
|
|
bodyBytes, err := io.ReadAll(resp.Body)
|
|
r.Check(err)
|
|
|
|
got := string(bodyBytes)
|
|
want := "Hello, I am Backend"
|
|
if got != want {
|
|
r.Fatalf("got %s, want %s", got, want)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestService_HasDefaultHTTPResolverFromAddr(t *testing.T) {
|
|
|
|
client, err := api.NewClient(api.DefaultConfig())
|
|
require.NoError(t, err)
|
|
|
|
s, err := NewService("foo", client)
|
|
require.NoError(t, err)
|
|
|
|
// Sanity check this is actually set in constructor since we always override
|
|
// it in tests. Full tests of the resolver func are in resolver_test.go
|
|
require.NotNil(t, s.httpResolverFromAddr)
|
|
|
|
fn := s.httpResolverFromAddr
|
|
|
|
expected := &ConsulResolver{
|
|
Client: client,
|
|
Namespace: "default",
|
|
Name: "foo",
|
|
Type: ConsulResolverTypeService,
|
|
}
|
|
got, err := fn("foo.service.consul")
|
|
require.NoError(t, err)
|
|
require.Equal(t, expected, got)
|
|
}
|
|
|
|
func getSubjects(cp *x509.CertPool) string {
|
|
subjectsIter := reflect.ValueOf(cp).Elem().FieldByName("byName").MapRange()
|
|
subjects := []string{}
|
|
for subjectsIter.Next() {
|
|
k := subjectsIter.Key()
|
|
subjects = append(subjects, k.String())
|
|
}
|
|
sort.Strings(subjects)
|
|
subjectList := strings.Join(subjects, ",")
|
|
return subjectList
|
|
}
|