diff --git a/src/main/java/com/monkeyk/sos/web/controller/OAuthRestController.java b/src/main/java/com/monkeyk/sos/web/controller/OAuthRestController.java index 81dd2f1..8a66d8a 100644 --- a/src/main/java/com/monkeyk/sos/web/controller/OAuthRestController.java +++ b/src/main/java/com/monkeyk/sos/web/controller/OAuthRestController.java @@ -21,6 +21,7 @@ import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.http.ResponseEntity; import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.oauth2.common.OAuth2AccessToken; import org.springframework.security.oauth2.common.exceptions.*; import org.springframework.security.oauth2.common.util.OAuth2Utils; @@ -68,6 +69,9 @@ public class OAuthRestController implements InitializingBean, ApplicationContext @Autowired private AuthorizationCodeServices authorizationCodeServices; + @Autowired + private PasswordEncoder passwordEncoder; + private AuthenticationManager authenticationManager; private OAuth2RequestFactory oAuth2RequestFactory; @@ -84,6 +88,16 @@ public class OAuthRestController implements InitializingBean, ApplicationContext String clientId = getClientId(parameters); ClientDetails authenticatedClient = clientDetailsService.loadClientByClientId(clientId); + //validate client_secret + String clientSecret = getClientSecret(parameters); + if (clientSecret == null || clientSecret.equals("")) { + throw new InvalidClientException("Bad client credentials"); + } else { + if (!this.passwordEncoder.matches(clientSecret, authenticatedClient.getClientSecret())) { + throw new InvalidClientException("Bad client credentials"); + } + } + TokenRequest tokenRequest = oAuth2RequestFactory.createTokenRequest(parameters, authenticatedClient); if (clientId != null && !clientId.equals("")) { @@ -96,9 +110,7 @@ public class OAuthRestController implements InitializingBean, ApplicationContext } } - if (authenticatedClient != null) { - oAuth2RequestValidator.validateScope(tokenRequest, authenticatedClient); - } + oAuth2RequestValidator.validateScope(tokenRequest, authenticatedClient); final String grantType = tokenRequest.getGrantType(); if (!StringUtils.hasText(grantType)) { @@ -169,20 +181,24 @@ public class OAuthRestController implements InitializingBean, ApplicationContext } - private boolean isRefreshTokenRequest(Map parameters) { - return "refresh_token".equals(parameters.get("grant_type")) && parameters.get("refresh_token") != null; + return "refresh_token".equals(parameters.get(OAuth2Utils.GRANT_TYPE)) && parameters.get("refresh_token") != null; } private boolean isAuthCodeRequest(Map parameters) { - return "authorization_code".equals(parameters.get("grant_type")) && parameters.get("code") != null; + return "authorization_code".equals(parameters.get(OAuth2Utils.GRANT_TYPE)) && parameters.get("code") != null; } protected String getClientId(Map parameters) { - return parameters.get("client_id"); + return parameters.get(OAuth2Utils.CLIENT_ID); } + protected String getClientSecret(Map parameters) { + return parameters.get("client_secret"); + } + + private AuthenticationManager getAuthenticationManager() { return this.authenticationManager; } @@ -193,6 +209,8 @@ public class OAuthRestController implements InitializingBean, ApplicationContext Assert.state(clientDetailsService != null, "ClientDetailsService must be provided"); Assert.state(authenticationManager != null, "AuthenticationManager must be provided"); + Assert.notNull(this.passwordEncoder, "PasswordEncoder is null"); + oAuth2RequestFactory = new DefaultOAuth2RequestFactory(clientDetailsService); }