mirror of https://github.com/k3s-io/k3s
Allow joining clusters when the server CA is trusted by the OS CA bundle (#2743)
* Add tests to clientaccess/token * Fix issues in clientaccess/token identified by tests * Update tests to close coverage gaps * Remove redundant check turned up by code coverage reports * Add warnings if CA hash will not be validated Signed-off-by: Brad Davidson <brad.davidson@rancher.com>pull/2915/head
parent
6c472b5942
commit
ad5e504cf0
|
@ -13,10 +13,11 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultClientTimeout = 20 * time.Second
|
||||
defaultClientTimeout = 10 * time.Second
|
||||
|
||||
defaultClient = &http.Client{
|
||||
Timeout: defaultClientTimeout,
|
||||
|
@ -32,8 +33,9 @@ var (
|
|||
)
|
||||
|
||||
const (
|
||||
tokenPrefix = "K10"
|
||||
tokenFormat = "%s%s::%s:%s"
|
||||
tokenPrefix = "K10"
|
||||
tokenFormat = "%s%s::%s:%s"
|
||||
caHashLength = sha256.Size * 2
|
||||
)
|
||||
|
||||
type OverrideURLCallback func(config []byte) (*url.URL, error)
|
||||
|
@ -59,16 +61,10 @@ func ParseAndValidateToken(server string, token string) (*Info, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err := info.setServer(server); err != nil {
|
||||
if err := info.setAndValidateServer(server); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if info.caHash != "" {
|
||||
if err := info.validateCAHash(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
|
@ -82,26 +78,24 @@ func ParseAndValidateTokenForUser(server string, token string, username string)
|
|||
|
||||
info.Username = username
|
||||
|
||||
if err := info.setServer(server); err != nil {
|
||||
if err := info.setAndValidateServer(server); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if info.caHash != "" {
|
||||
if err := info.validateCAHash(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// setAndValidateServer updates the remote server's cert info, and validates it against the provided hash
|
||||
func (info *Info) setAndValidateServer(server string) error {
|
||||
if err := info.setServer(server); err != nil {
|
||||
return err
|
||||
}
|
||||
return info.validateCAHash()
|
||||
}
|
||||
|
||||
// validateCACerts returns a boolean indicating whether or not a CA bundle matches the provided hash,
|
||||
// and a string containing the hash of the CA bundle.
|
||||
func validateCACerts(cacerts []byte, hash string) (bool, string) {
|
||||
if len(cacerts) == 0 && hash == "" {
|
||||
return true, ""
|
||||
}
|
||||
|
||||
newHash := hashCA(cacerts)
|
||||
return hash == newHash, newHash
|
||||
}
|
||||
|
@ -126,6 +120,10 @@ func ParseUsernamePassword(token string) (string, string, bool) {
|
|||
func parseToken(token string) (*Info, error) {
|
||||
var info = &Info{}
|
||||
|
||||
if len(token) == 0 {
|
||||
return nil, errors.New("token must not be empty")
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(token, tokenPrefix) {
|
||||
token = fmt.Sprintf(tokenFormat, tokenPrefix, "", "", token)
|
||||
}
|
||||
|
@ -136,13 +134,17 @@ func parseToken(token string) (*Info, error) {
|
|||
parts := strings.SplitN(token, "::", 2)
|
||||
token = parts[0]
|
||||
if len(parts) > 1 {
|
||||
hashLen := len(parts[0])
|
||||
if hashLen > 0 && hashLen != caHashLength {
|
||||
return nil, errors.New("invalid token CA hash length")
|
||||
}
|
||||
info.caHash = parts[0]
|
||||
token = parts[1]
|
||||
}
|
||||
|
||||
parts = strings.SplitN(token, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid token format")
|
||||
if len(parts) != 2 || len(parts[1]) == 0 {
|
||||
return nil, errors.New("invalid token format")
|
||||
}
|
||||
|
||||
info.Username = parts[0]
|
||||
|
@ -212,10 +214,20 @@ func (info *Info) setServer(server string) error {
|
|||
|
||||
// ValidateCAHash validates that info's caHash matches the CACerts hash.
|
||||
func (info *Info) validateCAHash() error {
|
||||
if ok, serverHash := validateCACerts(info.CACerts, info.caHash); !ok {
|
||||
return fmt.Errorf("token CA hash does not match the server CA hash: %s != %s", info.caHash, serverHash)
|
||||
if len(info.caHash) > 0 && len(info.CACerts) == 0 {
|
||||
// Warn if the user provided a CA hash but we're not going to validate because it's already trusted
|
||||
logrus.Warn("Cluster CA certificate is trusted by the host CA bundle. " +
|
||||
"Token CA hash will not be validated.")
|
||||
} else if len(info.caHash) == 0 && len(info.CACerts) > 0 {
|
||||
// Warn if the CA is self-signed but the user didn't provide a hash to validate it against
|
||||
logrus.Warn("Cluster CA certificate is not trusted by the host CA bundle, but the token does not include a CA hash. " +
|
||||
"Use the full token from the server's node-token file to enable Cluster CA validation.")
|
||||
} else if len(info.CACerts) > 0 && len(info.caHash) > 0 {
|
||||
// only verify CA hash if the server cert is not trusted by the OS CA bundle
|
||||
if ok, serverHash := validateCACerts(info.CACerts, info.caHash); !ok {
|
||||
return fmt.Errorf("token CA hash does not match the Cluster CA certificate hash: %s != %s", info.caHash, serverHash)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,401 @@
|
|||
package clientaccess
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rancher/dynamiclistener/cert"
|
||||
"github.com/rancher/dynamiclistener/factory"
|
||||
"github.com/rancher/k3s/pkg/bootstrap"
|
||||
"github.com/rancher/k3s/pkg/daemons/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultUsername = "server"
|
||||
defaultPassword = "token"
|
||||
)
|
||||
|
||||
// TestTrustedCA confirms that tokens are validated when the server uses a cert (self-signed or otherwise)
|
||||
// that is trusted by the OS CA bundle. This test must be run first, since it mucks with the system root certs.
|
||||
func TestTrustedCA(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
server := newTLSServer(t, defaultUsername, defaultPassword, false)
|
||||
defer server.Close()
|
||||
|
||||
testInfo := &Info{
|
||||
CACerts: getServerCA(server),
|
||||
BaseURL: server.URL,
|
||||
Username: defaultUsername,
|
||||
Password: defaultPassword,
|
||||
caHash: hashCA(getServerCA(server)),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
token string
|
||||
expected string
|
||||
}{
|
||||
{defaultPassword, ""},
|
||||
{testInfo.String(), testInfo.Username},
|
||||
}
|
||||
|
||||
// Point OS CA bundle at this test's CA cert to simulate a trusted CA cert.
|
||||
// Note that this only works if the OS CA bundle has not yet been loaded in this process,
|
||||
// as it is cached for the duration of the process lifetime.
|
||||
// Ref: https://github.com/golang/go/issues/41888
|
||||
path := t.TempDir() + "/ca.crt"
|
||||
writeServerCA(server, path)
|
||||
os.Setenv("SSL_CERT_FILE", path)
|
||||
|
||||
for _, testCase := range testCases {
|
||||
info, err := ParseAndValidateToken(server.URL, testCase.token)
|
||||
if assert.NoError(err, testCase) {
|
||||
assert.Nil(info.CACerts, testCase)
|
||||
assert.Equal(testCase.expected, info.Username, testCase.token)
|
||||
}
|
||||
|
||||
info, err = ParseAndValidateTokenForUser(server.URL, testCase.token, "agent")
|
||||
if assert.NoError(err, testCase) {
|
||||
assert.Nil(info.CACerts, testCase)
|
||||
assert.Equal("agent", info.Username, testCase)
|
||||
}
|
||||
}
|
||||
|
||||
// Confirm that the cert is actually trusted by the OS CA bundle by making a request
|
||||
// with empty cert pool
|
||||
testInfo.CACerts = nil
|
||||
res, err := Get("/v1-k3s/server-bootstrap", testInfo)
|
||||
assert.NoError(err)
|
||||
assert.NotEmpty(res)
|
||||
}
|
||||
|
||||
// TestUntrustedCA confirms that tokens are validated when the server uses a self-signed cert
|
||||
// that is NOT trusted by the OS CA bundle.
|
||||
func TestUntrustedCA(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
server := newTLSServer(t, defaultUsername, defaultPassword, false)
|
||||
defer server.Close()
|
||||
|
||||
testInfo := &Info{
|
||||
CACerts: getServerCA(server),
|
||||
BaseURL: server.URL,
|
||||
Username: defaultUsername,
|
||||
Password: defaultPassword,
|
||||
caHash: hashCA(getServerCA(server)),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
token string
|
||||
expected string
|
||||
}{
|
||||
{defaultPassword, ""},
|
||||
{testInfo.String(), testInfo.Username},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
info, err := ParseAndValidateToken(server.URL, testCase.token)
|
||||
if assert.NoError(err, testCase) {
|
||||
assert.Equal(testInfo.CACerts, info.CACerts, testCase)
|
||||
assert.Equal(testCase.expected, info.Username, testCase)
|
||||
}
|
||||
|
||||
info, err = ParseAndValidateTokenForUser(server.URL, testCase.token, "agent")
|
||||
if assert.NoError(err, testCase) {
|
||||
assert.Equal(testInfo.CACerts, info.CACerts, testCase)
|
||||
assert.Equal("agent", info.Username, testCase)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvalidServers tests that invalid server URLs are properly rejected
|
||||
func TestInvalidServers(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
testCases := []struct {
|
||||
server string
|
||||
token string
|
||||
expected string
|
||||
}{
|
||||
{" https://localhost:6443", "token", "Invalid server url, failed to parse: https://localhost:6443: parse \" https://localhost:6443\": first path segment in URL cannot contain colon"},
|
||||
{"http://localhost:6443", "token", "only https:// URLs are supported, invalid scheme: http://localhost:6443"},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
_, err := ParseAndValidateToken(testCase.server, testCase.token)
|
||||
assert.EqualError(err, testCase.expected, testCase)
|
||||
|
||||
_, err = ParseAndValidateTokenForUser(testCase.server, testCase.token, defaultUsername)
|
||||
assert.EqualError(err, testCase.expected, testCase)
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvalidTokens tests that tokens which are empty, invalid, or incorrect are properly rejected
|
||||
func TestInvalidTokens(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
server := newTLSServer(t, defaultUsername, defaultPassword, false)
|
||||
defer server.Close()
|
||||
|
||||
testCases := []struct {
|
||||
server string
|
||||
token string
|
||||
expected string
|
||||
}{
|
||||
{server.URL, "", "token must not be empty"},
|
||||
{server.URL, "K10::", "invalid token format"},
|
||||
{server.URL, "K10::x", "invalid token format"},
|
||||
{server.URL, "K10::x:", "invalid token format"},
|
||||
{server.URL, "K10XX::x:y", "invalid token CA hash length"},
|
||||
{server.URL,
|
||||
"K10XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX::x:y",
|
||||
"token CA hash does not match the Cluster CA certificate hash: XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX != " + hashCA(getServerCA(server))},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
info, err := ParseAndValidateToken(testCase.server, testCase.token)
|
||||
assert.EqualError(err, testCase.expected, testCase)
|
||||
assert.Nil(info, testCase)
|
||||
|
||||
info, err = ParseAndValidateTokenForUser(testCase.server, testCase.token, defaultUsername)
|
||||
assert.EqualError(err, testCase.expected, testCase)
|
||||
assert.Nil(info, testCase)
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvalidCredentials tests that tokens which don't have valid credentials are rejected
|
||||
func TestInvalidCredentials(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
server := newTLSServer(t, defaultUsername, defaultPassword, false)
|
||||
defer server.Close()
|
||||
|
||||
testInfo := &Info{
|
||||
CACerts: getServerCA(server),
|
||||
BaseURL: server.URL,
|
||||
Username: "nobody",
|
||||
Password: "invalid",
|
||||
caHash: hashCA(getServerCA(server)),
|
||||
}
|
||||
|
||||
testCases := []string{
|
||||
testInfo.Password,
|
||||
testInfo.String(),
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
info, err := ParseAndValidateToken(server.URL, testCase)
|
||||
assert.NoError(err, testCase)
|
||||
if assert.NotNil(info) {
|
||||
res, err := Get("/v1-k3s/server-bootstrap", info)
|
||||
assert.Error(err, testCase)
|
||||
assert.Empty(res, testCase)
|
||||
}
|
||||
|
||||
info, err = ParseAndValidateTokenForUser(server.URL, testCase, defaultUsername)
|
||||
assert.NoError(err, testCase)
|
||||
if assert.NotNil(info) {
|
||||
res, err := Get("/v1-k3s/server-bootstrap", info)
|
||||
assert.Error(err, testCase)
|
||||
assert.Empty(res, testCase)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrongCert tests that errors are returned when the server's cert isn't issued by its CA
|
||||
func TestWrongCert(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
server := newTLSServer(t, defaultUsername, defaultPassword, true)
|
||||
defer server.Close()
|
||||
|
||||
info, err := ParseAndValidateToken(server.URL, defaultPassword)
|
||||
assert.Error(err)
|
||||
assert.Nil(info)
|
||||
|
||||
info, err = ParseAndValidateTokenForUser(server.URL, defaultPassword, defaultUsername)
|
||||
assert.Error(err)
|
||||
assert.Nil(info)
|
||||
}
|
||||
|
||||
// TestConnectionFailures tests that connections are timed out properly
|
||||
func TestConnectionFailures(t *testing.T) {
|
||||
testDuration := (defaultClientTimeout * 2) + time.Second
|
||||
assert := assert.New(t)
|
||||
testCases := []struct {
|
||||
server string
|
||||
token string
|
||||
}{
|
||||
{"https://192.0.2.1:6443", "token"}, // RFC 5735 TEST-NET-1 for use in documentation and example code
|
||||
{"https://localhost:1", "token"},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
startTime := time.Now()
|
||||
info, err := ParseAndValidateToken(testCase.server, testCase.token)
|
||||
assert.Error(err, testCase)
|
||||
assert.Nil(info, testCase)
|
||||
assert.WithinDuration(time.Now(), startTime, testDuration, testCase)
|
||||
|
||||
startTime = time.Now()
|
||||
info, err = ParseAndValidateTokenForUser(testCase.server, testCase.token, defaultUsername)
|
||||
assert.Error(err, testCase)
|
||||
assert.Nil(info, testCase)
|
||||
assert.WithinDuration(startTime, time.Now(), testDuration, testCase)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserPass tests that usernames and passwords are parsed or not parsed from token strings
|
||||
func TestUserPass(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
testCases := []struct {
|
||||
token string
|
||||
username string
|
||||
password string
|
||||
expect bool
|
||||
}{
|
||||
{"K10XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX::username:password", "username", "password", true},
|
||||
{"password", "", "password", true},
|
||||
{"K10X::x", "", "", false},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
username, password, ok := ParseUsernamePassword(testCase.token)
|
||||
assert.Equal(testCase.expect, ok, testCase)
|
||||
if ok {
|
||||
assert.Equal(testCase.username, username, testCase)
|
||||
assert.Equal(testCase.password, password, testCase)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseAndGet tests URL handling along some hard-to-reach code paths
|
||||
func TestParseAndGet(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
server := newTLSServer(t, defaultUsername, defaultPassword, false)
|
||||
defer server.Close()
|
||||
|
||||
testCases := []struct {
|
||||
extraBasePre string
|
||||
extraBasePost string
|
||||
path string
|
||||
parseFail bool
|
||||
getFail bool
|
||||
}{
|
||||
{"/", "", "/cacerts", false, false},
|
||||
{"/%2", "", "/cacerts", true, false},
|
||||
{"", "", "/%2", false, true},
|
||||
{"", "/%2", "/cacerts", false, true},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
info, err := ParseAndValidateTokenForUser(server.URL+testCase.extraBasePre, defaultPassword, defaultUsername)
|
||||
// Check for expected error when parsing server + token
|
||||
if testCase.parseFail {
|
||||
assert.Error(err, testCase)
|
||||
} else if assert.NoError(err, testCase) {
|
||||
info.BaseURL = server.URL + testCase.extraBasePost
|
||||
_, err := Get(testCase.path, info)
|
||||
// Check for expected error when making Get request
|
||||
if testCase.getFail {
|
||||
assert.Error(err, testCase)
|
||||
} else {
|
||||
assert.NoError(err, testCase)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newTLSServer returns a HTTPS server that mocks the basic functionality required to validate K3s join tokens.
|
||||
// Each call to this function will generate new CA and server certificates unique to the returned server.
|
||||
func newTLSServer(t *testing.T, username, password string, sendWrongCA bool) *httptest.Server {
|
||||
var server *httptest.Server
|
||||
server = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/v1-k3s/server-bootstrap" {
|
||||
if authUsername, authPassword, ok := r.BasicAuth(); ok != true || authPassword != password || authUsername != username {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
bootstrapData := &config.ControlRuntimeBootstrap{}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := bootstrap.Write(w, bootstrapData); err != nil {
|
||||
t.Errorf("failed to write bootstrap: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/cacerts" {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
if _, err := w.Write(getServerCA(server)); err != nil {
|
||||
t.Errorf("Failed to write cacerts: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
}))
|
||||
|
||||
// Create new CA cert and key
|
||||
caCert, caKey, err := factory.GenCA()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Generate new server cert; reuse the key from the CA
|
||||
cfg := cert.Config{
|
||||
CommonName: "localhost",
|
||||
Organization: []string{"testing"},
|
||||
Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
AltNames: cert.AltNames{
|
||||
DNSNames: []string{"localhost"},
|
||||
IPs: []net.IP{net.IPv4(127, 0, 0, 1)},
|
||||
},
|
||||
}
|
||||
serverCert, err := cert.NewSignedCert(cfg, caKey, caCert, caKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Bind server and CA certs into chain for TLS listener configuration
|
||||
server.TLS = &tls.Config{}
|
||||
server.TLS.Certificates = []tls.Certificate{
|
||||
{Certificate: [][]byte{serverCert.Raw}, Leaf: serverCert, PrivateKey: caKey},
|
||||
{Certificate: [][]byte{caCert.Raw}, Leaf: caCert},
|
||||
}
|
||||
|
||||
if sendWrongCA {
|
||||
// Create new CA cert and key and use that as the CA cert instead of the one that actually signed the server cert
|
||||
badCert, _, err := factory.GenCA()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
server.TLS.Certificates[1].Certificate[0] = badCert.Raw
|
||||
server.TLS.Certificates[1].Leaf = badCert
|
||||
}
|
||||
|
||||
server.StartTLS()
|
||||
return server
|
||||
}
|
||||
|
||||
// getServerCA returns a byte slice containing the PEM encoding of the server's CA certificate
|
||||
func getServerCA(server *httptest.Server) []byte {
|
||||
certLen := len(server.TLS.Certificates)
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: server.TLS.Certificates[certLen-1].Certificate[0]})
|
||||
}
|
||||
|
||||
// writeServerCA writes the PEM-encoded server certificate to a given path
|
||||
func writeServerCA(server *httptest.Server, path string) error {
|
||||
certOut, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer certOut.Close()
|
||||
|
||||
if _, err := certOut.Write(getServerCA(server)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue