updated hmac and rsa signer to use afterPropertiesSet(), abstract oidc auth filter now adds multiple signers to map and then picks the one it needs, and key fetcher now gets jwk

pull/124/head
Mike Derryberry 2012-07-16 11:36:00 -04:00 committed by Justin Richer
parent 8b848af0fb
commit 4deaffd686
4 changed files with 55 additions and 76 deletions

View File

@ -45,9 +45,7 @@ import org.apache.commons.codec.binary.Base64;
import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.StringUtils;
import org.apache.http.client.HttpClient; import org.apache.http.client.HttpClient;
import org.apache.http.impl.client.DefaultHttpClient; import org.apache.http.impl.client.DefaultHttpClient;
import org.mitre.openid.connect.model.IdToken; import org.mitre.jwt.signer.JwsAlgorithm;
import org.mitre.jwt.model.Jwt;
import org.mitre.jwt.model.JwtHeader;
import org.mitre.jwt.signer.JwtSigner; import org.mitre.jwt.signer.JwtSigner;
import org.mitre.jwt.signer.impl.RsaSigner; import org.mitre.jwt.signer.impl.RsaSigner;
import org.mitre.jwt.signer.service.JwtSigningAndValidationService; import org.mitre.jwt.signer.service.JwtSigningAndValidationService;
@ -486,16 +484,20 @@ public class AbstractOIDCAuthenticationFilter extends
throw new AuthenticationServiceException("Problem parsing id_token return from Token endpoint: " + e); throw new AuthenticationServiceException("Problem parsing id_token return from Token endpoint: " + e);
} }
if(jwtValidator.validateSignature(jsonRoot.getAsJsonObject().get("id_token").getAsString()) if(jwtValidator.validateSignature(jsonRoot.getAsJsonObject().get("id_token").getAsString()) == false) {
&& idToken.getClaims().getIssuer() != null throw new AuthenticationServiceException("Signature not validated");
&& idToken.getClaims().getIssuer().equals(serverConfig.getIssuer())
&& !jwtValidator.isJwtExpired(idToken)
&& jwtValidator.validateIssuedAt(idToken)){
} }
else{ if(idToken.getClaims().getIssuer() == null) {
throw new AuthenticationServiceException("Problem verifying id_token"); throw new AuthenticationServiceException("Issuer is null");
}
if(!idToken.getClaims().getIssuer().equals(serverConfig.getIssuer())){
throw new AuthenticationServiceException("Issuers do not match");
}
if(jwtValidator.isJwtExpired(idToken)) {
throw new AuthenticationServiceException("Id Token is expired");
}
if(jwtValidator.validateIssuedAt(idToken) == false) {
throw new AuthenticationServiceException("Id Token issuedAt failed");
} }
} else { } else {
@ -698,10 +700,13 @@ public class AbstractOIDCAuthenticationFilter extends
RSAPublicKey rsaKey = (RSAPublicKey)signingKey; RSAPublicKey rsaKey = (RSAPublicKey)signingKey;
// build an RSA signer // build an RSA signer
// FIXME: where do we get the algorithm name? RsaSigner signer256 = new RsaSigner(JwsAlgorithm.RS256.toString(), rsaKey, null);
RsaSigner signer = new RsaSigner("RS256", rsaKey, null); RsaSigner signer384 = new RsaSigner(JwsAlgorithm.RS384.toString(), rsaKey, null);
RsaSigner signer512 = new RsaSigner(JwsAlgorithm.RS512.toString(), rsaKey, null);
signers.put(serverConfig.getIssuer(), signer);
signers.put(serverConfig.getIssuer(), signer256);
signers.put(serverConfig.getIssuer(), signer384);
signers.put(serverConfig.getIssuer(), signer512);
} }
JwtSigningAndValidationService signingAndValidationService = new JwtSigningAndValidationServiceDefault(signers); JwtSigningAndValidationService signingAndValidationService = new JwtSigningAndValidationServiceDefault(signers);

View File

@ -177,32 +177,16 @@ public class HmacSigner extends AbstractJwtSigner implements InitializingBean {
} }
private void initializeMac() { private void initializeMac() {
// TODO: check if it's already been done if (mac == null) {
try { try {
mac = Mac.getInstance(JwsAlgorithm.getByName(super.getAlgorithm()).getStandardName()); mac = Mac.getInstance(JwsAlgorithm.getByName(super.getAlgorithm()).getStandardName());
} catch (NoSuchAlgorithmException e) { } catch (NoSuchAlgorithmException e) {
// TODO Auto-generated catch block // TODO Auto-generated catch block
e.printStackTrace(); e.printStackTrace();
}
} }
} }
// TODO: nuke and clean up
// public void initializeMacJwe(String signatureBase) {
// List<String> parts = Lists.newArrayList(Splitter.on(".").split(signatureBase));
// String header = parts.get(0);
// JsonParser parser = new JsonParser();
// JsonObject object = (JsonObject) parser.parse(header);
//
// try {
// mac = Mac.getInstance(JwsAlgorithm.getByName(object.get("int").getAsString())
// .getStandardName());
// } catch (NoSuchAlgorithmException e) {
// // TODO Auto-generated catch block
// e.printStackTrace();
// }
// }
/* /*
* (non-Javadoc) * (non-Javadoc)
* *

View File

@ -173,7 +173,7 @@ public class RsaSigner extends AbstractJwtSigner implements InitializingBean {
public String generateSignature(String signatureBase) throws NoSuchAlgorithmException { public String generateSignature(String signatureBase) throws NoSuchAlgorithmException {
String sig = null; String sig = null;
initializeSigner(); initializeSigner();
try { try {
@ -234,7 +234,9 @@ public class RsaSigner extends AbstractJwtSigner implements InitializingBean {
loadKeysFromKeystore(); loadKeysFromKeystore();
} }
signer = Signature.getInstance(JwsAlgorithm.getByName(super.getAlgorithm()).getStandardName()); if (signer == null) {
signer = Signature.getInstance(JwsAlgorithm.getByName(super.getAlgorithm()).getStandardName());
}
} }
/* /*

View File

@ -11,15 +11,12 @@ import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPublicKey; import java.security.interfaces.RSAPublicKey;
import java.security.spec.InvalidKeySpecException; import java.security.spec.InvalidKeySpecException;
import java.security.spec.RSAPublicKeySpec; import java.security.spec.RSAPublicKeySpec;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.codec.binary.Base64; import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.http.client.HttpClient; import org.apache.http.client.HttpClient;
import org.apache.http.impl.client.DefaultHttpClient; import org.apache.http.impl.client.DefaultHttpClient;
import org.mitre.jwk.model.EC;
import org.mitre.jwk.model.Jwk;
import org.mitre.jwk.model.Rsa;
import org.mitre.openid.connect.config.OIDCServerConfiguration; import org.mitre.openid.connect.config.OIDCServerConfiguration;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.authentication.AuthenticationServiceException;
@ -36,9 +33,9 @@ public class KeyFetcher {
private HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient); private HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
private RestTemplate restTemplate = new RestTemplate(httpFactory); private RestTemplate restTemplate = new RestTemplate(httpFactory);
public List<Jwk> retrieveJwk(OIDCServerConfiguration serverConfig){ private static Log logger = LogFactory.getLog(KeyFetcher.class);
List<Jwk> keys = new ArrayList<Jwk>(); public JsonArray retrieveJwk(OIDCServerConfiguration serverConfig){
String jsonString = null; String jsonString = null;
@ -54,20 +51,7 @@ public class KeyFetcher {
JsonObject json = (JsonObject) new JsonParser().parse(jsonString); JsonObject json = (JsonObject) new JsonParser().parse(jsonString);
JsonArray getArray = json.getAsJsonArray("jwk"); JsonArray getArray = json.getAsJsonArray("jwk");
for (int i = 0; i < getArray.size(); i++){ return getArray;
JsonObject object = getArray.get(i).getAsJsonObject();
String algorithm = object.get("alg").getAsString();
if (algorithm.equals("RSA")){
Rsa rsa = new Rsa(object);
keys.add(rsa);
} else {
EC ec = new EC(object);
keys.add(ec);
}
}
return keys;
} }
public PublicKey retrieveX509Key(OIDCServerConfiguration serverConfig) { public PublicKey retrieveX509Key(OIDCServerConfiguration serverConfig) {
@ -81,8 +65,7 @@ public class KeyFetcher {
X509Certificate cert = (X509Certificate) factory.generateCertificate(x509Stream); X509Certificate cert = (X509Certificate) factory.generateCertificate(x509Stream);
key = cert.getPublicKey(); key = cert.getPublicKey();
} catch (HttpClientErrorException httpClientErrorException) { } catch (HttpClientErrorException httpClientErrorException) {
// TODO: add to log instead of this logger.error(httpClientErrorException);
httpClientErrorException.printStackTrace();
} catch (CertificateException e) { } catch (CertificateException e) {
// TODO Auto-generated catch block // TODO Auto-generated catch block
e.printStackTrace(); e.printStackTrace();
@ -98,19 +81,24 @@ public class KeyFetcher {
String jwkString = restTemplate.getForObject(serverConfig.getJwkSigningUrl(), String.class); String jwkString = restTemplate.getForObject(serverConfig.getJwkSigningUrl(), String.class);
JsonObject json = (JsonObject) new JsonParser().parse(jwkString); JsonObject json = (JsonObject) new JsonParser().parse(jwkString);
JsonArray getArray = json.getAsJsonArray("keys"); JsonArray getArray = json.getAsJsonArray("keys");
JsonObject object = getArray.get(0).getAsJsonObject(); // TODO: this only does something on the first key and assumes it's RSA... for(int i = 0; i < getArray.size(); i++) {
JsonObject object = getArray.get(i).getAsJsonObject();
String algorithm = object.get("alg").getAsString();
byte[] modulusByte = Base64.decodeBase64(object.get("mod").getAsString()); if(algorithm.equals("RSA")){
BigInteger modulus = new BigInteger(modulusByte); byte[] modulusByte = Base64.decodeBase64(object.get("mod").getAsString());
byte[] exponentByte = Base64.decodeBase64(object.get("exp").getAsString()); BigInteger modulus = new BigInteger(modulusByte);
BigInteger exponent = new BigInteger(exponentByte); byte[] exponentByte = Base64.decodeBase64(object.get("exp").getAsString());
BigInteger exponent = new BigInteger(exponentByte);
RSAPublicKeySpec spec = new RSAPublicKeySpec(modulus, exponent);
KeyFactory factory = KeyFactory.getInstance("RSA"); RSAPublicKeySpec spec = new RSAPublicKeySpec(modulus, exponent);
pub = (RSAPublicKey) factory.generatePublic(spec); KeyFactory factory = KeyFactory.getInstance("RSA");
pub = (RSAPublicKey) factory.generatePublic(spec);
}
}
} catch (HttpClientErrorException httpClientErrorException) { } catch (HttpClientErrorException httpClientErrorException) {
// TODO: add to log logger.error(httpClientErrorException);
httpClientErrorException.printStackTrace();
} catch (NoSuchAlgorithmException e) { } catch (NoSuchAlgorithmException e) {
// TODO Auto-generated catch block // TODO Auto-generated catch block
e.printStackTrace(); e.printStackTrace();