Generalized client key handling into a single cache service
parent
032d41e5ed
commit
642942b5cf
|
@ -0,0 +1,156 @@
|
|||
/*******************************************************************************
|
||||
* Copyright 2015 The MITRE Corporation
|
||||
* and the MIT Kerberos and Internet Trust Consortium
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*******************************************************************************/
|
||||
|
||||
package org.mitre.jwt.signer.service.impl;
|
||||
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import org.mitre.jose.keystore.JWKSetKeyStore;
|
||||
import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService;
|
||||
import org.mitre.jwt.encryption.service.impl.DefaultJWTEncryptionAndDecryptionService;
|
||||
import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
|
||||
import org.mitre.oauth2.model.ClientDetailsEntity;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.cache.CacheBuilder;
|
||||
import com.google.common.cache.CacheLoader;
|
||||
import com.google.common.cache.LoadingCache;
|
||||
import com.google.common.util.concurrent.UncheckedExecutionException;
|
||||
import com.nimbusds.jose.JWSAlgorithm;
|
||||
import com.nimbusds.jose.jwk.JWKSet;
|
||||
|
||||
/**
|
||||
*
|
||||
* Takes in a client and returns the appropriate validator or encrypter for
|
||||
* that client's registered key types.
|
||||
*
|
||||
* @author jricher
|
||||
*
|
||||
*/
|
||||
@Service
|
||||
public class ClientKeyCacheService {
|
||||
|
||||
private static Logger logger = LoggerFactory.getLogger(ClientKeyCacheService.class);
|
||||
|
||||
@Autowired
|
||||
private JWKSetCacheService jwksUriCache = new JWKSetCacheService();
|
||||
|
||||
@Autowired
|
||||
private SymmetricKeyJWTValidatorCacheService symmetricCache = new SymmetricKeyJWTValidatorCacheService();
|
||||
|
||||
// cache of validators for by-value JWKs
|
||||
private LoadingCache<JWKSet, JWTSigningAndValidationService> jwksValidators;
|
||||
|
||||
// cache of encryptors for by-value JWKs
|
||||
private LoadingCache<JWKSet, JWTEncryptionAndDecryptionService> jwksEncrypters;
|
||||
|
||||
public ClientKeyCacheService() {
|
||||
this.jwksValidators = CacheBuilder.newBuilder()
|
||||
.expireAfterWrite(1, TimeUnit.HOURS) // expires 1 hour after fetch
|
||||
.maximumSize(100)
|
||||
.build(new JWKSetVerifierBuilder());
|
||||
this.jwksEncrypters = CacheBuilder.newBuilder()
|
||||
.expireAfterWrite(1, TimeUnit.HOURS) // expires 1 hour after fetch
|
||||
.maximumSize(100)
|
||||
.build(new JWKSetEncryptorBuilder());
|
||||
}
|
||||
|
||||
|
||||
public JWTSigningAndValidationService getValidator(ClientDetailsEntity client, JWSAlgorithm alg) {
|
||||
|
||||
try {
|
||||
if (alg.equals(JWSAlgorithm.RS256)
|
||||
|| alg.equals(JWSAlgorithm.RS384)
|
||||
|| alg.equals(JWSAlgorithm.RS512)
|
||||
|| alg.equals(JWSAlgorithm.ES256)
|
||||
|| alg.equals(JWSAlgorithm.ES384)
|
||||
|| alg.equals(JWSAlgorithm.ES512)
|
||||
|| alg.equals(JWSAlgorithm.PS256)
|
||||
|| alg.equals(JWSAlgorithm.PS384)
|
||||
|| alg.equals(JWSAlgorithm.PS512)) {
|
||||
|
||||
// asymmetric key
|
||||
if (client.getJwks() != null) {
|
||||
return jwksValidators.get(client.getJwks());
|
||||
} else if (!Strings.isNullOrEmpty(client.getJwksUri())) {
|
||||
return jwksUriCache.getValidator(client.getJwksUri());
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
|
||||
} else if (alg.equals(JWSAlgorithm.HS256)
|
||||
|| alg.equals(JWSAlgorithm.HS384)
|
||||
|| alg.equals(JWSAlgorithm.HS512)) {
|
||||
|
||||
// symmetric key
|
||||
|
||||
return symmetricCache.getSymmetricValidtor(client);
|
||||
|
||||
} else {
|
||||
|
||||
return null;
|
||||
}
|
||||
} catch (UncheckedExecutionException | ExecutionException e) {
|
||||
logger.error("Problem loading client validator", e);
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public JWTEncryptionAndDecryptionService getEncrypter(ClientDetailsEntity client) {
|
||||
|
||||
try {
|
||||
if (client.getJwks() != null) {
|
||||
return jwksEncrypters.get(client.getJwks());
|
||||
} else if (!Strings.isNullOrEmpty(client.getJwksUri())) {
|
||||
return jwksUriCache.getEncrypter(client.getJwksUri());
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
} catch (UncheckedExecutionException | ExecutionException e) {
|
||||
logger.error("Problem loading client encrypter", e);
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
private class JWKSetEncryptorBuilder extends CacheLoader<JWKSet, JWTEncryptionAndDecryptionService> {
|
||||
|
||||
@Override
|
||||
public JWTEncryptionAndDecryptionService load(JWKSet key) throws Exception {
|
||||
return new DefaultJWTEncryptionAndDecryptionService(new JWKSetKeyStore(key));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private class JWKSetVerifierBuilder extends CacheLoader<JWKSet, JWTSigningAndValidationService> {
|
||||
|
||||
@Override
|
||||
public JWTSigningAndValidationService load(JWKSet key) throws Exception {
|
||||
return new DefaultJWTSigningAndValidationService(new JWKSetKeyStore(key));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -25,8 +25,7 @@ import java.util.HashSet;
|
|||
import java.util.Set;
|
||||
|
||||
import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
|
||||
import org.mitre.jwt.signer.service.impl.JWKSetCacheService;
|
||||
import org.mitre.jwt.signer.service.impl.SymmetricKeyJWTValidatorCacheService;
|
||||
import org.mitre.jwt.signer.service.impl.ClientKeyCacheService;
|
||||
import org.mitre.oauth2.model.ClientDetailsEntity;
|
||||
import org.mitre.oauth2.model.ClientDetailsEntity.AuthMethod;
|
||||
import org.mitre.oauth2.service.ClientDetailsEntityService;
|
||||
|
@ -63,11 +62,7 @@ public class JWTBearerAuthenticationProvider implements AuthenticationProvider {
|
|||
|
||||
// map of verifiers, load keys for clients
|
||||
@Autowired
|
||||
private JWKSetCacheService validators;
|
||||
|
||||
// map of symmetric verifiers for client secrets
|
||||
@Autowired
|
||||
private SymmetricKeyJWTValidatorCacheService symmetricCacheService;
|
||||
private ClientKeyCacheService validators;
|
||||
|
||||
// Allow for time sync issues by having a window of X seconds.
|
||||
private int timeSkewAllowance = 300;
|
||||
|
@ -114,7 +109,7 @@ public class JWTBearerAuthenticationProvider implements AuthenticationProvider {
|
|||
// this client doesn't support this type of authentication
|
||||
throw new AuthenticationServiceException("Client does not support this authentication method.");
|
||||
|
||||
} else if (client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY) &&
|
||||
} else if ((client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY) &&
|
||||
(alg.equals(JWSAlgorithm.RS256)
|
||||
|| alg.equals(JWSAlgorithm.RS384)
|
||||
|| alg.equals(JWSAlgorithm.RS512)
|
||||
|
@ -123,36 +118,23 @@ public class JWTBearerAuthenticationProvider implements AuthenticationProvider {
|
|||
|| alg.equals(JWSAlgorithm.ES512)
|
||||
|| alg.equals(JWSAlgorithm.PS256)
|
||||
|| alg.equals(JWSAlgorithm.PS384)
|
||||
|| alg.equals(JWSAlgorithm.PS512))) {
|
||||
|
||||
// it's a known public/private key algorithm
|
||||
|
||||
JWTSigningAndValidationService validator = validators.getValidator(client.getJwksUri());
|
||||
|
||||
if (validator == null) {
|
||||
throw new AuthenticationServiceException("Unable to create signature validator for client's JWKS URI: " + client.getJwksUri());
|
||||
}
|
||||
|
||||
if (!validator.validateSignature(jws)) {
|
||||
throw new AuthenticationServiceException("Signature did not validate for presented JWT authentication.");
|
||||
}
|
||||
} else if (client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_JWT) &&
|
||||
|| alg.equals(JWSAlgorithm.PS512)))
|
||||
|| (client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_JWT) &&
|
||||
(alg.equals(JWSAlgorithm.HS256)
|
||||
|| alg.equals(JWSAlgorithm.HS384)
|
||||
|| alg.equals(JWSAlgorithm.HS512))) {
|
||||
|| alg.equals(JWSAlgorithm.HS512)))) {
|
||||
|
||||
// it's HMAC, we need to make a validator based on the client secret
|
||||
|
||||
JWTSigningAndValidationService validator = symmetricCacheService.getSymmetricValidtor(client);
|
||||
JWTSigningAndValidationService validator = validators.getValidator(client, alg);
|
||||
|
||||
if (validator == null) {
|
||||
throw new AuthenticationServiceException("Unable to create signature validator for client's secret: " + client.getClientSecret());
|
||||
throw new AuthenticationServiceException("Unable to create signature validator for client " + client + " and algorithm " + alg);
|
||||
}
|
||||
|
||||
if (!validator.validateSignature(jws)) {
|
||||
throw new AuthenticationServiceException("Signature did not validate for presented JWT authentication.");
|
||||
}
|
||||
|
||||
} else {
|
||||
throw new AuthenticationServiceException("Unable to create signature validator for method " + client.getTokenEndpointAuthMethod() + " and algorithm " + alg);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -25,8 +25,7 @@ import java.util.UUID;
|
|||
|
||||
import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService;
|
||||
import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
|
||||
import org.mitre.jwt.signer.service.impl.JWKSetCacheService;
|
||||
import org.mitre.jwt.signer.service.impl.SymmetricKeyJWTValidatorCacheService;
|
||||
import org.mitre.jwt.signer.service.impl.ClientKeyCacheService;
|
||||
import org.mitre.oauth2.model.ClientDetailsEntity;
|
||||
import org.mitre.oauth2.service.ClientDetailsEntityService;
|
||||
import org.mitre.oauth2.service.SystemScopeService;
|
||||
|
@ -79,10 +78,7 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
|
|||
private ClientDetailsEntityService clientDetailsService;
|
||||
|
||||
@Autowired
|
||||
private JWKSetCacheService validators;
|
||||
|
||||
@Autowired
|
||||
private SymmetricKeyJWTValidatorCacheService symmetricCacheService;
|
||||
private ClientKeyCacheService validators;
|
||||
|
||||
@Autowired
|
||||
private SystemScopeService systemScopes;
|
||||
|
@ -205,51 +201,15 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
|
|||
throw new InvalidClientException("Client's registered request object signing algorithm (" + client.getRequestObjectSigningAlg() + ") does not match request object's actual algorithm (" + alg.getName() + ")");
|
||||
}
|
||||
|
||||
if (alg.equals(JWSAlgorithm.RS256)
|
||||
|| alg.equals(JWSAlgorithm.RS384)
|
||||
|| alg.equals(JWSAlgorithm.RS512)
|
||||
|| alg.equals(JWSAlgorithm.ES256)
|
||||
|| alg.equals(JWSAlgorithm.ES384)
|
||||
|| alg.equals(JWSAlgorithm.ES512)
|
||||
|| alg.equals(JWSAlgorithm.PS256)
|
||||
|| alg.equals(JWSAlgorithm.PS384)
|
||||
|| alg.equals(JWSAlgorithm.PS512)) {
|
||||
|
||||
// it's a public key, need to find the JWK URI and fetch the key
|
||||
|
||||
if (Strings.isNullOrEmpty(client.getJwksUri()) && client.getJwks() == null) {
|
||||
throw new InvalidClientException("Client must have a JWKS registered to use signed request objects with a public key.");
|
||||
}
|
||||
|
||||
// check JWT signature
|
||||
JWTSigningAndValidationService validator = validators.getValidator(client.getJwksUri());
|
||||
|
||||
if (validator == null) {
|
||||
throw new InvalidClientException("Unable to create signature validator for client's JWKS URI: " + client.getJwksUri());
|
||||
}
|
||||
|
||||
if (!validator.validateSignature(signedJwt)) {
|
||||
throw new InvalidClientException("Signature did not validate for presented JWT request object.");
|
||||
}
|
||||
} else if (alg.equals(JWSAlgorithm.HS256)
|
||||
|| alg.equals(JWSAlgorithm.HS384)
|
||||
|| alg.equals(JWSAlgorithm.HS512)) {
|
||||
|
||||
// it's HMAC, we need to make a validator based on the client secret
|
||||
|
||||
JWTSigningAndValidationService validator = symmetricCacheService.getSymmetricValidtor(client);
|
||||
|
||||
if (validator == null) {
|
||||
throw new InvalidClientException("Unable to create signature validator for client's secret: " + client.getClientSecret());
|
||||
}
|
||||
|
||||
if (!validator.validateSignature(signedJwt)) {
|
||||
throw new InvalidClientException("Signature did not validate for presented JWT request object.");
|
||||
}
|
||||
|
||||
JWTSigningAndValidationService validator = validators.getValidator(client, alg);
|
||||
|
||||
if (validator == null) {
|
||||
throw new InvalidClientException("Unable to create signature validator for client " + client + " and algorithm " + alg);
|
||||
}
|
||||
|
||||
if (!validator.validateSignature(signedJwt)) {
|
||||
throw new InvalidClientException("Signature did not validate for presented JWT request object.");
|
||||
}
|
||||
|
||||
} else if (jwt instanceof PlainJWT) {
|
||||
PlainJWT plainJwt = (PlainJWT)jwt;
|
||||
|
|
|
@ -23,7 +23,7 @@ import java.util.UUID;
|
|||
|
||||
import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService;
|
||||
import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
|
||||
import org.mitre.jwt.signer.service.impl.JWKSetCacheService;
|
||||
import org.mitre.jwt.signer.service.impl.ClientKeyCacheService;
|
||||
import org.mitre.jwt.signer.service.impl.SymmetricKeyJWTValidatorCacheService;
|
||||
import org.mitre.oauth2.model.AuthenticationHolderEntity;
|
||||
import org.mitre.oauth2.model.ClientDetailsEntity;
|
||||
|
@ -83,7 +83,7 @@ public class DefaultOIDCTokenService implements OIDCTokenService {
|
|||
private ConfigurationPropertiesBean configBean;
|
||||
|
||||
@Autowired
|
||||
private JWKSetCacheService encrypters;
|
||||
private ClientKeyCacheService encrypters;
|
||||
|
||||
@Autowired
|
||||
private SymmetricKeyJWTValidatorCacheService symmetricCacheService;
|
||||
|
@ -144,7 +144,7 @@ public class DefaultOIDCTokenService implements OIDCTokenService {
|
|||
&& client.getIdTokenEncryptedResponseEnc() != null && !client.getIdTokenEncryptedResponseEnc().equals(Algorithm.NONE)
|
||||
&& (!Strings.isNullOrEmpty(client.getJwksUri()) || client.getJwks() != null)) {
|
||||
|
||||
JWTEncryptionAndDecryptionService encrypter = encrypters.getEncrypter(client.getJwksUri());
|
||||
JWTEncryptionAndDecryptionService encrypter = encrypters.getEncrypter(client);
|
||||
|
||||
if (encrypter != null) {
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ import javax.servlet.http.HttpServletResponse;
|
|||
|
||||
import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService;
|
||||
import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
|
||||
import org.mitre.jwt.signer.service.impl.JWKSetCacheService;
|
||||
import org.mitre.jwt.signer.service.impl.ClientKeyCacheService;
|
||||
import org.mitre.jwt.signer.service.impl.SymmetricKeyJWTValidatorCacheService;
|
||||
import org.mitre.oauth2.model.ClientDetailsEntity;
|
||||
import org.mitre.openid.connect.config.ConfigurationPropertiesBean;
|
||||
|
@ -80,7 +80,7 @@ public class UserInfoJWTView extends UserInfoView {
|
|||
private ConfigurationPropertiesBean config;
|
||||
|
||||
@Autowired
|
||||
private JWKSetCacheService encrypters;
|
||||
private ClientKeyCacheService encrypters;
|
||||
|
||||
@Autowired
|
||||
private SymmetricKeyJWTValidatorCacheService symmetricCacheService;
|
||||
|
@ -115,7 +115,7 @@ public class UserInfoJWTView extends UserInfoView {
|
|||
|
||||
// encrypt it to the client's key
|
||||
|
||||
JWTEncryptionAndDecryptionService encrypter = encrypters.getEncrypter(client.getJwksUri());
|
||||
JWTEncryptionAndDecryptionService encrypter = encrypters.getEncrypter(client);
|
||||
|
||||
if (encrypter != null) {
|
||||
|
||||
|
|
Loading…
Reference in New Issue