diff --git a/openid-connect-common/src/main/java/org/mitre/jwt/signer/service/impl/ClientKeyCacheService.java b/openid-connect-common/src/main/java/org/mitre/jwt/signer/service/impl/ClientKeyCacheService.java new file mode 100644 index 000000000..02381f24c --- /dev/null +++ b/openid-connect-common/src/main/java/org/mitre/jwt/signer/service/impl/ClientKeyCacheService.java @@ -0,0 +1,156 @@ +/******************************************************************************* + * Copyright 2015 The MITRE Corporation + * and the MIT Kerberos and Internet Trust Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +package org.mitre.jwt.signer.service.impl; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +import org.mitre.jose.keystore.JWKSetKeyStore; +import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService; +import org.mitre.jwt.encryption.service.impl.DefaultJWTEncryptionAndDecryptionService; +import org.mitre.jwt.signer.service.JWTSigningAndValidationService; +import org.mitre.oauth2.model.ClientDetailsEntity; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import com.google.common.base.Strings; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.util.concurrent.UncheckedExecutionException; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.JWKSet; + +/** + * + * Takes in a client and returns the appropriate validator or encrypter for + * that client's registered key types. + * + * @author jricher + * + */ +@Service +public class ClientKeyCacheService { + + private static Logger logger = LoggerFactory.getLogger(ClientKeyCacheService.class); + + @Autowired + private JWKSetCacheService jwksUriCache = new JWKSetCacheService(); + + @Autowired + private SymmetricKeyJWTValidatorCacheService symmetricCache = new SymmetricKeyJWTValidatorCacheService(); + + // cache of validators for by-value JWKs + private LoadingCache<JWKSet, JWTSigningAndValidationService> jwksValidators; + + // cache of encryptors for by-value JWKs + private LoadingCache<JWKSet, JWTEncryptionAndDecryptionService> jwksEncrypters; + + public ClientKeyCacheService() { + this.jwksValidators = CacheBuilder.newBuilder() + .expireAfterWrite(1, TimeUnit.HOURS) // expires 1 hour after fetch + .maximumSize(100) + .build(new JWKSetVerifierBuilder()); + this.jwksEncrypters = CacheBuilder.newBuilder() + .expireAfterWrite(1, TimeUnit.HOURS) // expires 1 hour after fetch + .maximumSize(100) + .build(new JWKSetEncryptorBuilder()); + } + + + public JWTSigningAndValidationService getValidator(ClientDetailsEntity client, JWSAlgorithm alg) { + + try { + if (alg.equals(JWSAlgorithm.RS256) + || alg.equals(JWSAlgorithm.RS384) + || alg.equals(JWSAlgorithm.RS512) + || alg.equals(JWSAlgorithm.ES256) + || alg.equals(JWSAlgorithm.ES384) + || alg.equals(JWSAlgorithm.ES512) + || alg.equals(JWSAlgorithm.PS256) + || alg.equals(JWSAlgorithm.PS384) + || alg.equals(JWSAlgorithm.PS512)) { + + // asymmetric key + if (client.getJwks() != null) { + return jwksValidators.get(client.getJwks()); + } else if (!Strings.isNullOrEmpty(client.getJwksUri())) { + return jwksUriCache.getValidator(client.getJwksUri()); + } else { + return null; + } + + } else if (alg.equals(JWSAlgorithm.HS256) + || alg.equals(JWSAlgorithm.HS384) + || alg.equals(JWSAlgorithm.HS512)) { + + // symmetric key + + return symmetricCache.getSymmetricValidtor(client); + + } else { + + return null; + } + } catch (UncheckedExecutionException | ExecutionException e) { + logger.error("Problem loading client validator", e); + return null; + } + + } + + public JWTEncryptionAndDecryptionService getEncrypter(ClientDetailsEntity client) { + + try { + if (client.getJwks() != null) { + return jwksEncrypters.get(client.getJwks()); + } else if (!Strings.isNullOrEmpty(client.getJwksUri())) { + return jwksUriCache.getEncrypter(client.getJwksUri()); + } else { + return null; + } + } catch (UncheckedExecutionException | ExecutionException e) { + logger.error("Problem loading client encrypter", e); + return null; + } + + } + + + private class JWKSetEncryptorBuilder extends CacheLoader<JWKSet, JWTEncryptionAndDecryptionService> { + + @Override + public JWTEncryptionAndDecryptionService load(JWKSet key) throws Exception { + return new DefaultJWTEncryptionAndDecryptionService(new JWKSetKeyStore(key)); + } + + } + + private class JWKSetVerifierBuilder extends CacheLoader<JWKSet, JWTSigningAndValidationService> { + + @Override + public JWTSigningAndValidationService load(JWKSet key) throws Exception { + return new DefaultJWTSigningAndValidationService(new JWKSetKeyStore(key)); + } + + } + + +} diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/assertion/JWTBearerAuthenticationProvider.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/assertion/JWTBearerAuthenticationProvider.java index 8e453b6de..7ef8ab718 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/assertion/JWTBearerAuthenticationProvider.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/assertion/JWTBearerAuthenticationProvider.java @@ -25,8 +25,7 @@ import java.util.HashSet; import java.util.Set; import org.mitre.jwt.signer.service.JWTSigningAndValidationService; -import org.mitre.jwt.signer.service.impl.JWKSetCacheService; -import org.mitre.jwt.signer.service.impl.SymmetricKeyJWTValidatorCacheService; +import org.mitre.jwt.signer.service.impl.ClientKeyCacheService; import org.mitre.oauth2.model.ClientDetailsEntity; import org.mitre.oauth2.model.ClientDetailsEntity.AuthMethod; import org.mitre.oauth2.service.ClientDetailsEntityService; @@ -63,11 +62,7 @@ public class JWTBearerAuthenticationProvider implements AuthenticationProvider { // map of verifiers, load keys for clients @Autowired - private JWKSetCacheService validators; - - // map of symmetric verifiers for client secrets - @Autowired - private SymmetricKeyJWTValidatorCacheService symmetricCacheService; + private ClientKeyCacheService validators; // Allow for time sync issues by having a window of X seconds. private int timeSkewAllowance = 300; @@ -114,7 +109,7 @@ public class JWTBearerAuthenticationProvider implements AuthenticationProvider { // this client doesn't support this type of authentication throw new AuthenticationServiceException("Client does not support this authentication method."); - } else if (client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY) && + } else if ((client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY) && (alg.equals(JWSAlgorithm.RS256) || alg.equals(JWSAlgorithm.RS384) || alg.equals(JWSAlgorithm.RS512) @@ -123,36 +118,23 @@ public class JWTBearerAuthenticationProvider implements AuthenticationProvider { || alg.equals(JWSAlgorithm.ES512) || alg.equals(JWSAlgorithm.PS256) || alg.equals(JWSAlgorithm.PS384) - || alg.equals(JWSAlgorithm.PS512))) { - - // it's a known public/private key algorithm - - JWTSigningAndValidationService validator = validators.getValidator(client.getJwksUri()); - - if (validator == null) { - throw new AuthenticationServiceException("Unable to create signature validator for client's JWKS URI: " + client.getJwksUri()); - } - - if (!validator.validateSignature(jws)) { - throw new AuthenticationServiceException("Signature did not validate for presented JWT authentication."); - } - } else if (client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_JWT) && + || alg.equals(JWSAlgorithm.PS512))) + || (client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_JWT) && (alg.equals(JWSAlgorithm.HS256) || alg.equals(JWSAlgorithm.HS384) - || alg.equals(JWSAlgorithm.HS512))) { + || alg.equals(JWSAlgorithm.HS512)))) { - // it's HMAC, we need to make a validator based on the client secret - - JWTSigningAndValidationService validator = symmetricCacheService.getSymmetricValidtor(client); + JWTSigningAndValidationService validator = validators.getValidator(client, alg); if (validator == null) { - throw new AuthenticationServiceException("Unable to create signature validator for client's secret: " + client.getClientSecret()); + throw new AuthenticationServiceException("Unable to create signature validator for client " + client + " and algorithm " + alg); } if (!validator.validateSignature(jws)) { throw new AuthenticationServiceException("Signature did not validate for presented JWT authentication."); } - + } else { + throw new AuthenticationServiceException("Unable to create signature validator for method " + client.getTokenEndpointAuthMethod() + " and algorithm " + alg); } } diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/request/ConnectOAuth2RequestFactory.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/request/ConnectOAuth2RequestFactory.java index 1e93765de..74e843d5c 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/request/ConnectOAuth2RequestFactory.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/request/ConnectOAuth2RequestFactory.java @@ -25,8 +25,7 @@ import java.util.UUID; import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService; import org.mitre.jwt.signer.service.JWTSigningAndValidationService; -import org.mitre.jwt.signer.service.impl.JWKSetCacheService; -import org.mitre.jwt.signer.service.impl.SymmetricKeyJWTValidatorCacheService; +import org.mitre.jwt.signer.service.impl.ClientKeyCacheService; import org.mitre.oauth2.model.ClientDetailsEntity; import org.mitre.oauth2.service.ClientDetailsEntityService; import org.mitre.oauth2.service.SystemScopeService; @@ -79,10 +78,7 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory { private ClientDetailsEntityService clientDetailsService; @Autowired - private JWKSetCacheService validators; - - @Autowired - private SymmetricKeyJWTValidatorCacheService symmetricCacheService; + private ClientKeyCacheService validators; @Autowired private SystemScopeService systemScopes; @@ -205,51 +201,15 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory { throw new InvalidClientException("Client's registered request object signing algorithm (" + client.getRequestObjectSigningAlg() + ") does not match request object's actual algorithm (" + alg.getName() + ")"); } - if (alg.equals(JWSAlgorithm.RS256) - || alg.equals(JWSAlgorithm.RS384) - || alg.equals(JWSAlgorithm.RS512) - || alg.equals(JWSAlgorithm.ES256) - || alg.equals(JWSAlgorithm.ES384) - || alg.equals(JWSAlgorithm.ES512) - || alg.equals(JWSAlgorithm.PS256) - || alg.equals(JWSAlgorithm.PS384) - || alg.equals(JWSAlgorithm.PS512)) { - - // it's a public key, need to find the JWK URI and fetch the key - - if (Strings.isNullOrEmpty(client.getJwksUri()) && client.getJwks() == null) { - throw new InvalidClientException("Client must have a JWKS registered to use signed request objects with a public key."); - } - - // check JWT signature - JWTSigningAndValidationService validator = validators.getValidator(client.getJwksUri()); - - if (validator == null) { - throw new InvalidClientException("Unable to create signature validator for client's JWKS URI: " + client.getJwksUri()); - } - - if (!validator.validateSignature(signedJwt)) { - throw new InvalidClientException("Signature did not validate for presented JWT request object."); - } - } else if (alg.equals(JWSAlgorithm.HS256) - || alg.equals(JWSAlgorithm.HS384) - || alg.equals(JWSAlgorithm.HS512)) { - - // it's HMAC, we need to make a validator based on the client secret - - JWTSigningAndValidationService validator = symmetricCacheService.getSymmetricValidtor(client); - - if (validator == null) { - throw new InvalidClientException("Unable to create signature validator for client's secret: " + client.getClientSecret()); - } - - if (!validator.validateSignature(signedJwt)) { - throw new InvalidClientException("Signature did not validate for presented JWT request object."); - } - + JWTSigningAndValidationService validator = validators.getValidator(client, alg); + if (validator == null) { + throw new InvalidClientException("Unable to create signature validator for client " + client + " and algorithm " + alg); } + if (!validator.validateSignature(signedJwt)) { + throw new InvalidClientException("Signature did not validate for presented JWT request object."); + } } else if (jwt instanceof PlainJWT) { PlainJWT plainJwt = (PlainJWT)jwt; diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultOIDCTokenService.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultOIDCTokenService.java index e0d736f0b..fc6b8679a 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultOIDCTokenService.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultOIDCTokenService.java @@ -23,7 +23,7 @@ import java.util.UUID; import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService; import org.mitre.jwt.signer.service.JWTSigningAndValidationService; -import org.mitre.jwt.signer.service.impl.JWKSetCacheService; +import org.mitre.jwt.signer.service.impl.ClientKeyCacheService; import org.mitre.jwt.signer.service.impl.SymmetricKeyJWTValidatorCacheService; import org.mitre.oauth2.model.AuthenticationHolderEntity; import org.mitre.oauth2.model.ClientDetailsEntity; @@ -83,7 +83,7 @@ public class DefaultOIDCTokenService implements OIDCTokenService { private ConfigurationPropertiesBean configBean; @Autowired - private JWKSetCacheService encrypters; + private ClientKeyCacheService encrypters; @Autowired private SymmetricKeyJWTValidatorCacheService symmetricCacheService; @@ -144,7 +144,7 @@ public class DefaultOIDCTokenService implements OIDCTokenService { && client.getIdTokenEncryptedResponseEnc() != null && !client.getIdTokenEncryptedResponseEnc().equals(Algorithm.NONE) && (!Strings.isNullOrEmpty(client.getJwksUri()) || client.getJwks() != null)) { - JWTEncryptionAndDecryptionService encrypter = encrypters.getEncrypter(client.getJwksUri()); + JWTEncryptionAndDecryptionService encrypter = encrypters.getEncrypter(client); if (encrypter != null) { diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/view/UserInfoJWTView.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/view/UserInfoJWTView.java index 1ebfb5a02..fa822bfff 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/view/UserInfoJWTView.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/view/UserInfoJWTView.java @@ -32,7 +32,7 @@ import javax.servlet.http.HttpServletResponse; import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService; import org.mitre.jwt.signer.service.JWTSigningAndValidationService; -import org.mitre.jwt.signer.service.impl.JWKSetCacheService; +import org.mitre.jwt.signer.service.impl.ClientKeyCacheService; import org.mitre.jwt.signer.service.impl.SymmetricKeyJWTValidatorCacheService; import org.mitre.oauth2.model.ClientDetailsEntity; import org.mitre.openid.connect.config.ConfigurationPropertiesBean; @@ -80,7 +80,7 @@ public class UserInfoJWTView extends UserInfoView { private ConfigurationPropertiesBean config; @Autowired - private JWKSetCacheService encrypters; + private ClientKeyCacheService encrypters; @Autowired private SymmetricKeyJWTValidatorCacheService symmetricCacheService; @@ -115,7 +115,7 @@ public class UserInfoJWTView extends UserInfoView { // encrypt it to the client's key - JWTEncryptionAndDecryptionService encrypter = encrypters.getEncrypter(client.getJwksUri()); + JWTEncryptionAndDecryptionService encrypter = encrypters.getEncrypter(client); if (encrypter != null) {