Throwing exception on all other JWT types than SignedJWT

pull/1355/head
Tomasz Borowiec 2018-02-07 10:45:28 +01:00
parent c38b9d7a42
commit 37fba622b9
2 changed files with 77 additions and 141 deletions

View File

@ -46,7 +46,6 @@ import org.springframework.security.oauth2.common.exceptions.InvalidClientExcept
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.PlainJWT;
import com.nimbusds.jwt.SignedJWT;
/**
@ -92,64 +91,60 @@ public class JWTBearerAuthenticationProvider implements AuthenticationProvider {
JWT jwt = jwtAuth.getJwt();
JWTClaimsSet jwtClaims = jwt.getJWTClaimsSet();
if (jwt instanceof PlainJWT) {
if (!AuthMethod.NONE.equals(client.getTokenEndpointAuthMethod())) {
throw new AuthenticationServiceException("Client does not support this authentication method.");
}
} else if (jwt instanceof SignedJWT) {
// check the signature with nimbus
SignedJWT jws = (SignedJWT)jwt;
if (!(jwt instanceof SignedJWT)) {
throw new AuthenticationServiceException("Unsupported JWT type: " + jwt.getClass().getName());
}
JWSAlgorithm alg = jws.getHeader().getAlgorithm();
// check the signature with nimbus
SignedJWT jws = (SignedJWT) jwt;
if (client.getTokenEndpointAuthSigningAlg() != null &&
!client.getTokenEndpointAuthSigningAlg().equals(alg)) {
throw new AuthenticationServiceException("Client's registered token endpoint signing algorithm (" + client.getTokenEndpointAuthSigningAlg()
+ ") does not match token's actual algorithm (" + alg.getName() + ")");
JWSAlgorithm alg = jws.getHeader().getAlgorithm();
if (client.getTokenEndpointAuthSigningAlg() != null &&
!client.getTokenEndpointAuthSigningAlg().equals(alg)) {
throw new AuthenticationServiceException("Client's registered token endpoint signing algorithm (" + client.getTokenEndpointAuthSigningAlg()
+ ") does not match token's actual algorithm (" + alg.getName() + ")");
}
if (client.getTokenEndpointAuthMethod() == null ||
client.getTokenEndpointAuthMethod().equals(AuthMethod.NONE) ||
client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_BASIC) ||
client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_POST)) {
// this client doesn't support this type of authentication
throw new AuthenticationServiceException("Client does not support this authentication method.");
} else if ((client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY) &&
(alg.equals(JWSAlgorithm.RS256)
|| alg.equals(JWSAlgorithm.RS384)
|| alg.equals(JWSAlgorithm.RS512)
|| alg.equals(JWSAlgorithm.ES256)
|| alg.equals(JWSAlgorithm.ES384)
|| alg.equals(JWSAlgorithm.ES512)
|| alg.equals(JWSAlgorithm.PS256)
|| alg.equals(JWSAlgorithm.PS384)
|| alg.equals(JWSAlgorithm.PS512)))
|| (client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_JWT) &&
(alg.equals(JWSAlgorithm.HS256)
|| alg.equals(JWSAlgorithm.HS384)
|| alg.equals(JWSAlgorithm.HS512)))) {
// double-check the method is asymmetrical if we're in HEART mode
if (config.isHeartMode() && !client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY)) {
throw new AuthenticationServiceException("[HEART mode] Invalid authentication method");
}
if (client.getTokenEndpointAuthMethod() == null ||
client.getTokenEndpointAuthMethod().equals(AuthMethod.NONE) ||
client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_BASIC) ||
client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_POST)) {
JWTSigningAndValidationService validator = validators.getValidator(client, alg);
// this client doesn't support this type of authentication
throw new AuthenticationServiceException("Client does not support this authentication method.");
if (validator == null) {
throw new AuthenticationServiceException("Unable to create signature validator for client " + client + " and algorithm " + alg);
}
} else if ((client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY) &&
(alg.equals(JWSAlgorithm.RS256)
|| alg.equals(JWSAlgorithm.RS384)
|| alg.equals(JWSAlgorithm.RS512)
|| alg.equals(JWSAlgorithm.ES256)
|| alg.equals(JWSAlgorithm.ES384)
|| alg.equals(JWSAlgorithm.ES512)
|| alg.equals(JWSAlgorithm.PS256)
|| alg.equals(JWSAlgorithm.PS384)
|| alg.equals(JWSAlgorithm.PS512)))
|| (client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_JWT) &&
(alg.equals(JWSAlgorithm.HS256)
|| alg.equals(JWSAlgorithm.HS384)
|| alg.equals(JWSAlgorithm.HS512)))) {
// double-check the method is asymmetrical if we're in HEART mode
if (config.isHeartMode() && !client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY)) {
throw new AuthenticationServiceException("[HEART mode] Invalid authentication method");
}
JWTSigningAndValidationService validator = validators.getValidator(client, alg);
if (validator == null) {
throw new AuthenticationServiceException("Unable to create signature validator for client " + client + " and algorithm " + alg);
}
if (!validator.validateSignature(jws)) {
throw new AuthenticationServiceException("Signature did not validate for presented JWT authentication.");
}
} else {
throw new AuthenticationServiceException("Unable to create signature validator for method " + client.getTokenEndpointAuthMethod() + " and algorithm " + alg);
if (!validator.validateSignature(jws)) {
throw new AuthenticationServiceException("Signature did not validate for presented JWT authentication.");
}
} else {
throw new AuthenticationServiceException("Unsupported JWT type: " + jwt.getClass().getName());
throw new AuthenticationServiceException("Unable to create signature validator for method " + client.getTokenEndpointAuthMethod() + " and algorithm " + alg);
}
// check the issuer

View File

@ -2,8 +2,8 @@ package org.mitre.openid.connect.assertion;
import static org.hamcrest.CoreMatchers.hasItems;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.startsWith;
import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.when;
@ -110,20 +110,23 @@ public class TestJWTBearerAuthenticationProvider {
}
@Test
public void should_throw_AuthenticationServiceException_for_PlainJWT_when_AuthMethod_is_different_than_NONE() {
public void should_throw_AuthenticationServiceException_for_PlainJWT() {
mockPlainJWTAuthAttempt();
List<AuthMethod> unsupportedAuthMethods = Arrays.asList(
null, AuthMethod.PRIVATE_KEY, AuthMethod.PRIVATE_KEY, AuthMethod.SECRET_BASIC, AuthMethod.SECRET_JWT, AuthMethod.SECRET_POST
);
for (AuthMethod authMethod : unsupportedAuthMethods) {
when(client.getTokenEndpointAuthMethod()).thenReturn(authMethod);
Throwable thrown = authenticateAndReturnThrownException();
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), is("Unsupported JWT type: " + PlainJWT.class.getName()));
}
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), is("Client does not support this authentication method."));
}
@Test
public void should_throw_AuthenticationServiceException_for_EncryptedJWT() {
mockEncryptedJWTAuthAttempt();
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), is("Unsupported JWT type: " + EncryptedJWT.class.getName()));
}
@Test
@ -228,21 +231,10 @@ public class TestJWTBearerAuthenticationProvider {
assertThat(thrown.getMessage(), is("Signature did not validate for presented JWT authentication."));
}
@Test
public void should_throw_AuthenticationServiceException_for_EncryptedJWT() {
EncryptedJWT encryptedJWT = createEncryptedJWT();
when(token.getJwt()).thenReturn(encryptedJWT);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), is("Unsupported JWT type: " + EncryptedJWT.class.getName()));
}
@Test
public void should_throw_AuthenticationServiceException_when_null_issuer() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(null).build();
mockPlainJWTAuthAttempt(jwtClaimsSet);
mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException();
@ -253,7 +245,7 @@ public class TestJWTBearerAuthenticationProvider {
@Test
public void should_throw_AuthenticationServiceException_when_not_matching_issuer() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer("not matching").build();
mockPlainJWTAuthAttempt(jwtClaimsSet);
mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException();
@ -264,7 +256,7 @@ public class TestJWTBearerAuthenticationProvider {
@Test
public void should_throw_AuthenticationServiceException_when_null_expiration_time() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(null).build();
mockPlainJWTAuthAttempt(jwtClaimsSet);
mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException();
@ -276,7 +268,7 @@ public class TestJWTBearerAuthenticationProvider {
public void should_throw_AuthenticationServiceException_when_expired_jwt() {
Date expiredDate = new Date(System.currentTimeMillis() - TimeUnit.SECONDS.toMillis(500));
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(expiredDate).build();
mockPlainJWTAuthAttempt(jwtClaimsSet);
mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException();
@ -288,7 +280,7 @@ public class TestJWTBearerAuthenticationProvider {
public void should_throw_AuthenticationServiceException_when_jwt_valid_in_future() {
Date futureDate = new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(500));
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(futureDate).notBeforeTime(futureDate).build();
mockPlainJWTAuthAttempt(jwtClaimsSet);
mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException();
@ -300,7 +292,7 @@ public class TestJWTBearerAuthenticationProvider {
public void should_throw_AuthenticationServiceException_when_jwt_issued_in_future() {
Date futureDate = new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(500));
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(futureDate).issueTime(futureDate).build();
mockPlainJWTAuthAttempt(jwtClaimsSet);
mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException();
@ -311,7 +303,7 @@ public class TestJWTBearerAuthenticationProvider {
@Test
public void should_throw_AuthenticationServiceException_when_unmatching_audience() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(new Date()).audience("invalid").build();
mockPlainJWTAuthAttempt(jwtClaimsSet);
mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException();
@ -320,28 +312,7 @@ public class TestJWTBearerAuthenticationProvider {
}
@Test
public void should_return_valid_token_for_PlainJWT_when_audience_contains_token_endpoint() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder()
.issuer(CLIENT_ID)
.subject(SUBJECT)
.expirationTime(new Date())
.audience(ImmutableList.of("http://issuer.com/token", "invalid"))
.build();
PlainJWT jwt = mockPlainJWTAuthAttempt(jwtClaimsSet);
Authentication authentication = jwtBearerAuthenticationProvider.authenticate(token);
assertThat(authentication, instanceOf(JWTBearerAssertionAuthenticationToken.class));
JWTBearerAssertionAuthenticationToken token = (JWTBearerAssertionAuthenticationToken) authentication;
assertThat(token.getName(), is(SUBJECT));
assertThat(token.getJwt(), is(jwt));
assertThat(token.getAuthorities(), hasItems(authority1, authority2, authority3));
assertThat(token.getAuthorities().size(), is(4));
}
@Test
public void should_return_valid_token_for_SignedJWT_when_audience_contains_token_endpoint() {
public void should_return_valid_token_when_audience_contains_token_endpoint() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder()
.issuer(CLIENT_ID)
.subject(SUBJECT)
@ -362,29 +333,7 @@ public class TestJWTBearerAuthenticationProvider {
}
@Test
public void should_return_valid_token_for_PlainJWT_when_issuer_does_not_end_with_slash_and_audience_contains_token_endpoint() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder()
.issuer(CLIENT_ID)
.subject(SUBJECT)
.expirationTime(new Date())
.audience(ImmutableList.of("http://issuer.com/token"))
.build();
PlainJWT jwt = mockPlainJWTAuthAttempt(jwtClaimsSet);
when(config.getIssuer()).thenReturn("http://issuer.com/");
Authentication authentication = jwtBearerAuthenticationProvider.authenticate(token);
assertThat(authentication, instanceOf(JWTBearerAssertionAuthenticationToken.class));
JWTBearerAssertionAuthenticationToken token = (JWTBearerAssertionAuthenticationToken) authentication;
assertThat(token.getName(), is(SUBJECT));
assertThat(token.getJwt(), is(jwt));
assertThat(token.getAuthorities(), hasItems(authority1, authority2, authority3));
assertThat(token.getAuthorities().size(), is(4));
}
@Test
public void should_return_valid_token_for_SignedJWT_when_issuer_does_not_end_with_slash_and_audience_contains_token_endpoint() {
public void should_return_valid_token_when_issuer_does_not_end_with_slash_and_audience_contains_token_endpoint() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder()
.issuer(CLIENT_ID)
.subject(SUBJECT)
@ -405,14 +354,15 @@ public class TestJWTBearerAuthenticationProvider {
assertThat(token.getAuthorities().size(), is(4));
}
private PlainJWT mockPlainJWTAuthAttempt() {
return mockPlainJWTAuthAttempt(createJwtClaimsSet());
}
private PlainJWT mockPlainJWTAuthAttempt(JWTClaimsSet jwtClaimsSet) {
PlainJWT plainJWT = createPlainJWT(jwtClaimsSet);
private void mockPlainJWTAuthAttempt() {
PlainJWT plainJWT = new PlainJWT(createJwtClaimsSet());
when(token.getJwt()).thenReturn(plainJWT);
return plainJWT;
}
private void mockEncryptedJWTAuthAttempt() {
JWEHeader jweHeader = new JWEHeader.Builder(JWEAlgorithm.A128GCMKW, EncryptionMethod.A256GCM).build();
EncryptedJWT encryptedJWT = new EncryptedJWT(jweHeader, createJwtClaimsSet());
when(token.getJwt()).thenReturn(encryptedJWT);
}
private SignedJWT mockSignedJWTAuthAttempt() {
@ -436,10 +386,6 @@ public class TestJWTBearerAuthenticationProvider {
throw new AssertionError("No exception thrown when expected");
}
private PlainJWT createPlainJWT(JWTClaimsSet jwtClaimsSet) {
return new PlainJWT(jwtClaimsSet);
}
private SignedJWT createSignedJWT() {
return createSignedJWT(JWSAlgorithm.RS256);
}
@ -457,11 +403,6 @@ public class TestJWTBearerAuthenticationProvider {
return new SignedJWT(jwsHeader, jwtClaimsSet);
}
private EncryptedJWT createEncryptedJWT() {
JWEHeader jweHeader = new JWEHeader.Builder(JWEAlgorithm.A128GCMKW, EncryptionMethod.A256GCM).build();
return new EncryptedJWT(jweHeader, createJwtClaimsSet());
}
private JWTClaimsSet createJwtClaimsSet() {
return new JWTClaimsSet.Builder()
.issuer(CLIENT_ID)