mirror of https://github.com/hashicorp/consul
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
303 lines
7.5 KiB
303 lines
7.5 KiB
package connect |
|
|
|
import ( |
|
"bytes" |
|
"context" |
|
"crypto/tls" |
|
"crypto/x509" |
|
"fmt" |
|
"io" |
|
"net/http" |
|
"reflect" |
|
"sort" |
|
"strings" |
|
"testing" |
|
"time" |
|
|
|
"github.com/stretchr/testify/assert" |
|
"github.com/stretchr/testify/require" |
|
|
|
"github.com/hashicorp/consul/agent" |
|
"github.com/hashicorp/consul/agent/connect" |
|
"github.com/hashicorp/consul/api" |
|
"github.com/hashicorp/consul/sdk/testutil/retry" |
|
"github.com/hashicorp/consul/testrpc" |
|
) |
|
|
|
// Assert io.Closer implementation |
|
var _ io.Closer = new(Service) |
|
|
|
func TestService_Name(t *testing.T) { |
|
ca := connect.TestCA(t, nil) |
|
s := TestService(t, "web", ca) |
|
assert.Equal(t, "web", s.Name()) |
|
} |
|
|
|
func TestService_Dial(t *testing.T) { |
|
if testing.Short() { |
|
t.Skip("too slow for testing.Short") |
|
} |
|
|
|
ca := connect.TestCA(t, nil) |
|
|
|
tests := []struct { |
|
name string |
|
accept bool |
|
handshake bool |
|
presentService string |
|
wantErr string |
|
}{ |
|
{ |
|
name: "working", |
|
accept: true, |
|
handshake: true, |
|
presentService: "db", |
|
wantErr: "", |
|
}, |
|
{ |
|
name: "tcp connect fail", |
|
accept: false, |
|
handshake: false, |
|
presentService: "db", |
|
wantErr: "connection refused", |
|
}, |
|
{ |
|
name: "handshake timeout", |
|
accept: true, |
|
handshake: false, |
|
presentService: "db", |
|
wantErr: "i/o timeout", |
|
}, |
|
{ |
|
name: "bad cert", |
|
accept: true, |
|
handshake: true, |
|
presentService: "web", |
|
wantErr: "peer certificate mismatch", |
|
}, |
|
} |
|
for _, tt := range tests { |
|
t.Run(tt.name, func(t *testing.T) { |
|
|
|
s := TestService(t, "web", ca) |
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), |
|
100*time.Millisecond) |
|
defer cancel() |
|
|
|
testSvr := NewTestServer(t, tt.presentService, ca) |
|
testSvr.TimeoutHandshake = !tt.handshake |
|
|
|
if tt.accept { |
|
go func() { |
|
err := testSvr.Serve() |
|
require.NoError(t, err) |
|
}() |
|
<-testSvr.Listening |
|
defer testSvr.Close() |
|
} |
|
|
|
// Always expect to be connecting to a "DB" |
|
resolver := &StaticResolver{ |
|
Addr: testSvr.Addr, |
|
CertURI: connect.TestSpiffeIDService(t, "db"), |
|
} |
|
|
|
// All test runs should complete in under 500ms due to the timeout about. |
|
// Don't wait for whole test run to get stuck. |
|
testTimeout := 500 * time.Millisecond |
|
testTimer := time.AfterFunc(testTimeout, func() { |
|
panic(fmt.Sprintf("test timed out after %s", testTimeout)) |
|
}) |
|
|
|
conn, err := s.Dial(ctx, resolver) |
|
testTimer.Stop() |
|
|
|
if tt.wantErr == "" { |
|
require.NoError(t, err) |
|
require.IsType(t, &tls.Conn{}, conn) |
|
} else { |
|
require.Error(t, err) |
|
require.Contains(t, err.Error(), tt.wantErr) |
|
} |
|
|
|
if err == nil { |
|
conn.Close() |
|
} |
|
}) |
|
} |
|
} |
|
|
|
func TestService_ServerTLSConfig(t *testing.T) { |
|
if testing.Short() { |
|
t.Skip("too slow for testing.Short") |
|
} |
|
|
|
a := agent.StartTestAgent(t, agent.TestAgent{Name: "007", Overrides: ` |
|
connect { |
|
test_ca_leaf_root_change_spread = "1ns" |
|
} |
|
`}) |
|
defer a.Shutdown() |
|
testrpc.WaitForTestAgent(t, a.RPC, "dc1") |
|
client := a.Client() |
|
agent := client.Agent() |
|
|
|
// NewTestAgent setup a CA already by default |
|
|
|
// Register a local agent service |
|
reg := &api.AgentServiceRegistration{ |
|
Name: "web", |
|
Port: 8080, |
|
} |
|
err := agent.ServiceRegister(reg) |
|
require.NoError(t, err) |
|
|
|
// Now we should be able to create a service that will eventually get it's TLS |
|
// all by itself! |
|
service, err := NewService("web", client) |
|
require.NoError(t, err) |
|
|
|
// Wait for it to be ready |
|
select { |
|
case <-service.ReadyWait(): |
|
// continue with test case below |
|
case <-time.After(1 * time.Second): |
|
t.Fatalf("timeout waiting for Service.ReadyWait after 1s") |
|
} |
|
|
|
tlsCfg := service.ServerTLSConfig() |
|
|
|
// Sanity check it has a leaf with the right ServiceID and that validates with |
|
// the given roots. |
|
require.NotNil(t, tlsCfg.GetCertificate) |
|
leaf, err := tlsCfg.GetCertificate(&tls.ClientHelloInfo{}) |
|
require.NoError(t, err) |
|
cert, err := x509.ParseCertificate(leaf.Certificate[0]) |
|
require.NoError(t, err) |
|
require.Len(t, cert.URIs, 1) |
|
require.True(t, strings.HasSuffix(cert.URIs[0].String(), "/svc/web")) |
|
|
|
// Verify it as a client would |
|
err = clientSideVerifier(tlsCfg, leaf.Certificate) |
|
require.NoError(t, err) |
|
|
|
// Now test that rotating the root updates |
|
{ |
|
// Setup a new generated CA |
|
connect.TestCAConfigSet(t, a, nil) |
|
} |
|
|
|
// After some time, both root and leaves should be different but both should |
|
// still be correct. |
|
oldRootSubjects := getSubjects(tlsCfg.RootCAs) |
|
oldLeafSerial := cert.SerialNumber |
|
oldLeafKeyID := cert.SubjectKeyId |
|
retry.Run(t, func(r *retry.R) { |
|
updatedCfg := service.ServerTLSConfig() |
|
|
|
// Wait until roots are different |
|
rootSubjects := getSubjects(updatedCfg.RootCAs) |
|
if oldRootSubjects == rootSubjects { |
|
r.Fatalf("root certificates should have changed, got %s", |
|
rootSubjects) |
|
} |
|
|
|
leaf, err := updatedCfg.GetCertificate(&tls.ClientHelloInfo{}) |
|
r.Check(err) |
|
cert, err := x509.ParseCertificate(leaf.Certificate[0]) |
|
r.Check(err) |
|
|
|
if oldLeafSerial.Cmp(cert.SerialNumber) == 0 { |
|
r.Fatalf("leaf certificate should have changed, got serial %s", |
|
connect.EncodeSerialNumber(oldLeafSerial)) |
|
} |
|
if bytes.Equal(oldLeafKeyID, cert.SubjectKeyId) { |
|
r.Fatalf("leaf should have a different key, got matching SubjectKeyID = %s", |
|
connect.HexString(oldLeafKeyID)) |
|
} |
|
}) |
|
} |
|
|
|
func TestService_HTTPClient(t *testing.T) { |
|
ca := connect.TestCA(t, nil) |
|
|
|
s := TestService(t, "web", ca) |
|
|
|
// Run a test HTTP server |
|
testSvr := NewTestServer(t, "backend", ca) |
|
defer testSvr.Close() |
|
go func() { |
|
err := testSvr.ServeHTTPS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
w.Write([]byte("Hello, I am Backend")) |
|
})) |
|
require.NoError(t, err) |
|
}() |
|
<-testSvr.Listening |
|
|
|
// Still get connection refused some times so retry on those |
|
retry.Run(t, func(r *retry.R) { |
|
// Hook the service resolver to avoid needing full agent setup. |
|
s.httpResolverFromAddr = func(addr string) (Resolver, error) { |
|
// Require in this goroutine seems to block causing a timeout on the Get. |
|
//require.Equal(t,"https://backend.service.consul:443", addr) |
|
return &StaticResolver{ |
|
Addr: testSvr.Addr, |
|
CertURI: connect.TestSpiffeIDService(t, "backend"), |
|
}, nil |
|
} |
|
|
|
client := s.HTTPClient() |
|
client.Timeout = 1 * time.Second |
|
|
|
resp, err := client.Get("https://backend.service.consul/foo") |
|
r.Check(err) |
|
defer resp.Body.Close() |
|
|
|
bodyBytes, err := io.ReadAll(resp.Body) |
|
r.Check(err) |
|
|
|
got := string(bodyBytes) |
|
want := "Hello, I am Backend" |
|
if got != want { |
|
r.Fatalf("got %s, want %s", got, want) |
|
} |
|
}) |
|
} |
|
|
|
func TestService_HasDefaultHTTPResolverFromAddr(t *testing.T) { |
|
|
|
client, err := api.NewClient(api.DefaultConfig()) |
|
require.NoError(t, err) |
|
|
|
s, err := NewService("foo", client) |
|
require.NoError(t, err) |
|
|
|
// Sanity check this is actually set in constructor since we always override |
|
// it in tests. Full tests of the resolver func are in resolver_test.go |
|
require.NotNil(t, s.httpResolverFromAddr) |
|
|
|
fn := s.httpResolverFromAddr |
|
|
|
expected := &ConsulResolver{ |
|
Client: client, |
|
Namespace: "default", |
|
Name: "foo", |
|
Type: ConsulResolverTypeService, |
|
} |
|
got, err := fn("foo.service.consul") |
|
require.NoError(t, err) |
|
require.Equal(t, expected, got) |
|
} |
|
|
|
func getSubjects(cp *x509.CertPool) string { |
|
subjectsIter := reflect.ValueOf(cp).Elem().FieldByName("byName").MapRange() |
|
subjects := []string{} |
|
for subjectsIter.Next() { |
|
k := subjectsIter.Key() |
|
subjects = append(subjects, k.String()) |
|
} |
|
sort.Strings(subjects) |
|
subjectList := strings.Join(subjects, ",") |
|
return subjectList |
|
}
|
|
|