Merge pull request #1355 from ocadotechnology/jwt_fix

Throwing exception on all other JWT types than SignedJWT
pull/1357/head
Justin Richer 2018-02-07 08:52:59 -05:00 committed by GitHub
commit 7f956a5854
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 461 additions and 44 deletions

View File

@ -91,57 +91,60 @@ public class JWTBearerAuthenticationProvider implements AuthenticationProvider {
JWT jwt = jwtAuth.getJwt();
JWTClaimsSet jwtClaims = jwt.getJWTClaimsSet();
if (!(jwt instanceof SignedJWT)) {
throw new AuthenticationServiceException("Unsupported JWT type: " + jwt.getClass().getName());
}
// check the signature with nimbus
if (jwt instanceof SignedJWT) {
SignedJWT jws = (SignedJWT)jwt;
SignedJWT jws = (SignedJWT) jwt;
JWSAlgorithm alg = jws.getHeader().getAlgorithm();
JWSAlgorithm alg = jws.getHeader().getAlgorithm();
if (client.getTokenEndpointAuthSigningAlg() != null &&
!client.getTokenEndpointAuthSigningAlg().equals(alg)) {
throw new InvalidClientException("Client's registered request object signing algorithm (" + client.getRequestObjectSigningAlg() + ") does not match request object's actual algorithm (" + alg.getName() + ")");
if (client.getTokenEndpointAuthSigningAlg() != null &&
!client.getTokenEndpointAuthSigningAlg().equals(alg)) {
throw new AuthenticationServiceException("Client's registered token endpoint signing algorithm (" + client.getTokenEndpointAuthSigningAlg()
+ ") does not match token's actual algorithm (" + alg.getName() + ")");
}
if (client.getTokenEndpointAuthMethod() == null ||
client.getTokenEndpointAuthMethod().equals(AuthMethod.NONE) ||
client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_BASIC) ||
client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_POST)) {
// this client doesn't support this type of authentication
throw new AuthenticationServiceException("Client does not support this authentication method.");
} else if ((client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY) &&
(alg.equals(JWSAlgorithm.RS256)
|| alg.equals(JWSAlgorithm.RS384)
|| alg.equals(JWSAlgorithm.RS512)
|| alg.equals(JWSAlgorithm.ES256)
|| alg.equals(JWSAlgorithm.ES384)
|| alg.equals(JWSAlgorithm.ES512)
|| alg.equals(JWSAlgorithm.PS256)
|| alg.equals(JWSAlgorithm.PS384)
|| alg.equals(JWSAlgorithm.PS512)))
|| (client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_JWT) &&
(alg.equals(JWSAlgorithm.HS256)
|| alg.equals(JWSAlgorithm.HS384)
|| alg.equals(JWSAlgorithm.HS512)))) {
// double-check the method is asymmetrical if we're in HEART mode
if (config.isHeartMode() && !client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY)) {
throw new AuthenticationServiceException("[HEART mode] Invalid authentication method");
}
if (client.getTokenEndpointAuthMethod() == null ||
client.getTokenEndpointAuthMethod().equals(AuthMethod.NONE) ||
client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_BASIC) ||
client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_POST)) {
JWTSigningAndValidationService validator = validators.getValidator(client, alg);
// this client doesn't support this type of authentication
throw new AuthenticationServiceException("Client does not support this authentication method.");
} else if ((client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY) &&
(alg.equals(JWSAlgorithm.RS256)
|| alg.equals(JWSAlgorithm.RS384)
|| alg.equals(JWSAlgorithm.RS512)
|| alg.equals(JWSAlgorithm.ES256)
|| alg.equals(JWSAlgorithm.ES384)
|| alg.equals(JWSAlgorithm.ES512)
|| alg.equals(JWSAlgorithm.PS256)
|| alg.equals(JWSAlgorithm.PS384)
|| alg.equals(JWSAlgorithm.PS512)))
|| (client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_JWT) &&
(alg.equals(JWSAlgorithm.HS256)
|| alg.equals(JWSAlgorithm.HS384)
|| alg.equals(JWSAlgorithm.HS512)))) {
// double-check the method is asymmetrical if we're in HEART mode
if (config.isHeartMode() && !client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY)) {
throw new AuthenticationServiceException("[HEART mode] Invalid authentication method");
}
JWTSigningAndValidationService validator = validators.getValidator(client, alg);
if (validator == null) {
throw new AuthenticationServiceException("Unable to create signature validator for client " + client + " and algorithm " + alg);
}
if (!validator.validateSignature(jws)) {
throw new AuthenticationServiceException("Signature did not validate for presented JWT authentication.");
}
} else {
throw new AuthenticationServiceException("Unable to create signature validator for method " + client.getTokenEndpointAuthMethod() + " and algorithm " + alg);
if (validator == null) {
throw new AuthenticationServiceException("Unable to create signature validator for client " + client + " and algorithm " + alg);
}
if (!validator.validateSignature(jws)) {
throw new AuthenticationServiceException("Signature did not validate for presented JWT authentication.");
}
} else {
throw new AuthenticationServiceException("Unable to create signature validator for method " + client.getTokenEndpointAuthMethod() + " and algorithm " + alg);
}
// check the issuer

View File

@ -0,0 +1,414 @@
package org.mitre.openid.connect.assertion;
import static org.hamcrest.CoreMatchers.hasItems;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.startsWith;
import static org.junit.Assert.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.when;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
import org.mitre.jwt.signer.service.impl.ClientKeyCacheService;
import org.mitre.oauth2.model.ClientDetailsEntity;
import org.mitre.oauth2.model.ClientDetailsEntity.AuthMethod;
import org.mitre.oauth2.service.ClientDetailsEntityService;
import org.mitre.openid.connect.config.ConfigurationPropertiesBean;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jose.JWEHeader;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jwt.EncryptedJWT;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.PlainJWT;
import com.nimbusds.jwt.SignedJWT;
@RunWith(MockitoJUnitRunner.class)
public class TestJWTBearerAuthenticationProvider {
private static final String CLIENT_ID = "client";
private static final String SUBJECT = "subject";
@Mock
private ClientKeyCacheService validators;
@Mock
private ClientDetailsEntityService clientService;
@Mock
private ConfigurationPropertiesBean config;
@InjectMocks
private JWTBearerAuthenticationProvider jwtBearerAuthenticationProvider;
@Mock
private JWTBearerAssertionAuthenticationToken token;
@Mock
private ClientDetailsEntity client;
@Mock
private JWTSigningAndValidationService validator;
private GrantedAuthority authority1 = new SimpleGrantedAuthority("1");
private GrantedAuthority authority2 = new SimpleGrantedAuthority("2");
private GrantedAuthority authority3 = new SimpleGrantedAuthority("3");
@Before
public void setup() {
when(clientService.loadClientByClientId(CLIENT_ID)).thenReturn(client);
when(token.getName()).thenReturn(CLIENT_ID);
when(client.getClientId()).thenReturn(CLIENT_ID);
when(client.getTokenEndpointAuthMethod()).thenReturn(AuthMethod.NONE);
when(client.getAuthorities()).thenReturn(ImmutableSet.of(authority1, authority2, authority3));
when(validators.getValidator(client, JWSAlgorithm.RS256)).thenReturn(validator);
when(validator.validateSignature(any(SignedJWT.class))).thenReturn(true);
when(config.getIssuer()).thenReturn("http://issuer.com/");
}
@Test
public void should_not_support_UsernamePasswordAuthenticationToken() {
assertThat(jwtBearerAuthenticationProvider.supports(UsernamePasswordAuthenticationToken.class), is(false));
}
@Test
public void should_support_JWTBearerAssertionAuthenticationToken() {
assertThat(jwtBearerAuthenticationProvider.supports(JWTBearerAssertionAuthenticationToken.class), is(true));
}
@Test
public void should_throw_UsernameNotFoundException_when_clientService_throws_InvalidClientException() {
when(clientService.loadClientByClientId(CLIENT_ID)).thenThrow(new InvalidClientException("invalid client"));
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(UsernameNotFoundException.class));
assertThat(thrown.getMessage(), is("Could not find client: " + CLIENT_ID));
}
@Test
public void should_throw_AuthenticationServiceException_for_PlainJWT() {
mockPlainJWTAuthAttempt();
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), is("Unsupported JWT type: " + PlainJWT.class.getName()));
}
@Test
public void should_throw_AuthenticationServiceException_for_EncryptedJWT() {
mockEncryptedJWTAuthAttempt();
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), is("Unsupported JWT type: " + EncryptedJWT.class.getName()));
}
@Test
public void should_throw_AuthenticationServiceException_for_SignedJWT_when_signing_algorithms_do_not_match() {
when(client.getTokenEndpointAuthSigningAlg()).thenReturn(JWSAlgorithm.RS256);
SignedJWT signedJWT = createSignedJWT(JWSAlgorithm.ES384);
when(token.getJwt()).thenReturn(signedJWT);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), is("Client's registered token endpoint signing algorithm (RS256) does not match token's actual algorithm (ES384)"));
}
@Test
public void should_throw_AuthenticationServiceException_for_SignedJWT_when_unsupported_authentication_method_for_SignedJWT() {
List<AuthMethod> unsupportedAuthMethods =
Arrays.asList(null, AuthMethod.NONE, AuthMethod.SECRET_BASIC, AuthMethod.SECRET_POST);
for (AuthMethod unsupportedAuthMethod : unsupportedAuthMethods) {
SignedJWT signedJWT = createSignedJWT();
when(token.getJwt()).thenReturn(signedJWT);
when(client.getTokenEndpointAuthMethod()).thenReturn(unsupportedAuthMethod);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), is("Client does not support this authentication method."));
}
}
@Test
public void should_throw_AuthenticationServiceException_for_SignedJWT_when_invalid_algorithm_for_PRIVATE_KEY_auth_method() {
List<JWSAlgorithm> invalidAlgorithms = Arrays.asList(JWSAlgorithm.HS256, JWSAlgorithm.HS384, JWSAlgorithm.HS512);
for (JWSAlgorithm algorithm : invalidAlgorithms) {
SignedJWT signedJWT = createSignedJWT(algorithm);
when(token.getJwt()).thenReturn(signedJWT);
when(client.getTokenEndpointAuthMethod()).thenReturn(AuthMethod.PRIVATE_KEY);
when(client.getTokenEndpointAuthSigningAlg()).thenReturn(algorithm);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), startsWith("Unable to create signature validator for method"));
}
}
@Test
public void should_throw_AuthenticationServiceException_for_SignedJWT_when_invalid_algorithm_for_SECRET_JWT_auth_method() {
List<JWSAlgorithm> invalidAlgorithms = Arrays.asList(
JWSAlgorithm.RS256, JWSAlgorithm.RS384, JWSAlgorithm.RS512,
JWSAlgorithm.ES256, JWSAlgorithm.ES384, JWSAlgorithm.ES512,
JWSAlgorithm.PS256, JWSAlgorithm.PS384, JWSAlgorithm.PS512);
for (JWSAlgorithm algorithm : invalidAlgorithms) {
SignedJWT signedJWT = createSignedJWT(algorithm);
when(token.getJwt()).thenReturn(signedJWT);
when(client.getTokenEndpointAuthMethod()).thenReturn(AuthMethod.SECRET_JWT);
when(client.getTokenEndpointAuthSigningAlg()).thenReturn(algorithm);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), startsWith("Unable to create signature validator for method"));
}
}
@Test
public void should_throw_AuthenticationServiceException_for_SignedJWT_when_in_heart_mode_and_auth_method_is_not_PRIVATE_KEY() {
SignedJWT signedJWT = createSignedJWT(JWSAlgorithm.HS256);
when(token.getJwt()).thenReturn(signedJWT);
when(client.getTokenEndpointAuthSigningAlg()).thenReturn(JWSAlgorithm.HS256);
when(config.isHeartMode()).thenReturn(true);
when(client.getTokenEndpointAuthMethod()).thenReturn(AuthMethod.SECRET_JWT);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), is("[HEART mode] Invalid authentication method"));
}
@Test
public void should_throw_AuthenticationServiceException_for_SignedJWT_when_null_validator() {
mockSignedJWTAuthAttempt();
when(validators.getValidator(any(ClientDetailsEntity.class), any(JWSAlgorithm.class))).thenReturn(null);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), startsWith("Unable to create signature validator for client"));
}
@Test
public void should_throw_AuthenticationServiceException_for_SignedJWT_when_invalid_signature() {
SignedJWT signedJWT = mockSignedJWTAuthAttempt();
when(validator.validateSignature(signedJWT)).thenReturn(false);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), is("Signature did not validate for presented JWT authentication."));
}
@Test
public void should_throw_AuthenticationServiceException_when_null_issuer() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(null).build();
mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), is("Assertion Token Issuer is null"));
}
@Test
public void should_throw_AuthenticationServiceException_when_not_matching_issuer() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer("not matching").build();
mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), startsWith("Issuers do not match"));
}
@Test
public void should_throw_AuthenticationServiceException_when_null_expiration_time() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(null).build();
mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), is("Assertion Token does not have required expiration claim"));
}
@Test
public void should_throw_AuthenticationServiceException_when_expired_jwt() {
Date expiredDate = new Date(System.currentTimeMillis() - TimeUnit.SECONDS.toMillis(500));
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(expiredDate).build();
mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), startsWith("Assertion Token is expired"));
}
@Test
public void should_throw_AuthenticationServiceException_when_jwt_valid_in_future() {
Date futureDate = new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(500));
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(futureDate).notBeforeTime(futureDate).build();
mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), startsWith("Assertion Token not valid until"));
}
@Test
public void should_throw_AuthenticationServiceException_when_jwt_issued_in_future() {
Date futureDate = new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(500));
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(futureDate).issueTime(futureDate).build();
mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), startsWith("Assertion Token was issued in the future"));
}
@Test
public void should_throw_AuthenticationServiceException_when_unmatching_audience() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(new Date()).audience("invalid").build();
mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), startsWith("Audience does not match"));
}
@Test
public void should_return_valid_token_when_audience_contains_token_endpoint() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder()
.issuer(CLIENT_ID)
.subject(SUBJECT)
.expirationTime(new Date())
.audience(ImmutableList.of("http://issuer.com/token", "invalid"))
.build();
JWT jwt = mockSignedJWTAuthAttempt(jwtClaimsSet);
Authentication authentication = jwtBearerAuthenticationProvider.authenticate(token);
assertThat(authentication, instanceOf(JWTBearerAssertionAuthenticationToken.class));
JWTBearerAssertionAuthenticationToken token = (JWTBearerAssertionAuthenticationToken) authentication;
assertThat(token.getName(), is(SUBJECT));
assertThat(token.getJwt(), is(jwt));
assertThat(token.getAuthorities(), hasItems(authority1, authority2, authority3));
assertThat(token.getAuthorities().size(), is(4));
}
@Test
public void should_return_valid_token_when_issuer_does_not_end_with_slash_and_audience_contains_token_endpoint() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder()
.issuer(CLIENT_ID)
.subject(SUBJECT)
.expirationTime(new Date())
.audience(ImmutableList.of("http://issuer.com/token"))
.build();
JWT jwt = mockSignedJWTAuthAttempt(jwtClaimsSet);
when(config.getIssuer()).thenReturn("http://issuer.com/");
Authentication authentication = jwtBearerAuthenticationProvider.authenticate(token);
assertThat(authentication, instanceOf(JWTBearerAssertionAuthenticationToken.class));
JWTBearerAssertionAuthenticationToken token = (JWTBearerAssertionAuthenticationToken) authentication;
assertThat(token.getName(), is(SUBJECT));
assertThat(token.getJwt(), is(jwt));
assertThat(token.getAuthorities(), hasItems(authority1, authority2, authority3));
assertThat(token.getAuthorities().size(), is(4));
}
private void mockPlainJWTAuthAttempt() {
PlainJWT plainJWT = new PlainJWT(createJwtClaimsSet());
when(token.getJwt()).thenReturn(plainJWT);
}
private void mockEncryptedJWTAuthAttempt() {
JWEHeader jweHeader = new JWEHeader.Builder(JWEAlgorithm.A128GCMKW, EncryptionMethod.A256GCM).build();
EncryptedJWT encryptedJWT = new EncryptedJWT(jweHeader, createJwtClaimsSet());
when(token.getJwt()).thenReturn(encryptedJWT);
}
private SignedJWT mockSignedJWTAuthAttempt() {
return mockSignedJWTAuthAttempt(createJwtClaimsSet());
}
private SignedJWT mockSignedJWTAuthAttempt(JWTClaimsSet jwtClaimsSet) {
SignedJWT signedJWT = createSignedJWT(JWSAlgorithm.RS256, jwtClaimsSet);
when(token.getJwt()).thenReturn(signedJWT);
when(client.getTokenEndpointAuthMethod()).thenReturn(AuthMethod.PRIVATE_KEY);
when(client.getTokenEndpointAuthSigningAlg()).thenReturn(JWSAlgorithm.RS256);
return signedJWT;
}
private Throwable authenticateAndReturnThrownException() {
try {
jwtBearerAuthenticationProvider.authenticate(token);
} catch (Throwable throwable) {
return throwable;
}
throw new AssertionError("No exception thrown when expected");
}
private SignedJWT createSignedJWT() {
return createSignedJWT(JWSAlgorithm.RS256);
}
private SignedJWT createSignedJWT(JWSAlgorithm jwsAlgorithm) {
JWSHeader jwsHeader = new JWSHeader.Builder(jwsAlgorithm).build();
JWTClaimsSet claims = createJwtClaimsSet();
return new SignedJWT(jwsHeader, claims);
}
private SignedJWT createSignedJWT(JWSAlgorithm jwsAlgorithm, JWTClaimsSet jwtClaimsSet) {
JWSHeader jwsHeader = new JWSHeader.Builder(jwsAlgorithm).build();
return new SignedJWT(jwsHeader, jwtClaimsSet);
}
private JWTClaimsSet createJwtClaimsSet() {
return new JWTClaimsSet.Builder()
.issuer(CLIENT_ID)
.expirationTime(new Date())
.audience("http://issuer.com/")
.build();
}
}