fix: 🐛 Fix ACR for implicit and authorization_code flows

BREAKING CHANGE: 🧨 Database needs to be updated: `ALTER TABLE saved_user_auth DROP
source_class; ALTER TABLE saved_user_auth ADD COLUMN acr VARCHAR(1024);`
pull/1580/head
Dominik Frantisek Bucik 2021-11-19 16:14:21 +01:00
parent b4cd6a4642
commit 39bc00a3b0
No known key found for this signature in database
GPG Key ID: 25014C8DB2E7E62D
11 changed files with 99 additions and 57 deletions

View File

@ -86,6 +86,7 @@ CREATE TABLE IF NOT EXISTS authentication_holder_request_parameter (
CREATE TABLE IF NOT EXISTS saved_user_auth ( CREATE TABLE IF NOT EXISTS saved_user_auth (
id BIGINT GENERATED BY DEFAULT AS IDENTITY(START WITH 1) PRIMARY KEY, id BIGINT GENERATED BY DEFAULT AS IDENTITY(START WITH 1) PRIMARY KEY,
acr VARCHAR(1024),
name VARCHAR(1024), name VARCHAR(1024),
authenticated BOOLEAN, authenticated BOOLEAN,
source_class VARCHAR(2048) source_class VARCHAR(2048)

View File

@ -85,6 +85,7 @@ CREATE TABLE IF NOT EXISTS authentication_holder_request_parameter (
CREATE TABLE IF NOT EXISTS saved_user_auth ( CREATE TABLE IF NOT EXISTS saved_user_auth (
id BIGINT AUTO_INCREMENT PRIMARY KEY, id BIGINT AUTO_INCREMENT PRIMARY KEY,
acr VARCHAR(1024),
name VARCHAR(1024), name VARCHAR(1024),
authenticated BOOLEAN, authenticated BOOLEAN,
source_class VARCHAR(2048) source_class VARCHAR(2048)

View File

@ -86,6 +86,7 @@ CREATE TABLE IF NOT EXISTS authentication_holder_request_parameter (
CREATE TABLE IF NOT EXISTS saved_user_auth ( CREATE TABLE IF NOT EXISTS saved_user_auth (
id BIGSERIAL PRIMARY KEY, id BIGSERIAL PRIMARY KEY,
acr VARCHAR(1024),
name VARCHAR(1024), name VARCHAR(1024),
authenticated BOOLEAN, authenticated BOOLEAN,
source_class VARCHAR(2048) source_class VARCHAR(2048)

View File

@ -17,8 +17,10 @@
package cz.muni.ics.oauth2.model; package cz.muni.ics.oauth2.model;
import cz.muni.ics.oauth2.model.convert.SimpleGrantedAuthorityStringConverter; import cz.muni.ics.oauth2.model.convert.SimpleGrantedAuthorityStringConverter;
import cz.muni.ics.oidc.saml.SamlPrincipal;
import java.util.Collection; import java.util.Collection;
import java.util.HashSet; import java.util.HashSet;
import java.util.stream.Collectors;
import javax.persistence.Basic; import javax.persistence.Basic;
import javax.persistence.CollectionTable; import javax.persistence.CollectionTable;
import javax.persistence.Column; import javax.persistence.Column;
@ -32,8 +34,14 @@ import javax.persistence.Id;
import javax.persistence.JoinColumn; import javax.persistence.JoinColumn;
import javax.persistence.Table; import javax.persistence.Table;
import javax.persistence.Transient; import javax.persistence.Transient;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.opensaml.saml2.core.AuthnContext;
import org.opensaml.saml2.core.AuthnContextClassRef;
import org.opensaml.saml2.core.AuthnStatement;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.providers.ExpiringUsernameAuthenticationToken;
/** /**
* This class stands in for an original Authentication object. * This class stands in for an original Authentication object.
@ -42,6 +50,8 @@ import org.springframework.security.core.GrantedAuthority;
*/ */
@Entity @Entity
@Table(name="saved_user_auth") @Table(name="saved_user_auth")
@Slf4j
@ToString
public class SavedUserAuthentication implements Authentication { public class SavedUserAuthentication implements Authentication {
private static final long serialVersionUID = -1804249963940323488L; private static final long serialVersionUID = -1804249963940323488L;
@ -50,18 +60,21 @@ public class SavedUserAuthentication implements Authentication {
private String name; private String name;
private Collection<GrantedAuthority> authorities; private Collection<GrantedAuthority> authorities;
private boolean authenticated; private boolean authenticated;
private String sourceClass; private String acr;
public SavedUserAuthentication(Authentication src) { public SavedUserAuthentication(Authentication src) {
setName(src.getName()); setName(src.getName());
setAuthorities(new HashSet<>(src.getAuthorities())); setAuthorities(new HashSet<>(src.getAuthorities()));
setAuthenticated(src.isAuthenticated()); setAuthenticated(src.isAuthenticated());
if (src instanceof ExpiringUsernameAuthenticationToken) {
if (src instanceof SavedUserAuthentication) { ExpiringUsernameAuthenticationToken token = (ExpiringUsernameAuthenticationToken) src;
// if we're copying in a saved auth, carry over the original class name this.acr = ((SamlPrincipal) token.getPrincipal()).getSamlCredential()
setSourceClass(((SavedUserAuthentication) src).getSourceClass()); .getAuthenticationAssertion()
} else { .getAuthnStatements().stream()
setSourceClass(src.getClass().getName()); .map(AuthnStatement::getAuthnContext)
.map(AuthnContext::getAuthnContextClassRef)
.map(AuthnContextClassRef::getAuthnContextClassRef)
.collect(Collectors.joining());
} }
} }
@ -85,6 +98,10 @@ public class SavedUserAuthentication implements Authentication {
return name; return name;
} }
public void setName(String name) {
this.name = name;
}
@Override @Override
@ElementCollection(fetch = FetchType.EAGER) @ElementCollection(fetch = FetchType.EAGER)
@CollectionTable(name="saved_user_auth_authority", joinColumns=@JoinColumn(name="owner_id")) @CollectionTable(name="saved_user_auth_authority", joinColumns=@JoinColumn(name="owner_id"))
@ -94,6 +111,32 @@ public class SavedUserAuthentication implements Authentication {
return authorities; return authorities;
} }
public void setAuthorities(Collection<GrantedAuthority> authorities) {
this.authorities = authorities;
}
@Basic
@Column(name = "acr")
public String getAcr() {
return acr;
}
public void setAcr(String acr) {
this.acr = acr;
}
@Override
@Basic
@Column(name="authenticated")
public boolean isAuthenticated() {
return authenticated;
}
@Override
public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException {
this.authenticated = isAuthenticated;
}
@Override @Override
@Transient @Transient
public Object getCredentials() { public Object getCredentials() {
@ -112,34 +155,4 @@ public class SavedUserAuthentication implements Authentication {
return getName(); return getName();
} }
@Override
@Basic
@Column(name="authenticated")
public boolean isAuthenticated() {
return authenticated;
}
@Override
public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException {
this.authenticated = isAuthenticated;
}
@Basic
@Column(name="source_class")
public String getSourceClass() {
return sourceClass;
}
public void setSourceClass(String sourceClass) {
this.sourceClass = sourceClass;
}
public void setName(String name) {
this.name = name;
}
public void setAuthorities(Collection<GrantedAuthority> authorities) {
this.authorities = authorities;
}
} }

View File

@ -22,6 +22,7 @@ import cz.muni.ics.oauth2.model.DeviceCode;
import cz.muni.ics.oauth2.service.DeviceCodeService; import cz.muni.ics.oauth2.service.DeviceCodeService;
import cz.muni.ics.oauth2.web.DeviceEndpoint; import cz.muni.ics.oauth2.web.DeviceEndpoint;
import java.util.Date; import java.util.Date;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.oauth2.common.exceptions.InvalidGrantException; import org.springframework.security.oauth2.common.exceptions.InvalidGrantException;
import org.springframework.security.oauth2.provider.ClientDetails; import org.springframework.security.oauth2.provider.ClientDetails;
@ -42,6 +43,7 @@ import org.springframework.stereotype.Component;
* *
*/ */
@Component("deviceTokenGranter") @Component("deviceTokenGranter")
@Slf4j
public class DeviceTokenGranter extends AbstractTokenGranter { public class DeviceTokenGranter extends AbstractTokenGranter {
public static final String GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code"; public static final String GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code";

View File

@ -243,8 +243,9 @@ public class DeviceEndpoint {
@PreAuthorize("hasRole('ROLE_USER')") @PreAuthorize("hasRole('ROLE_USER')")
@RequestMapping(value = "/" + USER_URL + "/approve", method = RequestMethod.POST) @RequestMapping(value = "/" + USER_URL + "/approve", method = RequestMethod.POST)
public String approveDevice(@RequestParam("user_code") String userCode, @RequestParam(value = "user_oauth_approval") Boolean approve, ModelMap model, Authentication auth, HttpSession session) { public String approveDevice(@RequestParam("user_code") String userCode,
@RequestParam(value = "user_oauth_approval") Boolean approve,
ModelMap model, Authentication auth, HttpSession session) {
AuthorizationRequest authorizationRequest = (AuthorizationRequest) session.getAttribute("authorizationRequest"); AuthorizationRequest authorizationRequest = (AuthorizationRequest) session.getAttribute("authorizationRequest");
DeviceCode dc = (DeviceCode) session.getAttribute("deviceCode"); DeviceCode dc = (DeviceCode) session.getAttribute("deviceCode");

View File

@ -30,8 +30,7 @@ public class PerunSamlAuthenticationProvider extends SAMLAuthenticationProvider
@Override @Override
protected Object getPrincipal(SAMLCredential credential, Object userDetail) { protected Object getPrincipal(SAMLCredential credential, Object userDetail) {
PerunUser user = (PerunUser) userDetail; PerunUser user = (PerunUser) userDetail;
return new User(String.valueOf(user.getId()), credential.getRemoteEntityID(), return new SamlPrincipal(user.getId(), credential, getEntitlements(credential, userDetail));
getEntitlements(credential, userDetail));
} }
@Override @Override

View File

@ -0,0 +1,27 @@
package cz.muni.ics.oidc.saml;
import java.util.Collection;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.saml.SAMLCredential;
@Getter
@Setter
@ToString
public class SamlPrincipal extends User {
private Long perunUserId;
private SAMLCredential samlCredential;
public SamlPrincipal(Long perunUserId,
SAMLCredential samlCredential,
Collection<? extends GrantedAuthority> authorities) {
super(String.valueOf(perunUserId), "[PROTECTED]", authorities);
this.perunUserId = perunUserId;
this.samlCredential = samlCredential;
}
}

View File

@ -18,9 +18,6 @@ import lombok.extern.slf4j.Slf4j;
import net.minidev.json.JSONArray; import net.minidev.json.JSONArray;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.oauth2.provider.OAuth2Request; import org.springframework.security.oauth2.provider.OAuth2Request;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
/** /**
* Modifies ID Token. * Modifies ID Token.
@ -49,8 +46,11 @@ public class PerunOIDCTokenService extends DefaultOIDCTokenService {
} }
@Override @Override
protected void addCustomIdTokenClaims(JWTClaimsSet.Builder idClaims, ClientDetailsEntity client, OAuth2Request request, protected void addCustomIdTokenClaims(JWTClaimsSet.Builder idClaims,
String sub, OAuth2AccessTokenEntity accessToken) ClientDetailsEntity client,
OAuth2Request request,
String sub,
OAuth2AccessTokenEntity accessToken)
{ {
log.debug("modifying ID token"); log.debug("modifying ID token");
String userId = accessToken.getAuthenticationHolder().getAuthentication().getName(); String userId = accessToken.getAuthenticationHolder().getAuthentication().getName();
@ -73,18 +73,17 @@ public class PerunOIDCTokenService extends DefaultOIDCTokenService {
} }
} }
String acr = getAuthnContextClass(); if (accessToken.getAuthenticationHolder() != null
if (acr != null) { && accessToken.getAuthenticationHolder().getUserAuth() != null)
log.debug("adding to ID token claim acr with value {}", acr); {
idClaims.claim("acr", acr); String acr = accessToken.getAuthenticationHolder().getUserAuth().getAcr();
if (acr != null) {
log.debug("adding to ID token claim acr with value {}", acr);
idClaims.claim("acr", acr);
}
} }
} }
private String getAuthnContextClass() {
ServletRequestAttributes attr = (ServletRequestAttributes) RequestContextHolder.currentRequestAttributes();
return (String) attr.getAttribute(SESSION_PARAM_ACR, RequestAttributes.SCOPE_SESSION);
}
/** /**
* Converts claim values from com.google.gson.JsonElement to net.minidev.json.JSONObject or primitive value * Converts claim values from com.google.gson.JsonElement to net.minidev.json.JSONObject or primitive value
* *

View File

@ -66,7 +66,6 @@ import org.springframework.stereotype.Service;
* @author Amanda Anganes * @author Amanda Anganes
* *
*/ */
@Service
@Slf4j @Slf4j
public class DefaultOIDCTokenService implements OIDCTokenService { public class DefaultOIDCTokenService implements OIDCTokenService {

View File

@ -44,7 +44,6 @@ import org.springframework.security.oauth2.provider.OAuth2Request;
import org.springframework.security.oauth2.provider.token.TokenEnhancer; import org.springframework.security.oauth2.provider.token.TokenEnhancer;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@Service
@Slf4j @Slf4j
public class ConnectTokenEnhancer implements TokenEnhancer { public class ConnectTokenEnhancer implements TokenEnhancer {