refactoring submodule common - part 1

pull/1580/head
Dominik František Bučík 2020-03-31 09:55:35 +02:00 committed by Dominik Frantisek Bucik
parent 7eba3c12fe
commit 6fe33c1ed7
No known key found for this signature in database
GPG Key ID: 25014C8DB2E7E62D
19 changed files with 395 additions and 637 deletions

View File

@ -1,7 +1,6 @@
language: java
dist: xenial
jdk:
- oraclejdk8
sudo: false
- openjdk8
after_success:
- bash <(curl -s https://codecov.io/bash)

View File

@ -133,8 +133,6 @@ public abstract class AbstractPageOperationTemplate<T> {
finalReport(operationsCompleted, exceptionsSwallowedCount, exceptionsSwallowedClasses);
}
/**
* method responsible for fetching
* a page of items.
@ -188,18 +186,10 @@ public abstract class AbstractPageOperationTemplate<T> {
this.swallowExceptions = swallowExceptions;
}
/**
* @return the operationName
*/
public String getOperationName() {
return operationName;
}
/**
* @param operationName the operationName to set
*/
public void setOperationName(String operationName) {
this.operationName = operationName;
}

View File

@ -47,4 +47,5 @@ public class DefaultPageCriteria implements PageCriteria {
public int getPageSize() {
return pageSize;
}
}

View File

@ -23,6 +23,8 @@ package org.mitre.data;
*/
public interface PageCriteria {
public int getPageNumber();
public int getPageSize();
int getPageNumber();
int getPageSize();
}

View File

@ -32,13 +32,9 @@ import com.google.common.base.Strings;
* Provides utility methods for normalizing and parsing URIs for use with Webfinger Discovery.
*
* @author wkim
*
*/
public class WebfingerURLNormalizer {
/**
* Logger for this class
*/
private static final Logger logger = LoggerFactory.getLogger(WebfingerURLNormalizer.class);
// pattern used to parse user input; we can't use the built-in java URI parser
@ -55,14 +51,7 @@ public class WebfingerURLNormalizer {
"$"
);
/**
* Private constructor to prevent instantiation.
*/
private WebfingerURLNormalizer() {
// intentionally blank
}
private WebfingerURLNormalizer() { }
/**
* Normalize the resource string as per OIDC Discovery.
@ -73,11 +62,10 @@ public class WebfingerURLNormalizer {
// try to parse the URI
// NOTE: we can't use the Java built-in URI class because it doesn't split the parts appropriately
if (Strings.isNullOrEmpty(identifier)) {
if (StringUtils.isEmpty(identifier)) {
logger.warn("Can't normalize null or empty URI: " + identifier);
return null; // nothing we can do
return null;
} else {
//UriComponentsBuilder builder = UriComponentsBuilder.fromUriString(identifier);
UriComponentsBuilder builder = UriComponentsBuilder.newInstance();
@ -87,15 +75,14 @@ public class WebfingerURLNormalizer {
builder.userInfo(m.group(6));
builder.host(m.group(8));
String port = m.group(10);
if (!Strings.isNullOrEmpty(port)) {
if (!StringUtils.isEmpty(port)) {
builder.port(Integer.parseInt(port));
}
builder.path(m.group(11));
builder.query(m.group(13));
builder.fragment(m.group(15)); // we throw away the hash, but this is the group it would be if we kept it
} else {
// doesn't match the pattern, throw it out
logger.warn("Parser couldn't match input: " + identifier);
logger.warn("Parser couldn't match input: {}", identifier);
return null;
}
@ -110,7 +97,6 @@ public class WebfingerURLNormalizer {
// scheme empty, userinfo is not empty, path/query/port are empty
// set to "acct" (rule 2)
builder.scheme("acct");
} else {
// scheme is empty, but rule 2 doesn't apply
// set scheme to "https" (rule 3)
@ -123,39 +109,34 @@ public class WebfingerURLNormalizer {
return builder.build();
}
}
public static String serializeURL(UriComponents uri) {
if (uri.getScheme() != null &&
(uri.getScheme().equals("acct") ||
uri.getScheme().equals("mailto") ||
uri.getScheme().equals("tel") ||
uri.getScheme().equals("device")
)) {
String scheme = uri.getScheme();
if (scheme != null &&
(scheme.equals("acct") || scheme.equals("mailto") || scheme.equals("tel") || scheme.equals("device"))) {
// serializer copied from HierarchicalUriComponents but with "//" removed
StringBuilder uriBuilder = new StringBuilder();
if (uri.getScheme() != null) {
uriBuilder.append(uri.getScheme());
uriBuilder.append(':');
}
uriBuilder.append(scheme);
uriBuilder.append(':');
if (uri.getUserInfo() != null || uri.getHost() != null) {
if (uri.getUserInfo() != null) {
uriBuilder.append(uri.getUserInfo());
String userInfo = uri.getUserInfo();
String host = uri.getHost();
if (userInfo != null || host != null) {
if (userInfo != null) {
uriBuilder.append(userInfo);
uriBuilder.append('@');
}
if (uri.getHost() != null) {
uriBuilder.append(uri.getHost());
if (host != null) {
uriBuilder.append(host);
}
if (uri.getPort() != -1) {
int port = uri.getPort();
if (port != -1) {
uriBuilder.append(':');
uriBuilder.append(uri.getPort());
uriBuilder.append(port);
}
}
@ -173,17 +154,16 @@ public class WebfingerURLNormalizer {
uriBuilder.append(query);
}
if (uri.getFragment() != null) {
String fragment = uri.getFragment();
if (fragment != null) {
uriBuilder.append('#');
uriBuilder.append(uri.getFragment());
uriBuilder.append(fragment);
}
return uriBuilder.toString();
} else {
return uri.toUriString();
}
}
}

View File

@ -15,110 +15,87 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
/**
*
*/
package org.mitre.jose.keystore;
import java.io.IOException;
import java.io.InputStreamReader;
import java.text.ParseException;
import java.util.List;
import org.springframework.core.io.Resource;
import com.google.common.base.Charsets;
import com.google.common.io.CharStreams;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import org.springframework.core.io.Resource;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.text.ParseException;
import java.util.List;
import java.util.stream.Collectors;
/**
* @author jricher
*
*/
public class JWKSetKeyStore {
private JWKSet jwkSet;
private Resource location;
public JWKSetKeyStore() {
}
public JWKSetKeyStore() { }
public JWKSetKeyStore(JWKSet jwkSet) {
this.jwkSet = jwkSet;
this.setJwkSet(jwkSet);
initializeJwkSet();
}
private void initializeJwkSet() {
if (jwkSet == null) {
if (location != null) {
if (location.exists() && location.isReadable()) {
try {
// read in the file from disk
String s = CharStreams.toString(new InputStreamReader(location.getInputStream(), Charsets.UTF_8));
// parse it into a jwkSet object
jwkSet = JWKSet.parse(s);
} catch (IOException e) {
throw new IllegalArgumentException("Key Set resource could not be read: " + location);
} catch (ParseException e) {
throw new IllegalArgumentException("Key Set resource could not be parsed: " + location); }
} else {
throw new IllegalArgumentException("Key Set resource could not be read: " + location);
}
} else {
throw new IllegalArgumentException("Key store must be initialized with at least one of a jwkSet or a location.");
}
}
}
/**
* @return the jwkSet
*/
public JWKSet getJwkSet() {
return jwkSet;
}
/**
* @param jwkSet the jwkSet to set
*/
public void setJwkSet(JWKSet jwkSet) {
if (jwkSet == null) {
throw new IllegalArgumentException("Argument cannot be null");
}
this.jwkSet = jwkSet;
initializeJwkSet();
}
/**
* @return the location
*/
public Resource getLocation() {
return location;
}
/**
* @param location the location to set
*/
public void setLocation(Resource location) {
this.location = location;
initializeJwkSet();
}
/**
* Get the list of keys in this keystore. This is a passthrough to the underlying JWK Set
*/
public List<JWK> getKeys() {
if (jwkSet == null) {
initializeJwkSet();
}
return jwkSet.getKeys();
}
private void initializeJwkSet() {
if (jwkSet != null) {
return;
} else if (location == null) {
return;
}
if (location.exists() && location.isReadable()) {
try (BufferedReader br = new BufferedReader(
new InputStreamReader(location.getInputStream(), StandardCharsets.UTF_8))
) {
String s = br.lines().collect(Collectors.joining());
jwkSet = JWKSet.parse(s);
} catch (IOException e) {
throw new IllegalArgumentException("Key Set resource could not be read: " + location);
} catch (ParseException e) {
throw new IllegalArgumentException("Key Set resource could not be parsed: " + location); }
} else {
throw new IllegalArgumentException("Key Set resource could not be read: " + location);
}
}
}

View File

@ -0,0 +1,35 @@
package org.mitre.jwt.assertion;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.text.ParseException;
public abstract class AbstractAssertionValidator implements AssertionValidator {
private static Logger logger = LoggerFactory.getLogger(AbstractAssertionValidator.class);
/**
* Extract issuer from claims present in JWT assertion.
* @param assertion JWT assertion object.
* @return Value of issuer from claims (can be null), NULL in case of error when parsing the assertion.
*/
protected String extractIssuer(JWT assertion) {
if (!(assertion instanceof SignedJWT)) {
return null;
}
JWTClaimsSet claims;
try {
claims = assertion.getJWTClaimsSet();
} catch (ParseException e) {
logger.debug("Invalid assertion claims");
return null;
}
return claims.getIssuer();
}
}

View File

@ -24,6 +24,6 @@ import com.nimbusds.jwt.JWT;
*/
public interface AssertionValidator {
public boolean isValid(JWT assertion);
boolean isValid(JWT assertion);
}

View File

@ -24,7 +24,6 @@ import com.nimbusds.jwt.JWT;
* Reject all assertions passed in.
*
* @author jricher
*
*/
public class NullAssertionValidator implements AssertionValidator {
@ -34,7 +33,6 @@ public class NullAssertionValidator implements AssertionValidator {
@Override
public boolean isValid(JWT assertion) {
return false;
}
}

View File

@ -16,8 +16,7 @@
package org.mitre.jwt.assertion.impl;
import java.text.ParseException;
import org.mitre.jwt.assertion.AbstractAssertionValidator;
import org.mitre.jwt.assertion.AssertionValidator;
import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
import org.mitre.openid.connect.config.ConfigurationPropertiesBean;
@ -26,62 +25,41 @@ import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import com.google.common.base.Strings;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import org.springframework.util.StringUtils;
/**
* Validates all assertions generated by this server
*
* @author jricher
*
*/
@Component("selfAssertionValidator")
public class SelfAssertionValidator implements AssertionValidator {
public class SelfAssertionValidator extends AbstractAssertionValidator implements AssertionValidator {
private static Logger logger = LoggerFactory.getLogger(SelfAssertionValidator.class);
@Autowired
private ConfigurationPropertiesBean config;
private final ConfigurationPropertiesBean config;
private final JWTSigningAndValidationService jwtService;
@Autowired
private JWTSigningAndValidationService jwtService;
public SelfAssertionValidator(ConfigurationPropertiesBean config, JWTSigningAndValidationService jwtService) {
this.config = config;
this.jwtService = jwtService;
}
@Override
public boolean isValid(JWT assertion) {
if (!(assertion instanceof SignedJWT)) {
// unsigned assertion
return false;
}
JWTClaimsSet claims;
try {
claims = assertion.getJWTClaimsSet();
} catch (ParseException e) {
logger.debug("Invalid assertion claims");
return false;
}
// make sure the issuer exists
if (Strings.isNullOrEmpty(claims.getIssuer())) {
String issuer = extractIssuer(assertion);
if (StringUtils.isEmpty(issuer)) {
logger.debug("No issuer for assertion, rejecting");
return false;
}
// make sure the issuer is us
if (!claims.getIssuer().equals(config.getIssuer())) {
} else if (!issuer.equals(config.getIssuer())) {
logger.debug("Issuer is not the same as this server, rejecting");
return false;
}
// validate the signature based on our public key
if (jwtService.validateSignature((SignedJWT) assertion)) {
return true;
} else {
return false;
}
return jwtService.validateSignature((SignedJWT) assertion);
}
}

View File

@ -16,28 +16,24 @@
package org.mitre.jwt.assertion.impl;
import java.text.ParseException;
import java.util.HashMap;
import java.util.Map;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.SignedJWT;
import org.mitre.jwt.assertion.AbstractAssertionValidator;
import org.mitre.jwt.assertion.AssertionValidator;
import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
import org.mitre.jwt.signer.service.impl.JWKSetCacheService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.util.StringUtils;
import com.google.common.base.Strings;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import java.util.HashMap;
import java.util.Map;
/**
* Checks to see if the assertion was signed by a particular authority available from a whitelist
* Checks to see if the assertion has been signed by a particular authority available from a whitelist
* @author jricher
*
*/
public class WhitelistedIssuerAssertionValidator implements AssertionValidator {
public class WhitelistedIssuerAssertionValidator extends AbstractAssertionValidator implements AssertionValidator {
private static Logger logger = LoggerFactory.getLogger(WhitelistedIssuerAssertionValidator.class);
@ -45,60 +41,43 @@ public class WhitelistedIssuerAssertionValidator implements AssertionValidator {
* Map of issuer -> JWKSetUri
*/
private Map<String, String> whitelist = new HashMap<>();
private JWKSetCacheService jwkCache;
/**
* @return the whitelist
*/
public Map<String, String> getWhitelist() {
return whitelist;
}
/**
* @param whitelist the whitelist to set
*/
public void setWhitelist(Map<String, String> whitelist) {
this.whitelist = whitelist;
}
@Autowired
private JWKSetCacheService jwkCache;
public JWKSetCacheService getJwkCache() {
return jwkCache;
}
public void setJwkCache(JWKSetCacheService jwkCache) {
this.jwkCache = jwkCache;
}
@Override
public boolean isValid(JWT assertion) {
if (!(assertion instanceof SignedJWT)) {
// unsigned assertion
return false;
}
JWTClaimsSet claims;
try {
claims = assertion.getJWTClaimsSet();
} catch (ParseException e) {
logger.debug("Invalid assertion claims");
return false;
}
if (Strings.isNullOrEmpty(claims.getIssuer())) {
String issuer = extractIssuer(assertion);
if (StringUtils.isEmpty(issuer)) {
logger.debug("No issuer for assertion, rejecting");
return false;
}
if (!whitelist.containsKey(claims.getIssuer())) {
} else if (!whitelist.containsKey(issuer)) {
logger.debug("Issuer is not in whitelist, rejecting");
return false;
}
String jwksUri = whitelist.get(claims.getIssuer());
JWTSigningAndValidationService validator = jwkCache.getValidator(jwksUri);
if (validator.validateSignature((SignedJWT) assertion)) {
return true;
} else {
String jwksUri = whitelist.getOrDefault(issuer, null);
if (jwksUri == null) {
return false;
}
JWTSigningAndValidationService validator = jwkCache.getValidator(jwksUri);
return validator.validateSignature((SignedJWT) assertion);
}
}

View File

@ -27,7 +27,6 @@ import com.nimbusds.jose.jwk.JWK;
/**
* @author wkim
*
*/
public interface JWTEncryptionAndDecryptionService {
@ -37,7 +36,7 @@ public interface JWTEncryptionAndDecryptionService {
* Otherwise, if JWT claims are the payload, then use the JWEObject subclass EncryptedJWT instead.
* @param jwt
*/
public void encryptJwt(JWEObject jwt);
void encryptJwt(JWEObject jwt);
/**
* Decrypts the JWT in place with the default decrypter.
@ -45,24 +44,24 @@ public interface JWTEncryptionAndDecryptionService {
* Otherwise, if JWT claims are the payload, then use the JWEObject subclass EncryptedJWT instead.
* @param jwt
*/
public void decryptJwt(JWEObject jwt);
void decryptJwt(JWEObject jwt);
/**
* Get all public keys for this service, mapped by their Key ID
*/
public Map<String, JWK> getAllPublicKeys();
Map<String, JWK> getAllPublicKeys();
/**
* Get the list of all encryption algorithms supported by this service.
* @return
* @return List of supported encryption algorithms.
*/
public Collection<JWEAlgorithm> getAllEncryptionAlgsSupported();
Collection<JWEAlgorithm> getAllEncryptionAlgsSupported();
/**
* Get the list of all encryption methods supported by this service.
* @return
* @return List of supported encryption memthods.
*/
public Collection<EncryptionMethod> getAllEncryptionEncsSupported();
Collection<EncryptionMethod> getAllEncryptionEncsSupported();
/**
* TODO add functionality for encrypting and decrypting using a specified key id.

View File

@ -32,7 +32,6 @@ import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Strings;
import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWEAlgorithm;
@ -50,43 +49,30 @@ import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.jwk.RSAKey;
import org.springframework.util.StringUtils;
/**
* @author wkim
*
*/
public class DefaultJWTEncryptionAndDecryptionService implements JWTEncryptionAndDecryptionService {
/**
* Logger for this class
*/
private static final Logger logger = LoggerFactory.getLogger(DefaultJWTEncryptionAndDecryptionService.class);
// map of identifier to encrypter
private Map<String, JWEEncrypter> encrypters = new HashMap<>();
// map of identifier to decrypter
private Map<String, JWEDecrypter> decrypters = new HashMap<>();
private String defaultEncryptionKeyId;
private String defaultDecryptionKeyId;
private JWEAlgorithm defaultAlgorithm;
// map of identifier to key
private Map<String, JWK> keys = new HashMap<>();
/**
* Build this service based on the keys given. All public keys will be used to make encrypters,
* all private keys will be used to make decrypters.
*
* @param keys
* @throws NoSuchAlgorithmException
* @throws InvalidKeySpecException
* @throws JOSEException
* @param keys Map of keys
* @throws JOSEException Javascript Object Signing and Encryption (JOSE) exception.
*/
public DefaultJWTEncryptionAndDecryptionService(Map<String, JWK> keys) throws NoSuchAlgorithmException, InvalidKeySpecException, JOSEException {
public DefaultJWTEncryptionAndDecryptionService(Map<String, JWK> keys) throws JOSEException {
this.keys = keys;
buildEncryptersAndDecrypters();
}
@ -95,16 +81,12 @@ public class DefaultJWTEncryptionAndDecryptionService implements JWTEncryptionAn
* Build this service based on the given keystore. All keys must have a key
* id ({@code kid}) field in order to be used.
*
* @param keyStore
* @throws NoSuchAlgorithmException
* @throws InvalidKeySpecException
* @throws JOSEException
* @param keyStore JWK KeyStore
* @throws JOSEException Javascript Object Signing and Encryption (JOSE) exception.
*/
public DefaultJWTEncryptionAndDecryptionService(JWKSetKeyStore keyStore) throws NoSuchAlgorithmException, InvalidKeySpecException, JOSEException {
// convert all keys in the keystore to a map based on key id
public DefaultJWTEncryptionAndDecryptionService(JWKSetKeyStore keyStore) throws JOSEException {
for (JWK key : keyStore.getKeys()) {
if (!Strings.isNullOrEmpty(key.getKeyID())) {
if (!StringUtils.isEmpty(key.getKeyID())) {
this.keys.put(key.getKeyID(), key);
} else {
throw new IllegalArgumentException("Tried to load a key from a keystore without a 'kid' field: " + key);
@ -112,25 +94,6 @@ public class DefaultJWTEncryptionAndDecryptionService implements JWTEncryptionAn
}
buildEncryptersAndDecrypters();
}
@PostConstruct
public void afterPropertiesSet() {
if (keys == null) {
throw new IllegalArgumentException("Encryption and decryption service must have at least one key configured.");
}
try {
buildEncryptersAndDecrypters();
} catch (NoSuchAlgorithmException e) {
throw new IllegalArgumentException("Encryption and decryption service could not find given algorithm.");
} catch (InvalidKeySpecException e) {
throw new IllegalArgumentException("Encryption and decryption service saw an invalid key specification.");
} catch (JOSEException e) {
throw new IllegalArgumentException("Encryption and decryption service was unable to process JOSE object.");
}
}
public String getDefaultEncryptionKeyId() {
@ -171,9 +134,19 @@ public class DefaultJWTEncryptionAndDecryptionService implements JWTEncryptionAn
this.defaultAlgorithm = defaultAlgorithm;
}
/* (non-Javadoc)
* @see org.mitre.jwt.encryption.service.JwtEncryptionAndDecryptionService#encryptJwt(com.nimbusds.jwt.EncryptedJWT)
*/
@PostConstruct
public void afterPropertiesSet() {
if (keys == null) {
throw new IllegalArgumentException("Encryption and decryption service must have at least one key configured.");
}
try {
buildEncryptersAndDecrypters();
} catch (JOSEException e) {
throw new IllegalArgumentException("Encryption and decryption service was unable to process JOSE object.");
}
}
@Override
public void encryptJwt(JWEObject jwt) {
if (getDefaultEncryptionKeyId() == null) {
@ -185,15 +158,10 @@ public class DefaultJWTEncryptionAndDecryptionService implements JWTEncryptionAn
try {
jwt.encrypt(encrypter);
} catch (JOSEException e) {
logger.error("Failed to encrypt JWT, error was: ", e);
}
}
/* (non-Javadoc)
* @see org.mitre.jwt.encryption.service.JwtEncryptionAndDecryptionService#decryptJwt(com.nimbusds.jwt.EncryptedJWT)
*/
@Override
public void decryptJwt(JWEObject jwt) {
if (getDefaultDecryptionKeyId() == null) {
@ -205,88 +173,8 @@ public class DefaultJWTEncryptionAndDecryptionService implements JWTEncryptionAn
try {
jwt.decrypt(decrypter);
} catch (JOSEException e) {
logger.error("Failed to decrypt JWT, error was: ", e);
}
}
/**
* Builds all the encrypters and decrypters for this service based on the key map.
* @throws
* @throws InvalidKeySpecException
* @throws NoSuchAlgorithmException
* @throws JOSEException
*/
private void buildEncryptersAndDecrypters() throws NoSuchAlgorithmException, InvalidKeySpecException, JOSEException {
for (Map.Entry<String, JWK> jwkEntry : keys.entrySet()) {
String id = jwkEntry.getKey();
JWK jwk = jwkEntry.getValue();
if (jwk instanceof RSAKey) {
// build RSA encrypters and decrypters
RSAEncrypter encrypter = new RSAEncrypter((RSAKey) jwk); // there should always at least be the public key
encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
encrypters.put(id, encrypter);
if (jwk.isPrivate()) { // we can decrypt!
RSADecrypter decrypter = new RSADecrypter((RSAKey) jwk);
decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
decrypters.put(id, decrypter);
} else {
logger.warn("No private key for key #" + jwk.getKeyID());
}
} else if (jwk instanceof ECKey) {
// build EC Encrypters and decrypters
ECDHEncrypter encrypter = new ECDHEncrypter((ECKey) jwk);
encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
encrypters.put(id, encrypter);
if (jwk.isPrivate()) { // we can decrypt too
ECDHDecrypter decrypter = new ECDHDecrypter((ECKey) jwk);
decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
decrypters.put(id, decrypter);
} else {
logger.warn("No private key for key # " + jwk.getKeyID());
}
} else if (jwk instanceof OctetSequenceKey) {
// build symmetric encrypters and decrypters
DirectEncrypter encrypter = new DirectEncrypter((OctetSequenceKey) jwk);
encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
DirectDecrypter decrypter = new DirectDecrypter((OctetSequenceKey) jwk);
decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
encrypters.put(id, encrypter);
decrypters.put(id, decrypter);
} else {
logger.warn("Unknown key type: " + jwk);
}
}
}
@Override
public Map<String, JWK> getAllPublicKeys() {
Map<String, JWK> pubKeys = new HashMap<>();
// pull out all public keys
for (String keyId : keys.keySet()) {
JWK key = keys.get(keyId);
JWK pub = key.toPublicJWK();
if (pub != null) {
pubKeys.put(keyId, pub);
}
}
return pubKeys;
}
@Override
@ -304,9 +192,6 @@ public class DefaultJWTEncryptionAndDecryptionService implements JWTEncryptionAn
return algs;
}
/* (non-Javadoc)
* @see org.mitre.jwt.encryption.service.JwtEncryptionAndDecryptionService#getAllEncryptionEncsSupported()
*/
@Override
public Collection<EncryptionMethod> getAllEncryptionEncsSupported() {
Set<EncryptionMethod> encs = new HashSet<>();
@ -322,5 +207,67 @@ public class DefaultJWTEncryptionAndDecryptionService implements JWTEncryptionAn
return encs;
}
@Override
public Map<String, JWK> getAllPublicKeys() {
Map<String, JWK> pubKeys = new HashMap<>();
for (String keyId : keys.keySet()) {
JWK key = keys.get(keyId);
JWK pub = key.toPublicJWK();
if (pub != null) {
pubKeys.put(keyId, pub);
}
}
return pubKeys;
}
/**
* Builds all the encrypters and decrypters for this service based on the key map.
* @throws
* @throws JOSEException
*/
private void buildEncryptersAndDecrypters() throws JOSEException {
for (Map.Entry<String, JWK> jwkEntry : keys.entrySet()) {
String id = jwkEntry.getKey();
JWK jwk = jwkEntry.getValue();
if (jwk instanceof RSAKey) {
RSAEncrypter encrypter = new RSAEncrypter((RSAKey) jwk);
encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
encrypters.put(id, encrypter);
if (jwk.isPrivate()) { // we can decrypt!
RSADecrypter decrypter = new RSADecrypter((RSAKey) jwk);
decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
decrypters.put(id, decrypter);
} else {
logger.warn("No private key for key #{}", jwk.getKeyID());
}
} else if (jwk instanceof ECKey) {
ECDHEncrypter encrypter = new ECDHEncrypter((ECKey) jwk);
encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
encrypters.put(id, encrypter);
if (jwk.isPrivate()) { // we can decrypt too
ECDHDecrypter decrypter = new ECDHDecrypter((ECKey) jwk);
decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
decrypters.put(id, decrypter);
} else {
logger.warn("No private key for key #{}", jwk.getKeyID());
}
} else if (jwk instanceof OctetSequenceKey) {
DirectEncrypter encrypter = new DirectEncrypter((OctetSequenceKey) jwk);
encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
DirectDecrypter decrypter = new DirectDecrypter((OctetSequenceKey) jwk);
decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
encrypters.put(id, encrypter);
decrypters.put(id, decrypter);
} else {
logger.warn("Unknown key type: {}", jwk);
}
}
}
}

View File

@ -30,7 +30,7 @@ public interface JWTSigningAndValidationService {
/**
* Get all public keys for this service, mapped by their Key ID
*/
public Map<String, JWK> getAllPublicKeys();
Map<String, JWK> getAllPublicKeys();
/**
* Checks the signature of the given JWT against all configured signers,
@ -41,7 +41,7 @@ public interface JWTSigningAndValidationService {
* @return true if the signature is valid, false if not
* @throws NoSuchAlgorithmException
*/
public boolean validateSignature(SignedJWT jwtString);
boolean validateSignature(SignedJWT jwtString);
/**
* Called to sign a jwt in place for a client that hasn't registered a preferred signing algorithm.
@ -51,19 +51,19 @@ public interface JWTSigningAndValidationService {
* @return the signed jwt
* @throws NoSuchAlgorithmException
*/
public void signJwt(SignedJWT jwt);
void signJwt(SignedJWT jwt);
/**
* Get the default signing algorithm for use when nothing else has been specified.
* @return
*/
public JWSAlgorithm getDefaultSigningAlgorithm();
JWSAlgorithm getDefaultSigningAlgorithm();
/**
* Get the list of all signing algorithms supported by this service.
* @return
*/
public Collection<JWSAlgorithm> getAllSigningAlgsSupported();
Collection<JWSAlgorithm> getAllSigningAlgsSupported();
/**
* Sign a jwt using the selected algorithm. The algorithm is selected using the String parameter values specified
@ -73,9 +73,9 @@ public interface JWTSigningAndValidationService {
* @param alg the name of the algorithm to use, as specified in JWS s.6
* @return the signed jwt
*/
public void signJwt(SignedJWT jwt, JWSAlgorithm alg);
void signJwt(SignedJWT jwt, JWSAlgorithm alg);
public String getDefaultSignerKeyId();
String getDefaultSignerKeyId();
/**
* TODO: method to sign a jwt using a specified algorithm and a key id

View File

@ -16,6 +16,10 @@
package org.mitre.jwt.signer.service.impl;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
@ -36,58 +40,57 @@ import com.google.common.cache.LoadingCache;
import com.google.common.util.concurrent.UncheckedExecutionException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.jwk.JWKSet;
import org.springframework.util.StringUtils;
/**
*
* Takes in a client and returns the appropriate validator or encrypter for
* that client's registered key types.
*
* @author jricher
*
*/
@Service
public class ClientKeyCacheService {
private static Logger logger = LoggerFactory.getLogger(ClientKeyCacheService.class);
@Autowired
private JWKSetCacheService jwksUriCache = new JWKSetCacheService();
@Autowired
private SymmetricKeyJWTValidatorCacheService symmetricCache = new SymmetricKeyJWTValidatorCacheService();
// cache of validators for by-value JWKs
private JWKSetCacheService jwksUriCache;
private SymmetricKeyJWTValidatorCacheService symmetricCache;
private LoadingCache<JWKSet, JWTSigningAndValidationService> jwksValidators;
// cache of encryptors for by-value JWKs
private LoadingCache<JWKSet, JWTEncryptionAndDecryptionService> jwksEncrypters;
public ClientKeyCacheService() {
@Autowired
public ClientKeyCacheService(JWKSetCacheService jwksUriCache, SymmetricKeyJWTValidatorCacheService symmetricCache) {
this.jwksValidators = CacheBuilder.newBuilder()
.expireAfterWrite(1, TimeUnit.HOURS) // expires 1 hour after fetch
.expireAfterWrite(1, TimeUnit.HOURS)
.maximumSize(100)
.build(new JWKSetVerifierBuilder());
this.jwksEncrypters = CacheBuilder.newBuilder()
.expireAfterWrite(1, TimeUnit.HOURS) // expires 1 hour after fetch
.expireAfterWrite(1, TimeUnit.HOURS)
.maximumSize(100)
.build(new JWKSetEncryptorBuilder());
if (jwksUriCache == null) {
this.jwksUriCache = new JWKSetCacheService();
} else {
this.jwksUriCache = jwksUriCache;
}
if (symmetricCache == null) {
this.symmetricCache = new SymmetricKeyJWTValidatorCacheService();
} else {
this.symmetricCache = symmetricCache;
}
}
public JWTSigningAndValidationService getValidator(ClientDetailsEntity client, JWSAlgorithm alg) {
Set<JWSAlgorithm> asymmetric = new HashSet<>(Arrays.asList(JWSAlgorithm.RS256, JWSAlgorithm.RS384,
JWSAlgorithm.RS512, JWSAlgorithm.ES256, JWSAlgorithm.ES384, JWSAlgorithm.ES512, JWSAlgorithm.PS256,
JWSAlgorithm.PS384, JWSAlgorithm.PS512));
Set<JWSAlgorithm> symmetric = new HashSet<>(Arrays.asList(JWSAlgorithm.HS256, JWSAlgorithm.HS384,
JWSAlgorithm.HS512));
try {
if (alg.equals(JWSAlgorithm.RS256)
|| alg.equals(JWSAlgorithm.RS384)
|| alg.equals(JWSAlgorithm.RS512)
|| alg.equals(JWSAlgorithm.ES256)
|| alg.equals(JWSAlgorithm.ES384)
|| alg.equals(JWSAlgorithm.ES512)
|| alg.equals(JWSAlgorithm.PS256)
|| alg.equals(JWSAlgorithm.PS384)
|| alg.equals(JWSAlgorithm.PS512)) {
// asymmetric key
if (asymmetric.contains(alg)) {
if (client.getJwks() != null) {
return jwksValidators.get(client.getJwks());
} else if (!Strings.isNullOrEmpty(client.getJwksUri())) {
@ -95,32 +98,22 @@ public class ClientKeyCacheService {
} else {
return null;
}
} else if (alg.equals(JWSAlgorithm.HS256)
|| alg.equals(JWSAlgorithm.HS384)
|| alg.equals(JWSAlgorithm.HS512)) {
// symmetric key
} else if (symmetric.contains(alg)) {
return symmetricCache.getSymmetricValidtor(client);
} else {
return null;
}
} catch (UncheckedExecutionException | ExecutionException e) {
logger.error("Problem loading client validator", e);
return null;
}
}
public JWTEncryptionAndDecryptionService getEncrypter(ClientDetailsEntity client) {
try {
if (client.getJwks() != null) {
return jwksEncrypters.get(client.getJwks());
} else if (!Strings.isNullOrEmpty(client.getJwksUri())) {
} else if (!StringUtils.isEmpty(client.getJwksUri())) {
return jwksUriCache.getEncrypter(client.getJwksUri());
} else {
return null;
@ -129,27 +122,20 @@ public class ClientKeyCacheService {
logger.error("Problem loading client encrypter", e);
return null;
}
}
private class JWKSetEncryptorBuilder extends CacheLoader<JWKSet, JWTEncryptionAndDecryptionService> {
private static class JWKSetEncryptorBuilder extends CacheLoader<JWKSet, JWTEncryptionAndDecryptionService> {
@Override
public JWTEncryptionAndDecryptionService load(JWKSet key) throws Exception {
return new DefaultJWTEncryptionAndDecryptionService(new JWKSetKeyStore(key));
}
}
private class JWKSetVerifierBuilder extends CacheLoader<JWKSet, JWTSigningAndValidationService> {
private static class JWKSetVerifierBuilder extends CacheLoader<JWKSet, JWTSigningAndValidationService> {
@Override
public JWTSigningAndValidationService load(JWKSet key) throws Exception {
return new DefaultJWTSigningAndValidationService(new JWKSetKeyStore(key));
}
}
}

View File

@ -17,23 +17,9 @@
*******************************************************************************/
package org.mitre.jwt.signer.service.impl;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import org.mitre.jose.keystore.JWKSetKeyStore;
import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Strings;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSProvider;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.crypto.ECDSASigner;
@ -47,40 +33,36 @@ import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jwt.SignedJWT;
import org.mitre.jose.keystore.JWKSetKeyStore;
import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StringUtils;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
public class DefaultJWTSigningAndValidationService implements JWTSigningAndValidationService {
// map of identifier to signer
private Map<String, JWSSigner> signers = new HashMap<>();
// map of identifier to verifier
private Map<String, JWSVerifier> verifiers = new HashMap<>();
/**
* Logger for this class
*/
private static final Logger logger = LoggerFactory.getLogger(DefaultJWTSigningAndValidationService.class);
private Map<String, JWSSigner> signers = new HashMap<>();
private Map<String, JWSVerifier> verifiers = new HashMap<>();
private String defaultSignerKeyId;
private JWSAlgorithm defaultAlgorithm;
// map of identifier to key
private Map<String, JWK> keys = new HashMap<>();
/**
* Build this service based on the keys given. All public keys will be used
* to make verifiers, all private keys will be used to make signers.
*
* @param keys
* A map of key identifier to key
*
* @throws InvalidKeySpecException
* If the keys in the JWKs are not valid
* @throws NoSuchAlgorithmException
* If there is no appropriate algorithm to tie the keys to.
* @param keys A map of key identifier to key.
*/
public DefaultJWTSigningAndValidationService(Map<String, JWK> keys) throws NoSuchAlgorithmException, InvalidKeySpecException {
public DefaultJWTSigningAndValidationService(Map<String, JWK> keys) {
this.keys = keys;
buildSignersAndVerifiers();
}
@ -89,23 +71,14 @@ public class DefaultJWTSigningAndValidationService implements JWTSigningAndValid
* Build this service based on the given keystore. All keys must have a key
* id ({@code kid}) field in order to be used.
*
* @param keyStore
* the keystore to load all keys from
*
* @throws InvalidKeySpecException
* If the keys in the JWKs are not valid
* @throws NoSuchAlgorithmException
* If there is no appropriate algorithm to tie the keys to.
* @param keyStore The keystore to load all keys from.
*/
public DefaultJWTSigningAndValidationService(JWKSetKeyStore keyStore) throws NoSuchAlgorithmException, InvalidKeySpecException {
// convert all keys in the keystore to a map based on key id
public DefaultJWTSigningAndValidationService(JWKSetKeyStore keyStore) {
if (keyStore!= null && keyStore.getJwkSet() != null) {
for (JWK key : keyStore.getKeys()) {
if (!Strings.isNullOrEmpty(key.getKeyID())) {
// use the key ID that's built into the key itself
if (!StringUtils.isEmpty(key.getKeyID())) {
this.keys.put(key.getKeyID(), key);
} else {
// create a random key id
String fakeKid = UUID.randomUUID().toString();
this.keys.put(fakeKid, key);
}
@ -114,25 +87,15 @@ public class DefaultJWTSigningAndValidationService implements JWTSigningAndValid
buildSignersAndVerifiers();
}
/**
* @return the defaultSignerKeyId
*/
@Override
public String getDefaultSignerKeyId() {
return defaultSignerKeyId;
}
/**
* @param defaultSignerKeyId the defaultSignerKeyId to set
*/
public void setDefaultSignerKeyId(String defaultSignerId) {
this.defaultSignerKeyId = defaultSignerId;
}
/**
* @return
*/
@Override
public JWSAlgorithm getDefaultSigningAlgorithm() {
return defaultAlgorithm;
@ -150,65 +113,6 @@ public class DefaultJWTSigningAndValidationService implements JWTSigningAndValid
}
}
/**
* Build all of the signers and verifiers for this based on the key map.
* @throws InvalidKeySpecException If the keys in the JWKs are not valid
* @throws NoSuchAlgorithmException If there is no appropriate algorithm to tie the keys to.
*/
private void buildSignersAndVerifiers() throws NoSuchAlgorithmException, InvalidKeySpecException {
for (Map.Entry<String, JWK> jwkEntry : keys.entrySet()) {
String id = jwkEntry.getKey();
JWK jwk = jwkEntry.getValue();
try {
if (jwk instanceof RSAKey) {
// build RSA signers & verifiers
if (jwk.isPrivate()) { // only add the signer if there's a private key
RSASSASigner signer = new RSASSASigner((RSAKey) jwk);
signers.put(id, signer);
}
RSASSAVerifier verifier = new RSASSAVerifier((RSAKey) jwk);
verifiers.put(id, verifier);
} else if (jwk instanceof ECKey) {
// build EC signers & verifiers
if (jwk.isPrivate()) {
ECDSASigner signer = new ECDSASigner((ECKey) jwk);
signers.put(id, signer);
}
ECDSAVerifier verifier = new ECDSAVerifier((ECKey) jwk);
verifiers.put(id, verifier);
} else if (jwk instanceof OctetSequenceKey) {
// build HMAC signers & verifiers
if (jwk.isPrivate()) { // technically redundant check because all HMAC keys are private
MACSigner signer = new MACSigner((OctetSequenceKey) jwk);
signers.put(id, signer);
}
MACVerifier verifier = new MACVerifier((OctetSequenceKey) jwk);
verifiers.put(id, verifier);
} else {
logger.warn("Unknown key type: " + jwk);
}
} catch (JOSEException e) {
logger.warn("Exception loading signer/verifier", e);
}
}
if (defaultSignerKeyId == null && keys.size() == 1) {
// if there's only one key, it's the default
setDefaultSignerKeyId(keys.keySet().iterator().next());
}
}
/**
* Sign a jwt in place using the configured default signer.
*/
@ -223,15 +127,12 @@ public class DefaultJWTSigningAndValidationService implements JWTSigningAndValid
try {
jwt.sign(signer);
} catch (JOSEException e) {
logger.error("Failed to sign JWT, error was: ", e);
}
}
@Override
public void signJwt(SignedJWT jwt, JWSAlgorithm alg) {
JWSSigner signer = null;
for (JWSSigner s : signers.values()) {
@ -244,31 +145,27 @@ public class DefaultJWTSigningAndValidationService implements JWTSigningAndValid
if (signer == null) {
//If we can't find an algorithm that matches, we can't sign
logger.error("No matching algirthm found for alg=" + alg);
} else {
try {
jwt.sign(signer);
} catch (JOSEException e) {
logger.error("Failed to sign JWT, error was: ", e);
}
}
try {
jwt.sign(signer);
} catch (JOSEException e) {
logger.error("Failed to sign JWT, error was: ", e);
}
}
@Override
public boolean validateSignature(SignedJWT jwt) {
for (JWSVerifier verifier : verifiers.values()) {
try {
if (jwt.verify(verifier)) {
return true;
}
} catch (JOSEException e) {
logger.error("Failed to validate signature with " + verifier + " error message: " + e.getMessage());
logger.error("Failed to validate signature with {} error message: {}", verifier, e.getMessage());
}
}
return false;
}
@ -276,36 +173,84 @@ public class DefaultJWTSigningAndValidationService implements JWTSigningAndValid
public Map<String, JWK> getAllPublicKeys() {
Map<String, JWK> pubKeys = new HashMap<>();
// pull all keys out of the verifiers if we know how
for (String keyId : keys.keySet()) {
keys.keySet().forEach(keyId -> {
JWK key = keys.get(keyId);
JWK pub = key.toPublicJWK();
if (pub != null) {
pubKeys.put(keyId, pub);
}
}
});
return pubKeys;
}
/* (non-Javadoc)
* @see org.mitre.jwt.signer.service.JwtSigningAndValidationService#getAllSigningAlgsSupported()
*/
@Override
public Collection<JWSAlgorithm> getAllSigningAlgsSupported() {
Set<JWSAlgorithm> algs = new HashSet<>();
for (JWSSigner signer : signers.values()) {
algs.addAll(signer.supportedJWSAlgorithms());
}
for (JWSVerifier verifier : verifiers.values()) {
algs.addAll(verifier.supportedJWSAlgorithms());
}
signers.values().stream().map(JWSProvider::supportedJWSAlgorithms).forEach(algs::addAll);
verifiers.values().stream().map(JWSProvider::supportedJWSAlgorithms).forEach(algs::addAll);
return algs;
}
private void buildSignersAndVerifiers() {
for (Map.Entry<String, JWK> jwkEntry : keys.entrySet()) {
String id = jwkEntry.getKey();
JWK jwk = jwkEntry.getValue();
try {
if (jwk instanceof RSAKey) {
processRSAKey(signers, verifiers, jwk, id);
} else if (jwk instanceof ECKey) {
processECKey(signers, verifiers, jwk, id);
} else if (jwk instanceof OctetSequenceKey) {
processOctetKey(signers, verifiers, jwk, id);
} else {
logger.warn("Unknown key type: {}", jwk);
}
} catch (JOSEException e) {
logger.warn("Exception loading signer/verifier", e);
}
}
if (defaultSignerKeyId == null && keys.size() == 1) {
setDefaultSignerKeyId(keys.keySet().iterator().next());
}
}
private void processOctetKey(Map<String, JWSSigner> signers, Map<String, JWSVerifier> verifiers, JWK jwk, String id)
throws JOSEException
{
if (jwk.isPrivate()) {
MACSigner signer = new MACSigner((OctetSequenceKey) jwk);
signers.put(id, signer);
}
MACVerifier verifier = new MACVerifier((OctetSequenceKey) jwk);
verifiers.put(id, verifier);
}
private void processECKey(Map<String, JWSSigner> signers, Map<String, JWSVerifier> verifiers, JWK jwk, String id)
throws JOSEException
{
if (jwk.isPrivate()) {
ECDSASigner signer = new ECDSASigner((ECKey) jwk);
signers.put(id, signer);
}
ECDSAVerifier verifier = new ECDSAVerifier((ECKey) jwk);
verifiers.put(id, verifier);
}
private void processRSAKey(Map<String, JWSSigner> signers, Map<String, JWSVerifier> verifiers, JWK jwk, String id)
throws JOSEException
{
if (jwk.isPrivate()) {
RSASSASigner signer = new RSASSASigner((RSAKey) jwk);
signers.put(id, signer);
}
RSASSAVerifier verifier = new RSASSAVerifier((RSAKey) jwk);
verifiers.put(id, verifier);
}
}

View File

@ -44,25 +44,17 @@ import com.google.gson.JsonParseException;
import com.nimbusds.jose.jwk.JWKSet;
/**
*
* Creates a caching map of JOSE signers/validators and encrypters/decryptors
* keyed on the JWK Set URI. Dynamically loads JWK Sets to create the services.
*
* @author jricher
*
*/
@Service
public class JWKSetCacheService {
/**
* Logger for this class
*/
private static final Logger logger = LoggerFactory.getLogger(JWKSetCacheService.class);
// map of jwk set uri -> signing/validation service built on the keys found in that jwk set
private LoadingCache<String, JWTSigningAndValidationService> validators;
// map of jwk set uri -> encryption/decryption service built on the keys found in that jwk set
private LoadingCache<String, JWTEncryptionAndDecryptionService> encrypters;
public JWKSetCacheService() {
@ -80,7 +72,6 @@ public class JWKSetCacheService {
* @param jwksUri
* @return
* @throws ExecutionException
* @see com.google.common.cache.Cache#get(java.lang.Object)
*/
public JWTSigningAndValidationService getValidator(String jwksUri) {
try {
@ -100,22 +91,14 @@ public class JWKSetCacheService {
}
}
/**
* @author jricher
*
*/
private class JWKSetVerifierFetcher extends CacheLoader<String, JWTSigningAndValidationService> {
private HttpComponentsClientHttpRequestFactory httpFactory;
private static class JWKSetVerifierFetcher extends CacheLoader<String, JWTSigningAndValidationService> {
private RestTemplate restTemplate;
JWKSetVerifierFetcher(HttpClient httpClient) {
this.httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
this.restTemplate = new RestTemplate(httpFactory);
}
/**
* Load the JWK Set and build the appropriate signing service.
*/
@Override
public JWTSigningAndValidationService load(String key) throws Exception {
String jsonString = restTemplate.getForObject(key, String.class);
@ -123,29 +106,18 @@ public class JWKSetCacheService {
JWKSetKeyStore keyStore = new JWKSetKeyStore(jwkSet);
JWTSigningAndValidationService service = new DefaultJWTSigningAndValidationService(keyStore);
return service;
return new DefaultJWTSigningAndValidationService(keyStore);
}
}
/**
* @author jricher
*
*/
private class JWKSetEncryptorFetcher extends CacheLoader<String, JWTEncryptionAndDecryptionService> {
private HttpComponentsClientHttpRequestFactory httpFactory;
private static class JWKSetEncryptorFetcher extends CacheLoader<String, JWTEncryptionAndDecryptionService> {
private RestTemplate restTemplate;
public JWKSetEncryptorFetcher(HttpClient httpClient) {
this.httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
this.restTemplate = new RestTemplate(httpFactory);
}
/* (non-Javadoc)
* @see com.google.common.cache.CacheLoader#load(java.lang.Object)
*/
@Override
public JWTEncryptionAndDecryptionService load(String key) throws Exception {
try {
@ -154,9 +126,7 @@ public class JWKSetCacheService {
JWKSetKeyStore keyStore = new JWKSetKeyStore(jwkSet);
JWTEncryptionAndDecryptionService service = new DefaultJWTEncryptionAndDecryptionService(keyStore);
return service;
return new DefaultJWTEncryptionAndDecryptionService(keyStore);
} catch (JsonParseException | RestClientException e) {
throw new IllegalArgumentException("Unable to load JWK Set");
}

View File

@ -15,19 +15,6 @@
*******************************************************************************/
package org.mitre.jwt.signer.service.impl;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
import org.mitre.oauth2.model.ClientDetailsEntity;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import com.google.common.base.Strings;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
@ -37,24 +24,29 @@ import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.util.Base64URL;
import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
import org.mitre.oauth2.model.ClientDetailsEntity;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
/**
* Creates and caches symmetrical validators for clients based on client secrets.
*
* @author jricher
*
*/
@Service
public class SymmetricKeyJWTValidatorCacheService {
/**
* Logger for this class
*/
private static final Logger logger = LoggerFactory.getLogger(SymmetricKeyJWTValidatorCacheService.class);
private LoadingCache<String, JWTSigningAndValidationService> validators;
public SymmetricKeyJWTValidatorCacheService() {
validators = CacheBuilder.newBuilder()
.expireAfterAccess(24, TimeUnit.HOURS)
@ -62,59 +54,38 @@ public class SymmetricKeyJWTValidatorCacheService {
.build(new SymmetricValidatorBuilder());
}
/**
* Create a symmetric signing and validation service for the given client
*
* @param client
* @return
*/
public JWTSigningAndValidationService getSymmetricValidtor(ClientDetailsEntity client) {
if (client == null) {
logger.error("Couldn't create symmetric validator for null client");
return null;
}
if (Strings.isNullOrEmpty(client.getClientSecret())) {
logger.error("Couldn't create symmetric validator for client " + client.getClientId() + " without a client secret");
if (StringUtils.isEmpty(client.getClientSecret())) {
logger.error("Couldn't create symmetric validator for client {} without a client secret", client.getClientId());
return null;
}
try {
return validators.get(client.getClientSecret());
} catch (UncheckedExecutionException ue) {
} catch (UncheckedExecutionException | ExecutionException ue) {
logger.error("Problem loading client validator", ue);
return null;
} catch (ExecutionException e) {
logger.error("Problem loading client validator", e);
return null;
}
}
public class SymmetricValidatorBuilder extends CacheLoader<String, JWTSigningAndValidationService> {
public static class SymmetricValidatorBuilder extends CacheLoader<String, JWTSigningAndValidationService> {
@Override
public JWTSigningAndValidationService load(String key) throws Exception {
try {
public JWTSigningAndValidationService load(String key) {
String id = "SYMMETRIC-KEY";
JWK jwk = new OctetSequenceKey.Builder(Base64URL.encode(key))
.keyUse(KeyUse.SIGNATURE)
.keyID(id)
.build();
Map<String, JWK> keys = ImmutableMap.of(id, jwk);
JWTSigningAndValidationService service = new DefaultJWTSigningAndValidationService(keys);
String id = "SYMMETRIC-KEY";
JWK jwk = new OctetSequenceKey.Builder(Base64URL.encode(key))
.keyUse(KeyUse.SIGNATURE)
.keyID(id)
.build();
Map<String, JWK> keys = ImmutableMap.of(id, jwk);
return service;
} catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
logger.error("Couldn't create symmetric validator for client", e);
}
throw new IllegalArgumentException("Couldn't create symmetric validator for client");
return new DefaultJWTSigningAndValidationService(keys);
}
}
}

View File

@ -36,6 +36,7 @@ import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.util.Base64URL;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
@ -121,7 +122,7 @@ public class TestJWKSetKeyStore {
assertEquals(ks.getJwkSet(), jwkSet);
JWKSetKeyStore ks_empty= new JWKSetKeyStore();
assertEquals(ks_empty.getJwkSet(), null);
assertNull(ks_empty.getJwkSet());
boolean thrown = false;
try {