From 26792d2fba721457a9756aed692588709098b3f0 Mon Sep 17 00:00:00 2001 From: Mike Derryberry Date: Tue, 10 Jul 2012 14:28:07 -0400 Subject: [PATCH] updated decryption to generate cik and cek based off of key derivation --- .../jwt/encryption/AbstractJweDecrypter.java | 45 ++++++++++++++++++ .../mitre/jwt/encryption/JwtAlgorithm.java | 17 ++++--- .../mitre/jwt/encryption/JwtDecrypter.java | 4 +- .../jwt/encryption/impl/RsaDecrypter.java | 46 ++++++++----------- .../jwt/encryption/impl/RsaEncrypter.java | 4 +- 5 files changed, 76 insertions(+), 40 deletions(-) diff --git a/openid-connect-common/src/main/java/org/mitre/jwt/encryption/AbstractJweDecrypter.java b/openid-connect-common/src/main/java/org/mitre/jwt/encryption/AbstractJweDecrypter.java index db4f85b3f..711c23590 100644 --- a/openid-connect-common/src/main/java/org/mitre/jwt/encryption/AbstractJweDecrypter.java +++ b/openid-connect-common/src/main/java/org/mitre/jwt/encryption/AbstractJweDecrypter.java @@ -1,5 +1,6 @@ package org.mitre.jwt.encryption; +import java.security.MessageDigest; import java.security.PrivateKey; import java.security.PublicKey; @@ -9,6 +10,8 @@ public abstract class AbstractJweDecrypter implements JwtDecrypter { private PublicKey publicKey; + public MessageDigest md; + public PrivateKey getPrivateKey() { return privateKey; } @@ -24,4 +27,46 @@ public abstract class AbstractJweDecrypter implements JwtDecrypter { public void setPublicKey(PublicKey publicKey) { this.publicKey = publicKey; } + + public byte[] generateContentKey(byte[] cmk, int keyDataLen, byte[] type) { + + long MAX_HASH_INPUTLEN = Long.MAX_VALUE; + long UNSIGNED_INT_MAX_VALUE = 4294967395L; + + keyDataLen = keyDataLen / 8; + byte[] key = new byte[keyDataLen]; + int hashLen = md.getDigestLength(); + int reps = keyDataLen / hashLen; + if (reps > UNSIGNED_INT_MAX_VALUE) { + throw new IllegalArgumentException("Key derivation failed"); + } + int counter = 1; + byte[] counterInBytes = intToFourBytes(counter); + if ((counterInBytes.length + cmk.length + type.length) * 8 > MAX_HASH_INPUTLEN) { + throw new IllegalArgumentException("Key derivation failed"); + } + for (int i = 0; i <= reps; i++) { + md.reset(); + md.update(intToFourBytes(i + 1)); + md.update(cmk); + md.update(type); + byte[] hash = md.digest(); + if (i < reps) { + System.arraycopy(hash, 0, key, hashLen * i, hashLen); + } else { + System.arraycopy(hash, 0, key, hashLen * i, keyDataLen % hashLen); + } + } + return key; + + } + + public byte[] intToFourBytes(int i) { + byte[] res = new byte[4]; + res[0] = (byte) (i >>> 24); + res[1] = (byte) ((i >>> 16) & 0xFF); + res[2] = (byte) ((i >>> 8) & 0xFF); + res[3] = (byte) (i & 0xFF); + return res; + } } diff --git a/openid-connect-common/src/main/java/org/mitre/jwt/encryption/JwtAlgorithm.java b/openid-connect-common/src/main/java/org/mitre/jwt/encryption/JwtAlgorithm.java index cbc491b8e..ca5c96ea3 100644 --- a/openid-connect-common/src/main/java/org/mitre/jwt/encryption/JwtAlgorithm.java +++ b/openid-connect-common/src/main/java/org/mitre/jwt/encryption/JwtAlgorithm.java @@ -6,15 +6,18 @@ public enum JwtAlgorithm { //TODO:Fill in values for each standard name // RSA - RSA1_5(""), - RSA_OAEP(""), + RSA1_5("RSA"), + RSA_OAEP("RSA"), //EC - ECDH_ES(""), + ECDH_ES("EC"), //AES - A128KW(""), - A256KW(""), - A128GCM(""), - A256GCM(""); + A128KW("AES"), + A256KW("AES"), + A128CBC("AES"), + A256CBC("AES"), + A128GCM("AES"), + A256GCM("AES"); + /** diff --git a/openid-connect-common/src/main/java/org/mitre/jwt/encryption/JwtDecrypter.java b/openid-connect-common/src/main/java/org/mitre/jwt/encryption/JwtDecrypter.java index 5f877751d..e441256f4 100644 --- a/openid-connect-common/src/main/java/org/mitre/jwt/encryption/JwtDecrypter.java +++ b/openid-connect-common/src/main/java/org/mitre/jwt/encryption/JwtDecrypter.java @@ -1,14 +1,12 @@ package org.mitre.jwt.encryption; -import java.security.Key; - import org.mitre.jwe.model.Jwe; public interface JwtDecrypter { public Jwe decrypt(String encryptedJwe); - public String decryptCipherText(Jwe jwe, Key cek); + public byte[] decryptCipherText(Jwe jwe, byte[] cek); public byte[] decryptEncryptionKey(Jwe jwe); diff --git a/openid-connect-common/src/main/java/org/mitre/jwt/encryption/impl/RsaDecrypter.java b/openid-connect-common/src/main/java/org/mitre/jwt/encryption/impl/RsaDecrypter.java index e4986c8a2..48cbf8112 100644 --- a/openid-connect-common/src/main/java/org/mitre/jwt/encryption/impl/RsaDecrypter.java +++ b/openid-connect-common/src/main/java/org/mitre/jwt/encryption/impl/RsaDecrypter.java @@ -1,20 +1,17 @@ package org.mitre.jwt.encryption.impl; import java.security.InvalidKeyException; -import java.security.Key; -import java.security.KeyPair; -import java.security.KeyPairGenerator; import java.security.NoSuchAlgorithmException; -import java.security.PrivateKey; -import java.security.PublicKey; import javax.crypto.BadPaddingException; import javax.crypto.Cipher; import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; +import javax.crypto.spec.SecretKeySpec; import org.mitre.jwe.model.Jwe; import org.mitre.jwt.encryption.AbstractJweDecrypter; +import org.mitre.jwt.encryption.AlgorithmLength; import org.mitre.jwt.signer.impl.HmacSigner; @@ -31,29 +28,23 @@ public class RsaDecrypter extends AbstractJweDecrypter { String alg = jwe.getHeader().getAlgorithm(); if(alg.equals("RS256") || alg.equals("RS384") || alg.equals("RS512")) { - - PrivateKey contentEncryptionKey = null; - PublicKey contentIntegrityKey = null; - try { - - KeyPairGenerator keyGen = KeyPairGenerator.getInstance(jwe.getHeader().getKeyDerivationFunction()); - KeyPair keyPair = keyGen.genKeyPair(); - contentEncryptionKey = keyPair.getPrivate(); - contentIntegrityKey = keyPair.getPublic(); - - } catch (NoSuchAlgorithmException e1) { - // TODO Auto-generated catch block - e1.printStackTrace(); - } - - jwe.setCiphertext(decryptCipherText(jwe, contentEncryptionKey).getBytes()); + //decrypt to get cmk to be used for cek and cik jwe.setEncryptedKey(decryptEncryptionKey(jwe)); + //generation of cek and cik + String algorithmLength = AlgorithmLength.getByName(jwe.getHeader().getEncryptionMethod()).getStandardName(); + int keyLength = Integer.parseInt(algorithmLength); + byte[] contentEncryptionKey = generateContentKey(jwe.getEncryptedKey(), keyLength, new String("Encryption").getBytes()); + byte[] contentIntegrityKey = generateContentKey(jwe.getEncryptedKey(), keyLength, new String("Integrity").getBytes()); + + //decrypt ciphertext to get claims + jwe.setCiphertext(decryptCipherText(jwe, contentEncryptionKey)); + //generate signature for decrypted signature base in order to verify that decryption worked String signature = null; try { - HmacSigner hmacSigner = new HmacSigner(contentIntegrityKey.getEncoded()); + HmacSigner hmacSigner = new HmacSigner(contentIntegrityKey); signature = hmacSigner.generateSignature(jwe.getSignatureBase()); } catch (NoSuchAlgorithmException e) { // TODO Auto-generated catch block @@ -76,15 +67,14 @@ public class RsaDecrypter extends AbstractJweDecrypter { } @Override - public String decryptCipherText(Jwe jwe, Key cek) { + public byte[] decryptCipherText(Jwe jwe, byte[] cek) { Cipher cipher; - String clearTextString = null; + byte[] clearText = null; try { cipher = Cipher.getInstance("RSA"); - cipher.init(Cipher.DECRYPT_MODE, cek); - byte[] clearText = cipher.doFinal(jwe.getCiphertext()); - clearTextString = new String(clearText); + cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(cek, "RSA")); + clearText = cipher.doFinal(jwe.getCiphertext()); } catch (NoSuchAlgorithmException e) { // TODO Auto-generated catch block @@ -103,7 +93,7 @@ public class RsaDecrypter extends AbstractJweDecrypter { e.printStackTrace(); } - return clearTextString; + return clearText; } diff --git a/openid-connect-common/src/main/java/org/mitre/jwt/encryption/impl/RsaEncrypter.java b/openid-connect-common/src/main/java/org/mitre/jwt/encryption/impl/RsaEncrypter.java index 130744a7f..4334ad174 100644 --- a/openid-connect-common/src/main/java/org/mitre/jwt/encryption/impl/RsaEncrypter.java +++ b/openid-connect-common/src/main/java/org/mitre/jwt/encryption/impl/RsaEncrypter.java @@ -115,9 +115,9 @@ public class RsaEncrypter extends AbstractJweEncrypter { Cipher cipher; try { - cipher = Cipher.getInstance(JwtAlgorithm.getByName(jwe.getHeader().getAlgorithm()).getStandardName()); + cipher = Cipher.getInstance(JwtAlgorithm.getByName(jwe.getHeader().getEncryptionMethod()).getStandardName()); - cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(contentEncryptionKey, 0, contentEncryptionKey.length, "RSA")); + cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(contentEncryptionKey, "RSA")); cipherText = cipher.doFinal(jwe.getClaims().toString().getBytes()); } catch (NoSuchAlgorithmException e) {