From c38b9d7a42c03b76cc3331d58a9298b47efeb547 Mon Sep 17 00:00:00 2001 From: Tomasz Borowiec Date: Mon, 5 Feb 2018 14:19:36 +0100 Subject: [PATCH 1/2] added PlainJWT and EncryptedJWT support + tests --- .../JWTBearerAuthenticationProvider.java | 14 +- .../TestJWTBearerAuthenticationProvider.java | 473 ++++++++++++++++++ 2 files changed, 484 insertions(+), 3 deletions(-) create mode 100644 openid-connect-server/src/test/java/org/mitre/openid/connect/assertion/TestJWTBearerAuthenticationProvider.java diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/assertion/JWTBearerAuthenticationProvider.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/assertion/JWTBearerAuthenticationProvider.java index a5d7d3f34..928e12076 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/assertion/JWTBearerAuthenticationProvider.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/assertion/JWTBearerAuthenticationProvider.java @@ -46,6 +46,7 @@ import org.springframework.security.oauth2.common.exceptions.InvalidClientExcept import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.PlainJWT; import com.nimbusds.jwt.SignedJWT; /** @@ -91,15 +92,20 @@ public class JWTBearerAuthenticationProvider implements AuthenticationProvider { JWT jwt = jwtAuth.getJwt(); JWTClaimsSet jwtClaims = jwt.getJWTClaimsSet(); - // check the signature with nimbus - if (jwt instanceof SignedJWT) { + if (jwt instanceof PlainJWT) { + if (!AuthMethod.NONE.equals(client.getTokenEndpointAuthMethod())) { + throw new AuthenticationServiceException("Client does not support this authentication method."); + } + } else if (jwt instanceof SignedJWT) { + // check the signature with nimbus SignedJWT jws = (SignedJWT)jwt; JWSAlgorithm alg = jws.getHeader().getAlgorithm(); if (client.getTokenEndpointAuthSigningAlg() != null && !client.getTokenEndpointAuthSigningAlg().equals(alg)) { - throw new InvalidClientException("Client's registered request object signing algorithm (" + client.getRequestObjectSigningAlg() + ") does not match request object's actual algorithm (" + alg.getName() + ")"); + throw new AuthenticationServiceException("Client's registered token endpoint signing algorithm (" + client.getTokenEndpointAuthSigningAlg() + + ") does not match token's actual algorithm (" + alg.getName() + ")"); } if (client.getTokenEndpointAuthMethod() == null || @@ -142,6 +148,8 @@ public class JWTBearerAuthenticationProvider implements AuthenticationProvider { } else { throw new AuthenticationServiceException("Unable to create signature validator for method " + client.getTokenEndpointAuthMethod() + " and algorithm " + alg); } + } else { + throw new AuthenticationServiceException("Unsupported JWT type: " + jwt.getClass().getName()); } // check the issuer diff --git a/openid-connect-server/src/test/java/org/mitre/openid/connect/assertion/TestJWTBearerAuthenticationProvider.java b/openid-connect-server/src/test/java/org/mitre/openid/connect/assertion/TestJWTBearerAuthenticationProvider.java new file mode 100644 index 000000000..1a2eb0845 --- /dev/null +++ b/openid-connect-server/src/test/java/org/mitre/openid/connect/assertion/TestJWTBearerAuthenticationProvider.java @@ -0,0 +1,473 @@ +package org.mitre.openid.connect.assertion; + +import static org.hamcrest.CoreMatchers.hasItems; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.startsWith; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Date; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mitre.jwt.signer.service.JWTSigningAndValidationService; +import org.mitre.jwt.signer.service.impl.ClientKeyCacheService; +import org.mitre.oauth2.model.ClientDetailsEntity; +import org.mitre.oauth2.model.ClientDetailsEntity.AuthMethod; +import org.mitre.oauth2.service.ClientDetailsEntityService; +import org.mitre.openid.connect.config.ConfigurationPropertiesBean; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; +import org.springframework.security.authentication.AuthenticationServiceException; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.security.core.userdetails.UsernameNotFoundException; +import org.springframework.security.oauth2.common.exceptions.InvalidClientException; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.nimbusds.jose.EncryptionMethod; +import com.nimbusds.jose.JWEAlgorithm; +import com.nimbusds.jose.JWEHeader; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.JWSHeader; +import com.nimbusds.jwt.EncryptedJWT; +import com.nimbusds.jwt.JWT; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.PlainJWT; +import com.nimbusds.jwt.SignedJWT; + +@RunWith(MockitoJUnitRunner.class) +public class TestJWTBearerAuthenticationProvider { + + private static final String CLIENT_ID = "client"; + private static final String SUBJECT = "subject"; + + @Mock + private ClientKeyCacheService validators; + @Mock + private ClientDetailsEntityService clientService; + @Mock + private ConfigurationPropertiesBean config; + + @InjectMocks + private JWTBearerAuthenticationProvider jwtBearerAuthenticationProvider; + + @Mock + private JWTBearerAssertionAuthenticationToken token; + @Mock + private ClientDetailsEntity client; + @Mock + private JWTSigningAndValidationService validator; + + private GrantedAuthority authority1 = new SimpleGrantedAuthority("1"); + private GrantedAuthority authority2 = new SimpleGrantedAuthority("2"); + private GrantedAuthority authority3 = new SimpleGrantedAuthority("3"); + + @Before + public void setup() { + when(clientService.loadClientByClientId(CLIENT_ID)).thenReturn(client); + + when(token.getName()).thenReturn(CLIENT_ID); + + when(client.getClientId()).thenReturn(CLIENT_ID); + when(client.getTokenEndpointAuthMethod()).thenReturn(AuthMethod.NONE); + when(client.getAuthorities()).thenReturn(ImmutableSet.of(authority1, authority2, authority3)); + + when(validators.getValidator(client, JWSAlgorithm.RS256)).thenReturn(validator); + when(validator.validateSignature(any(SignedJWT.class))).thenReturn(true); + + when(config.getIssuer()).thenReturn("http://issuer.com/"); + } + + @Test + public void should_not_support_UsernamePasswordAuthenticationToken() { + assertThat(jwtBearerAuthenticationProvider.supports(UsernamePasswordAuthenticationToken.class), is(false)); + } + + @Test + public void should_support_JWTBearerAssertionAuthenticationToken() { + assertThat(jwtBearerAuthenticationProvider.supports(JWTBearerAssertionAuthenticationToken.class), is(true)); + } + + @Test + public void should_throw_UsernameNotFoundException_when_clientService_throws_InvalidClientException() { + when(clientService.loadClientByClientId(CLIENT_ID)).thenThrow(new InvalidClientException("invalid client")); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(UsernameNotFoundException.class)); + assertThat(thrown.getMessage(), is("Could not find client: " + CLIENT_ID)); + } + + @Test + public void should_throw_AuthenticationServiceException_for_PlainJWT_when_AuthMethod_is_different_than_NONE() { + mockPlainJWTAuthAttempt(); + List unsupportedAuthMethods = Arrays.asList( + null, AuthMethod.PRIVATE_KEY, AuthMethod.PRIVATE_KEY, AuthMethod.SECRET_BASIC, AuthMethod.SECRET_JWT, AuthMethod.SECRET_POST + ); + + for (AuthMethod authMethod : unsupportedAuthMethods) { + when(client.getTokenEndpointAuthMethod()).thenReturn(authMethod); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), is("Client does not support this authentication method.")); + } + } + + @Test + public void should_throw_AuthenticationServiceException_for_SignedJWT_when_signing_algorithms_do_not_match() { + when(client.getTokenEndpointAuthSigningAlg()).thenReturn(JWSAlgorithm.RS256); + SignedJWT signedJWT = createSignedJWT(JWSAlgorithm.ES384); + when(token.getJwt()).thenReturn(signedJWT); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), is("Client's registered token endpoint signing algorithm (RS256) does not match token's actual algorithm (ES384)")); + } + + @Test + public void should_throw_AuthenticationServiceException_for_SignedJWT_when_unsupported_authentication_method_for_SignedJWT() { + List unsupportedAuthMethods = + Arrays.asList(null, AuthMethod.NONE, AuthMethod.SECRET_BASIC, AuthMethod.SECRET_POST); + + for (AuthMethod unsupportedAuthMethod : unsupportedAuthMethods) { + SignedJWT signedJWT = createSignedJWT(); + when(token.getJwt()).thenReturn(signedJWT); + when(client.getTokenEndpointAuthMethod()).thenReturn(unsupportedAuthMethod); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), is("Client does not support this authentication method.")); + } + } + + @Test + public void should_throw_AuthenticationServiceException_for_SignedJWT_when_invalid_algorithm_for_PRIVATE_KEY_auth_method() { + List invalidAlgorithms = Arrays.asList(JWSAlgorithm.HS256, JWSAlgorithm.HS384, JWSAlgorithm.HS512); + + for (JWSAlgorithm algorithm : invalidAlgorithms) { + SignedJWT signedJWT = createSignedJWT(algorithm); + when(token.getJwt()).thenReturn(signedJWT); + when(client.getTokenEndpointAuthMethod()).thenReturn(AuthMethod.PRIVATE_KEY); + when(client.getTokenEndpointAuthSigningAlg()).thenReturn(algorithm); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), startsWith("Unable to create signature validator for method")); + } + } + + @Test + public void should_throw_AuthenticationServiceException_for_SignedJWT_when_invalid_algorithm_for_SECRET_JWT_auth_method() { + List invalidAlgorithms = Arrays.asList( + JWSAlgorithm.RS256, JWSAlgorithm.RS384, JWSAlgorithm.RS512, + JWSAlgorithm.ES256, JWSAlgorithm.ES384, JWSAlgorithm.ES512, + JWSAlgorithm.PS256, JWSAlgorithm.PS384, JWSAlgorithm.PS512); + + for (JWSAlgorithm algorithm : invalidAlgorithms) { + SignedJWT signedJWT = createSignedJWT(algorithm); + when(token.getJwt()).thenReturn(signedJWT); + when(client.getTokenEndpointAuthMethod()).thenReturn(AuthMethod.SECRET_JWT); + when(client.getTokenEndpointAuthSigningAlg()).thenReturn(algorithm); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), startsWith("Unable to create signature validator for method")); + } + } + + @Test + public void should_throw_AuthenticationServiceException_for_SignedJWT_when_in_heart_mode_and_auth_method_is_not_PRIVATE_KEY() { + SignedJWT signedJWT = createSignedJWT(JWSAlgorithm.HS256); + when(token.getJwt()).thenReturn(signedJWT); + when(client.getTokenEndpointAuthSigningAlg()).thenReturn(JWSAlgorithm.HS256); + when(config.isHeartMode()).thenReturn(true); + when(client.getTokenEndpointAuthMethod()).thenReturn(AuthMethod.SECRET_JWT); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), is("[HEART mode] Invalid authentication method")); + } + + @Test + public void should_throw_AuthenticationServiceException_for_SignedJWT_when_null_validator() { + mockSignedJWTAuthAttempt(); + when(validators.getValidator(any(ClientDetailsEntity.class), any(JWSAlgorithm.class))).thenReturn(null); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), startsWith("Unable to create signature validator for client")); + } + + @Test + public void should_throw_AuthenticationServiceException_for_SignedJWT_when_invalid_signature() { + SignedJWT signedJWT = mockSignedJWTAuthAttempt(); + when(validator.validateSignature(signedJWT)).thenReturn(false); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), is("Signature did not validate for presented JWT authentication.")); + } + + @Test + public void should_throw_AuthenticationServiceException_for_EncryptedJWT() { + EncryptedJWT encryptedJWT = createEncryptedJWT(); + when(token.getJwt()).thenReturn(encryptedJWT); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), is("Unsupported JWT type: " + EncryptedJWT.class.getName())); + } + + @Test + public void should_throw_AuthenticationServiceException_when_null_issuer() { + JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(null).build(); + mockPlainJWTAuthAttempt(jwtClaimsSet); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), is("Assertion Token Issuer is null")); + } + + @Test + public void should_throw_AuthenticationServiceException_when_not_matching_issuer() { + JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer("not matching").build(); + mockPlainJWTAuthAttempt(jwtClaimsSet); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), startsWith("Issuers do not match")); + } + + @Test + public void should_throw_AuthenticationServiceException_when_null_expiration_time() { + JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(null).build(); + mockPlainJWTAuthAttempt(jwtClaimsSet); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), is("Assertion Token does not have required expiration claim")); + } + + @Test + public void should_throw_AuthenticationServiceException_when_expired_jwt() { + Date expiredDate = new Date(System.currentTimeMillis() - TimeUnit.SECONDS.toMillis(500)); + JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(expiredDate).build(); + mockPlainJWTAuthAttempt(jwtClaimsSet); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), startsWith("Assertion Token is expired")); + } + + @Test + public void should_throw_AuthenticationServiceException_when_jwt_valid_in_future() { + Date futureDate = new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(500)); + JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(futureDate).notBeforeTime(futureDate).build(); + mockPlainJWTAuthAttempt(jwtClaimsSet); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), startsWith("Assertion Token not valid until")); + } + + @Test + public void should_throw_AuthenticationServiceException_when_jwt_issued_in_future() { + Date futureDate = new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(500)); + JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(futureDate).issueTime(futureDate).build(); + mockPlainJWTAuthAttempt(jwtClaimsSet); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), startsWith("Assertion Token was issued in the future")); + } + + @Test + public void should_throw_AuthenticationServiceException_when_unmatching_audience() { + JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(new Date()).audience("invalid").build(); + mockPlainJWTAuthAttempt(jwtClaimsSet); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), startsWith("Audience does not match")); + } + + @Test + public void should_return_valid_token_for_PlainJWT_when_audience_contains_token_endpoint() { + JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder() + .issuer(CLIENT_ID) + .subject(SUBJECT) + .expirationTime(new Date()) + .audience(ImmutableList.of("http://issuer.com/token", "invalid")) + .build(); + PlainJWT jwt = mockPlainJWTAuthAttempt(jwtClaimsSet); + + Authentication authentication = jwtBearerAuthenticationProvider.authenticate(token); + + assertThat(authentication, instanceOf(JWTBearerAssertionAuthenticationToken.class)); + + JWTBearerAssertionAuthenticationToken token = (JWTBearerAssertionAuthenticationToken) authentication; + assertThat(token.getName(), is(SUBJECT)); + assertThat(token.getJwt(), is(jwt)); + assertThat(token.getAuthorities(), hasItems(authority1, authority2, authority3)); + assertThat(token.getAuthorities().size(), is(4)); + } + + @Test + public void should_return_valid_token_for_SignedJWT_when_audience_contains_token_endpoint() { + JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder() + .issuer(CLIENT_ID) + .subject(SUBJECT) + .expirationTime(new Date()) + .audience(ImmutableList.of("http://issuer.com/token", "invalid")) + .build(); + JWT jwt = mockSignedJWTAuthAttempt(jwtClaimsSet); + + Authentication authentication = jwtBearerAuthenticationProvider.authenticate(token); + + assertThat(authentication, instanceOf(JWTBearerAssertionAuthenticationToken.class)); + + JWTBearerAssertionAuthenticationToken token = (JWTBearerAssertionAuthenticationToken) authentication; + assertThat(token.getName(), is(SUBJECT)); + assertThat(token.getJwt(), is(jwt)); + assertThat(token.getAuthorities(), hasItems(authority1, authority2, authority3)); + assertThat(token.getAuthorities().size(), is(4)); + } + + @Test + public void should_return_valid_token_for_PlainJWT_when_issuer_does_not_end_with_slash_and_audience_contains_token_endpoint() { + JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder() + .issuer(CLIENT_ID) + .subject(SUBJECT) + .expirationTime(new Date()) + .audience(ImmutableList.of("http://issuer.com/token")) + .build(); + PlainJWT jwt = mockPlainJWTAuthAttempt(jwtClaimsSet); + when(config.getIssuer()).thenReturn("http://issuer.com/"); + + Authentication authentication = jwtBearerAuthenticationProvider.authenticate(token); + + assertThat(authentication, instanceOf(JWTBearerAssertionAuthenticationToken.class)); + + JWTBearerAssertionAuthenticationToken token = (JWTBearerAssertionAuthenticationToken) authentication; + assertThat(token.getName(), is(SUBJECT)); + assertThat(token.getJwt(), is(jwt)); + assertThat(token.getAuthorities(), hasItems(authority1, authority2, authority3)); + assertThat(token.getAuthorities().size(), is(4)); + } + + @Test + public void should_return_valid_token_for_SignedJWT_when_issuer_does_not_end_with_slash_and_audience_contains_token_endpoint() { + JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder() + .issuer(CLIENT_ID) + .subject(SUBJECT) + .expirationTime(new Date()) + .audience(ImmutableList.of("http://issuer.com/token")) + .build(); + JWT jwt = mockSignedJWTAuthAttempt(jwtClaimsSet); + when(config.getIssuer()).thenReturn("http://issuer.com/"); + + Authentication authentication = jwtBearerAuthenticationProvider.authenticate(token); + + assertThat(authentication, instanceOf(JWTBearerAssertionAuthenticationToken.class)); + + JWTBearerAssertionAuthenticationToken token = (JWTBearerAssertionAuthenticationToken) authentication; + assertThat(token.getName(), is(SUBJECT)); + assertThat(token.getJwt(), is(jwt)); + assertThat(token.getAuthorities(), hasItems(authority1, authority2, authority3)); + assertThat(token.getAuthorities().size(), is(4)); + } + + private PlainJWT mockPlainJWTAuthAttempt() { + return mockPlainJWTAuthAttempt(createJwtClaimsSet()); + } + + private PlainJWT mockPlainJWTAuthAttempt(JWTClaimsSet jwtClaimsSet) { + PlainJWT plainJWT = createPlainJWT(jwtClaimsSet); + when(token.getJwt()).thenReturn(plainJWT); + return plainJWT; + } + + private SignedJWT mockSignedJWTAuthAttempt() { + return mockSignedJWTAuthAttempt(createJwtClaimsSet()); + } + + private SignedJWT mockSignedJWTAuthAttempt(JWTClaimsSet jwtClaimsSet) { + SignedJWT signedJWT = createSignedJWT(JWSAlgorithm.RS256, jwtClaimsSet); + when(token.getJwt()).thenReturn(signedJWT); + when(client.getTokenEndpointAuthMethod()).thenReturn(AuthMethod.PRIVATE_KEY); + when(client.getTokenEndpointAuthSigningAlg()).thenReturn(JWSAlgorithm.RS256); + return signedJWT; + } + + private Throwable authenticateAndReturnThrownException() { + try { + jwtBearerAuthenticationProvider.authenticate(token); + } catch (Throwable throwable) { + return throwable; + } + throw new AssertionError("No exception thrown when expected"); + } + + private PlainJWT createPlainJWT(JWTClaimsSet jwtClaimsSet) { + return new PlainJWT(jwtClaimsSet); + } + + private SignedJWT createSignedJWT() { + return createSignedJWT(JWSAlgorithm.RS256); + } + + private SignedJWT createSignedJWT(JWSAlgorithm jwsAlgorithm) { + JWSHeader jwsHeader = new JWSHeader.Builder(jwsAlgorithm).build(); + JWTClaimsSet claims = createJwtClaimsSet(); + + return new SignedJWT(jwsHeader, claims); + } + + private SignedJWT createSignedJWT(JWSAlgorithm jwsAlgorithm, JWTClaimsSet jwtClaimsSet) { + JWSHeader jwsHeader = new JWSHeader.Builder(jwsAlgorithm).build(); + + return new SignedJWT(jwsHeader, jwtClaimsSet); + } + + private EncryptedJWT createEncryptedJWT() { + JWEHeader jweHeader = new JWEHeader.Builder(JWEAlgorithm.A128GCMKW, EncryptionMethod.A256GCM).build(); + return new EncryptedJWT(jweHeader, createJwtClaimsSet()); + } + + private JWTClaimsSet createJwtClaimsSet() { + return new JWTClaimsSet.Builder() + .issuer(CLIENT_ID) + .expirationTime(new Date()) + .audience("http://issuer.com/") + .build(); + } + +} From 37fba622b9c9f621bc563f6da193748ad420d8ca Mon Sep 17 00:00:00 2001 From: Tomasz Borowiec Date: Wed, 7 Feb 2018 10:45:28 +0100 Subject: [PATCH 2/2] Throwing exception on all other JWT types than SignedJWT --- .../JWTBearerAuthenticationProvider.java | 97 +++++++------- .../TestJWTBearerAuthenticationProvider.java | 121 +++++------------- 2 files changed, 77 insertions(+), 141 deletions(-) diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/assertion/JWTBearerAuthenticationProvider.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/assertion/JWTBearerAuthenticationProvider.java index 928e12076..aa110ea2c 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/assertion/JWTBearerAuthenticationProvider.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/assertion/JWTBearerAuthenticationProvider.java @@ -46,7 +46,6 @@ import org.springframework.security.oauth2.common.exceptions.InvalidClientExcept import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWTClaimsSet; -import com.nimbusds.jwt.PlainJWT; import com.nimbusds.jwt.SignedJWT; /** @@ -92,64 +91,60 @@ public class JWTBearerAuthenticationProvider implements AuthenticationProvider { JWT jwt = jwtAuth.getJwt(); JWTClaimsSet jwtClaims = jwt.getJWTClaimsSet(); - if (jwt instanceof PlainJWT) { - if (!AuthMethod.NONE.equals(client.getTokenEndpointAuthMethod())) { - throw new AuthenticationServiceException("Client does not support this authentication method."); - } - } else if (jwt instanceof SignedJWT) { - // check the signature with nimbus - SignedJWT jws = (SignedJWT)jwt; + if (!(jwt instanceof SignedJWT)) { + throw new AuthenticationServiceException("Unsupported JWT type: " + jwt.getClass().getName()); + } - JWSAlgorithm alg = jws.getHeader().getAlgorithm(); + // check the signature with nimbus + SignedJWT jws = (SignedJWT) jwt; - if (client.getTokenEndpointAuthSigningAlg() != null && - !client.getTokenEndpointAuthSigningAlg().equals(alg)) { - throw new AuthenticationServiceException("Client's registered token endpoint signing algorithm (" + client.getTokenEndpointAuthSigningAlg() - + ") does not match token's actual algorithm (" + alg.getName() + ")"); + JWSAlgorithm alg = jws.getHeader().getAlgorithm(); + + if (client.getTokenEndpointAuthSigningAlg() != null && + !client.getTokenEndpointAuthSigningAlg().equals(alg)) { + throw new AuthenticationServiceException("Client's registered token endpoint signing algorithm (" + client.getTokenEndpointAuthSigningAlg() + + ") does not match token's actual algorithm (" + alg.getName() + ")"); + } + + if (client.getTokenEndpointAuthMethod() == null || + client.getTokenEndpointAuthMethod().equals(AuthMethod.NONE) || + client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_BASIC) || + client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_POST)) { + + // this client doesn't support this type of authentication + throw new AuthenticationServiceException("Client does not support this authentication method."); + + } else if ((client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY) && + (alg.equals(JWSAlgorithm.RS256) + || alg.equals(JWSAlgorithm.RS384) + || alg.equals(JWSAlgorithm.RS512) + || alg.equals(JWSAlgorithm.ES256) + || alg.equals(JWSAlgorithm.ES384) + || alg.equals(JWSAlgorithm.ES512) + || alg.equals(JWSAlgorithm.PS256) + || alg.equals(JWSAlgorithm.PS384) + || alg.equals(JWSAlgorithm.PS512))) + || (client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_JWT) && + (alg.equals(JWSAlgorithm.HS256) + || alg.equals(JWSAlgorithm.HS384) + || alg.equals(JWSAlgorithm.HS512)))) { + + // double-check the method is asymmetrical if we're in HEART mode + if (config.isHeartMode() && !client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY)) { + throw new AuthenticationServiceException("[HEART mode] Invalid authentication method"); } - if (client.getTokenEndpointAuthMethod() == null || - client.getTokenEndpointAuthMethod().equals(AuthMethod.NONE) || - client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_BASIC) || - client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_POST)) { + JWTSigningAndValidationService validator = validators.getValidator(client, alg); - // this client doesn't support this type of authentication - throw new AuthenticationServiceException("Client does not support this authentication method."); + if (validator == null) { + throw new AuthenticationServiceException("Unable to create signature validator for client " + client + " and algorithm " + alg); + } - } else if ((client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY) && - (alg.equals(JWSAlgorithm.RS256) - || alg.equals(JWSAlgorithm.RS384) - || alg.equals(JWSAlgorithm.RS512) - || alg.equals(JWSAlgorithm.ES256) - || alg.equals(JWSAlgorithm.ES384) - || alg.equals(JWSAlgorithm.ES512) - || alg.equals(JWSAlgorithm.PS256) - || alg.equals(JWSAlgorithm.PS384) - || alg.equals(JWSAlgorithm.PS512))) - || (client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_JWT) && - (alg.equals(JWSAlgorithm.HS256) - || alg.equals(JWSAlgorithm.HS384) - || alg.equals(JWSAlgorithm.HS512)))) { - - // double-check the method is asymmetrical if we're in HEART mode - if (config.isHeartMode() && !client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY)) { - throw new AuthenticationServiceException("[HEART mode] Invalid authentication method"); - } - - JWTSigningAndValidationService validator = validators.getValidator(client, alg); - - if (validator == null) { - throw new AuthenticationServiceException("Unable to create signature validator for client " + client + " and algorithm " + alg); - } - - if (!validator.validateSignature(jws)) { - throw new AuthenticationServiceException("Signature did not validate for presented JWT authentication."); - } - } else { - throw new AuthenticationServiceException("Unable to create signature validator for method " + client.getTokenEndpointAuthMethod() + " and algorithm " + alg); + if (!validator.validateSignature(jws)) { + throw new AuthenticationServiceException("Signature did not validate for presented JWT authentication."); } } else { - throw new AuthenticationServiceException("Unsupported JWT type: " + jwt.getClass().getName()); + throw new AuthenticationServiceException("Unable to create signature validator for method " + client.getTokenEndpointAuthMethod() + " and algorithm " + alg); } // check the issuer diff --git a/openid-connect-server/src/test/java/org/mitre/openid/connect/assertion/TestJWTBearerAuthenticationProvider.java b/openid-connect-server/src/test/java/org/mitre/openid/connect/assertion/TestJWTBearerAuthenticationProvider.java index 1a2eb0845..fde99f499 100644 --- a/openid-connect-server/src/test/java/org/mitre/openid/connect/assertion/TestJWTBearerAuthenticationProvider.java +++ b/openid-connect-server/src/test/java/org/mitre/openid/connect/assertion/TestJWTBearerAuthenticationProvider.java @@ -2,8 +2,8 @@ package org.mitre.openid.connect.assertion; import static org.hamcrest.CoreMatchers.hasItems; import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.startsWith; -import static org.hamcrest.core.Is.is; import static org.junit.Assert.assertThat; import static org.mockito.Matchers.any; import static org.mockito.Mockito.when; @@ -110,20 +110,23 @@ public class TestJWTBearerAuthenticationProvider { } @Test - public void should_throw_AuthenticationServiceException_for_PlainJWT_when_AuthMethod_is_different_than_NONE() { + public void should_throw_AuthenticationServiceException_for_PlainJWT() { mockPlainJWTAuthAttempt(); - List unsupportedAuthMethods = Arrays.asList( - null, AuthMethod.PRIVATE_KEY, AuthMethod.PRIVATE_KEY, AuthMethod.SECRET_BASIC, AuthMethod.SECRET_JWT, AuthMethod.SECRET_POST - ); - for (AuthMethod authMethod : unsupportedAuthMethods) { - when(client.getTokenEndpointAuthMethod()).thenReturn(authMethod); + Throwable thrown = authenticateAndReturnThrownException(); - Throwable thrown = authenticateAndReturnThrownException(); + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), is("Unsupported JWT type: " + PlainJWT.class.getName())); + } - assertThat(thrown, instanceOf(AuthenticationServiceException.class)); - assertThat(thrown.getMessage(), is("Client does not support this authentication method.")); - } + @Test + public void should_throw_AuthenticationServiceException_for_EncryptedJWT() { + mockEncryptedJWTAuthAttempt(); + + Throwable thrown = authenticateAndReturnThrownException(); + + assertThat(thrown, instanceOf(AuthenticationServiceException.class)); + assertThat(thrown.getMessage(), is("Unsupported JWT type: " + EncryptedJWT.class.getName())); } @Test @@ -228,21 +231,10 @@ public class TestJWTBearerAuthenticationProvider { assertThat(thrown.getMessage(), is("Signature did not validate for presented JWT authentication.")); } - @Test - public void should_throw_AuthenticationServiceException_for_EncryptedJWT() { - EncryptedJWT encryptedJWT = createEncryptedJWT(); - when(token.getJwt()).thenReturn(encryptedJWT); - - Throwable thrown = authenticateAndReturnThrownException(); - - assertThat(thrown, instanceOf(AuthenticationServiceException.class)); - assertThat(thrown.getMessage(), is("Unsupported JWT type: " + EncryptedJWT.class.getName())); - } - @Test public void should_throw_AuthenticationServiceException_when_null_issuer() { JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(null).build(); - mockPlainJWTAuthAttempt(jwtClaimsSet); + mockSignedJWTAuthAttempt(jwtClaimsSet); Throwable thrown = authenticateAndReturnThrownException(); @@ -253,7 +245,7 @@ public class TestJWTBearerAuthenticationProvider { @Test public void should_throw_AuthenticationServiceException_when_not_matching_issuer() { JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer("not matching").build(); - mockPlainJWTAuthAttempt(jwtClaimsSet); + mockSignedJWTAuthAttempt(jwtClaimsSet); Throwable thrown = authenticateAndReturnThrownException(); @@ -264,7 +256,7 @@ public class TestJWTBearerAuthenticationProvider { @Test public void should_throw_AuthenticationServiceException_when_null_expiration_time() { JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(null).build(); - mockPlainJWTAuthAttempt(jwtClaimsSet); + mockSignedJWTAuthAttempt(jwtClaimsSet); Throwable thrown = authenticateAndReturnThrownException(); @@ -276,7 +268,7 @@ public class TestJWTBearerAuthenticationProvider { public void should_throw_AuthenticationServiceException_when_expired_jwt() { Date expiredDate = new Date(System.currentTimeMillis() - TimeUnit.SECONDS.toMillis(500)); JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(expiredDate).build(); - mockPlainJWTAuthAttempt(jwtClaimsSet); + mockSignedJWTAuthAttempt(jwtClaimsSet); Throwable thrown = authenticateAndReturnThrownException(); @@ -288,7 +280,7 @@ public class TestJWTBearerAuthenticationProvider { public void should_throw_AuthenticationServiceException_when_jwt_valid_in_future() { Date futureDate = new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(500)); JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(futureDate).notBeforeTime(futureDate).build(); - mockPlainJWTAuthAttempt(jwtClaimsSet); + mockSignedJWTAuthAttempt(jwtClaimsSet); Throwable thrown = authenticateAndReturnThrownException(); @@ -300,7 +292,7 @@ public class TestJWTBearerAuthenticationProvider { public void should_throw_AuthenticationServiceException_when_jwt_issued_in_future() { Date futureDate = new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(500)); JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(futureDate).issueTime(futureDate).build(); - mockPlainJWTAuthAttempt(jwtClaimsSet); + mockSignedJWTAuthAttempt(jwtClaimsSet); Throwable thrown = authenticateAndReturnThrownException(); @@ -311,7 +303,7 @@ public class TestJWTBearerAuthenticationProvider { @Test public void should_throw_AuthenticationServiceException_when_unmatching_audience() { JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(new Date()).audience("invalid").build(); - mockPlainJWTAuthAttempt(jwtClaimsSet); + mockSignedJWTAuthAttempt(jwtClaimsSet); Throwable thrown = authenticateAndReturnThrownException(); @@ -320,28 +312,7 @@ public class TestJWTBearerAuthenticationProvider { } @Test - public void should_return_valid_token_for_PlainJWT_when_audience_contains_token_endpoint() { - JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder() - .issuer(CLIENT_ID) - .subject(SUBJECT) - .expirationTime(new Date()) - .audience(ImmutableList.of("http://issuer.com/token", "invalid")) - .build(); - PlainJWT jwt = mockPlainJWTAuthAttempt(jwtClaimsSet); - - Authentication authentication = jwtBearerAuthenticationProvider.authenticate(token); - - assertThat(authentication, instanceOf(JWTBearerAssertionAuthenticationToken.class)); - - JWTBearerAssertionAuthenticationToken token = (JWTBearerAssertionAuthenticationToken) authentication; - assertThat(token.getName(), is(SUBJECT)); - assertThat(token.getJwt(), is(jwt)); - assertThat(token.getAuthorities(), hasItems(authority1, authority2, authority3)); - assertThat(token.getAuthorities().size(), is(4)); - } - - @Test - public void should_return_valid_token_for_SignedJWT_when_audience_contains_token_endpoint() { + public void should_return_valid_token_when_audience_contains_token_endpoint() { JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder() .issuer(CLIENT_ID) .subject(SUBJECT) @@ -362,29 +333,7 @@ public class TestJWTBearerAuthenticationProvider { } @Test - public void should_return_valid_token_for_PlainJWT_when_issuer_does_not_end_with_slash_and_audience_contains_token_endpoint() { - JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder() - .issuer(CLIENT_ID) - .subject(SUBJECT) - .expirationTime(new Date()) - .audience(ImmutableList.of("http://issuer.com/token")) - .build(); - PlainJWT jwt = mockPlainJWTAuthAttempt(jwtClaimsSet); - when(config.getIssuer()).thenReturn("http://issuer.com/"); - - Authentication authentication = jwtBearerAuthenticationProvider.authenticate(token); - - assertThat(authentication, instanceOf(JWTBearerAssertionAuthenticationToken.class)); - - JWTBearerAssertionAuthenticationToken token = (JWTBearerAssertionAuthenticationToken) authentication; - assertThat(token.getName(), is(SUBJECT)); - assertThat(token.getJwt(), is(jwt)); - assertThat(token.getAuthorities(), hasItems(authority1, authority2, authority3)); - assertThat(token.getAuthorities().size(), is(4)); - } - - @Test - public void should_return_valid_token_for_SignedJWT_when_issuer_does_not_end_with_slash_and_audience_contains_token_endpoint() { + public void should_return_valid_token_when_issuer_does_not_end_with_slash_and_audience_contains_token_endpoint() { JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder() .issuer(CLIENT_ID) .subject(SUBJECT) @@ -405,14 +354,15 @@ public class TestJWTBearerAuthenticationProvider { assertThat(token.getAuthorities().size(), is(4)); } - private PlainJWT mockPlainJWTAuthAttempt() { - return mockPlainJWTAuthAttempt(createJwtClaimsSet()); - } - - private PlainJWT mockPlainJWTAuthAttempt(JWTClaimsSet jwtClaimsSet) { - PlainJWT plainJWT = createPlainJWT(jwtClaimsSet); + private void mockPlainJWTAuthAttempt() { + PlainJWT plainJWT = new PlainJWT(createJwtClaimsSet()); when(token.getJwt()).thenReturn(plainJWT); - return plainJWT; + } + + private void mockEncryptedJWTAuthAttempt() { + JWEHeader jweHeader = new JWEHeader.Builder(JWEAlgorithm.A128GCMKW, EncryptionMethod.A256GCM).build(); + EncryptedJWT encryptedJWT = new EncryptedJWT(jweHeader, createJwtClaimsSet()); + when(token.getJwt()).thenReturn(encryptedJWT); } private SignedJWT mockSignedJWTAuthAttempt() { @@ -436,10 +386,6 @@ public class TestJWTBearerAuthenticationProvider { throw new AssertionError("No exception thrown when expected"); } - private PlainJWT createPlainJWT(JWTClaimsSet jwtClaimsSet) { - return new PlainJWT(jwtClaimsSet); - } - private SignedJWT createSignedJWT() { return createSignedJWT(JWSAlgorithm.RS256); } @@ -457,11 +403,6 @@ public class TestJWTBearerAuthenticationProvider { return new SignedJWT(jwsHeader, jwtClaimsSet); } - private EncryptedJWT createEncryptedJWT() { - JWEHeader jweHeader = new JWEHeader.Builder(JWEAlgorithm.A128GCMKW, EncryptionMethod.A256GCM).build(); - return new EncryptedJWT(jweHeader, createJwtClaimsSet()); - } - private JWTClaimsSet createJwtClaimsSet() { return new JWTClaimsSet.Builder() .issuer(CLIENT_ID)