Merged modules and updated dependencies

pull/1580/head
Dominik František Bučík 2020-05-29 14:14:03 +02:00 committed by Dominik Frantisek Bucik
parent e58df0e51e
commit 734cba256a
No known key found for this signature in database
GPG Key ID: 25014C8DB2E7E62D
38 changed files with 532 additions and 716 deletions

View File

@ -114,18 +114,6 @@
</exclusion> </exclusion>
</exclusions> </exclusions>
</dependency> </dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>jcl-over-slf4j</artifactId>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
</dependency>
<dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.hsqldb</groupId> <groupId>org.hsqldb</groupId>
<artifactId>hsqldb</artifactId> <artifactId>hsqldb</artifactId>
@ -139,8 +127,8 @@
<artifactId>spring-security-taglibs</artifactId> <artifactId>spring-security-taglibs</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>javax.servlet</groupId> <groupId>javax.servlet.jsp.jstl</groupId>
<artifactId>jstl</artifactId> <artifactId>jstl-api</artifactId>
</dependency> </dependency>
<dependency> <dependency>

View File

@ -40,8 +40,8 @@
</build> </build>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.mitre</groupId> <groupId>org.springframework.security.oauth</groupId>
<artifactId>openid-connect-common</artifactId> <artifactId>spring-security-oauth2</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.springframework</groupId> <groupId>org.springframework</groupId>
@ -70,13 +70,56 @@
<dependency> <dependency>
<groupId>org.eclipse.persistence</groupId> <groupId>org.eclipse.persistence</groupId>
<artifactId>org.eclipse.persistence.jpa</artifactId> <artifactId>org.eclipse.persistence.jpa</artifactId>
<scope>test</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.commons</groupId> <groupId>commons-io</groupId>
<artifactId>commons-io</artifactId> <artifactId>commons-io</artifactId>
</dependency> </dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
</dependency>
<dependency>
<groupId>com.nimbusds</groupId>
<artifactId>nimbus-jose-jwt</artifactId>
</dependency>
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcprov-jdk15on</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
<dependency>
<groupId>javax.servlet</groupId>
<artifactId>servlet-api</artifactId>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-test</artifactId>
</dependency>
</dependencies> </dependencies>
<description>OpenID Connect server libraries for Spring and Spring Security.</description> <description>OpenID Connect server libraries for Spring and Spring Security.</description>
<url /> <url />

View File

@ -33,135 +33,23 @@ public abstract class AbstractPageOperationTemplate<T> {
private static final Logger logger = LoggerFactory.getLogger(AbstractPageOperationTemplate.class); private static final Logger logger = LoggerFactory.getLogger(AbstractPageOperationTemplate.class);
private static int DEFAULT_MAX_PAGES = 1000; private static final int DEFAULT_MAX_PAGES = 1000;
private static long DEFAULT_MAX_TIME_MILLIS = 600000L; //10 Minutes private static final long DEFAULT_MAX_TIME_MILLIS = 600000L; //10 Minutes
/**
* int specifying the maximum number of
* pages which should be fetched before
* execution should terminate
*/
private int maxPages; private int maxPages;
/**
* long specifying the maximum execution time
* in milliseconds
*/
private long maxTime; private long maxTime;
/**
* boolean specifying whether or not Exceptions
* incurred performing the operation should be
* swallowed during execution default true.
*/
private boolean swallowExceptions = true; private boolean swallowExceptions = true;
private String operationName;
/**
* String that is used for logging in final tallies.
*/
private String operationName = "";
/**
* default constructor which sets the value of
* maxPages and maxTime to DEFAULT_MAX_PAGES and
* DEFAULT_MAX_TIME_MILLIS respectively
*/
public AbstractPageOperationTemplate(String operationName){ public AbstractPageOperationTemplate(String operationName){
this(DEFAULT_MAX_PAGES, DEFAULT_MAX_TIME_MILLIS, operationName); this(DEFAULT_MAX_PAGES, DEFAULT_MAX_TIME_MILLIS, operationName);
} }
/**
* Instantiates a new AbstractPageOperationTemplate with the
* given maxPages and maxTime
*
* @param maxPages the maximum number of pages to fetch.
* @param maxTime the maximum execution time.
*/
public AbstractPageOperationTemplate(int maxPages, long maxTime, String operationName){ public AbstractPageOperationTemplate(int maxPages, long maxTime, String operationName){
this.maxPages = maxPages; this.maxPages = maxPages;
this.maxTime = maxTime; this.maxTime = maxTime;
this.operationName = operationName; this.operationName = operationName;
} }
/**
* Execute the operation on each member of a page of results
* retrieved through the fetch method. the method will execute
* until either the maxPages or maxTime limit is reached or until
* the fetch method returns no more results. Exceptions thrown
* performing the operation on the item will be swallowed if the
* swallowException (default true) field is set true.
*/
public void execute(){
logger.debug("[" + getOperationName() + "] Starting execution of paged operation. maximum time: " + maxTime + ", maximum pages: " + maxPages);
long startTime = System.currentTimeMillis();
long executionTime = 0;
int i = 0;
int exceptionsSwallowedCount = 0;
int operationsCompleted = 0;
Set<String> exceptionsSwallowedClasses = new HashSet<String>();
while (i< maxPages && executionTime < maxTime){
Collection<T> page = fetchPage();
if(page == null || page.size() == 0){
break;
}
for (T item : page) {
try {
doOperation(item);
operationsCompleted++;
} catch (Exception e){
if(swallowExceptions){
exceptionsSwallowedCount++;
exceptionsSwallowedClasses.add(e.getClass().getName());
logger.debug("Swallowing exception " + e.getMessage(), e);
} else {
logger.debug("Rethrowing exception " + e.getMessage());
throw e;
}
}
}
i++;
executionTime = System.currentTimeMillis() - startTime;
}
finalReport(operationsCompleted, exceptionsSwallowedCount, exceptionsSwallowedClasses);
}
/**
* method responsible for fetching
* a page of items.
*
* @return the collection of items
*/
public abstract Collection<T> fetchPage();
/**
* method responsible for performing desired
* operation on a fetched page item.
*
* @param item the item
*/
protected abstract void doOperation(T item);
/**
* Method responsible for final report of progress.
* @return
*/
protected void finalReport(int operationsCompleted, int exceptionsSwallowedCount, Set<String> exceptionsSwallowedClasses) {
if (operationsCompleted > 0 || exceptionsSwallowedCount > 0) {
logger.info("[" + getOperationName() + "] Paged operation run: completed " + operationsCompleted + "; swallowed " + exceptionsSwallowedCount + " exceptions");
}
for(String className: exceptionsSwallowedClasses) {
logger.warn("[" + getOperationName() + "] Paged operation swallowed at least one exception of type " + className);
}
}
public int getMaxPages() { public int getMaxPages() {
return maxPages; return maxPages;
} }
@ -193,4 +81,79 @@ public abstract class AbstractPageOperationTemplate<T> {
public void setOperationName(String operationName) { public void setOperationName(String operationName) {
this.operationName = operationName; this.operationName = operationName;
} }
/**
* Execute the operation on each member of a page of results
* retrieved through the fetch method. the method will execute
* until either the maxPages or maxTime limit is reached or until
* the fetch method returns no more results. Exceptions thrown
* performing the operation on the item will be swallowed if the
* swallowException (default true) field is set true.
*/
public void execute(){
logger.debug("[{}] Starting execution of paged operation. max time: {}, max pages: {}", getOperationName(), maxTime, maxPages);
long startTime = System.currentTimeMillis();
long executionTime = 0;
int i = 0;
int exceptionsSwallowedCount = 0;
int operationsCompleted = 0;
Set<String> exceptionsSwallowedClasses = new HashSet<>();
while (i < maxPages && executionTime < maxTime){
Collection<T> page = fetchPage();
if (page == null || page.size() == 0){
break;
}
for (T item : page) {
try {
doOperation(item);
operationsCompleted++;
} catch (Exception e){
if(swallowExceptions){
exceptionsSwallowedCount++;
exceptionsSwallowedClasses.add(e.getClass().getName());
logger.debug("Swallowing exception " + e.getMessage(), e);
} else {
logger.debug("Rethrowing exception " + e.getMessage());
throw e;
}
}
}
i++;
executionTime = System.currentTimeMillis() - startTime;
}
finalReport(operationsCompleted, exceptionsSwallowedCount, exceptionsSwallowedClasses);
}
/**
* Fetch a page of items.
*
* @return the collection of items
*/
public abstract Collection<T> fetchPage();
/**
* Perform operation of fetched page of items.
*
* @param item the item
*/
protected abstract void doOperation(T item);
/**
* Method responsible for final report of progress.
*/
protected void finalReport(int operationsCompleted, int exceptionsSwallowedCount, Set<String> exceptionsSwallowedClasses) {
if (operationsCompleted > 0 || exceptionsSwallowedCount > 0) {
logger.info("[{}] Paged operation run: completed {}; swallowed {} exceptions",
getOperationName(), operationsCompleted, exceptionsSwallowedCount);
}
for(String className: exceptionsSwallowedClasses) {
logger.warn("[{}] Paged operation swallowed at least one exception of type {}", getOperationName(), className);
}
}
} }

View File

@ -26,8 +26,8 @@ public class DefaultPageCriteria implements PageCriteria {
private static final int DEFAULT_PAGE_NUMBER = 0; private static final int DEFAULT_PAGE_NUMBER = 0;
private static final int DEFAULT_PAGE_SIZE = 100; private static final int DEFAULT_PAGE_SIZE = 100;
private int pageNumber; private final int pageNumber;
private int pageSize; private final int pageSize;
public DefaultPageCriteria(){ public DefaultPageCriteria(){
this(DEFAULT_PAGE_NUMBER, DEFAULT_PAGE_SIZE); this(DEFAULT_PAGE_NUMBER, DEFAULT_PAGE_SIZE);

View File

@ -66,7 +66,6 @@ public class WebfingerURLNormalizer {
logger.warn("Can't normalize null or empty URI: " + identifier); logger.warn("Can't normalize null or empty URI: " + identifier);
return null; return null;
} else { } else {
//UriComponentsBuilder builder = UriComponentsBuilder.fromUriString(identifier);
UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); UriComponentsBuilder builder = UriComponentsBuilder.newInstance();
Matcher m = pattern.matcher(identifier); Matcher m = pattern.matcher(identifier);

View File

@ -48,30 +48,20 @@ import com.google.gson.JsonObject;
*/ */
@Component("webfingerView") @Component("webfingerView")
public class WebfingerView extends AbstractView { public class WebfingerView extends AbstractView {
/**
* Logger for this class
*/
private static final Logger logger = LoggerFactory.getLogger(WebfingerView.class); private static final Logger logger = LoggerFactory.getLogger(WebfingerView.class);
private Gson gson = new GsonBuilder() private final Gson gson = new GsonBuilder()
.setExclusionStrategies(new ExclusionStrategy() { .setExclusionStrategies(new ExclusionStrategy() {
@Override @Override
public boolean shouldSkipField(FieldAttributes f) { public boolean shouldSkipField(FieldAttributes f) {
return false; return false;
} }
@Override @Override
public boolean shouldSkipClass(Class<?> clazz) { public boolean shouldSkipClass(Class<?> clazz) {
// skip the JPA binding wrapper // skip the JPA binding wrapper
if (clazz.equals(BeanPropertyBindingResult.class)) { return clazz.equals(BeanPropertyBindingResult.class);
return true;
}
return false;
} }
}) })
.serializeNulls() .serializeNulls()
.setDateFormat("yyyy-MM-dd'T'HH:mm:ssZ") .setDateFormat("yyyy-MM-dd'T'HH:mm:ssZ")
@ -79,21 +69,17 @@ public class WebfingerView extends AbstractView {
@Override @Override
protected void renderMergedOutputModel(Map<String, Object> model, HttpServletRequest request, HttpServletResponse response) { protected void renderMergedOutputModel(Map<String, Object> model, HttpServletRequest request, HttpServletResponse response) {
response.setContentType("application/jrd+json"); response.setContentType("application/jrd+json");
HttpStatus code = (HttpStatus) model.get(HttpCodeView.CODE); HttpStatus code = (HttpStatus) model.get(HttpCodeView.CODE);
if (code == null) { if (code == null) {
code = HttpStatus.OK; // default to 200 code = HttpStatus.OK;
} }
response.setStatus(code.value()); response.setStatus(code.value());
try { try {
String resource = (String) model.get("resource");
String resource = (String)model.get("resource"); String issuer = (String) model.get("issuer");
String issuer = (String)model.get("issuer");
JsonObject obj = new JsonObject(); JsonObject obj = new JsonObject();
obj.addProperty("subject", resource); obj.addProperty("subject", resource);
@ -108,11 +94,8 @@ public class WebfingerView extends AbstractView {
Writer out = response.getWriter(); Writer out = response.getWriter();
gson.toJson(obj, out); gson.toJson(obj, out);
} catch (IOException e) { } catch (IOException e) {
logger.error("IOException in WebfingerView.java: ", e);
logger.error("IOException in JsonEntityView.java: ", e);
} }
} }

View File

@ -68,88 +68,61 @@ import com.nimbusds.jose.JWSAlgorithm;
@Controller @Controller
public class DiscoveryEndpoint { public class DiscoveryEndpoint {
private static final Logger logger = LoggerFactory.getLogger(DiscoveryEndpoint.class);
public static final String WELL_KNOWN_URL = ".well-known"; public static final String WELL_KNOWN_URL = ".well-known";
public static final String OPENID_CONFIGURATION_URL = WELL_KNOWN_URL + "/openid-configuration"; public static final String OPENID_CONFIGURATION_URL = WELL_KNOWN_URL + "/openid-configuration";
public static final String WEBFINGER_URL = WELL_KNOWN_URL + "/webfinger"; public static final String WEBFINGER_URL = WELL_KNOWN_URL + "/webfinger";
private static final String ISSUER_STRING = "http://openid.net/specs/connect/1.0/issuer";
/** private final ConfigurationPropertiesBean config;
* Logger for this class private final SystemScopeService scopeService;
*/ private final JWTSigningAndValidationService signService;
private static final Logger logger = LoggerFactory.getLogger(DiscoveryEndpoint.class); private final JWTEncryptionAndDecryptionService encService;
private final UserInfoService userService;
@Autowired
private ConfigurationPropertiesBean config;
@Autowired
private SystemScopeService scopeService;
@Autowired
private JWTSigningAndValidationService signService;
@Autowired
private JWTEncryptionAndDecryptionService encService;
@Autowired
private UserInfoService userService;
// used to map JWA algorithms objects to strings // used to map JWA algorithms objects to strings
private Function<Algorithm, String> toAlgorithmName = new Function<Algorithm, String>() { private final Function<Algorithm, String> toAlgorithmName = alg -> alg == null ? null : alg.getName();
@Override
public String apply(Algorithm alg) {
if (alg == null) {
return null;
} else {
return alg.getName();
}
}
};
@RequestMapping(value={"/" + WEBFINGER_URL}, produces = MediaType.APPLICATION_JSON_VALUE) @Autowired
public String webfinger(@RequestParam("resource") String resource, @RequestParam(value = "rel", required = false) String rel, Model model) { public DiscoveryEndpoint(UserInfoService userService, ConfigurationPropertiesBean config,
SystemScopeService scopeService, JWTSigningAndValidationService signService,
JWTEncryptionAndDecryptionService encService) {
this.userService = userService;
this.config = config;
this.scopeService = scopeService;
this.signService = signService;
this.encService = encService;
}
if (!Strings.isNullOrEmpty(rel) && !rel.equals("http://openid.net/specs/connect/1.0/issuer")) { @RequestMapping(value = '/' + WEBFINGER_URL, produces = MediaType.APPLICATION_JSON_VALUE)
logger.warn("Responding to webfinger request for non-OIDC relation: " + rel); public String webfinger(@RequestParam("resource") String resource,
@RequestParam(value = "rel", required = false) String rel,
Model model) {
if (!Strings.isNullOrEmpty(rel) && !rel.equals(ISSUER_STRING)) {
logger.warn("Responding to webfinger request for non-OIDC relation: {}", rel);
} }
if (!resource.equals(config.getIssuer())) { if (!resource.equals(config.getIssuer())) {
// it's not the issuer directly, need to check other methods // it's not the issuer directly, need to check other methods
UriComponents resourceUri = WebfingerURLNormalizer.normalizeResource(resource); UriComponents resourceUri = WebfingerURLNormalizer.normalizeResource(resource);
if (resourceUri != null if (resourceUri != null
&& resourceUri.getScheme() != null && resourceUri.getScheme() != null
&& resourceUri.getScheme().equals("acct")) { && resourceUri.getScheme().equals("acct")) {
// acct: URI (email address format) UserInfo user = extractUser(resourceUri);
// check on email addresses first
UserInfo user = userService.getByEmailAddress(resourceUri.getUserInfo() + "@" + resourceUri.getHost());
if (user == null) { if (user == null) {
// user wasn't found, see if the local part of the username matches, plus our issuer host logger.info("User not found: {}", resource);
model.addAttribute(HttpCodeView.CODE, HttpStatus.NOT_FOUND);
user = userService.getByUsername(resourceUri.getUserInfo()); // first part is the username return HttpCodeView.VIEWNAME;
if (user != null) {
// username matched, check the host component
UriComponents issuerComponents = UriComponentsBuilder.fromHttpUrl(config.getIssuer()).build();
if (!Strings.nullToEmpty(issuerComponents.getHost())
.equals(Strings.nullToEmpty(resourceUri.getHost()))) {
logger.info("Host mismatch, expected " + issuerComponents.getHost() + " got " + resourceUri.getHost());
model.addAttribute(HttpCodeView.CODE, HttpStatus.NOT_FOUND);
return HttpCodeView.VIEWNAME;
}
} else {
// if the user's still null, punt and say we didn't find them
logger.info("User not found: " + resource);
model.addAttribute(HttpCodeView.CODE, HttpStatus.NOT_FOUND);
return HttpCodeView.VIEWNAME;
}
} }
UriComponents issuerComponents = UriComponentsBuilder.fromHttpUrl(config.getIssuer()).build();
if (!Strings.nullToEmpty(issuerComponents.getHost())
.equals(Strings.nullToEmpty(resourceUri.getHost()))) {
logger.info("Host mismatch, expected " + issuerComponents.getHost() + " got " + resourceUri.getHost());
model.addAttribute(HttpCodeView.CODE, HttpStatus.NOT_FOUND);
return HttpCodeView.VIEWNAME;
}
} else { } else {
logger.info("Unknown URI format: " + resource); logger.info("Unknown URI format: " + resource);
model.addAttribute(HttpCodeView.CODE, HttpStatus.NOT_FOUND); model.addAttribute(HttpCodeView.CODE, HttpStatus.NOT_FOUND);
@ -157,13 +130,20 @@ public class DiscoveryEndpoint {
} }
} }
// if we got here, then we're good, return ourselves
model.addAttribute("resource", resource); model.addAttribute("resource", resource);
model.addAttribute("issuer", config.getIssuer()); model.addAttribute("issuer", config.getIssuer());
return "webfingerView"; return "webfingerView";
} }
private UserInfo extractUser(UriComponents resourceUri) {
UserInfo user = userService.getByEmailAddress(resourceUri.getUserInfo() + "@" + resourceUri.getHost());
if (user == null) {
user = userService.getByUsername(resourceUri.getUserInfo()); // first part is the username
}
return user;
}
@RequestMapping("/" + OPENID_CONFIGURATION_URL) @RequestMapping("/" + OPENID_CONFIGURATION_URL)
public String providerConfiguration(Model model) { public String providerConfiguration(Model model) {

View File

@ -91,11 +91,9 @@ public class JWKSetKeyStore {
throw new IllegalArgumentException("Key Set resource could not be read: " + location); throw new IllegalArgumentException("Key Set resource could not be read: " + location);
} catch (ParseException e) { } catch (ParseException e) {
throw new IllegalArgumentException("Key Set resource could not be parsed: " + location); } throw new IllegalArgumentException("Key Set resource could not be parsed: " + location); }
} else { } else {
throw new IllegalArgumentException("Key Set resource could not be read: " + location); throw new IllegalArgumentException("Key Set resource could not be read: " + location);
} }
} }
} }

View File

@ -10,7 +10,7 @@ import java.text.ParseException;
public abstract class AbstractAssertionValidator implements AssertionValidator { public abstract class AbstractAssertionValidator implements AssertionValidator {
private static Logger logger = LoggerFactory.getLogger(AbstractAssertionValidator.class); private static final Logger logger = LoggerFactory.getLogger(AbstractAssertionValidator.class);
/** /**
* Extract issuer from claims present in JWT assertion. * Extract issuer from claims present in JWT assertion.

View File

@ -37,7 +37,7 @@ import org.springframework.util.StringUtils;
@Component("selfAssertionValidator") @Component("selfAssertionValidator")
public class SelfAssertionValidator extends AbstractAssertionValidator implements AssertionValidator { public class SelfAssertionValidator extends AbstractAssertionValidator implements AssertionValidator {
private static Logger logger = LoggerFactory.getLogger(SelfAssertionValidator.class); private static final Logger logger = LoggerFactory.getLogger(SelfAssertionValidator.class);
private final ConfigurationPropertiesBean config; private final ConfigurationPropertiesBean config;
private final JWTSigningAndValidationService jwtService; private final JWTSigningAndValidationService jwtService;

View File

@ -35,12 +35,9 @@ import java.util.Map;
*/ */
public class WhitelistedIssuerAssertionValidator extends AbstractAssertionValidator implements AssertionValidator { public class WhitelistedIssuerAssertionValidator extends AbstractAssertionValidator implements AssertionValidator {
private static Logger logger = LoggerFactory.getLogger(WhitelistedIssuerAssertionValidator.class); private static final Logger logger = LoggerFactory.getLogger(WhitelistedIssuerAssertionValidator.class);
/** private Map<String, String> whitelist = new HashMap<>(); //Map of issuer -> JWKSetUri
* Map of issuer -> JWKSetUri
*/
private Map<String, String> whitelist = new HashMap<>();
private JWKSetCacheService jwkCache; private JWKSetCacheService jwkCache;
public Map<String, String> getWhitelist() { public Map<String, String> getWhitelist() {

View File

@ -17,8 +17,6 @@
*******************************************************************************/ *******************************************************************************/
package org.mitre.jwt.encryption.service.impl; package org.mitre.jwt.encryption.service.impl;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
@ -27,6 +25,7 @@ import java.util.Set;
import javax.annotation.PostConstruct; import javax.annotation.PostConstruct;
import com.nimbusds.jose.KeyLengthException;
import org.mitre.jose.keystore.JWKSetKeyStore; import org.mitre.jose.keystore.JWKSetKeyStore;
import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService; import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -58,8 +57,8 @@ public class DefaultJWTEncryptionAndDecryptionService implements JWTEncryptionAn
private static final Logger logger = LoggerFactory.getLogger(DefaultJWTEncryptionAndDecryptionService.class); private static final Logger logger = LoggerFactory.getLogger(DefaultJWTEncryptionAndDecryptionService.class);
private Map<String, JWEEncrypter> encrypters = new HashMap<>(); private final Map<String, JWEEncrypter> encrypters = new HashMap<>();
private Map<String, JWEDecrypter> decrypters = new HashMap<>(); private final Map<String, JWEDecrypter> decrypters = new HashMap<>();
private String defaultEncryptionKeyId; private String defaultEncryptionKeyId;
private String defaultDecryptionKeyId; private String defaultDecryptionKeyId;
private JWEAlgorithm defaultAlgorithm; private JWEAlgorithm defaultAlgorithm;
@ -233,41 +232,53 @@ public class DefaultJWTEncryptionAndDecryptionService implements JWTEncryptionAn
JWK jwk = jwkEntry.getValue(); JWK jwk = jwkEntry.getValue();
if (jwk instanceof RSAKey) { if (jwk instanceof RSAKey) {
RSAEncrypter encrypter = new RSAEncrypter((RSAKey) jwk); handleRSAKey(id, jwk);
encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
encrypters.put(id, encrypter);
if (jwk.isPrivate()) { // we can decrypt!
RSADecrypter decrypter = new RSADecrypter((RSAKey) jwk);
decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
decrypters.put(id, decrypter);
} else {
logger.warn("No private key for key #{}", jwk.getKeyID());
}
} else if (jwk instanceof ECKey) { } else if (jwk instanceof ECKey) {
ECDHEncrypter encrypter = new ECDHEncrypter((ECKey) jwk); handleECKey(id, jwk);
encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
encrypters.put(id, encrypter);
if (jwk.isPrivate()) { // we can decrypt too
ECDHDecrypter decrypter = new ECDHDecrypter((ECKey) jwk);
decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
decrypters.put(id, decrypter);
} else {
logger.warn("No private key for key #{}", jwk.getKeyID());
}
} else if (jwk instanceof OctetSequenceKey) { } else if (jwk instanceof OctetSequenceKey) {
DirectEncrypter encrypter = new DirectEncrypter((OctetSequenceKey) jwk); handleOctetSeqKey(id, jwk);
encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
DirectDecrypter decrypter = new DirectDecrypter((OctetSequenceKey) jwk);
decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
encrypters.put(id, encrypter);
decrypters.put(id, decrypter);
} else { } else {
logger.warn("Unknown key type: {}", jwk); logger.warn("Unknown key type: {}", jwk);
} }
} }
} }
private void handleOctetSeqKey(String id, JWK jwk) throws KeyLengthException {
DirectEncrypter encrypter = new DirectEncrypter((OctetSequenceKey) jwk);
encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
DirectDecrypter decrypter = new DirectDecrypter((OctetSequenceKey) jwk);
decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
encrypters.put(id, encrypter);
decrypters.put(id, decrypter);
}
private void handleECKey(String id, JWK jwk) throws JOSEException {
ECDHEncrypter encrypter = new ECDHEncrypter((ECKey) jwk);
encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
encrypters.put(id, encrypter);
if (jwk.isPrivate()) { // we can decrypt too
ECDHDecrypter decrypter = new ECDHDecrypter((ECKey) jwk);
decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
decrypters.put(id, decrypter);
} else {
logger.warn("No private key for key #{}", jwk.getKeyID());
}
}
private void handleRSAKey(String id, JWK jwk) throws JOSEException {
RSAEncrypter encrypter = new RSAEncrypter((RSAKey) jwk);
encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
encrypters.put(id, encrypter);
if (jwk.isPrivate()) { // we can decrypt!
RSADecrypter decrypter = new RSADecrypter((RSAKey) jwk);
decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
decrypters.put(id, decrypter);
} else {
logger.warn("No private key for key #{}", jwk.getKeyID());
}
}
} }

View File

@ -27,17 +27,17 @@ import com.nimbusds.jwt.SignedJWT;
public interface JWTSigningAndValidationService { public interface JWTSigningAndValidationService {
/**
* Get all public keys for this service, mapped by their Key ID
*/
Map<String, JWK> getAllPublicKeys(); Map<String, JWK> getAllPublicKeys();
JWSAlgorithm getDefaultSigningAlgorithm();
Collection<JWSAlgorithm> getAllSigningAlgsSupported();
/** /**
* Checks the signature of the given JWT against all configured signers, * Checks the signature of the given JWT against all configured signers,
* returns true if at least one of the signers validates it. * returns true if at least one of the signers validates it.
* *
* @param jwtString * @param jwtString the string representation of the JWT as sent on the wire
* the string representation of the JWT as sent on the wire
* @return true if the signature is valid, false if not * @return true if the signature is valid, false if not
* @throws NoSuchAlgorithmException * @throws NoSuchAlgorithmException
*/ */
@ -53,18 +53,6 @@ public interface JWTSigningAndValidationService {
*/ */
void signJwt(SignedJWT jwt); void signJwt(SignedJWT jwt);
/**
* Get the default signing algorithm for use when nothing else has been specified.
* @return
*/
JWSAlgorithm getDefaultSigningAlgorithm();
/**
* Get the list of all signing algorithms supported by this service.
* @return
*/
Collection<JWSAlgorithm> getAllSigningAlgsSupported();
/** /**
* Sign a jwt using the selected algorithm. The algorithm is selected using the String parameter values specified * Sign a jwt using the selected algorithm. The algorithm is selected using the String parameter values specified
* in the JWT spec, section 6. I.E., "HS256" means HMAC with SHA-256 and corresponds to our HmacSigner class. * in the JWT spec, section 6. I.E., "HS256" means HMAC with SHA-256 and corresponds to our HmacSigner class.

View File

@ -17,7 +17,6 @@
package org.mitre.jwt.signer.service.impl; package org.mitre.jwt.signer.service.impl;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
@ -99,7 +98,7 @@ public class ClientKeyCacheService {
return null; return null;
} }
} else if (symmetric.contains(alg)) { } else if (symmetric.contains(alg)) {
return symmetricCache.getSymmetricValidtor(client); return symmetricCache.getSymmetricValidator(client);
} else { } else {
return null; return null;
} }

View File

@ -18,6 +18,7 @@
package org.mitre.jwt.signer.service.impl; package org.mitre.jwt.signer.service.impl;
import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSProvider; import com.nimbusds.jose.JWSProvider;
import com.nimbusds.jose.JWSSigner; import com.nimbusds.jose.JWSSigner;
@ -50,8 +51,9 @@ public class DefaultJWTSigningAndValidationService implements JWTSigningAndValid
private static final Logger logger = LoggerFactory.getLogger(DefaultJWTSigningAndValidationService.class); private static final Logger logger = LoggerFactory.getLogger(DefaultJWTSigningAndValidationService.class);
private Map<String, JWSSigner> signers = new HashMap<>(); private final Map<String, JWSSigner> signers = new HashMap<>();
private Map<String, JWSVerifier> verifiers = new HashMap<>(); private final Map<String, JWSVerifier> verifiers = new HashMap<>();
private String defaultSignerKeyId; private String defaultSignerKeyId;
private JWSAlgorithm defaultAlgorithm; private JWSAlgorithm defaultAlgorithm;
private Map<String, JWK> keys = new HashMap<>(); private Map<String, JWK> keys = new HashMap<>();
@ -113,9 +115,6 @@ public class DefaultJWTSigningAndValidationService implements JWTSigningAndValid
} }
} }
/**
* Sign a jwt in place using the configured default signer.
*/
@Override @Override
public void signJwt(SignedJWT jwt) { public void signJwt(SignedJWT jwt) {
if (getDefaultSignerKeyId() == null) { if (getDefaultSignerKeyId() == null) {
@ -143,8 +142,7 @@ public class DefaultJWTSigningAndValidationService implements JWTSigningAndValid
} }
if (signer == null) { if (signer == null) {
//If we can't find an algorithm that matches, we can't sign logger.error("No matching algorithm found for alg={}", alg);
logger.error("No matching algirthm found for alg=" + alg);
} else { } else {
try { try {
jwt.sign(signer); jwt.sign(signer);
@ -158,9 +156,7 @@ public class DefaultJWTSigningAndValidationService implements JWTSigningAndValid
public boolean validateSignature(SignedJWT jwt) { public boolean validateSignature(SignedJWT jwt) {
for (JWSVerifier verifier : verifiers.values()) { for (JWSVerifier verifier : verifiers.values()) {
try { try {
if (jwt.verify(verifier)) { return jwt.verify(verifier);
return true;
}
} catch (JOSEException e) { } catch (JOSEException e) {
logger.error("Failed to validate signature with {} error message: {}", verifier, e.getMessage()); logger.error("Failed to validate signature with {} error message: {}", verifier, e.getMessage());
} }

View File

@ -54,8 +54,8 @@ public class JWKSetCacheService {
private static final Logger logger = LoggerFactory.getLogger(JWKSetCacheService.class); private static final Logger logger = LoggerFactory.getLogger(JWKSetCacheService.class);
private LoadingCache<String, JWTSigningAndValidationService> validators; private final LoadingCache<String, JWTSigningAndValidationService> validators;
private LoadingCache<String, JWTEncryptionAndDecryptionService> encrypters; private final LoadingCache<String, JWTEncryptionAndDecryptionService> encrypters;
public JWKSetCacheService() { public JWKSetCacheService() {
this.validators = CacheBuilder.newBuilder() this.validators = CacheBuilder.newBuilder()
@ -68,16 +68,11 @@ public class JWKSetCacheService {
.build(new JWKSetEncryptorFetcher(HttpClientBuilder.create().useSystemProperties().build())); .build(new JWKSetEncryptorFetcher(HttpClientBuilder.create().useSystemProperties().build()));
} }
/**
* @param jwksUri
* @return
* @throws ExecutionException
*/
public JWTSigningAndValidationService getValidator(String jwksUri) { public JWTSigningAndValidationService getValidator(String jwksUri) {
try { try {
return validators.get(jwksUri); return validators.get(jwksUri);
} catch (UncheckedExecutionException | ExecutionException e) { } catch (UncheckedExecutionException | ExecutionException e) {
logger.warn("Couldn't load JWK Set from " + jwksUri + ": " + e.getMessage()); logger.warn("Couldn't load JWK Set from {}: {}", jwksUri, e.getMessage());
return null; return null;
} }
} }
@ -86,13 +81,13 @@ public class JWKSetCacheService {
try { try {
return encrypters.get(jwksUri); return encrypters.get(jwksUri);
} catch (UncheckedExecutionException | ExecutionException e) { } catch (UncheckedExecutionException | ExecutionException e) {
logger.warn("Couldn't load JWK Set from " + jwksUri + ": " + e.getMessage()); logger.warn("Couldn't load JWK Set from {}: {}", jwksUri, e.getMessage());
return null; return null;
} }
} }
private static class JWKSetVerifierFetcher extends CacheLoader<String, JWTSigningAndValidationService> { private static class JWKSetVerifierFetcher extends CacheLoader<String, JWTSigningAndValidationService> {
private RestTemplate restTemplate; private final RestTemplate restTemplate;
JWKSetVerifierFetcher(HttpClient httpClient) { JWKSetVerifierFetcher(HttpClient httpClient) {
HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient); HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
@ -103,15 +98,13 @@ public class JWKSetCacheService {
public JWTSigningAndValidationService load(String key) throws Exception { public JWTSigningAndValidationService load(String key) throws Exception {
String jsonString = restTemplate.getForObject(key, String.class); String jsonString = restTemplate.getForObject(key, String.class);
JWKSet jwkSet = JWKSet.parse(jsonString); JWKSet jwkSet = JWKSet.parse(jsonString);
JWKSetKeyStore keyStore = new JWKSetKeyStore(jwkSet); JWKSetKeyStore keyStore = new JWKSetKeyStore(jwkSet);
return new DefaultJWTSigningAndValidationService(keyStore); return new DefaultJWTSigningAndValidationService(keyStore);
} }
} }
private static class JWKSetEncryptorFetcher extends CacheLoader<String, JWTEncryptionAndDecryptionService> { private static class JWKSetEncryptorFetcher extends CacheLoader<String, JWTEncryptionAndDecryptionService> {
private RestTemplate restTemplate; private final RestTemplate restTemplate;
public JWKSetEncryptorFetcher(HttpClient httpClient) { public JWKSetEncryptorFetcher(HttpClient httpClient) {
HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient); HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
@ -123,9 +116,7 @@ public class JWKSetCacheService {
try { try {
String jsonString = restTemplate.getForObject(key, String.class); String jsonString = restTemplate.getForObject(key, String.class);
JWKSet jwkSet = JWKSet.parse(jsonString); JWKSet jwkSet = JWKSet.parse(jsonString);
JWKSetKeyStore keyStore = new JWKSetKeyStore(jwkSet); JWKSetKeyStore keyStore = new JWKSetKeyStore(jwkSet);
return new DefaultJWTEncryptionAndDecryptionService(keyStore); return new DefaultJWTEncryptionAndDecryptionService(keyStore);
} catch (JsonParseException | RestClientException e) { } catch (JsonParseException | RestClientException e) {
throw new IllegalArgumentException("Unable to load JWK Set"); throw new IllegalArgumentException("Unable to load JWK Set");

View File

@ -45,7 +45,7 @@ public class SymmetricKeyJWTValidatorCacheService {
private static final Logger logger = LoggerFactory.getLogger(SymmetricKeyJWTValidatorCacheService.class); private static final Logger logger = LoggerFactory.getLogger(SymmetricKeyJWTValidatorCacheService.class);
private LoadingCache<String, JWTSigningAndValidationService> validators; private final LoadingCache<String, JWTSigningAndValidationService> validators;
public SymmetricKeyJWTValidatorCacheService() { public SymmetricKeyJWTValidatorCacheService() {
validators = CacheBuilder.newBuilder() validators = CacheBuilder.newBuilder()
@ -54,13 +54,11 @@ public class SymmetricKeyJWTValidatorCacheService {
.build(new SymmetricValidatorBuilder()); .build(new SymmetricValidatorBuilder());
} }
public JWTSigningAndValidationService getSymmetricValidtor(ClientDetailsEntity client) { public JWTSigningAndValidationService getSymmetricValidator(ClientDetailsEntity client) {
if (client == null) { if (client == null) {
logger.error("Couldn't create symmetric validator for null client"); logger.error("Couldn't create symmetric validator for null client");
return null; return null;
} } else if (StringUtils.isEmpty(client.getClientSecret())) {
if (StringUtils.isEmpty(client.getClientSecret())) {
logger.error("Couldn't create symmetric validator for client {} without a client secret", client.getClientId()); logger.error("Couldn't create symmetric validator for client {} without a client secret", client.getClientId());
return null; return null;
} }
@ -76,7 +74,6 @@ public class SymmetricKeyJWTValidatorCacheService {
public static class SymmetricValidatorBuilder extends CacheLoader<String, JWTSigningAndValidationService> { public static class SymmetricValidatorBuilder extends CacheLoader<String, JWTSigningAndValidationService> {
@Override @Override
public JWTSigningAndValidationService load(String key) { public JWTSigningAndValidationService load(String key) {
String id = "SYMMETRIC-KEY"; String id = "SYMMETRIC-KEY";
JWK jwk = new OctetSequenceKey.Builder(Base64URL.encode(key)) JWK jwk = new OctetSequenceKey.Builder(Base64URL.encode(key))
.keyUse(KeyUse.SIGNATURE) .keyUse(KeyUse.SIGNATURE)

View File

@ -26,7 +26,6 @@ import com.nimbusds.jwt.JWT;
* Take in an assertion and token request and generate an OAuth2Request from it, including scopes and other important components * Take in an assertion and token request and generate an OAuth2Request from it, including scopes and other important components
* *
* @author jricher * @author jricher
*
*/ */
public interface AssertionOAuth2RequestFactory { public interface AssertionOAuth2RequestFactory {

View File

@ -36,27 +36,23 @@ import com.nimbusds.jwt.JWTClaimsSet;
* - aud, array of audience IDs * - aud, array of audience IDs
* *
* @author jricher * @author jricher
*
*/ */
public class DirectCopyRequestFactory implements AssertionOAuth2RequestFactory { public class DirectCopyRequestFactory implements AssertionOAuth2RequestFactory {
/* (non-Javadoc)
* @see org.mitre.oauth2.assertion.AssertionOAuth2RequestFactory#createOAuth2Request(org.springframework.security.oauth2.provider.ClientDetails, org.springframework.security.oauth2.provider.TokenRequest, com.nimbusds.jwt.JWT)
*/
@Override @Override
public OAuth2Request createOAuth2Request(ClientDetails client, TokenRequest tokenRequest, JWT assertion) { public OAuth2Request createOAuth2Request(ClientDetails client, TokenRequest tokenRequest, JWT assertion) {
try { try {
JWTClaimsSet claims = assertion.getJWTClaimsSet(); JWTClaimsSet claims = assertion.getJWTClaimsSet();
Set<String> scope = OAuth2Utils.parseParameterList(claims.getStringClaim("scope")); Set<String> scope = OAuth2Utils.parseParameterList(claims.getStringClaim("scope"));
Set<String> resources = Sets.newHashSet(claims.getAudience()); Set<String> resources = Sets.newHashSet(claims.getAudience());
return new OAuth2Request(tokenRequest.getRequestParameters(), client.getClientId(), client.getAuthorities(), true, scope, resources, null, null, null); return new OAuth2Request(tokenRequest.getRequestParameters(), client.getClientId(),
client.getAuthorities(), true, scope, resources, null,
null, null);
} catch (ParseException e) { } catch (ParseException e) {
return null; return null;
} }
} }
} }

View File

@ -20,30 +20,18 @@ import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
/** /**
* @author jricher * @author jricher
*
*/ */
public class AuthorizationPendingException extends OAuth2Exception { public class AuthorizationPendingException extends OAuth2Exception {
/** private static final long serialVersionUID = -7078098692596870940L;
* @param msg
*/
public AuthorizationPendingException(String msg) { public AuthorizationPendingException(String msg) {
super(msg); super(msg);
} }
/**
*
*/
private static final long serialVersionUID = -7078098692596870940L;
/* (non-Javadoc)
* @see org.springframework.security.oauth2.common.exceptions.OAuth2Exception#getOAuth2ErrorCode()
*/
@Override @Override
public String getOAuth2ErrorCode() { public String getOAuth2ErrorCode() {
return "authorization_pending"; return "authorization_pending";
} }
} }

View File

@ -24,21 +24,12 @@ import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
*/ */
public class DeviceCodeExpiredException extends OAuth2Exception { public class DeviceCodeExpiredException extends OAuth2Exception {
/** private static final long serialVersionUID = -7078098692596870940L;
* @param msg
*/
public DeviceCodeExpiredException(String msg) { public DeviceCodeExpiredException(String msg) {
super(msg); super(msg);
} }
/**
*
*/
private static final long serialVersionUID = -7078098692596870940L;
/* (non-Javadoc)
* @see org.springframework.security.oauth2.common.exceptions.OAuth2Exception#getOAuth2ErrorCode()
*/
@Override @Override
public String getOAuth2ErrorCode() { public String getOAuth2ErrorCode() {
return "expired_token"; return "expired_token";

View File

@ -19,14 +19,10 @@ package org.mitre.oauth2.exception;
public class DuplicateClientIdException extends RuntimeException { public class DuplicateClientIdException extends RuntimeException {
private static final long serialVersionUID = 1L;
public DuplicateClientIdException(String clientId) { public DuplicateClientIdException(String clientId) {
super("Duplicate client id: " + clientId); super("Duplicate client id: " + clientId);
} }
/**
*
*/
private static final long serialVersionUID = 1L;
} }

View File

@ -50,8 +50,6 @@ import org.mitre.oauth2.model.convert.JWTStringConverter;
import org.mitre.openid.connect.model.ApprovedSite; import org.mitre.openid.connect.model.ApprovedSite;
import org.mitre.uma.model.Permission; import org.mitre.uma.model.Permission;
import org.springframework.security.oauth2.common.OAuth2AccessToken; import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2AccessTokenJackson1Deserializer;
import org.springframework.security.oauth2.common.OAuth2AccessTokenJackson1Serializer;
import org.springframework.security.oauth2.common.OAuth2AccessTokenJackson2Deserializer; import org.springframework.security.oauth2.common.OAuth2AccessTokenJackson2Deserializer;
import org.springframework.security.oauth2.common.OAuth2AccessTokenJackson2Serializer; import org.springframework.security.oauth2.common.OAuth2AccessTokenJackson2Serializer;
import org.springframework.security.oauth2.common.OAuth2RefreshToken; import org.springframework.security.oauth2.common.OAuth2RefreshToken;
@ -74,8 +72,6 @@ import com.nimbusds.jwt.JWT;
@NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_RESOURCE_SET, query = "select a from OAuth2AccessTokenEntity a join a.permissions p where p.resourceSet.id = :" + OAuth2AccessTokenEntity.PARAM_RESOURCE_SET_ID), @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_RESOURCE_SET, query = "select a from OAuth2AccessTokenEntity a join a.permissions p where p.resourceSet.id = :" + OAuth2AccessTokenEntity.PARAM_RESOURCE_SET_ID),
@NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_NAME, query = "select r from OAuth2AccessTokenEntity r where r.authenticationHolder.userAuth.name = :" + OAuth2AccessTokenEntity.PARAM_NAME) @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_NAME, query = "select r from OAuth2AccessTokenEntity r where r.authenticationHolder.userAuth.name = :" + OAuth2AccessTokenEntity.PARAM_NAME)
}) })
@org.codehaus.jackson.map.annotate.JsonSerialize(using = OAuth2AccessTokenJackson1Serializer.class)
@org.codehaus.jackson.map.annotate.JsonDeserialize(using = OAuth2AccessTokenJackson1Deserializer.class)
@com.fasterxml.jackson.databind.annotation.JsonSerialize(using = OAuth2AccessTokenJackson2Serializer.class) @com.fasterxml.jackson.databind.annotation.JsonSerialize(using = OAuth2AccessTokenJackson2Serializer.class)
@com.fasterxml.jackson.databind.annotation.JsonDeserialize(using = OAuth2AccessTokenJackson2Deserializer.class) @com.fasterxml.jackson.databind.annotation.JsonDeserialize(using = OAuth2AccessTokenJackson2Deserializer.class)
public class OAuth2AccessTokenEntity implements OAuth2AccessToken { public class OAuth2AccessTokenEntity implements OAuth2AccessToken {

View File

@ -21,13 +21,28 @@ package org.mitre.oauth2.service.impl;
import org.mitre.openid.connect.config.ConfigurationPropertiesBean; import org.mitre.openid.connect.config.ConfigurationPropertiesBean;
import org.mitre.openid.connect.service.BlacklistedSiteService; import org.mitre.openid.connect.service.BlacklistedSiteService;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.oauth2.common.exceptions.InvalidGrantException;
import org.springframework.security.oauth2.common.exceptions.InvalidRequestException; import org.springframework.security.oauth2.common.exceptions.InvalidRequestException;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception; import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.security.oauth2.common.exceptions.RedirectMismatchException;
import org.springframework.security.oauth2.provider.ClientDetails; import org.springframework.security.oauth2.provider.ClientDetails;
import org.springframework.security.oauth2.provider.endpoint.DefaultRedirectResolver; import org.springframework.security.oauth2.provider.endpoint.DefaultRedirectResolver;
import org.springframework.security.oauth2.provider.endpoint.RedirectResolver;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
/** /**
* *
@ -38,7 +53,7 @@ import com.google.common.base.Strings;
* *
*/ */
@Component("blacklistAwareRedirectResolver") @Component("blacklistAwareRedirectResolver")
public class BlacklistAwareRedirectResolver extends DefaultRedirectResolver { public class BlacklistAwareRedirectResolver implements RedirectResolver {
@Autowired @Autowired
private BlacklistedSiteService blacklistService; private BlacklistedSiteService blacklistService;
@ -46,37 +61,36 @@ public class BlacklistAwareRedirectResolver extends DefaultRedirectResolver {
@Autowired @Autowired
private ConfigurationPropertiesBean config; private ConfigurationPropertiesBean config;
private Collection<String> redirectGrantTypes = Arrays.asList("implicit", "authorization_code");
private boolean matchSubdomains = false;
private boolean matchPorts = true;
private boolean strictMatch = true; private boolean strictMatch = true;
/* (non-Javadoc) /**
* @see org.springframework.security.oauth2.provider.endpoint.RedirectResolver#resolveRedirect(java.lang.String, org.springframework.security.oauth2.provider.ClientDetails) * Flag to indicate that requested URIs will match if they are a subdomain of the registered value.
*
* @param matchSubdomains the flag value to set (default true)
*/ */
@Override public void setMatchSubdomains(boolean matchSubdomains) {
public String resolveRedirect(String requestedRedirect, ClientDetails client) throws OAuth2Exception { this.matchSubdomains = matchSubdomains;
String redirect = super.resolveRedirect(requestedRedirect, client);
if (blacklistService.isBlacklisted(redirect)) {
// don't let it go through
throw new InvalidRequestException("The supplied redirect_uri is not allowed on this server.");
} else {
// not blacklisted, passed the parent test, we're fine
return redirect;
}
} }
/* (non-Javadoc) /**
* @see org.springframework.security.oauth2.provider.endpoint.DefaultRedirectResolver#redirectMatches(java.lang.String, java.lang.String) * Flag that enables/disables port matching between the requested redirect URI and the registered redirect URI(s).
*
* @param matchPorts true to enable port matching, false to disable (defaults to true)
*/ */
@Override public void setMatchPorts(boolean matchPorts) {
protected boolean redirectMatches(String requestedRedirect, String redirectUri) { this.matchPorts = matchPorts;
}
if (isStrictMatch()) {
// we're doing a strict string match for all clients
return Strings.nullToEmpty(requestedRedirect).equals(redirectUri);
} else {
// otherwise do the prefix-match from the library
return super.redirectMatches(requestedRedirect, redirectUri);
}
/**
* Grant types that are permitted to have a redirect uri.
*
* @param redirectGrantTypes the redirect grant types to set
*/
public void setRedirectGrantTypes(Collection<String> redirectGrantTypes) {
this.redirectGrantTypes = new HashSet<String>(redirectGrantTypes);
} }
/** /**
@ -100,6 +114,172 @@ public class BlacklistAwareRedirectResolver extends DefaultRedirectResolver {
this.strictMatch = strictMatch; this.strictMatch = strictMatch;
} }
/* (non-Javadoc)
* @see org.springframework.security.oauth2.provider.endpoint.RedirectResolver#resolveRedirect(java.lang.String, org.springframework.security.oauth2.provider.ClientDetails)
*/
@Override
public String resolveRedirect(String requestedRedirect, ClientDetails client) throws OAuth2Exception {
Set<String> authorizedGrantTypes = client.getAuthorizedGrantTypes();
if (authorizedGrantTypes.isEmpty()) {
throw new InvalidGrantException("A client must have at least one authorized grant type.");
}
if (!containsRedirectGrantType(authorizedGrantTypes)) {
throw new InvalidGrantException(
"A redirect_uri can only be used by implicit or authorization_code grant types.");
}
Set<String> registeredRedirectUris = client.getRegisteredRedirectUri();
if (registeredRedirectUris == null || registeredRedirectUris.isEmpty()) {
throw new InvalidRequestException("At least one redirect_uri must be registered with the client.");
}
String redirect = obtainMatchingRedirect(registeredRedirectUris, requestedRedirect);
if (blacklistService.isBlacklisted(redirect)) {
// don't let it go through
throw new InvalidRequestException("The supplied redirect_uri is not allowed on this server.");
} else {
// not blacklisted, passed the parent test, we're fine
return redirect;
}
}
/**
* Whether the requested redirect URI "matches" the specified redirect URI. For a URL, this implementation tests if
* the user requested redirect starts with the registered redirect, so it would have the same host and root path if
* it is an HTTP URL. The port, userinfo, query params also matched. Request redirect uri path can include
* additional parameters which are ignored for the match
* <p>
* For other (non-URL) cases, such as for some implicit clients, the redirect_uri must be an exact match.
*
* @param requestedRedirect The requested redirect URI.
* @param redirectUri The registered redirect URI.
* @return Whether the requested redirect URI "matches" the specified redirect URI.
*/
protected boolean redirectMatches(String requestedRedirect, String redirectUri) {
UriComponents requestedRedirectUri = UriComponentsBuilder.fromUriString(requestedRedirect).build();
UriComponents registeredRedirectUri = UriComponentsBuilder.fromUriString(redirectUri).build();
boolean schemeMatch = isEqual(registeredRedirectUri.getScheme(), requestedRedirectUri.getScheme());
boolean userInfoMatch = isEqual(registeredRedirectUri.getUserInfo(), requestedRedirectUri.getUserInfo());
boolean hostMatch = hostMatches(registeredRedirectUri.getHost(), requestedRedirectUri.getHost());
boolean portMatch = !matchPorts || registeredRedirectUri.getPort() == requestedRedirectUri.getPort();
boolean pathMatch = true;
boolean queryParamMatch = true;
if (strictMatch) {
pathMatch = isEqual(registeredRedirectUri.getPath(),
StringUtils.cleanPath(requestedRedirectUri.getPath()));
queryParamMatch = matchQueryParams(registeredRedirectUri.getQueryParams(),
requestedRedirectUri.getQueryParams());
}
return schemeMatch && userInfoMatch && hostMatch && portMatch && pathMatch && queryParamMatch;
}
/**
* @param grantTypes some grant types
* @return true if the supplied grant types includes one or more of the redirect types
*/
private boolean containsRedirectGrantType(Set<String> grantTypes) {
for (String type : grantTypes) {
if (redirectGrantTypes.contains(type)) {
return true;
}
}
return false;
}
/**
* Attempt to match one of the registered URIs to the that of the requested one.
*
* @param redirectUris the set of the registered URIs to try and find a match. This cannot be null or empty.
* @param requestedRedirect the URI used as part of the request
* @return redirect uri
* @throws RedirectMismatchException if no match was found
*/
private String obtainMatchingRedirect(Set<String> redirectUris, String requestedRedirect) {
Assert.notEmpty(redirectUris, "Redirect URIs cannot be empty");
if (redirectUris.size() == 1 && requestedRedirect == null) {
return redirectUris.iterator().next();
}
for (String redirectUri : redirectUris) {
if (requestedRedirect != null && redirectMatches(requestedRedirect, redirectUri)) {
// Initialize with the registered redirect-uri
UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri);
UriComponents requestedRedirectUri = UriComponentsBuilder.fromUriString(requestedRedirect).build();
if (this.matchSubdomains) {
redirectUriBuilder.host(requestedRedirectUri.getHost());
}
if (!this.matchPorts) {
redirectUriBuilder.port(requestedRedirectUri.getPort());
}
if (!this.strictMatch) {
redirectUriBuilder.path(requestedRedirectUri.getPath());
}
redirectUriBuilder.replaceQuery(requestedRedirectUri.getQuery()); // retain additional params (if any)
redirectUriBuilder.fragment(null);
return redirectUriBuilder.build().toUriString();
}
}
throw new RedirectMismatchException("Invalid redirect: " + requestedRedirect
+ " does not match one of the registered values.");
}
/**
* Compares two strings but treats empty string or null equal
*
* @param str1
* @param str2
* @return true if strings are equal, false otherwise
*/
private boolean isEqual(String str1, String str2) {
if (StringUtils.isEmpty(str1)) {
return StringUtils.isEmpty(str2);
} else {
return str1.equals(str2);
}
}
/**
* Check if host matches the registered value.
*
* @param registered the registered host. Can be null.
* @param requested the requested host. Can be null.
* @return true if they match
*/
protected boolean hostMatches(String registered, String requested) {
if (matchSubdomains) {
return isEqual(registered, requested) || (requested != null && requested.endsWith("." + registered));
}
return isEqual(registered, requested);
}
/**
* Checks whether the registered redirect uri query params key and values contains match the requested set
*
* The requested redirect uri query params are allowed to contain additional params which will be retained
*
* @param registeredRedirectUriQueryParams
* @param requestedRedirectUriQueryParams
* @return whether the params match
*/
private boolean matchQueryParams(MultiValueMap<String, String> registeredRedirectUriQueryParams,
MultiValueMap<String, String> requestedRedirectUriQueryParams)
{
for (String key : registeredRedirectUriQueryParams.keySet()) {
List<String> registeredRedirectUriQueryParamsValues = registeredRedirectUriQueryParams.get(key);
List<String> requestedRedirectUriQueryParamsValues = requestedRedirectUriQueryParams.get(key);
if (!registeredRedirectUriQueryParamsValues.equals(requestedRedirectUriQueryParamsValues)) {
return false;
}
}
return true;
}
} }

View File

@ -188,7 +188,7 @@ public class DefaultOIDCTokenService implements OIDCTokenService {
null, null); null, null);
idToken = new SignedJWT(header, idClaims.build()); idToken = new SignedJWT(header, idClaims.build());
JWTSigningAndValidationService signer = symmetricCacheService.getSymmetricValidtor(client); JWTSigningAndValidationService signer = symmetricCacheService.getSymmetricValidator(client);
// sign it with the client's secret // sign it with the client's secret
signer.signJwt((SignedJWT) idToken); signer.signJwt((SignedJWT) idToken);

View File

@ -142,7 +142,7 @@ public class UserInfoJWTView extends UserInfoView {
|| signingAlg.equals(JWSAlgorithm.HS512)) { || signingAlg.equals(JWSAlgorithm.HS512)) {
// sign it with the client's secret // sign it with the client's secret
JWTSigningAndValidationService signer = symmetricCacheService.getSymmetricValidtor(client); JWTSigningAndValidationService signer = symmetricCacheService.getSymmetricValidator(client);
signer.signJwt(signed); signer.signJwt(signed);
} else { } else {

View File

@ -20,11 +20,6 @@
*/ */
package org.mitre.openid.connect.web; package org.mitre.openid.connect.web;
import java.lang.reflect.Type;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.mitre.openid.connect.model.OIDCAuthenticationToken; import org.mitre.openid.connect.model.OIDCAuthenticationToken;
import org.mitre.openid.connect.model.UserInfo; import org.mitre.openid.connect.model.UserInfo;
import org.mitre.openid.connect.service.UserInfoService; import org.mitre.openid.connect.service.UserInfoService;
@ -38,11 +33,12 @@ import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;
import com.google.gson.Gson; import com.google.gson.Gson;
import com.google.gson.GsonBuilder; import com.google.gson.GsonBuilder;
import com.google.gson.JsonElement;
import com.google.gson.JsonPrimitive; import com.google.gson.JsonPrimitive;
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer; import com.google.gson.JsonSerializer;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/** /**
* Injects the UserInfo object for the current user into the current model's context, if both exist. Allows JSPs and the like to call "userInfo.name" and other fields. * Injects the UserInfo object for the current user into the current model's context, if both exist. Allows JSPs and the like to call "userInfo.name" and other fields.
* *

View File

@ -16,6 +16,9 @@
package org.mitre.uma.model; package org.mitre.uma.model;
import com.google.gson.JsonElement;
import org.mitre.oauth2.model.convert.JsonElementStringConverter;
import java.util.Set; import java.util.Set;
import javax.persistence.Basic; import javax.persistence.Basic;
@ -31,10 +34,6 @@ import javax.persistence.Id;
import javax.persistence.JoinColumn; import javax.persistence.JoinColumn;
import javax.persistence.Table; import javax.persistence.Table;
import org.mitre.oauth2.model.convert.JsonElementStringConverter;
import com.google.gson.JsonElement;
/** /**
* @author jricher * @author jricher
*/ */

View File

@ -74,8 +74,6 @@ public class TestBlacklistAwareRedirectResolver {
when(client.getAuthorizedGrantTypes()).thenReturn(ImmutableSet.of("authorization_code")); when(client.getAuthorizedGrantTypes()).thenReturn(ImmutableSet.of("authorization_code"));
when(client.getRegisteredRedirectUri()).thenReturn(ImmutableSet.of(goodUri, blacklistedUri)); when(client.getRegisteredRedirectUri()).thenReturn(ImmutableSet.of(goodUri, blacklistedUri));
when(config.isHeartMode()).thenReturn(false);
} }
@Test @Test
@ -141,8 +139,6 @@ public class TestBlacklistAwareRedirectResolver {
@Test @Test
public void testHeartMode() { public void testHeartMode() {
when(config.isHeartMode()).thenReturn(true);
// this is not an exact match // this is not an exact match
boolean res1 = resolver.redirectMatches(pathUri, goodUri); boolean res1 = resolver.redirectMatches(pathUri, goodUri);

View File

@ -17,15 +17,11 @@
*******************************************************************************/ *******************************************************************************/
package org.mitre.oauth2.service.impl; package org.mitre.oauth2.service.impl;
import java.util.HashSet; import com.google.common.collect.Sets;
import java.util.LinkedHashSet;
import java.util.Set;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mitre.oauth2.model.ClientDetailsEntity; import org.mitre.oauth2.model.ClientDetailsEntity;
import org.mitre.oauth2.model.ClientDetailsEntity.AuthMethod;
import org.mitre.oauth2.model.SystemScope; import org.mitre.oauth2.model.SystemScope;
import org.mitre.oauth2.repository.OAuth2ClientRepository; import org.mitre.oauth2.repository.OAuth2ClientRepository;
import org.mitre.oauth2.repository.OAuth2TokenRepository; import org.mitre.oauth2.repository.OAuth2TokenRepository;
@ -40,22 +36,23 @@ import org.mitre.uma.model.ResourceSet;
import org.mitre.uma.service.ResourceSetService; import org.mitre.uma.service.ResourceSetService;
import org.mockito.AdditionalAnswers; import org.mockito.AdditionalAnswers;
import org.mockito.InjectMocks; import org.mockito.InjectMocks;
import org.mockito.Matchers; import org.mockito.ArgumentMatchers;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.springframework.security.oauth2.common.exceptions.InvalidClientException; import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
import com.google.common.collect.Sets; import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Set;
import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.core.Is.is;
import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.core.IsEqual.equalTo;
import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.core.IsNull.notNullValue;
import static org.hamcrest.core.IsNull.nullValue;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
/** /**
@ -99,7 +96,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
public void prepare() { public void prepare() {
Mockito.reset(clientRepository, tokenRepository, approvedSiteService, whitelistedSiteService, blacklistedSiteService, scopeService, statsService); Mockito.reset(clientRepository, tokenRepository, approvedSiteService, whitelistedSiteService, blacklistedSiteService, scopeService, statsService);
Mockito.when(clientRepository.saveClient(Matchers.any(ClientDetailsEntity.class))).thenAnswer(new Answer<ClientDetailsEntity>() { Mockito.when(clientRepository.saveClient(ArgumentMatchers.any(ClientDetailsEntity.class))).thenAnswer(new Answer<ClientDetailsEntity>() {
@Override @Override
public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable { public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments(); Object[] args = invocation.getArguments();
@ -107,15 +104,10 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
} }
}); });
Mockito.when(clientRepository.updateClient(Matchers.anyLong(), Matchers.any(ClientDetailsEntity.class))).thenAnswer(new Answer<ClientDetailsEntity>() { Mockito.when(clientRepository.updateClient(ArgumentMatchers.nullable(Long.class), ArgumentMatchers.any(ClientDetailsEntity.class)))
@Override .then(a -> a.getArgument(1));
public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments();
return (ClientDetailsEntity) args[1];
}
});
Mockito.when(scopeService.fromStrings(Matchers.anySet())).thenAnswer(new Answer<Set<SystemScope>>() { Mockito.when(scopeService.fromStrings(ArgumentMatchers.anySet())).thenAnswer(new Answer<Set<SystemScope>>() {
@Override @Override
public Set<SystemScope> answer(InvocationOnMock invocation) throws Throwable { public Set<SystemScope> answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments(); Object[] args = invocation.getArguments();
@ -128,7 +120,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
} }
}); });
Mockito.when(scopeService.toStrings(Matchers.anySet())).thenAnswer(new Answer<Set<String>>() { Mockito.when(scopeService.toStrings(ArgumentMatchers.anySet())).thenAnswer(new Answer<Set<String>>() {
@Override @Override
public Set<String> answer(InvocationOnMock invocation) throws Throwable { public Set<String> answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments(); Object[] args = invocation.getArguments();
@ -142,7 +134,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
}); });
// we're not testing reserved scopes here, just pass through when it's called // we're not testing reserved scopes here, just pass through when it's called
Mockito.when(scopeService.removeReservedScopes(Matchers.anySet())).then(AdditionalAnswers.returnsFirstArg()); Mockito.when(scopeService.removeReservedScopes(ArgumentMatchers.anySet())).then(AdditionalAnswers.returnsFirstArg());
Mockito.when(config.isHeartMode()).thenReturn(false); Mockito.when(config.isHeartMode()).thenReturn(false);
@ -187,7 +179,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
service.saveNewClient(client); service.saveNewClient(client);
Mockito.verify(client).setClientId(Matchers.anyString()); Mockito.verify(client).setClientId(ArgumentMatchers.anyString());
} }
/** /**
@ -217,7 +209,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
client = service.saveNewClient(client); client = service.saveNewClient(client);
Mockito.verify(scopeService, Mockito.atLeastOnce()).removeReservedScopes(Matchers.anySet()); Mockito.verify(scopeService, Mockito.atLeastOnce()).removeReservedScopes(ArgumentMatchers.anySet());
assertThat(client.getScope().contains(SystemScopeService.OFFLINE_ACCESS), is(equalTo(false))); assertThat(client.getScope().contains(SystemScopeService.OFFLINE_ACCESS), is(equalTo(false)));
} }
@ -343,7 +335,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
client = service.updateClient(oldClient, client); client = service.updateClient(oldClient, client);
Mockito.verify(scopeService, Mockito.atLeastOnce()).removeReservedScopes(Matchers.anySet()); Mockito.verify(scopeService, Mockito.atLeastOnce()).removeReservedScopes(ArgumentMatchers.anySet());
assertThat(client.getScope().contains(SystemScopeService.OFFLINE_ACCESS), is(equalTo(true))); assertThat(client.getScope().contains(SystemScopeService.OFFLINE_ACCESS), is(equalTo(true)));
} }
@ -359,7 +351,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
client = service.updateClient(oldClient, client); client = service.updateClient(oldClient, client);
Mockito.verify(scopeService, Mockito.atLeastOnce()).removeReservedScopes(Matchers.anySet()); Mockito.verify(scopeService, Mockito.atLeastOnce()).removeReservedScopes(ArgumentMatchers.anySet());
assertThat(client.getScope().contains(SystemScopeService.OFFLINE_ACCESS), is(equalTo(false))); assertThat(client.getScope().contains(SystemScopeService.OFFLINE_ACCESS), is(equalTo(false)));
} }
@ -375,7 +367,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
grantTypes.add("client_credentials"); grantTypes.add("client_credentials");
client.setGrantTypes(grantTypes); client.setGrantTypes(grantTypes);
client.setTokenEndpointAuthMethod(AuthMethod.PRIVATE_KEY); client.setTokenEndpointAuthMethod(ClientDetailsEntity.AuthMethod.PRIVATE_KEY);
client.setRedirectUris(Sets.newHashSet("https://foo.bar/")); client.setRedirectUris(Sets.newHashSet("https://foo.bar/"));
@ -396,7 +388,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
grantTypes.add("client_credentials"); grantTypes.add("client_credentials");
client.setGrantTypes(grantTypes); client.setGrantTypes(grantTypes);
client.setTokenEndpointAuthMethod(AuthMethod.NONE); client.setTokenEndpointAuthMethod(ClientDetailsEntity.AuthMethod.NONE);
client.setRedirectUris(Sets.newHashSet("https://foo.bar/")); client.setRedirectUris(Sets.newHashSet("https://foo.bar/"));
@ -417,7 +409,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
grantTypes.add("implicit"); grantTypes.add("implicit");
client.setGrantTypes(grantTypes); client.setGrantTypes(grantTypes);
client.setTokenEndpointAuthMethod(AuthMethod.PRIVATE_KEY); client.setTokenEndpointAuthMethod(ClientDetailsEntity.AuthMethod.PRIVATE_KEY);
client.setJwksUri("https://foo.bar/jwks"); client.setJwksUri("https://foo.bar/jwks");
@ -434,7 +426,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
grantTypes.add("authorization_code"); grantTypes.add("authorization_code");
client.setGrantTypes(grantTypes); client.setGrantTypes(grantTypes);
client.setTokenEndpointAuthMethod(AuthMethod.SECRET_POST); client.setTokenEndpointAuthMethod(ClientDetailsEntity.AuthMethod.SECRET_POST);
client.setRedirectUris(Sets.newHashSet("https://foo.bar/")); client.setRedirectUris(Sets.newHashSet("https://foo.bar/"));
@ -453,7 +445,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
grantTypes.add("implicit"); grantTypes.add("implicit");
client.setGrantTypes(grantTypes); client.setGrantTypes(grantTypes);
client.setTokenEndpointAuthMethod(AuthMethod.PRIVATE_KEY); client.setTokenEndpointAuthMethod(ClientDetailsEntity.AuthMethod.PRIVATE_KEY);
client.setRedirectUris(Sets.newHashSet("https://foo.bar/")); client.setRedirectUris(Sets.newHashSet("https://foo.bar/"));
@ -472,7 +464,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
grantTypes.add("client_credentials"); grantTypes.add("client_credentials");
client.setGrantTypes(grantTypes); client.setGrantTypes(grantTypes);
client.setTokenEndpointAuthMethod(AuthMethod.SECRET_BASIC); client.setTokenEndpointAuthMethod(ClientDetailsEntity.AuthMethod.SECRET_BASIC);
client.setRedirectUris(Sets.newHashSet("https://foo.bar/")); client.setRedirectUris(Sets.newHashSet("https://foo.bar/"));
@ -491,7 +483,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
grantTypes.add("authorization_code"); grantTypes.add("authorization_code");
client.setGrantTypes(grantTypes); client.setGrantTypes(grantTypes);
client.setTokenEndpointAuthMethod(AuthMethod.PRIVATE_KEY); client.setTokenEndpointAuthMethod(ClientDetailsEntity.AuthMethod.PRIVATE_KEY);
service.saveNewClient(client); service.saveNewClient(client);
@ -506,7 +498,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
grantTypes.add("implicit"); grantTypes.add("implicit");
client.setGrantTypes(grantTypes); client.setGrantTypes(grantTypes);
client.setTokenEndpointAuthMethod(AuthMethod.NONE); client.setTokenEndpointAuthMethod(ClientDetailsEntity.AuthMethod.NONE);
service.saveNewClient(client); service.saveNewClient(client);
@ -521,7 +513,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
grantTypes.add("client_credentials"); grantTypes.add("client_credentials");
client.setGrantTypes(grantTypes); client.setGrantTypes(grantTypes);
client.setTokenEndpointAuthMethod(AuthMethod.PRIVATE_KEY); client.setTokenEndpointAuthMethod(ClientDetailsEntity.AuthMethod.PRIVATE_KEY);
client.setRedirectUris(Sets.newHashSet("http://foo.bar/")); client.setRedirectUris(Sets.newHashSet("http://foo.bar/"));
@ -538,7 +530,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
grantTypes.add("authorization_code"); grantTypes.add("authorization_code");
client.setGrantTypes(grantTypes); client.setGrantTypes(grantTypes);
client.setTokenEndpointAuthMethod(AuthMethod.PRIVATE_KEY); client.setTokenEndpointAuthMethod(ClientDetailsEntity.AuthMethod.PRIVATE_KEY);
client.setRedirectUris(Sets.newHashSet("http://foo.bar/")); client.setRedirectUris(Sets.newHashSet("http://foo.bar/"));
@ -557,7 +549,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
grantTypes.add("authorization_code"); grantTypes.add("authorization_code");
client.setGrantTypes(grantTypes); client.setGrantTypes(grantTypes);
client.setTokenEndpointAuthMethod(AuthMethod.PRIVATE_KEY); client.setTokenEndpointAuthMethod(ClientDetailsEntity.AuthMethod.PRIVATE_KEY);
client.setRedirectUris(Sets.newHashSet("https://foo.bar/")); client.setRedirectUris(Sets.newHashSet("https://foo.bar/"));
@ -578,7 +570,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
grantTypes.add("refresh_token"); grantTypes.add("refresh_token");
client.setGrantTypes(grantTypes); client.setGrantTypes(grantTypes);
client.setTokenEndpointAuthMethod(AuthMethod.PRIVATE_KEY); client.setTokenEndpointAuthMethod(ClientDetailsEntity.AuthMethod.PRIVATE_KEY);
client.setRedirectUris(Sets.newHashSet("https://foo.bar/")); client.setRedirectUris(Sets.newHashSet("https://foo.bar/"));
@ -600,7 +592,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
grantTypes.add("refresh_token"); grantTypes.add("refresh_token");
client.setGrantTypes(grantTypes); client.setGrantTypes(grantTypes);
client.setTokenEndpointAuthMethod(AuthMethod.PRIVATE_KEY); client.setTokenEndpointAuthMethod(ClientDetailsEntity.AuthMethod.PRIVATE_KEY);
client.setRedirectUris(Sets.newHashSet("http://foo.bar/")); client.setRedirectUris(Sets.newHashSet("http://foo.bar/"));
@ -620,7 +612,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
grantTypes.add("refresh_token"); grantTypes.add("refresh_token");
client.setGrantTypes(grantTypes); client.setGrantTypes(grantTypes);
client.setTokenEndpointAuthMethod(AuthMethod.PRIVATE_KEY); client.setTokenEndpointAuthMethod(ClientDetailsEntity.AuthMethod.PRIVATE_KEY);
client.setRedirectUris(Sets.newHashSet("http://localhost/", "https://foo.bar", "foo://bar")); client.setRedirectUris(Sets.newHashSet("http://localhost/", "https://foo.bar", "foo://bar"));

View File

@ -191,8 +191,6 @@ public class TestDefaultOAuth2ProviderTokenService {
// we're not testing restricted or reserved scopes here, just pass through // we're not testing restricted or reserved scopes here, just pass through
when(scopeService.removeReservedScopes(anySet())).then(returnsFirstArg()); when(scopeService.removeReservedScopes(anySet())).then(returnsFirstArg());
when(scopeService.removeRestrictedAndReservedScopes(anySet())).then(returnsFirstArg());
when(tokenEnhancer.enhance(any(OAuth2AccessTokenEntity.class), any(OAuth2Authentication.class))) when(tokenEnhancer.enhance(any(OAuth2AccessTokenEntity.class), any(OAuth2Authentication.class)))
.thenAnswer(new Answer<OAuth2AccessTokenEntity>(){ .thenAnswer(new Answer<OAuth2AccessTokenEntity>(){
@Override @Override

View File

@ -123,7 +123,6 @@ public class TestDefaultApprovedSiteService {
String otherId = "a different id"; String otherId = "a different id";
client.setClientId(otherId); client.setClientId(otherId);
service.clearApprovedSitesForClient(client); service.clearApprovedSitesForClient(client);
Mockito.when(repository.getByClientId(otherId)).thenReturn(new HashSet<ApprovedSite>());
Mockito.verify(repository, never()).remove(any(ApprovedSite.class)); Mockito.verify(repository, never()).remove(any(ApprovedSite.class));
} }

View File

@ -63,11 +63,6 @@ public class TestDefaultStatsService {
private ApprovedSite ap5 = Mockito.mock(ApprovedSite.class); private ApprovedSite ap5 = Mockito.mock(ApprovedSite.class);
private ApprovedSite ap6 = Mockito.mock(ApprovedSite.class); private ApprovedSite ap6 = Mockito.mock(ApprovedSite.class);
private ClientDetailsEntity client1 = Mockito.mock(ClientDetailsEntity.class);
private ClientDetailsEntity client2 = Mockito.mock(ClientDetailsEntity.class);
private ClientDetailsEntity client3 = Mockito.mock(ClientDetailsEntity.class);
private ClientDetailsEntity client4 = Mockito.mock(ClientDetailsEntity.class);
@Mock @Mock
private ApprovedSiteService approvedSiteService; private ApprovedSiteService approvedSiteService;
@ -102,12 +97,6 @@ public class TestDefaultStatsService {
Mockito.when(ap6.getClientId()).thenReturn(clientId4); Mockito.when(ap6.getClientId()).thenReturn(clientId4);
Mockito.when(approvedSiteService.getAll()).thenReturn(Sets.newHashSet(ap1, ap2, ap3, ap4)); Mockito.when(approvedSiteService.getAll()).thenReturn(Sets.newHashSet(ap1, ap2, ap3, ap4));
Mockito.when(client1.getId()).thenReturn(1L);
Mockito.when(client2.getId()).thenReturn(2L);
Mockito.when(client3.getId()).thenReturn(3L);
Mockito.when(client4.getId()).thenReturn(4L);
Mockito.when(approvedSiteService.getByClientId(clientId1)).thenReturn(Sets.newHashSet(ap1, ap2)); Mockito.when(approvedSiteService.getByClientId(clientId1)).thenReturn(Sets.newHashSet(ap1, ap2));
Mockito.when(approvedSiteService.getByClientId(clientId2)).thenReturn(Sets.newHashSet(ap3)); Mockito.when(approvedSiteService.getByClientId(clientId2)).thenReturn(Sets.newHashSet(ap3));
Mockito.when(approvedSiteService.getByClientId(clientId3)).thenReturn(Sets.newHashSet(ap4)); Mockito.when(approvedSiteService.getByClientId(clientId3)).thenReturn(Sets.newHashSet(ap4));

View File

@ -61,7 +61,7 @@ import org.mockito.InjectMocks;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.springframework.format.annotation.DateTimeFormat.ISO; import org.springframework.format.annotation.DateTimeFormat.ISO;
import org.springframework.format.datetime.DateFormatter; import org.springframework.format.datetime.DateFormatter;
@ -150,7 +150,6 @@ public class TestMITREidDataService_1_0 {
when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); when(mockedClient1.getClientId()).thenReturn("mocked_client_1");
AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder1.getId()).thenReturn(1L);
OAuth2RefreshTokenEntity token1 = new OAuth2RefreshTokenEntity(); OAuth2RefreshTokenEntity token1 = new OAuth2RefreshTokenEntity();
token1.setId(1L); token1.setId(1L);
@ -165,7 +164,6 @@ public class TestMITREidDataService_1_0 {
when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); when(mockedClient2.getClientId()).thenReturn("mocked_client_2");
AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder2.getId()).thenReturn(2L);
OAuth2RefreshTokenEntity token2 = new OAuth2RefreshTokenEntity(); OAuth2RefreshTokenEntity token2 = new OAuth2RefreshTokenEntity();
token2.setId(2L); token2.setId(2L);
@ -229,7 +227,6 @@ public class TestMITREidDataService_1_0 {
@Override @Override
public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable {
AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class);
when(_auth.getId()).thenReturn(id);
id++; id++;
return _auth; return _auth;
} }
@ -267,7 +264,6 @@ public class TestMITREidDataService_1_0 {
when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); when(mockedClient1.getClientId()).thenReturn("mocked_client_1");
AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder1.getId()).thenReturn(1L);
OAuth2AccessTokenEntity token1 = new OAuth2AccessTokenEntity(); OAuth2AccessTokenEntity token1 = new OAuth2AccessTokenEntity();
token1.setId(1L); token1.setId(1L);
@ -285,10 +281,8 @@ public class TestMITREidDataService_1_0 {
when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); when(mockedClient2.getClientId()).thenReturn("mocked_client_2");
AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder2.getId()).thenReturn(2L);
OAuth2RefreshTokenEntity mockRefreshToken2 = mock(OAuth2RefreshTokenEntity.class); OAuth2RefreshTokenEntity mockRefreshToken2 = mock(OAuth2RefreshTokenEntity.class);
when(mockRefreshToken2.getId()).thenReturn(1L);
OAuth2AccessTokenEntity token2 = new OAuth2AccessTokenEntity(); OAuth2AccessTokenEntity token2 = new OAuth2AccessTokenEntity();
token2.setId(2L); token2.setId(2L);
@ -359,7 +353,6 @@ public class TestMITREidDataService_1_0 {
@Override @Override
public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable {
AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class);
when(_auth.getId()).thenReturn(id);
id++; id++;
return _auth; return _auth;
} }
@ -554,13 +547,6 @@ public class TestMITREidDataService_1_0 {
return _site; return _site;
} }
}); });
when(wlSiteRepository.getById(anyLong())).thenAnswer(new Answer<WhitelistedSite>() {
@Override
public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable {
Long _id = (Long) invocation.getArguments()[0];
return fakeDb.get(_id);
}
});
dataService.importData(reader); dataService.importData(reader);
verify(wlSiteRepository, times(3)).save(capturedWhitelistedSites.capture()); verify(wlSiteRepository, times(3)).save(capturedWhitelistedSites.capture());
@ -580,8 +566,6 @@ public class TestMITREidDataService_1_0 {
Date accessDate1 = formatter.parse("2014-09-10T23:49:44.090+0000", Locale.ENGLISH); Date accessDate1 = formatter.parse("2014-09-10T23:49:44.090+0000", Locale.ENGLISH);
OAuth2AccessTokenEntity mockToken1 = mock(OAuth2AccessTokenEntity.class); OAuth2AccessTokenEntity mockToken1 = mock(OAuth2AccessTokenEntity.class);
when(mockToken1.getId()).thenReturn(1L);
ApprovedSite site1 = new ApprovedSite(); ApprovedSite site1 = new ApprovedSite();
site1.setId(1L); site1.setId(1L);
site1.setClientId("foo"); site1.setClientId("foo");
@ -589,7 +573,6 @@ public class TestMITREidDataService_1_0 {
site1.setAccessDate(accessDate1); site1.setAccessDate(accessDate1);
site1.setUserId("user1"); site1.setUserId("user1");
site1.setAllowedScopes(ImmutableSet.of("openid", "phone")); site1.setAllowedScopes(ImmutableSet.of("openid", "phone"));
when(mockToken1.getApprovedSite()).thenReturn(site1);
Date creationDate2 = formatter.parse("2014-09-11T18:49:44.090+0000", Locale.ENGLISH); Date creationDate2 = formatter.parse("2014-09-11T18:49:44.090+0000", Locale.ENGLISH);
Date accessDate2 = formatter.parse("2014-09-11T20:49:44.090+0000", Locale.ENGLISH); Date accessDate2 = formatter.parse("2014-09-11T20:49:44.090+0000", Locale.ENGLISH);
@ -648,25 +631,13 @@ public class TestMITREidDataService_1_0 {
return fakeDb.get(_id); return fakeDb.get(_id);
} }
}); });
when(wlSiteRepository.getById(isNull(Long.class))).thenAnswer(new Answer<WhitelistedSite>() {
Long id = 244L;
@Override
public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable {
WhitelistedSite _site = mock(WhitelistedSite.class);
when(_site.getId()).thenReturn(id++);
return _site;
}
});
when(tokenRepository.getAccessTokenById(isNull(Long.class))).thenAnswer(new Answer<OAuth2AccessTokenEntity>() { when(tokenRepository.getAccessTokenById(isNull(Long.class))).thenAnswer(new Answer<OAuth2AccessTokenEntity>() {
Long id = 221L; Long id = 221L;
@Override @Override
public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable { public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable {
OAuth2AccessTokenEntity _token = mock(OAuth2AccessTokenEntity.class); return mock(OAuth2AccessTokenEntity.class);
when(_token.getId()).thenReturn(id++);
return _token;
} }
}); });
when(tokenRepository.getAccessTokensForApprovedSite(site1)).thenReturn(Lists.newArrayList(mockToken1));
dataService.importData(reader); dataService.importData(reader);
//2 for sites, 1 for updating access token ref on #1 //2 for sites, 1 for updating access token ref on #1
@ -835,7 +806,6 @@ public class TestMITREidDataService_1_0 {
Date expirationDate1 = formatter.parse(expiration1, Locale.ENGLISH); Date expirationDate1 = formatter.parse(expiration1, Locale.ENGLISH);
ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class); ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class);
when(mockedClient1.getClientId()).thenReturn("mocked_client_1");
OAuth2Request req1 = new OAuth2Request(new HashMap<String, String>(), "client1", new ArrayList<GrantedAuthority>(), OAuth2Request req1 = new OAuth2Request(new HashMap<String, String>(), "client1", new ArrayList<GrantedAuthority>(),
true, new HashSet<String>(), new HashSet<String>(), "http://foo.com", true, new HashSet<String>(), new HashSet<String>(), "http://foo.com",
@ -858,7 +828,6 @@ public class TestMITREidDataService_1_0 {
Date expirationDate2 = formatter.parse(expiration2, Locale.ENGLISH); Date expirationDate2 = formatter.parse(expiration2, Locale.ENGLISH);
ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class); ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class);
when(mockedClient2.getClientId()).thenReturn("mocked_client_2");
OAuth2Request req2 = new OAuth2Request(new HashMap<String, String>(), "client2", new ArrayList<GrantedAuthority>(), OAuth2Request req2 = new OAuth2Request(new HashMap<String, String>(), "client2", new ArrayList<GrantedAuthority>(),
true, new HashSet<String>(), new HashSet<String>(), "http://bar.com", true, new HashSet<String>(), new HashSet<String>(), "http://bar.com",
@ -929,7 +898,6 @@ public class TestMITREidDataService_1_0 {
public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable { public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable {
String _clientId = (String) invocation.getArguments()[0]; String _clientId = (String) invocation.getArguments()[0];
ClientDetailsEntity _client = mock(ClientDetailsEntity.class); ClientDetailsEntity _client = mock(ClientDetailsEntity.class);
when(_client.getClientId()).thenReturn(_clientId);
return _client; return _client;
} }
}); });
@ -967,4 +935,4 @@ public class TestMITREidDataService_1_0 {
dataService.exportData(writer); dataService.exportData(writer);
} }
} }

View File

@ -152,7 +152,6 @@ public class TestMITREidDataService_1_1 {
when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); when(mockedClient1.getClientId()).thenReturn("mocked_client_1");
AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder1.getId()).thenReturn(1L);
OAuth2RefreshTokenEntity token1 = new OAuth2RefreshTokenEntity(); OAuth2RefreshTokenEntity token1 = new OAuth2RefreshTokenEntity();
token1.setId(1L); token1.setId(1L);
@ -168,7 +167,6 @@ public class TestMITREidDataService_1_1 {
when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); when(mockedClient2.getClientId()).thenReturn("mocked_client_2");
AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder2.getId()).thenReturn(2L);
OAuth2RefreshTokenEntity token2 = new OAuth2RefreshTokenEntity(); OAuth2RefreshTokenEntity token2 = new OAuth2RefreshTokenEntity();
token2.setId(2L); token2.setId(2L);
@ -232,7 +230,6 @@ public class TestMITREidDataService_1_1 {
@Override @Override
public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable {
AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class);
when(_auth.getId()).thenReturn(id);
id++; id++;
return _auth; return _auth;
} }
@ -271,7 +268,6 @@ public class TestMITREidDataService_1_1 {
when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); when(mockedClient1.getClientId()).thenReturn("mocked_client_1");
AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder1.getId()).thenReturn(1L);
OAuth2AccessTokenEntity token1 = new OAuth2AccessTokenEntity(); OAuth2AccessTokenEntity token1 = new OAuth2AccessTokenEntity();
token1.setId(1L); token1.setId(1L);
@ -289,10 +285,8 @@ public class TestMITREidDataService_1_1 {
when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); when(mockedClient2.getClientId()).thenReturn("mocked_client_2");
AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder2.getId()).thenReturn(2L);
OAuth2RefreshTokenEntity mockRefreshToken2 = mock(OAuth2RefreshTokenEntity.class); OAuth2RefreshTokenEntity mockRefreshToken2 = mock(OAuth2RefreshTokenEntity.class);
when(mockRefreshToken2.getId()).thenReturn(1L);
OAuth2AccessTokenEntity token2 = new OAuth2AccessTokenEntity(); OAuth2AccessTokenEntity token2 = new OAuth2AccessTokenEntity();
token2.setId(2L); token2.setId(2L);
@ -363,7 +357,6 @@ public class TestMITREidDataService_1_1 {
@Override @Override
public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable {
AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class);
when(_auth.getId()).thenReturn(id);
id++; id++;
return _auth; return _auth;
} }
@ -557,13 +550,6 @@ public class TestMITREidDataService_1_1 {
return _site; return _site;
} }
}); });
when(wlSiteRepository.getById(anyLong())).thenAnswer(new Answer<WhitelistedSite>() {
@Override
public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable {
Long _id = (Long) invocation.getArguments()[0];
return fakeDb.get(_id);
}
});
dataService.importData(reader); dataService.importData(reader);
verify(wlSiteRepository, times(3)).save(capturedWhitelistedSites.capture()); verify(wlSiteRepository, times(3)).save(capturedWhitelistedSites.capture());
@ -583,7 +569,6 @@ public class TestMITREidDataService_1_1 {
Date accessDate1 = formatter.parse("2014-09-10T23:49:44.090+0000", Locale.ENGLISH); Date accessDate1 = formatter.parse("2014-09-10T23:49:44.090+0000", Locale.ENGLISH);
OAuth2AccessTokenEntity mockToken1 = mock(OAuth2AccessTokenEntity.class); OAuth2AccessTokenEntity mockToken1 = mock(OAuth2AccessTokenEntity.class);
when(mockToken1.getId()).thenReturn(1L);
ApprovedSite site1 = new ApprovedSite(); ApprovedSite site1 = new ApprovedSite();
site1.setId(1L); site1.setId(1L);
@ -592,7 +577,6 @@ public class TestMITREidDataService_1_1 {
site1.setAccessDate(accessDate1); site1.setAccessDate(accessDate1);
site1.setUserId("user1"); site1.setUserId("user1");
site1.setAllowedScopes(ImmutableSet.of("openid", "phone")); site1.setAllowedScopes(ImmutableSet.of("openid", "phone"));
when(mockToken1.getApprovedSite()).thenReturn(site1);
Date creationDate2 = formatter.parse("2014-09-11T18:49:44.090+0000", Locale.ENGLISH); Date creationDate2 = formatter.parse("2014-09-11T18:49:44.090+0000", Locale.ENGLISH);
Date accessDate2 = formatter.parse("2014-09-11T20:49:44.090+0000", Locale.ENGLISH); Date accessDate2 = formatter.parse("2014-09-11T20:49:44.090+0000", Locale.ENGLISH);
@ -651,21 +635,11 @@ public class TestMITREidDataService_1_1 {
return fakeDb.get(_id); return fakeDb.get(_id);
} }
}); });
when(wlSiteRepository.getById(isNull(Long.class))).thenAnswer(new Answer<WhitelistedSite>() {
Long id = 432L;
@Override
public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable {
WhitelistedSite _site = mock(WhitelistedSite.class);
when(_site.getId()).thenReturn(id++);
return _site;
}
});
when(tokenRepository.getAccessTokenById(isNull(Long.class))).thenAnswer(new Answer<OAuth2AccessTokenEntity>() { when(tokenRepository.getAccessTokenById(isNull(Long.class))).thenAnswer(new Answer<OAuth2AccessTokenEntity>() {
Long id = 245L; Long id = 245L;
@Override @Override
public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable { public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable {
OAuth2AccessTokenEntity _token = mock(OAuth2AccessTokenEntity.class); OAuth2AccessTokenEntity _token = mock(OAuth2AccessTokenEntity.class);
when(_token.getId()).thenReturn(id++);
return _token; return _token;
} }
}); });
@ -837,7 +811,6 @@ public class TestMITREidDataService_1_1 {
Date expirationDate1 = formatter.parse(expiration1, Locale.ENGLISH); Date expirationDate1 = formatter.parse(expiration1, Locale.ENGLISH);
ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class); ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class);
when(mockedClient1.getClientId()).thenReturn("mocked_client_1");
OAuth2Request req1 = new OAuth2Request(new HashMap<String, String>(), "client1", new ArrayList<GrantedAuthority>(), OAuth2Request req1 = new OAuth2Request(new HashMap<String, String>(), "client1", new ArrayList<GrantedAuthority>(),
true, new HashSet<String>(), new HashSet<String>(), "http://foo.com", true, new HashSet<String>(), new HashSet<String>(), "http://foo.com",
@ -860,7 +833,6 @@ public class TestMITREidDataService_1_1 {
Date expirationDate2 = formatter.parse(expiration2, Locale.ENGLISH); Date expirationDate2 = formatter.parse(expiration2, Locale.ENGLISH);
ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class); ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class);
when(mockedClient2.getClientId()).thenReturn("mocked_client_2");
OAuth2Request req2 = new OAuth2Request(new HashMap<String, String>(), "client2", new ArrayList<GrantedAuthority>(), OAuth2Request req2 = new OAuth2Request(new HashMap<String, String>(), "client2", new ArrayList<GrantedAuthority>(),
true, new HashSet<String>(), new HashSet<String>(), "http://bar.com", true, new HashSet<String>(), new HashSet<String>(), "http://bar.com",
@ -931,7 +903,6 @@ public class TestMITREidDataService_1_1 {
public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable { public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable {
String _clientId = (String) invocation.getArguments()[0]; String _clientId = (String) invocation.getArguments()[0];
ClientDetailsEntity _client = mock(ClientDetailsEntity.class); ClientDetailsEntity _client = mock(ClientDetailsEntity.class);
when(_client.getClientId()).thenReturn(_clientId);
return _client; return _client;
} }
}); });

View File

@ -154,7 +154,6 @@ public class TestMITREidDataService_1_2 {
when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); when(mockedClient1.getClientId()).thenReturn("mocked_client_1");
AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder1.getId()).thenReturn(1L);
OAuth2RefreshTokenEntity token1 = new OAuth2RefreshTokenEntity(); OAuth2RefreshTokenEntity token1 = new OAuth2RefreshTokenEntity();
token1.setId(1L); token1.setId(1L);
@ -170,7 +169,6 @@ public class TestMITREidDataService_1_2 {
when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); when(mockedClient2.getClientId()).thenReturn("mocked_client_2");
AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder2.getId()).thenReturn(2L);
OAuth2RefreshTokenEntity token2 = new OAuth2RefreshTokenEntity(); OAuth2RefreshTokenEntity token2 = new OAuth2RefreshTokenEntity();
token2.setId(2L); token2.setId(2L);
@ -234,7 +232,6 @@ public class TestMITREidDataService_1_2 {
@Override @Override
public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable {
AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class);
when(_auth.getId()).thenReturn(id);
id++; id++;
return _auth; return _auth;
} }
@ -273,7 +270,6 @@ public class TestMITREidDataService_1_2 {
when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); when(mockedClient1.getClientId()).thenReturn("mocked_client_1");
AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder1.getId()).thenReturn(1L);
OAuth2AccessTokenEntity token1 = new OAuth2AccessTokenEntity(); OAuth2AccessTokenEntity token1 = new OAuth2AccessTokenEntity();
token1.setId(1L); token1.setId(1L);
@ -291,10 +287,8 @@ public class TestMITREidDataService_1_2 {
when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); when(mockedClient2.getClientId()).thenReturn("mocked_client_2");
AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder2.getId()).thenReturn(2L);
OAuth2RefreshTokenEntity mockRefreshToken2 = mock(OAuth2RefreshTokenEntity.class); OAuth2RefreshTokenEntity mockRefreshToken2 = mock(OAuth2RefreshTokenEntity.class);
when(mockRefreshToken2.getId()).thenReturn(1L);
OAuth2AccessTokenEntity token2 = new OAuth2AccessTokenEntity(); OAuth2AccessTokenEntity token2 = new OAuth2AccessTokenEntity();
token2.setId(2L); token2.setId(2L);
@ -365,7 +359,6 @@ public class TestMITREidDataService_1_2 {
@Override @Override
public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable {
AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class);
when(_auth.getId()).thenReturn(id);
id++; id++;
return _auth; return _auth;
} }
@ -559,13 +552,6 @@ public class TestMITREidDataService_1_2 {
return _site; return _site;
} }
}); });
when(wlSiteRepository.getById(anyLong())).thenAnswer(new Answer<WhitelistedSite>() {
@Override
public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable {
Long _id = (Long) invocation.getArguments()[0];
return fakeDb.get(_id);
}
});
dataService.importData(reader); dataService.importData(reader);
verify(wlSiteRepository, times(3)).save(capturedWhitelistedSites.capture()); verify(wlSiteRepository, times(3)).save(capturedWhitelistedSites.capture());
@ -585,7 +571,6 @@ public class TestMITREidDataService_1_2 {
Date accessDate1 = formatter.parse("2014-09-10T23:49:44.090+0000", Locale.ENGLISH); Date accessDate1 = formatter.parse("2014-09-10T23:49:44.090+0000", Locale.ENGLISH);
OAuth2AccessTokenEntity mockToken1 = mock(OAuth2AccessTokenEntity.class); OAuth2AccessTokenEntity mockToken1 = mock(OAuth2AccessTokenEntity.class);
when(mockToken1.getId()).thenReturn(1L);
ApprovedSite site1 = new ApprovedSite(); ApprovedSite site1 = new ApprovedSite();
site1.setId(1L); site1.setId(1L);
@ -594,7 +579,6 @@ public class TestMITREidDataService_1_2 {
site1.setAccessDate(accessDate1); site1.setAccessDate(accessDate1);
site1.setUserId("user1"); site1.setUserId("user1");
site1.setAllowedScopes(ImmutableSet.of("openid", "phone")); site1.setAllowedScopes(ImmutableSet.of("openid", "phone"));
when(mockToken1.getApprovedSite()).thenReturn(site1);
Date creationDate2 = formatter.parse("2014-09-11T18:49:44.090+0000", Locale.ENGLISH); Date creationDate2 = formatter.parse("2014-09-11T18:49:44.090+0000", Locale.ENGLISH);
Date accessDate2 = formatter.parse("2014-09-11T20:49:44.090+0000", Locale.ENGLISH); Date accessDate2 = formatter.parse("2014-09-11T20:49:44.090+0000", Locale.ENGLISH);
@ -653,21 +637,11 @@ public class TestMITREidDataService_1_2 {
return fakeDb.get(_id); return fakeDb.get(_id);
} }
}); });
when(wlSiteRepository.getById(isNull(Long.class))).thenAnswer(new Answer<WhitelistedSite>() {
Long id = 432L;
@Override
public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable {
WhitelistedSite _site = mock(WhitelistedSite.class);
when(_site.getId()).thenReturn(id++);
return _site;
}
});
when(tokenRepository.getAccessTokenById(isNull(Long.class))).thenAnswer(new Answer<OAuth2AccessTokenEntity>() { when(tokenRepository.getAccessTokenById(isNull(Long.class))).thenAnswer(new Answer<OAuth2AccessTokenEntity>() {
Long id = 245L; Long id = 245L;
@Override @Override
public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable { public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable {
OAuth2AccessTokenEntity _token = mock(OAuth2AccessTokenEntity.class); OAuth2AccessTokenEntity _token = mock(OAuth2AccessTokenEntity.class);
when(_token.getId()).thenReturn(id++);
return _token; return _token;
} }
}); });
@ -839,7 +813,6 @@ public class TestMITREidDataService_1_2 {
Date expirationDate1 = formatter.parse(expiration1, Locale.ENGLISH); Date expirationDate1 = formatter.parse(expiration1, Locale.ENGLISH);
ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class); ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class);
when(mockedClient1.getClientId()).thenReturn("mocked_client_1");
OAuth2Request req1 = new OAuth2Request(new HashMap<String, String>(), "client1", new ArrayList<GrantedAuthority>(), OAuth2Request req1 = new OAuth2Request(new HashMap<String, String>(), "client1", new ArrayList<GrantedAuthority>(),
true, new HashSet<String>(), new HashSet<String>(), "http://foo.com", true, new HashSet<String>(), new HashSet<String>(), "http://foo.com",
@ -862,7 +835,6 @@ public class TestMITREidDataService_1_2 {
Date expirationDate2 = formatter.parse(expiration2, Locale.ENGLISH); Date expirationDate2 = formatter.parse(expiration2, Locale.ENGLISH);
ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class); ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class);
when(mockedClient2.getClientId()).thenReturn("mocked_client_2");
OAuth2Request req2 = new OAuth2Request(new HashMap<String, String>(), "client2", new ArrayList<GrantedAuthority>(), OAuth2Request req2 = new OAuth2Request(new HashMap<String, String>(), "client2", new ArrayList<GrantedAuthority>(),
true, new HashSet<String>(), new HashSet<String>(), "http://bar.com", true, new HashSet<String>(), new HashSet<String>(), "http://bar.com",
@ -933,7 +905,6 @@ public class TestMITREidDataService_1_2 {
public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable { public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable {
String _clientId = (String) invocation.getArguments()[0]; String _clientId = (String) invocation.getArguments()[0];
ClientDetailsEntity _client = mock(ClientDetailsEntity.class); ClientDetailsEntity _client = mock(ClientDetailsEntity.class);
when(_client.getClientId()).thenReturn(_clientId);
return _client; return _client;
} }
}); });

View File

@ -280,8 +280,6 @@ public class TestMITREidDataService_1_3 {
when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); when(mockedClient1.getClientId()).thenReturn("mocked_client_1");
AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder1.getId()).thenReturn(1L);
OAuth2RefreshTokenEntity token1 = new OAuth2RefreshTokenEntity(); OAuth2RefreshTokenEntity token1 = new OAuth2RefreshTokenEntity();
token1.setId(1L); token1.setId(1L);
token1.setClient(mockedClient1); token1.setClient(mockedClient1);
@ -296,8 +294,6 @@ public class TestMITREidDataService_1_3 {
when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); when(mockedClient2.getClientId()).thenReturn("mocked_client_2");
AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder2.getId()).thenReturn(2L);
OAuth2RefreshTokenEntity token2 = new OAuth2RefreshTokenEntity(); OAuth2RefreshTokenEntity token2 = new OAuth2RefreshTokenEntity();
token2.setId(2L); token2.setId(2L);
token2.setClient(mockedClient2); token2.setClient(mockedClient2);
@ -360,7 +356,6 @@ public class TestMITREidDataService_1_3 {
@Override @Override
public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable {
AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class);
when(_auth.getId()).thenReturn(id);
id++; id++;
return _auth; return _auth;
} }
@ -530,8 +525,6 @@ public class TestMITREidDataService_1_3 {
when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); when(mockedClient1.getClientId()).thenReturn("mocked_client_1");
AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder1.getId()).thenReturn(1L);
OAuth2AccessTokenEntity token1 = new OAuth2AccessTokenEntity(); OAuth2AccessTokenEntity token1 = new OAuth2AccessTokenEntity();
token1.setId(1L); token1.setId(1L);
token1.setClient(mockedClient1); token1.setClient(mockedClient1);
@ -548,11 +541,7 @@ public class TestMITREidDataService_1_3 {
when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); when(mockedClient2.getClientId()).thenReturn("mocked_client_2");
AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class);
when(mockedAuthHolder2.getId()).thenReturn(2L);
OAuth2RefreshTokenEntity mockRefreshToken2 = mock(OAuth2RefreshTokenEntity.class); OAuth2RefreshTokenEntity mockRefreshToken2 = mock(OAuth2RefreshTokenEntity.class);
when(mockRefreshToken2.getId()).thenReturn(1L);
OAuth2AccessTokenEntity token2 = new OAuth2AccessTokenEntity(); OAuth2AccessTokenEntity token2 = new OAuth2AccessTokenEntity();
token2.setId(2L); token2.setId(2L);
token2.setClient(mockedClient2); token2.setClient(mockedClient2);
@ -622,7 +611,6 @@ public class TestMITREidDataService_1_3 {
@Override @Override
public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable {
AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class);
when(_auth.getId()).thenReturn(id);
id++; id++;
return _auth; return _auth;
} }
@ -1109,13 +1097,6 @@ public class TestMITREidDataService_1_3 {
return _site; return _site;
} }
}); });
when(wlSiteRepository.getById(anyLong())).thenAnswer(new Answer<WhitelistedSite>() {
@Override
public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable {
Long _id = (Long) invocation.getArguments()[0];
return fakeDb.get(_id);
}
});
dataService.importData(reader); dataService.importData(reader);
verify(wlSiteRepository, times(3)).save(capturedWhitelistedSites.capture()); verify(wlSiteRepository, times(3)).save(capturedWhitelistedSites.capture());
@ -1135,7 +1116,6 @@ public class TestMITREidDataService_1_3 {
Date accessDate1 = formatter.parse("2014-09-10T23:49:44.090+0000", Locale.ENGLISH); Date accessDate1 = formatter.parse("2014-09-10T23:49:44.090+0000", Locale.ENGLISH);
OAuth2AccessTokenEntity mockToken1 = mock(OAuth2AccessTokenEntity.class); OAuth2AccessTokenEntity mockToken1 = mock(OAuth2AccessTokenEntity.class);
when(mockToken1.getId()).thenReturn(1L);
ApprovedSite site1 = new ApprovedSite(); ApprovedSite site1 = new ApprovedSite();
site1.setId(1L); site1.setId(1L);
@ -1144,7 +1124,6 @@ public class TestMITREidDataService_1_3 {
site1.setAccessDate(accessDate1); site1.setAccessDate(accessDate1);
site1.setUserId("user1"); site1.setUserId("user1");
site1.setAllowedScopes(ImmutableSet.of("openid", "phone")); site1.setAllowedScopes(ImmutableSet.of("openid", "phone"));
when(mockToken1.getApprovedSite()).thenReturn(site1);
Date creationDate2 = formatter.parse("2014-09-11T18:49:44.090+0000", Locale.ENGLISH); Date creationDate2 = formatter.parse("2014-09-11T18:49:44.090+0000", Locale.ENGLISH);
Date accessDate2 = formatter.parse("2014-09-11T20:49:44.090+0000", Locale.ENGLISH); Date accessDate2 = formatter.parse("2014-09-11T20:49:44.090+0000", Locale.ENGLISH);
@ -1250,7 +1229,6 @@ public class TestMITREidDataService_1_3 {
Date accessDate1 = formatter.parse("2014-09-10T23:49:44.090+0000", Locale.ENGLISH); Date accessDate1 = formatter.parse("2014-09-10T23:49:44.090+0000", Locale.ENGLISH);
OAuth2AccessTokenEntity mockToken1 = mock(OAuth2AccessTokenEntity.class); OAuth2AccessTokenEntity mockToken1 = mock(OAuth2AccessTokenEntity.class);
when(mockToken1.getId()).thenReturn(1L);
ApprovedSite site1 = new ApprovedSite(); ApprovedSite site1 = new ApprovedSite();
site1.setId(1L); site1.setId(1L);
@ -1259,7 +1237,6 @@ public class TestMITREidDataService_1_3 {
site1.setAccessDate(accessDate1); site1.setAccessDate(accessDate1);
site1.setUserId("user1"); site1.setUserId("user1");
site1.setAllowedScopes(ImmutableSet.of("openid", "phone")); site1.setAllowedScopes(ImmutableSet.of("openid", "phone"));
when(mockToken1.getApprovedSite()).thenReturn(site1);
Date creationDate2 = formatter.parse("2014-09-11T18:49:44.090+0000", Locale.ENGLISH); Date creationDate2 = formatter.parse("2014-09-11T18:49:44.090+0000", Locale.ENGLISH);
Date accessDate2 = formatter.parse("2014-09-11T20:49:44.090+0000", Locale.ENGLISH); Date accessDate2 = formatter.parse("2014-09-11T20:49:44.090+0000", Locale.ENGLISH);
@ -1318,21 +1295,11 @@ public class TestMITREidDataService_1_3 {
return fakeDb.get(_id); return fakeDb.get(_id);
} }
}); });
when(wlSiteRepository.getById(isNull(Long.class))).thenAnswer(new Answer<WhitelistedSite>() {
Long id = 432L;
@Override
public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable {
WhitelistedSite _site = mock(WhitelistedSite.class);
when(_site.getId()).thenReturn(id++);
return _site;
}
});
when(tokenRepository.getAccessTokenById(isNull(Long.class))).thenAnswer(new Answer<OAuth2AccessTokenEntity>() { when(tokenRepository.getAccessTokenById(isNull(Long.class))).thenAnswer(new Answer<OAuth2AccessTokenEntity>() {
Long id = 245L; Long id = 245L;
@Override @Override
public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable { public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable {
OAuth2AccessTokenEntity _token = mock(OAuth2AccessTokenEntity.class); OAuth2AccessTokenEntity _token = mock(OAuth2AccessTokenEntity.class);
when(_token.getId()).thenReturn(id++);
return _token; return _token;
} }
}); });
@ -1721,7 +1688,6 @@ public class TestMITREidDataService_1_3 {
Date expirationDate1 = formatter.parse(expiration1, Locale.ENGLISH); Date expirationDate1 = formatter.parse(expiration1, Locale.ENGLISH);
ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class); ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class);
when(mockedClient1.getClientId()).thenReturn("mocked_client_1");
OAuth2Request req1 = new OAuth2Request(new HashMap<String, String>(), "client1", new ArrayList<GrantedAuthority>(), OAuth2Request req1 = new OAuth2Request(new HashMap<String, String>(), "client1", new ArrayList<GrantedAuthority>(),
true, new HashSet<String>(), new HashSet<String>(), "http://foo.com", true, new HashSet<String>(), new HashSet<String>(), "http://foo.com",
@ -1744,7 +1710,6 @@ public class TestMITREidDataService_1_3 {
Date expirationDate2 = formatter.parse(expiration2, Locale.ENGLISH); Date expirationDate2 = formatter.parse(expiration2, Locale.ENGLISH);
ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class); ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class);
when(mockedClient2.getClientId()).thenReturn("mocked_client_2");
OAuth2Request req2 = new OAuth2Request(new HashMap<String, String>(), "client2", new ArrayList<GrantedAuthority>(), OAuth2Request req2 = new OAuth2Request(new HashMap<String, String>(), "client2", new ArrayList<GrantedAuthority>(),
true, new HashSet<String>(), new HashSet<String>(), "http://bar.com", true, new HashSet<String>(), new HashSet<String>(), "http://bar.com",
@ -1815,7 +1780,6 @@ public class TestMITREidDataService_1_3 {
public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable { public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable {
String _clientId = (String) invocation.getArguments()[0]; String _clientId = (String) invocation.getArguments()[0];
ClientDetailsEntity _client = mock(ClientDetailsEntity.class); ClientDetailsEntity _client = mock(ClientDetailsEntity.class);
when(_client.getClientId()).thenReturn(_clientId);
return _client; return _client;
} }
}); });

154
pom.xml
View File

@ -358,23 +358,24 @@
<spring-security-oauth2.version>2.4.1.RELEASE</spring-security-oauth2.version> <spring-security-oauth2.version>2.4.1.RELEASE</spring-security-oauth2.version>
<jackson.version>2.11.0</jackson.version> <jackson.version>2.11.0</jackson.version>
<jstl.version>1.2</jstl.version> <jstl.version>1.2</jstl.version>
<servlet-api.version>4.0.1</servlet-api.version> <servlet-api.version>2.5</servlet-api.version>
<jsp-api.version>2.2</jsp-api.version> <jsp-api.version>2.2</jsp-api.version>
<mysql.version>8.0.20</mysql.version> <mysql.version>8.0.20</mysql.version>
<hsqldb.version>2.4.0</hsqldb.version> <hsqldb.version>2.4.0</hsqldb.version>
<eclipse-persistence.version>2.7.7</eclipse-persistence.version> <eclipse-persistence.version>2.7.7</eclipse-persistence.version>
<javax-persistence.version>2.2.1</javax-persistence.version> <javax-persistence.version>2.2.1</javax-persistence.version>
<hikari.version>2.12.0</hikari.version> <hikari.version>3.4.5</hikari.version>
<logback.version>1.2.3</logback.version>
<org.slf4j-version>1.7.30</org.slf4j-version> <org.slf4j-version>1.7.30</org.slf4j-version>
<log4j-core.version>2.13.3</log4j-core.version> <log4j-core.version>2.13.3</log4j-core.version>
<junit-jupiter-api.version>5.6.2</junit-jupiter-api.version> <junit.version>4.13</junit.version>
<easymock.version>4.2</easymock.version> <easymock.version>4.2</easymock.version>
<mockito-all.version>1.10.19</mockito-all.version> <mockito.version>3.2.4</mockito.version>
<guava.version>29.0-jre</guava.version> <guava.version>29.0-jre</guava.version>
<gson.version>2.8.6</gson.version> <gson.version>2.8.6</gson.version>
<httpclient.version>4.5.12</httpclient.version> <httpclient.version>4.5.12</httpclient.version>
<nimbus-jose-jwt.version>8.17.1</nimbus-jose-jwt.version> <nimbus-jose-jwt.version>8.17.1</nimbus-jose-jwt.version>
<bcprov-jdk15on.version>1.65.01</bcprov-jdk15on.version> <bcprov-jdk15on.version>1.65</bcprov-jdk15on.version>
<commons-io.version>2.7</commons-io.version> <commons-io.version>2.7</commons-io.version>
<wro4j-extensions.version>1.9.0</wro4j-extensions.version> <wro4j-extensions.version>1.9.0</wro4j-extensions.version>
</properties> </properties>
@ -402,6 +403,12 @@
<artifactId>spring-security-oauth2</artifactId> <artifactId>spring-security-oauth2</artifactId>
<version>${spring-security-oauth2.version}</version> <version>${spring-security-oauth2.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-test</artifactId>
<version>${spring.version}</version>
<scope>test</scope>
</dependency>
<!-- Jackson --> <!-- Jackson -->
<dependency> <dependency>
<groupId>com.fasterxml.jackson.core</groupId> <groupId>com.fasterxml.jackson.core</groupId>
@ -416,14 +423,14 @@
<!-- Servlet --> <!-- Servlet -->
<dependency> <dependency>
<groupId>javax.servlet.jsp.jstl</groupId> <groupId>javax.servlet.jsp.jstl</groupId>
<artifactId>jstl</artifactId> <artifactId>jstl-api</artifactId>
<version>${jstl.version}</version> <version>${jstl.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>javax.servlet</groupId> <groupId>javax.servlet</groupId>
<artifactId>javax.servlet-api</artifactId> <artifactId>servlet-api</artifactId>
<version>${servlet-api.version}</version> <version>${servlet-api.version}</version>
<scope>provided</scope> <scope>compile</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>javax.servlet.jsp</groupId> <groupId>javax.servlet.jsp</groupId>
@ -448,11 +455,6 @@
<artifactId>org.eclipse.persistence.jpa</artifactId> <artifactId>org.eclipse.persistence.jpa</artifactId>
<version>${eclipse-persistence.version}</version> <version>${eclipse-persistence.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.eclipse.persistence</groupId>
<artifactId>javax.persistence</artifactId>
<version>${javax-persistence.version}</version>
</dependency>
<dependency> <dependency>
<groupId>com.zaxxer</groupId> <groupId>com.zaxxer</groupId>
<artifactId>HikariCP</artifactId> <artifactId>HikariCP</artifactId>
@ -460,68 +462,21 @@
</dependency> </dependency>
<!-- Logging --> <!-- Logging -->
<dependency> <dependency>
<groupId>org.slf4j</groupId> <groupId>ch.qos.logback</groupId>
<artifactId>slf4j-api</artifactId> <artifactId>logback-classic</artifactId>
<version>${org.slf4j-version}</version> <version>${logback.version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-jdk14</artifactId>
<scope>test</scope>
<version>${org.slf4j-version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>jcl-over-slf4j</artifactId>
<version>${org.slf4j-version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
<version>${org.slf4j-version}</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-core</artifactId>
<version>${log4j-core.version}</version>
<exclusions>
<exclusion>
<groupId>javax.mail</groupId>
<artifactId>mail</artifactId>
</exclusion>
<exclusion>
<groupId>javax.jms</groupId>
<artifactId>jms</artifactId>
</exclusion>
<exclusion>
<groupId>com.sun.jdmk</groupId>
<artifactId>jmxtools</artifactId>
</exclusion>
<exclusion>
<groupId>com.sun.jmx</groupId>
<artifactId>jmxri</artifactId>
</exclusion>
</exclusions>
<scope>runtime</scope>
</dependency> </dependency>
<!-- Test --> <!-- Test -->
<dependency> <dependency>
<groupId>org.junit.jupiter</groupId> <groupId>junit</groupId>
<artifactId>junit-jupiter-api</artifactId> <artifactId>junit</artifactId>
<version>${junit-jupiter-api.version}</version> <version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.easymock</groupId>
<artifactId>easymock</artifactId>
<version>${easymock.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.mockito</groupId> <groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId> <artifactId>mockito-core</artifactId>
<version>${mockito-all.version}</version> <version>${mockito.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<!-- MITREid Connect components --> <!-- MITREid Connect components -->
@ -530,18 +485,6 @@
<artifactId>openid-connect-server</artifactId> <artifactId>openid-connect-server</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.mitre</groupId>
<artifactId>openid-connect-server-webapp</artifactId>
<version>${project.version}</version>
<type>war</type>
</dependency>
<dependency>
<groupId>org.mitre</groupId>
<artifactId>openid-connect-server-webapp</artifactId>
<version>${project.version}</version>
<type>warpath</type>
</dependency>
<!-- Other libraries --> <!-- Other libraries -->
<dependency> <dependency>
<groupId>com.google.guava</groupId> <groupId>com.google.guava</groupId>
@ -557,12 +500,6 @@
<groupId>org.apache.httpcomponents</groupId> <groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId> <artifactId>httpclient</artifactId>
<version>${httpclient.version}</version> <version>${httpclient.version}</version>
<exclusions>
<exclusion>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
</exclusion>
</exclusions>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.nimbusds</groupId> <groupId>com.nimbusds</groupId>
@ -580,54 +517,11 @@
<version>${eclipse-persistence.version}</version> <version>${eclipse-persistence.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.commons</groupId> <groupId>commons-io</groupId>
<artifactId>commons-io</artifactId> <artifactId>commons-io</artifactId>
<version>${commons-io.version}</version> <version>${commons-io.version}</version>
</dependency> </dependency>
<dependency>
<groupId>ro.isdc.wro4j</groupId>
<artifactId>wro4j-extensions</artifactId>
<version>${wro4j-extensions.version}</version>
</dependency>
</dependencies> </dependencies>
</dependencyManagement> </dependencyManagement>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
</dependency>
<dependency>
<groupId>org.easymock</groupId>
<artifactId>easymock</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-test</artifactId>
<exclusions>
<exclusion>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-jdk14</artifactId>
</dependency>
<dependency>
<groupId>javax.servlet</groupId>
<artifactId>servlet-api</artifactId>
</dependency>
<dependency>
<groupId>javax.servlet.jsp</groupId>
<artifactId>jsp-api</artifactId>
</dependency>
</dependencies>
</project> </project>