diff --git a/openid-connect-client/src/main/java/org/mitre/openid/connect/client/OIDCAuthenticationFilter.java b/openid-connect-client/src/main/java/org/mitre/openid/connect/client/OIDCAuthenticationFilter.java index bbb1b1c01..4c27b0a72 100644 --- a/openid-connect-client/src/main/java/org/mitre/openid/connect/client/OIDCAuthenticationFilter.java +++ b/openid-connect-client/src/main/java/org/mitre/openid/connect/client/OIDCAuthenticationFilter.java @@ -23,6 +23,9 @@ import static org.mitre.oauth2.model.ClientDetailsEntity.AuthMethod.SECRET_JWT; import java.io.IOException; import java.math.BigInteger; import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.text.ParseException; import java.util.Date; @@ -40,6 +43,7 @@ import org.apache.http.impl.client.HttpClientBuilder; import org.mitre.jwt.signer.service.JWTSigningAndValidationService; import org.mitre.jwt.signer.service.impl.JWKSetCacheService; import org.mitre.jwt.signer.service.impl.SymmetricKeyJWTValidatorCacheService; +import org.mitre.oauth2.model.PKCEAlgorithm; import org.mitre.oauth2.model.RegisteredClient; import org.mitre.openid.connect.client.model.IssuerServiceResponse; import org.mitre.openid.connect.client.service.AuthRequestOptionsService; @@ -75,6 +79,7 @@ import com.nimbusds.jose.Algorithm; import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSHeader; import com.nimbusds.jose.util.Base64; +import com.nimbusds.jose.util.Base64URL; import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTParser; @@ -90,10 +95,11 @@ import com.nimbusds.jwt.SignedJWT; public class OIDCAuthenticationFilter extends AbstractAuthenticationProcessingFilter { protected final static String REDIRECT_URI_SESION_VARIABLE = "redirect_uri"; + protected final static String CODE_VERIFIER_SESSION_VARIABLE = "code_verifier"; protected final static String STATE_SESSION_VARIABLE = "state"; protected final static String NONCE_SESSION_VARIABLE = "nonce"; protected final static String ISSUER_SESSION_VARIABLE = "issuer"; - protected static final String TARGET_SESSION_VARIABLE = "target"; + protected final static String TARGET_SESSION_VARIABLE = "target"; protected final static int HTTP_SOCKET_TIMEOUT = 30000; public final static String FILTER_PROCESSES_URL = "/openid_connect_login"; @@ -262,6 +268,26 @@ public class OIDCAuthenticationFilter extends AbstractAuthenticationProcessingFi String state = createState(session); Map options = authOptions.getOptions(serverConfig, clientConfig, request); + + // if we're using PKCE, handle the challenge here + if (clientConfig.getCodeChallengeMethod() != null) { + String codeVerifier = createCodeVerifier(session); + options.put("code_challenge_method", clientConfig.getCodeChallengeMethod().getName()); + if (clientConfig.getCodeChallengeMethod().equals(PKCEAlgorithm.plain)) { + options.put("code_challenge", codeVerifier); + } else if (clientConfig.getCodeChallengeMethod().equals(PKCEAlgorithm.S256)) { + try { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + String hash = Base64URL.encode(digest.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII))).toString(); + options.put("code_challenge", hash); + } catch (NoSuchAlgorithmException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + + + } + } String authRequest = authRequestBuilder.buildAuthRequestUrl(serverConfig, clientConfig, redirectUri, nonce, state, options, issResp.getLoginHint()); @@ -302,6 +328,11 @@ public class OIDCAuthenticationFilter extends AbstractAuthenticationProcessingFi form.add("grant_type", "authorization_code"); form.add("code", authorizationCode); form.setAll(authOptions.getTokenOptions(serverConfig, clientConfig, request)); + + String codeVerifier = getStoredCodeVerifier(session); + if (codeVerifier != null) { + form.add("code_verifier", codeVerifier); + } String redirectUri = getStoredSessionString(session, REDIRECT_URI_SESION_VARIABLE); if (redirectUri != null) { @@ -675,6 +706,26 @@ public class OIDCAuthenticationFilter extends AbstractAuthenticationProcessingFi protected static String getStoredState(HttpSession session) { return getStoredSessionString(session, STATE_SESSION_VARIABLE); } + + /** + * Create a random code challenge and store it in the session + * @param session + * @return + */ + protected static String createCodeVerifier(HttpSession session) { + String challenge = new BigInteger(50, new SecureRandom()).toString(16); + session.setAttribute(CODE_VERIFIER_SESSION_VARIABLE, challenge); + return challenge; + } + + /** + * Retrieve the stored challenge from our session + * @param session + * @return + */ + protected static String getStoredCodeVerifier(HttpSession session) { + return getStoredSessionString(session, CODE_VERIFIER_SESSION_VARIABLE); + } @Override