diff --git a/openid-connect-client/src/main/java/org/mitre/openid/connect/client/OIDCAuthenticationFilter.java b/openid-connect-client/src/main/java/org/mitre/openid/connect/client/OIDCAuthenticationFilter.java index e42bc089e..0b89ac285 100644 --- a/openid-connect-client/src/main/java/org/mitre/openid/connect/client/OIDCAuthenticationFilter.java +++ b/openid-connect-client/src/main/java/org/mitre/openid/connect/client/OIDCAuthenticationFilter.java @@ -16,21 +16,18 @@ ******************************************************************************/ package org.mitre.openid.connect.client; -import java.io.IOException; -import java.math.BigInteger; -import java.security.SecureRandom; -import java.text.ParseException; -import java.util.Date; - -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpSession; - +import com.google.common.base.Strings; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import com.nimbusds.jose.util.Base64; +import com.nimbusds.jwt.ReadOnlyJWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; import org.apache.commons.lang.StringUtils; import org.apache.http.impl.client.DefaultHttpClient; import org.mitre.jwt.signer.service.JwtSigningAndValidationService; import org.mitre.jwt.signer.service.impl.JWKSetSigningAndValidationServiceCacheService; +import org.mitre.oauth2.model.ClientDetailsEntity; import org.mitre.openid.connect.client.model.IssuerServiceResponse; import org.mitre.openid.connect.client.service.AuthRequestUrlBuilder; import org.mitre.openid.connect.client.service.ClientConfigurationService; @@ -38,6 +35,8 @@ import org.mitre.openid.connect.client.service.IssuerService; import org.mitre.openid.connect.client.service.ServerConfigurationService; import org.mitre.openid.connect.config.ServerConfiguration; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpMethod; +import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.Authentication; @@ -49,12 +48,18 @@ import org.springframework.util.MultiValueMap; import org.springframework.web.client.HttpClientErrorException; import org.springframework.web.client.RestTemplate; -import com.google.common.base.Strings; -import com.google.gson.JsonElement; -import com.google.gson.JsonObject; -import com.google.gson.JsonParser; -import com.nimbusds.jwt.ReadOnlyJWTClaimsSet; -import com.nimbusds.jwt.SignedJWT; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; +import java.io.IOException; +import java.math.BigInteger; +import java.net.URI; +import java.security.SecureRandom; +import java.text.ParseException; +import java.util.Date; + +import static org.mitre.oauth2.model.ClientDetailsEntity.AuthMethod.SECRET_BASIC; /** * OpenID Connect Authentication Filter class @@ -224,7 +229,7 @@ public class OIDCAuthenticationFilter extends AbstractAuthenticationProcessingFi // pull the configurations based on that issuer ServerConfiguration serverConfig = servers.getServerConfiguration(issuer); - ClientDetails clientConfig = clients.getClientConfiguration(serverConfig); + final ClientDetails clientConfig = clients.getClientConfiguration(serverConfig); MultiValueMap form = new LinkedMultiValueMap(); form.add("grant_type", "authorization_code"); @@ -240,20 +245,30 @@ public class OIDCAuthenticationFilter extends AbstractAuthenticationProcessingFi httpClient.getParams().setParameter("http.socket.timeout", new Integer(httpSocketTimeout)); - /* Use these for basic auth: - * - UsernamePasswordCredentials credentials = new UsernamePasswordCredentials(clientConfig.getClientId(), clientConfig.getClientSecret()); - httpClient.getCredentialsProvider().setCredentials(AuthScope.ANY, credentials); - */ - /* Alternatively, use form-based auth: - */ - form.add("client_id", clientConfig.getClientId()); - form.add("client_secret", clientConfig.getClientSecret()); - /**/ - HttpComponentsClientHttpRequestFactory factory = new HttpComponentsClientHttpRequestFactory(httpClient); - RestTemplate restTemplate = new RestTemplate(factory); + RestTemplate restTemplate; + + if(clientConfig instanceof ClientDetailsEntity && SECRET_BASIC.equals(((ClientDetailsEntity) clientConfig).getTokenEndpointAuthMethod())){ + restTemplate = new RestTemplate(factory){ + + @Override + protected ClientHttpRequest createRequest(URI url, HttpMethod method) throws IOException { + ClientHttpRequest httpRequest = super.createRequest(url, method); + httpRequest.getHeaders().add("Authorization", + String.format("Basic %s", Base64.encode(String.format("%s:%s", clientConfig.getClientId(), clientConfig.getClientSecret())) )); + + + + return httpRequest; + } + }; + }else{ //Alternatively use form based auth + restTemplate = new RestTemplate(factory); + + form.add("client_id", clientConfig.getClientId()); + form.add("client_secret", clientConfig.getClientSecret()); + } logger.debug("tokenEndpointURI = " + serverConfig.getTokenEndpointUri()); logger.debug("form = " + form);