Throwing exception on all other JWT types than SignedJWT

pull/1355/head
Tomasz Borowiec 2018-02-07 10:45:28 +01:00
parent c38b9d7a42
commit 37fba622b9
2 changed files with 77 additions and 141 deletions

View File

@ -46,7 +46,6 @@ import org.springframework.security.oauth2.common.exceptions.InvalidClientExcept
import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.PlainJWT;
import com.nimbusds.jwt.SignedJWT; import com.nimbusds.jwt.SignedJWT;
/** /**
@ -92,13 +91,12 @@ public class JWTBearerAuthenticationProvider implements AuthenticationProvider {
JWT jwt = jwtAuth.getJwt(); JWT jwt = jwtAuth.getJwt();
JWTClaimsSet jwtClaims = jwt.getJWTClaimsSet(); JWTClaimsSet jwtClaims = jwt.getJWTClaimsSet();
if (jwt instanceof PlainJWT) { if (!(jwt instanceof SignedJWT)) {
if (!AuthMethod.NONE.equals(client.getTokenEndpointAuthMethod())) { throw new AuthenticationServiceException("Unsupported JWT type: " + jwt.getClass().getName());
throw new AuthenticationServiceException("Client does not support this authentication method.");
} }
} else if (jwt instanceof SignedJWT) {
// check the signature with nimbus // check the signature with nimbus
SignedJWT jws = (SignedJWT)jwt; SignedJWT jws = (SignedJWT) jwt;
JWSAlgorithm alg = jws.getHeader().getAlgorithm(); JWSAlgorithm alg = jws.getHeader().getAlgorithm();
@ -148,9 +146,6 @@ public class JWTBearerAuthenticationProvider implements AuthenticationProvider {
} else { } else {
throw new AuthenticationServiceException("Unable to create signature validator for method " + client.getTokenEndpointAuthMethod() + " and algorithm " + alg); throw new AuthenticationServiceException("Unable to create signature validator for method " + client.getTokenEndpointAuthMethod() + " and algorithm " + alg);
} }
} else {
throw new AuthenticationServiceException("Unsupported JWT type: " + jwt.getClass().getName());
}
// check the issuer // check the issuer
if (jwtClaims.getIssuer() == null) { if (jwtClaims.getIssuer() == null) {

View File

@ -2,8 +2,8 @@ package org.mitre.openid.connect.assertion;
import static org.hamcrest.CoreMatchers.hasItems; import static org.hamcrest.CoreMatchers.hasItems;
import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.startsWith; import static org.hamcrest.CoreMatchers.startsWith;
import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThat;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -110,20 +110,23 @@ public class TestJWTBearerAuthenticationProvider {
} }
@Test @Test
public void should_throw_AuthenticationServiceException_for_PlainJWT_when_AuthMethod_is_different_than_NONE() { public void should_throw_AuthenticationServiceException_for_PlainJWT() {
mockPlainJWTAuthAttempt(); mockPlainJWTAuthAttempt();
List<AuthMethod> unsupportedAuthMethods = Arrays.asList(
null, AuthMethod.PRIVATE_KEY, AuthMethod.PRIVATE_KEY, AuthMethod.SECRET_BASIC, AuthMethod.SECRET_JWT, AuthMethod.SECRET_POST
);
for (AuthMethod authMethod : unsupportedAuthMethods) {
when(client.getTokenEndpointAuthMethod()).thenReturn(authMethod);
Throwable thrown = authenticateAndReturnThrownException(); Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class)); assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), is("Client does not support this authentication method.")); assertThat(thrown.getMessage(), is("Unsupported JWT type: " + PlainJWT.class.getName()));
} }
@Test
public void should_throw_AuthenticationServiceException_for_EncryptedJWT() {
mockEncryptedJWTAuthAttempt();
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), is("Unsupported JWT type: " + EncryptedJWT.class.getName()));
} }
@Test @Test
@ -228,21 +231,10 @@ public class TestJWTBearerAuthenticationProvider {
assertThat(thrown.getMessage(), is("Signature did not validate for presented JWT authentication.")); assertThat(thrown.getMessage(), is("Signature did not validate for presented JWT authentication."));
} }
@Test
public void should_throw_AuthenticationServiceException_for_EncryptedJWT() {
EncryptedJWT encryptedJWT = createEncryptedJWT();
when(token.getJwt()).thenReturn(encryptedJWT);
Throwable thrown = authenticateAndReturnThrownException();
assertThat(thrown, instanceOf(AuthenticationServiceException.class));
assertThat(thrown.getMessage(), is("Unsupported JWT type: " + EncryptedJWT.class.getName()));
}
@Test @Test
public void should_throw_AuthenticationServiceException_when_null_issuer() { public void should_throw_AuthenticationServiceException_when_null_issuer() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(null).build(); JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(null).build();
mockPlainJWTAuthAttempt(jwtClaimsSet); mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException(); Throwable thrown = authenticateAndReturnThrownException();
@ -253,7 +245,7 @@ public class TestJWTBearerAuthenticationProvider {
@Test @Test
public void should_throw_AuthenticationServiceException_when_not_matching_issuer() { public void should_throw_AuthenticationServiceException_when_not_matching_issuer() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer("not matching").build(); JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer("not matching").build();
mockPlainJWTAuthAttempt(jwtClaimsSet); mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException(); Throwable thrown = authenticateAndReturnThrownException();
@ -264,7 +256,7 @@ public class TestJWTBearerAuthenticationProvider {
@Test @Test
public void should_throw_AuthenticationServiceException_when_null_expiration_time() { public void should_throw_AuthenticationServiceException_when_null_expiration_time() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(null).build(); JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(null).build();
mockPlainJWTAuthAttempt(jwtClaimsSet); mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException(); Throwable thrown = authenticateAndReturnThrownException();
@ -276,7 +268,7 @@ public class TestJWTBearerAuthenticationProvider {
public void should_throw_AuthenticationServiceException_when_expired_jwt() { public void should_throw_AuthenticationServiceException_when_expired_jwt() {
Date expiredDate = new Date(System.currentTimeMillis() - TimeUnit.SECONDS.toMillis(500)); Date expiredDate = new Date(System.currentTimeMillis() - TimeUnit.SECONDS.toMillis(500));
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(expiredDate).build(); JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(expiredDate).build();
mockPlainJWTAuthAttempt(jwtClaimsSet); mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException(); Throwable thrown = authenticateAndReturnThrownException();
@ -288,7 +280,7 @@ public class TestJWTBearerAuthenticationProvider {
public void should_throw_AuthenticationServiceException_when_jwt_valid_in_future() { public void should_throw_AuthenticationServiceException_when_jwt_valid_in_future() {
Date futureDate = new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(500)); Date futureDate = new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(500));
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(futureDate).notBeforeTime(futureDate).build(); JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(futureDate).notBeforeTime(futureDate).build();
mockPlainJWTAuthAttempt(jwtClaimsSet); mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException(); Throwable thrown = authenticateAndReturnThrownException();
@ -300,7 +292,7 @@ public class TestJWTBearerAuthenticationProvider {
public void should_throw_AuthenticationServiceException_when_jwt_issued_in_future() { public void should_throw_AuthenticationServiceException_when_jwt_issued_in_future() {
Date futureDate = new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(500)); Date futureDate = new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(500));
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(futureDate).issueTime(futureDate).build(); JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(futureDate).issueTime(futureDate).build();
mockPlainJWTAuthAttempt(jwtClaimsSet); mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException(); Throwable thrown = authenticateAndReturnThrownException();
@ -311,7 +303,7 @@ public class TestJWTBearerAuthenticationProvider {
@Test @Test
public void should_throw_AuthenticationServiceException_when_unmatching_audience() { public void should_throw_AuthenticationServiceException_when_unmatching_audience() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(new Date()).audience("invalid").build(); JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().issuer(CLIENT_ID).expirationTime(new Date()).audience("invalid").build();
mockPlainJWTAuthAttempt(jwtClaimsSet); mockSignedJWTAuthAttempt(jwtClaimsSet);
Throwable thrown = authenticateAndReturnThrownException(); Throwable thrown = authenticateAndReturnThrownException();
@ -320,28 +312,7 @@ public class TestJWTBearerAuthenticationProvider {
} }
@Test @Test
public void should_return_valid_token_for_PlainJWT_when_audience_contains_token_endpoint() { public void should_return_valid_token_when_audience_contains_token_endpoint() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder()
.issuer(CLIENT_ID)
.subject(SUBJECT)
.expirationTime(new Date())
.audience(ImmutableList.of("http://issuer.com/token", "invalid"))
.build();
PlainJWT jwt = mockPlainJWTAuthAttempt(jwtClaimsSet);
Authentication authentication = jwtBearerAuthenticationProvider.authenticate(token);
assertThat(authentication, instanceOf(JWTBearerAssertionAuthenticationToken.class));
JWTBearerAssertionAuthenticationToken token = (JWTBearerAssertionAuthenticationToken) authentication;
assertThat(token.getName(), is(SUBJECT));
assertThat(token.getJwt(), is(jwt));
assertThat(token.getAuthorities(), hasItems(authority1, authority2, authority3));
assertThat(token.getAuthorities().size(), is(4));
}
@Test
public void should_return_valid_token_for_SignedJWT_when_audience_contains_token_endpoint() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder() JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder()
.issuer(CLIENT_ID) .issuer(CLIENT_ID)
.subject(SUBJECT) .subject(SUBJECT)
@ -362,29 +333,7 @@ public class TestJWTBearerAuthenticationProvider {
} }
@Test @Test
public void should_return_valid_token_for_PlainJWT_when_issuer_does_not_end_with_slash_and_audience_contains_token_endpoint() { public void should_return_valid_token_when_issuer_does_not_end_with_slash_and_audience_contains_token_endpoint() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder()
.issuer(CLIENT_ID)
.subject(SUBJECT)
.expirationTime(new Date())
.audience(ImmutableList.of("http://issuer.com/token"))
.build();
PlainJWT jwt = mockPlainJWTAuthAttempt(jwtClaimsSet);
when(config.getIssuer()).thenReturn("http://issuer.com/");
Authentication authentication = jwtBearerAuthenticationProvider.authenticate(token);
assertThat(authentication, instanceOf(JWTBearerAssertionAuthenticationToken.class));
JWTBearerAssertionAuthenticationToken token = (JWTBearerAssertionAuthenticationToken) authentication;
assertThat(token.getName(), is(SUBJECT));
assertThat(token.getJwt(), is(jwt));
assertThat(token.getAuthorities(), hasItems(authority1, authority2, authority3));
assertThat(token.getAuthorities().size(), is(4));
}
@Test
public void should_return_valid_token_for_SignedJWT_when_issuer_does_not_end_with_slash_and_audience_contains_token_endpoint() {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder() JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder()
.issuer(CLIENT_ID) .issuer(CLIENT_ID)
.subject(SUBJECT) .subject(SUBJECT)
@ -405,14 +354,15 @@ public class TestJWTBearerAuthenticationProvider {
assertThat(token.getAuthorities().size(), is(4)); assertThat(token.getAuthorities().size(), is(4));
} }
private PlainJWT mockPlainJWTAuthAttempt() { private void mockPlainJWTAuthAttempt() {
return mockPlainJWTAuthAttempt(createJwtClaimsSet()); PlainJWT plainJWT = new PlainJWT(createJwtClaimsSet());
}
private PlainJWT mockPlainJWTAuthAttempt(JWTClaimsSet jwtClaimsSet) {
PlainJWT plainJWT = createPlainJWT(jwtClaimsSet);
when(token.getJwt()).thenReturn(plainJWT); when(token.getJwt()).thenReturn(plainJWT);
return plainJWT; }
private void mockEncryptedJWTAuthAttempt() {
JWEHeader jweHeader = new JWEHeader.Builder(JWEAlgorithm.A128GCMKW, EncryptionMethod.A256GCM).build();
EncryptedJWT encryptedJWT = new EncryptedJWT(jweHeader, createJwtClaimsSet());
when(token.getJwt()).thenReturn(encryptedJWT);
} }
private SignedJWT mockSignedJWTAuthAttempt() { private SignedJWT mockSignedJWTAuthAttempt() {
@ -436,10 +386,6 @@ public class TestJWTBearerAuthenticationProvider {
throw new AssertionError("No exception thrown when expected"); throw new AssertionError("No exception thrown when expected");
} }
private PlainJWT createPlainJWT(JWTClaimsSet jwtClaimsSet) {
return new PlainJWT(jwtClaimsSet);
}
private SignedJWT createSignedJWT() { private SignedJWT createSignedJWT() {
return createSignedJWT(JWSAlgorithm.RS256); return createSignedJWT(JWSAlgorithm.RS256);
} }
@ -457,11 +403,6 @@ public class TestJWTBearerAuthenticationProvider {
return new SignedJWT(jwsHeader, jwtClaimsSet); return new SignedJWT(jwsHeader, jwtClaimsSet);
} }
private EncryptedJWT createEncryptedJWT() {
JWEHeader jweHeader = new JWEHeader.Builder(JWEAlgorithm.A128GCMKW, EncryptionMethod.A256GCM).build();
return new EncryptedJWT(jweHeader, createJwtClaimsSet());
}
private JWTClaimsSet createJwtClaimsSet() { private JWTClaimsSet createJwtClaimsSet() {
return new JWTClaimsSet.Builder() return new JWTClaimsSet.Builder()
.issuer(CLIENT_ID) .issuer(CLIENT_ID)