@ -16,16 +16,16 @@
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * /
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * /
package org.mitre.openid.connect ;
package org.mitre.openid.connect ;
import java.security.NoSuchAlgorithmException ;
import java.security.spec.InvalidKeySpecException ;
import java.text.ParseException ;
import java.text.ParseException ;
import java.util.Collections ;
import java.util.Collections ;
import java.util.HashMap ;
import java.util.HashMap ;
import java.util.List ;
import java.util.Map ;
import java.util.Map ;
import java.util.Set ;
import java.util.Set ;
import net.minidev.json.JSONObject ;
import org.mitre.jwt.signer.service.JwtSigningAndValidationService ;
import org.mitre.jwt.signer.service.JwtSigningAndValidationService ;
import org.mitre.jwt.signer.service.impl.DefaultJwtSigningAndValidationService ;
import org.mitre.jwt.signer.service.impl.JWKSetSigningAndValidationServiceCacheService ;
import org.mitre.jwt.signer.service.impl.JWKSetSigningAndValidationServiceCacheService ;
import org.mitre.oauth2.model.ClientDetailsEntity ;
import org.mitre.oauth2.model.ClientDetailsEntity ;
import org.mitre.oauth2.service.ClientDetailsEntityService ;
import org.mitre.oauth2.service.ClientDetailsEntityService ;
@ -35,7 +35,6 @@ import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired ;
import org.springframework.beans.factory.annotation.Autowired ;
import org.springframework.security.authentication.AuthenticationServiceException ;
import org.springframework.security.authentication.AuthenticationServiceException ;
import org.springframework.security.oauth2.common.exceptions.InvalidClientException ;
import org.springframework.security.oauth2.common.exceptions.InvalidClientException ;
import org.springframework.security.oauth2.common.exceptions.InvalidScopeException ;
import org.springframework.security.oauth2.common.util.OAuth2Utils ;
import org.springframework.security.oauth2.common.util.OAuth2Utils ;
import org.springframework.security.oauth2.provider.AuthorizationRequest ;
import org.springframework.security.oauth2.provider.AuthorizationRequest ;
import org.springframework.security.oauth2.provider.ClientDetails ;
import org.springframework.security.oauth2.provider.ClientDetails ;
@ -43,11 +42,18 @@ import org.springframework.security.oauth2.provider.DefaultOAuth2RequestFactory;
import org.springframework.security.oauth2.provider.OAuth2Request ;
import org.springframework.security.oauth2.provider.OAuth2Request ;
import org.springframework.stereotype.Component ;
import org.springframework.stereotype.Component ;
import com.google.common.base.Joiner ;
import com.google.common.base.Splitter ;
import com.google.common.base.Strings ;
import com.google.common.base.Strings ;
import com.google.common.collect.Iterables ;
import com.google.common.collect.ImmutableMap ;
import com.nimbusds.jose.util.JSONObjectUtils ;
import com.nimbusds.jose.Algorithm ;
import com.nimbusds.jose.JWSAlgorithm ;
import com.nimbusds.jose.jwk.JWK ;
import com.nimbusds.jose.jwk.OctetSequenceKey ;
import com.nimbusds.jose.jwk.Use ;
import com.nimbusds.jose.util.Base64URL ;
import com.nimbusds.jwt.JWT ;
import com.nimbusds.jwt.JWTParser ;
import com.nimbusds.jwt.PlainJWT ;
import com.nimbusds.jwt.ReadOnlyJWTClaimsSet ;
import com.nimbusds.jwt.SignedJWT ;
import com.nimbusds.jwt.SignedJWT ;
@Component ( "connectOAuth2RequestFactory" )
@Component ( "connectOAuth2RequestFactory" )
@ -129,31 +135,109 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
// parse the request object
// parse the request object
try {
try {
SignedJWT jwsObject = SignedJWT . parse ( jwtString ) ;
JWT jwt = JWTParser . parse ( jwtString ) ;
JSONObject claims = jwsObject . getPayload ( ) . toJSONObject ( ) ;
/ *
if ( jwt instanceof EncryptedJWT ) {
// TODO: it's an encrypted JWT, decrypt it and use it
} else {
// it's not encrypted...
}
* /
// TODO: check parameter consistency, move keys to constants
// TODO: check parameter consistency, move keys to constants
String clientId = JSONObjectUtils . getString ( claims , "client_id" ) ;
if ( jwt instanceof SignedJWT ) {
if ( clientId ! = null ) {
// it's a signed JWT, check the signature
parameters . put ( "client_id" , clientId ) ;
SignedJWT signedJwt = ( SignedJWT ) jwt ;
String clientId = inputParams . get ( "client_id" ) ;
if ( clientId = = null ) {
clientId = signedJwt . getJWTClaimsSet ( ) . getStringClaim ( "client_id" ) ;
}
}
ClientDetailsEntity client = clientDetailsService . loadClientByClientId ( clientId ) ;
ClientDetailsEntity client = clientDetailsService . loadClientByClientId ( clientId ) ;
if ( client = = null ) {
throw new InvalidClientException ( "Client not found: " + clientId ) ;
}
JWSAlgorithm alg = signedJwt . getHeader ( ) . getAlgorithm ( ) ;
if ( client . getRequestObjectSigningAlg ( ) ! = null ) {
if ( ! client . getRequestObjectSigningAlg ( ) . equals ( alg ) ) {
throw new AuthenticationServiceException ( "Client's registered request object signing algorithm (" + client . getRequestObjectSigningAlg ( ) . getAlgorithmName ( ) + ") does not match request object's actual algorithm (" + alg . getName ( ) + ")" ) ;
}
}
if ( alg . equals ( JWSAlgorithm . RS256 )
| | alg . equals ( JWSAlgorithm . RS384 )
| | alg . equals ( JWSAlgorithm . RS512 ) ) {
// it's RSA, need to find the JWK URI and fetch the key
if ( client . getJwksUri ( ) = = null ) {
if ( client . getJwksUri ( ) = = null ) {
throw new InvalidClientException ( "Client must have a JWKS URI registered to use request objects." ) ;
throw new InvalidClientException ( "Client must have a JWKS URI registered to use signed request objects.") ;
}
}
// check JWT signature
// check JWT signature
JwtSigningAndValidationService validator = validators . get ( client . getJwksUri ( ) ) ;
JwtSigningAndValidationService validator = validators . get ( client . getJwksUri ( ) ) ;
if ( validator = = null ) {
if ( validator = = null ) {
throw new InvalidClientException ( "Unable to create signature validator for client's JWKS URI: " + client . getJwksUri ( ) ) ;
throw new InvalidClientException ( "Unable to create signature validator for client's JWKS URI: " + client . getJwksUri ( ) ) ;
}
}
if ( ! validator . validateSignature ( jwsObject ) ) {
if ( ! validator . validateSignature ( signedJw t) ) {
throw new AuthenticationServiceException ( "Signature did not validate for presented JWT request object." ) ;
throw new AuthenticationServiceException ( "Signature did not validate for presented JWT request object." ) ;
}
}
} else if ( alg . equals ( JWSAlgorithm . HS256 )
| | alg . equals ( JWSAlgorithm . HS384 )
| | alg . equals ( JWSAlgorithm . HS512 ) ) {
// it's HMAC, we need to make a validator based on the client secret
JwtSigningAndValidationService validator = getSymmetricValidtor ( client ) ;
if ( validator = = null ) {
throw new InvalidClientException ( "Unable to create signature validator for client's secret: " + client . getClientSecret ( ) ) ;
}
if ( ! validator . validateSignature ( signedJwt ) ) {
throw new AuthenticationServiceException ( "Signature did not validate for presented JWT request object." ) ;
}
}
} else if ( jwt instanceof PlainJWT ) {
PlainJWT plainJwt = ( PlainJWT ) jwt ;
String clientId = inputParams . get ( "client_id" ) ;
if ( clientId = = null ) {
clientId = plainJwt . getJWTClaimsSet ( ) . getStringClaim ( "client_id" ) ;
}
ClientDetailsEntity client = clientDetailsService . loadClientByClientId ( clientId ) ;
if ( client = = null ) {
throw new InvalidClientException ( "Client not found: " + clientId ) ;
}
if ( client . getRequestObjectSigningAlg ( ) = = null ) {
throw new InvalidClientException ( "Client is not registered for unsigned request objects (no request_object_signing_alg registered)" ) ;
} else if ( ! client . getRequestObjectSigningAlg ( ) . getAlgorithm ( ) . equals ( Algorithm . NONE ) ) {
throw new InvalidClientException ( "Client is not registered for unsigned request objects (request_object_signing_alg is " + client . getRequestObjectSigningAlg ( ) . getAlgorithmName ( ) + ")" ) ;
}
// if we got here, we're OK, keep processing
}
/ *
/ *
* if ( in Claims ) :
* if ( in Claims ) :
@ -168,46 +252,53 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
* we don ' t care
* we don ' t care
* /
* /
String responseTypes = JSONObjectUtils . getString ( claims , "response_type" ) ;
ReadOnlyJWTClaimsSet claims = jwt . getJWTClaimsSet ( ) ;
String clientId = claims . getStringClaim ( "client_id" ) ;
if ( clientId ! = null ) {
parameters . put ( "client_id" , clientId ) ;
}
String responseTypes = claims . getStringClaim ( "response_type" ) ;
if ( responseTypes ! = null ) {
if ( responseTypes ! = null ) {
parameters . put ( "response_type" , responseTypes ) ;
parameters . put ( "response_type" , responseTypes ) ;
}
}
if ( claims . get ( "redirect_uri" ) ! = null ) {
if ( claims . get StringClaim ( "redirect_uri" ) ! = null ) {
if ( inputParams . containsKey ( "redirect_uri" ) = = false ) {
if ( inputParams . containsKey ( "redirect_uri" ) = = false ) {
parameters . put ( "redirect_uri" , JSONObjectUtils . getString ( claims , "redirect_uri" ) ) ;
parameters . put ( "redirect_uri" , claims. getStringClaim ( "redirect_uri" ) ) ;
}
}
}
}
String state = JSONObjectUtils. getString ( claims , "state" ) ;
String state = claims. getStringClaim ( "state" ) ;
if ( state ! = null ) {
if ( state ! = null ) {
if ( inputParams . containsKey ( "state" ) = = false ) {
if ( inputParams . containsKey ( "state" ) = = false ) {
parameters . put ( "state" , state ) ;
parameters . put ( "state" , state ) ;
}
}
}
}
String nonce = JSONObjectUtils. getString ( claims , "nonce" ) ;
String nonce = claims. getStringClaim ( "nonce" ) ;
if ( nonce ! = null ) {
if ( nonce ! = null ) {
if ( inputParams . containsKey ( "nonce" ) = = false ) {
if ( inputParams . containsKey ( "nonce" ) = = false ) {
parameters . put ( "nonce" , nonce ) ;
parameters . put ( "nonce" , nonce ) ;
}
}
}
}
String display = JSONObjectUtils. getString ( claims , "display" ) ;
String display = claims. getStringClaim ( "display" ) ;
if ( display ! = null ) {
if ( display ! = null ) {
if ( inputParams . containsKey ( "display" ) = = false ) {
if ( inputParams . containsKey ( "display" ) = = false ) {
parameters . put ( "display" , display ) ;
parameters . put ( "display" , display ) ;
}
}
}
}
String prompt = JSONObjectUtils. getString ( claims , "prompt" ) ;
String prompt = claims. getStringClaim ( "prompt" ) ;
if ( prompt ! = null ) {
if ( prompt ! = null ) {
if ( inputParams . containsKey ( "prompt" ) = = false ) {
if ( inputParams . containsKey ( "prompt" ) = = false ) {
parameters . put ( "prompt" , prompt ) ;
parameters . put ( "prompt" , prompt ) ;
}
}
}
}
String scope = JSONObjectUtils. getString ( claims , "scope" ) ;
String scope = claims. getStringClaim ( "scope" ) ;
if ( scope ! = null ) {
if ( scope ! = null ) {
if ( inputParams . containsKey ( "scope" ) = = false ) {
if ( inputParams . containsKey ( "scope" ) = = false ) {
parameters . put ( "scope" , scope ) ;
parameters . put ( "scope" , scope ) ;
@ -219,4 +310,42 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
return parameters ;
return parameters ;
}
}
/ * *
* Create a symmetric signing and validation service for the given client
*
* @param client
* @return
* /
private JwtSigningAndValidationService getSymmetricValidtor ( ClientDetailsEntity client ) {
if ( client = = null ) {
logger . error ( "Couldn't create symmetric validator for null client" ) ;
return null ;
}
if ( Strings . isNullOrEmpty ( client . getClientSecret ( ) ) ) {
logger . error ( "Couldn't create symmetric validator for client " + client . getClientId ( ) + " without a client secret" ) ;
return null ;
}
try {
JWK jwk = new OctetSequenceKey ( new Base64URL ( client . getClientSecret ( ) ) , Use . SIGNATURE , null , client . getClientId ( ) , null , null , null ) ;
Map < String , JWK > keys = ImmutableMap . of ( client . getClientId ( ) , jwk ) ;
JwtSigningAndValidationService service = new DefaultJwtSigningAndValidationService ( keys ) ;
return service ;
} catch ( NoSuchAlgorithmException e ) {
// TODO Auto-generated catch block
logger . error ( "Couldn't create symmetric validator for client " + client . getClientId ( ) , e ) ;
} catch ( InvalidKeySpecException e ) {
// TODO Auto-generated catch block
logger . error ( "Couldn't create symmetric validator for client " + client . getClientId ( ) , e ) ;
}
return null ;
}
}
}