process transactions and redirect based interactions

pull/1536/head
Justin Richer 2020-05-11 14:08:48 -04:00
parent 6c4f25b18b
commit 80eb57402c
3 changed files with 702 additions and 0 deletions

View File

@ -0,0 +1,112 @@
package org.mitre.xyz;
import java.security.MessageDigest;
import java.util.Base64;
import java.util.function.Function;
import org.bouncycastle.jcajce.provider.digest.SHA1;
import org.bouncycastle.jcajce.provider.digest.SHA256;
import org.bouncycastle.jcajce.provider.digest.SHA3;
import org.bouncycastle.jcajce.provider.digest.SHA512;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Joiner;
/**
* @author jricher
*
*/
public abstract class Hash {
private static final Logger log = LoggerFactory.getLogger(Hash.class);
public enum Method {
SHA3("sha3", Hash::SHA3_512_encode),
SHA2("sha2", Hash::SHA2_512_encode)
;
private String name;
private Function<String, String> function;
/**
* @param string
* @param object
*/
Method(String name, Function<String, String> function) {
this.name = name;
this.function = function;
}
public String getName() {
return name;
}
public Function<String, String> getFunction() {
return function;
}
public static Method fromJson(String key) {
return key == null ? null :
valueOf(key.toUpperCase());
}
public String toJson() {
return name().toLowerCase();
}
}
public static String SHA3_512_encode(String input) {
MessageDigest digest = new SHA3.Digest512();
byte[] output = digest.digest(input.getBytes());
byte[] encoded = Base64.getUrlEncoder().withoutPadding().encode(output);
return new String(encoded);
}
public static String SHA2_512_encode(String input) {
MessageDigest digest = new SHA512.Digest();
byte[] output = digest.digest(input.getBytes());
byte[] encoded = Base64.getUrlEncoder().withoutPadding().encode(output);
return new String(encoded);
}
public static String CalculateInteractHash(String clientNonce, String serverNonce, String interact, Method method) {
return method.getFunction().apply(
Joiner.on('\n')
.join(clientNonce,
serverNonce,
interact));
}
public static String SHA256_encode(String input) {
if (input == null || input.isEmpty()) {
return null;
}
MessageDigest digest = new SHA256.Digest();
byte[] output = digest.digest(input.getBytes());
byte[] encoded = Base64.getUrlEncoder().withoutPadding().encode(output);
return new String(encoded);
}
public static String SHA1_digest(byte[] input) {
if (input == null || input.length == 0) {
return null;
}
MessageDigest digest = new SHA1.Digest();
byte[] output = digest.digest(input);
byte[] encoded = Base64.getEncoder().encode(output);
return new String(encoded);
}
}

View File

@ -0,0 +1,247 @@
package org.mitre.xyz;
import java.net.URISyntaxException;
import java.util.Date;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import org.apache.http.client.utils.URIBuilder;
import org.mitre.oauth2.model.AuthenticationHolderEntity;
import org.mitre.oauth2.model.SystemScope;
import org.mitre.oauth2.service.ClientDetailsEntityService;
import org.mitre.oauth2.service.SystemScopeService;
import org.mitre.openid.connect.model.UserInfo;
import org.mitre.openid.connect.service.ScopeClaimTranslationService;
import org.mitre.openid.connect.service.StatsService;
import org.mitre.openid.connect.service.UserInfoService;
import org.mitre.openid.connect.token.TofuUserApprovalHandler;
import org.mitre.openid.connect.view.HttpCodeView;
import org.mitre.xyz.TxEndpoint.Status;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.common.util.OAuth2Utils;
import org.springframework.security.oauth2.provider.AuthorizationRequest;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.OAuth2RequestFactory;
import org.springframework.security.oauth2.provider.endpoint.RedirectResolver;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.view.RedirectView;
import com.google.common.base.Joiner;
import com.google.common.base.Strings;
import com.google.common.collect.Sets;
import com.google.gson.JsonObject;
/**
* @author jricher
*
*/
@Controller
@RequestMapping("/interact")
public class IxEndpoint {
@Autowired
private ClientDetailsEntityService clientService;
@Autowired
private SystemScopeService scopeService;
@Autowired
private ScopeClaimTranslationService scopeClaimTranslationService;
@Autowired
private UserInfoService userInfoService;
@Autowired
private StatsService statsService;
@Autowired
private RedirectResolver redirectResolver;
@Autowired
private TxService txService;
@Autowired
private OAuth2RequestFactory oAuth2RequestFactory;
@Autowired
private TofuUserApprovalHandler userApprovalHandler;
@PreAuthorize("hasRole('ROLE_USER')")
@RequestMapping(method = RequestMethod.GET, path = "/{id}")
public String interact(@PathVariable("id") String id, Model m, Authentication auth) {
TxEntity tx = txService.loadByInteractUrl(id);
if (tx == null) {
m.addAttribute(HttpCodeView.CODE, HttpStatus.BAD_REQUEST);
return HttpCodeView.VIEWNAME;
}
m.addAttribute("client", tx.getClient());
m.addAttribute("redirect_uri", tx.getCallbackUri());
Set<SystemScope> scopes = scopeService.fromStrings(tx.getScope());
Set<SystemScope> sortedScopes = new LinkedHashSet<>(scopes.size());
Set<SystemScope> systemScopes = scopeService.getAll();
// sort scopes for display based on the inherent order of system scopes
for (SystemScope s : systemScopes) {
if (scopes.contains(s)) {
sortedScopes.add(s);
}
}
// add in any scopes that aren't system scopes to the end of the list
sortedScopes.addAll(Sets.difference(scopes, systemScopes));
m.addAttribute("scopes", sortedScopes);
// get the userinfo claims for each scope
UserInfo user = userInfoService.getByUsername(auth.getName());
Map<String, Map<String, String>> claimsForScopes = new HashMap<>();
if (user != null) {
JsonObject userJson = user.toJson();
for (SystemScope systemScope : sortedScopes) {
Map<String, String> claimValues = new HashMap<>();
Set<String> claims = scopeClaimTranslationService.getClaimsForScope(systemScope.getValue());
for (String claim : claims) {
if (userJson.has(claim) && userJson.get(claim).isJsonPrimitive()) {
// TODO: this skips the address claim
claimValues.put(claim, userJson.get(claim).getAsString());
}
}
claimsForScopes.put(systemScope.getValue(), claimValues);
}
}
m.addAttribute("claims", claimsForScopes);
// client stats
Integer count = statsService.getCountForClientId(tx.getClient().getClientId()).getApprovedSiteCount();
m.addAttribute("count", count);
// contacts
if (tx.getClient().getContacts() != null) {
String contacts = Joiner.on(", ").join(tx.getClient().getContacts());
m.addAttribute("contacts", contacts);
}
// if the client is over a week old and has more than one registration, don't give such a big warning
// instead, tag as "Generally Recognized As Safe" (gras)
Date lastWeek = new Date(System.currentTimeMillis() - (60 * 60 * 24 * 7 * 1000));
if (count > 1 && tx.getClient().getCreatedAt() != null && tx.getClient().getCreatedAt().before(lastWeek)) {
m.addAttribute("gras", true);
} else {
m.addAttribute("gras", false);
}
m.addAttribute("form_target", "interact/" + id);
return "approve";
}
@PreAuthorize("hasRole('ROLE_USER')")
@RequestMapping(method = RequestMethod.POST, path = "/{id}", params = OAuth2Utils.USER_OAUTH_APPROVAL)
public ModelAndView approveOrDeny(@PathVariable("id") String id,
@RequestParam Map<String, String> approvalParameters,
Model m, Authentication auth) {
TxEntity tx = txService.loadByInteractUrl(id);
if (tx == null) {
m.addAttribute(HttpCodeView.CODE, HttpStatus.BAD_REQUEST);
new ModelAndView(HttpCodeView.VIEWNAME);
}
// FIXME: this is using a simplified constructor
AuthorizationRequest ar = new AuthorizationRequest(tx.getClient().getClientId(), tx.getScope());
ar.setRedirectUri(tx.getCallbackUri());
ar.setApprovalParameters(approvalParameters);
userApprovalHandler.updateAfterApproval(ar, auth);
boolean approved = userApprovalHandler.isApproved(ar, auth);
ar.setApproved(approved);
AuthenticationHolderEntity ah = tx.getAuthenticationHolder();
OAuth2Authentication o2a = new OAuth2Authentication(ah.getAuthentication().getOAuth2Request(), auth);
ah.setAuthentication(o2a);
ah.setApproved(approved);
tx.setAuthenticationHolder(ah);
tx.setStatus(approved ? Status.AUTHORIZED : Status.DENIED);
if (!Strings.isNullOrEmpty(tx.getCallbackUri())) {
String interactRef = UUID.randomUUID().toString();
String hash = Hash.CalculateInteractHash(tx.getClientNonce(), tx.getServerNonce(), interactRef, tx.getHashMethod());
try {
String redirectTo = new URIBuilder(tx.getCallbackUri())
.addParameter("interact", interactRef)
.addParameter("hash", hash).build().toString();
tx.setInteractionRef(interactRef);
txService.save(tx);
return new ModelAndView(new RedirectView(redirectTo));
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
} else {
// no callback, show completion page
// pre-process the scopes
Set<SystemScope> scopes = scopeService.fromStrings(tx.getScope());
Set<SystemScope> sortedScopes = new LinkedHashSet<>(scopes.size());
Set<SystemScope> systemScopes = scopeService.getAll();
// sort scopes for display based on the inherent order of system scopes
for (SystemScope s : systemScopes) {
if (scopes.contains(s)) {
sortedScopes.add(s);
}
}
// add in any scopes that aren't system scopes to the end of the list
sortedScopes.addAll(Sets.difference(scopes, systemScopes));
m.addAttribute("scopes", sortedScopes);
m.addAttribute("approved", true);
txService.save(tx);
// TODO: we are re-using the device approval page here
return new ModelAndView("deviceApproved");
}
}
}

View File

@ -0,0 +1,343 @@
package org.mitre.xyz;
import java.net.URI;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.text.ParseException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import javax.servlet.http.HttpServletRequest;
import org.mitre.oauth2.model.AuthenticationHolderEntity;
import org.mitre.oauth2.model.ClientDetailsEntity;
import org.mitre.oauth2.service.ClientDetailsEntityService;
import org.mitre.oauth2.service.OAuth2TokenEntityService;
import org.mitre.oauth2.service.SystemScopeService;
import org.mitre.openid.connect.config.ConfigurationPropertiesBean;
import org.mitre.openid.connect.view.HttpCodeView;
import org.mitre.openid.connect.view.JsonEntityView;
import org.mitre.openid.connect.view.JsonErrorView;
import org.mitre.xyz.Hash.Method;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.OAuth2Request;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JOSEObject;
import com.nimbusds.jose.JWSObject;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.Payload;
import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.util.Base64URL;
/**
* @author jricher
*
*/
@Controller
public class TxEndpoint {
public enum Status {
NEW, // newly created transaction, nothing's been done to it yet
ISSUED, // an access token has been issued
AUTHORIZED, // the user has authorized but a token has not been issued yet
WAITING, // we are waiting for the user
DENIED; // the user denied the transaction
}
@Autowired
private ClientDetailsEntityService clientService;
@Autowired
private TxService txService;
@Autowired
private SystemScopeService scopeService;
@Autowired
private OAuth2TokenEntityService tokenService;
@Autowired
private ConfigurationPropertiesBean config;
@RequestMapping(consumes = MediaType.APPLICATION_JSON_VALUE,
produces = MediaType.APPLICATION_JSON_VALUE,
path = "/transaction",
method = RequestMethod.POST)
public String transaction(@RequestBody String incoming,
@RequestHeader(name = HttpHeaders.AUTHORIZATION, required = false) String auth,
@RequestHeader(name = "Signature", required = false) String signature,
@RequestHeader(name = "Digest", required = false) String digest,
@RequestHeader(name = "Detached-JWS", required = false) String jwsd,
@RequestHeader(name = "DPoP", required = false) String dpop,
@RequestHeader(name = "PoP", required = false) String oauthPop,
Model m,
HttpServletRequest req) {
JsonObject json = JsonParser.parseString(incoming).getAsJsonObject();
TxEntity tx;
if (json.has("handle")) {
// it's a handle to an existing transaction, load it up and wire things in
tx = txService.loadByHandle(json.get("handle").getAsString());
} else {
// otherwise we build one from the parts
tx = new TxEntity();
// client ID is passed in as the handle for the key object
// we don't support ephemeral keys (yet)
String clientId = json.get("keys").getAsString();
// first, load the client
ClientDetailsEntity client = clientService.loadClientByClientId(clientId);
if (client == null) {
m.addAttribute(JsonErrorView.ERROR, "unknown_key");
m.addAttribute(JsonErrorView.ERROR_MESSAGE, "The key handle presented does not match a client.");
m.addAttribute(HttpCodeView.CODE, HttpStatus.BAD_REQUEST);
return JsonErrorView.VIEWNAME;
}
tx.setClient(client);
// scopes are passed in as handles for the resources
JsonArray resources = json.get("resources").getAsJsonArray();
Set<String> scopes = StreamSupport.stream(resources.spliterator(), false)
.filter( e -> e.isJsonPrimitive() ) // filter out anything that's not a handle
.map( e -> e.getAsString() )
.collect(Collectors.toSet());
tx.setScope(scopes);
tx.setStatus(Status.NEW);
}
// process transaction
// check signatures
// get the only key
if (tx.getClient().getJwks() == null || tx.getClient().getJwks().getKeys().size() == 0) {
m.addAttribute(JsonErrorView.ERROR, "unknown_key");
m.addAttribute(JsonErrorView.ERROR_MESSAGE, "The key handle presented does not map to a key.");
m.addAttribute(HttpCodeView.CODE, HttpStatus.BAD_REQUEST);
return JsonErrorView.VIEWNAME;
}
JWKSet clientJwks = tx.getClient().getJwks();
if (clientJwks.getKeys().size() != 1) {
m.addAttribute(JsonErrorView.ERROR, "unknown_key");
m.addAttribute(JsonErrorView.ERROR_MESSAGE, "The key handle presented maps to multiple keys.");
m.addAttribute(HttpCodeView.CODE, HttpStatus.BAD_REQUEST);
return JsonErrorView.VIEWNAME;
}
// TODO: this doesn't allow for multiple keys to be selected, this could be a client property
// TODO: this doesn't allow for jwks_uri loaded keys
JWK clientJwk = clientJwks.getKeys().get(0);
// check the signature on the incoming request
// TODO: make this configurable on clients, for now assume JWSD
checkDetachedJws(jwsd, incoming, clientJwk);
// process the transaction based on its current state
switch (tx.getStatus()) {
case NEW:
// now make sure the client is asking for scopes that it's allowed to
if (!scopeService.scopesMatch(tx.getClient().getScope(), tx.getScope())) {
m.addAttribute(JsonErrorView.ERROR, "resource_not_allowed");
m.addAttribute(JsonErrorView.ERROR_MESSAGE, "The client requested resources it does not have access to.");
m.addAttribute(HttpCodeView.CODE, HttpStatus.BAD_REQUEST);
return JsonErrorView.VIEWNAME;
}
OAuth2Request o2r = new OAuth2Request(
Collections.emptyMap(), tx.getClient().getClientId(),
tx.getClient().getAuthorities(), tx.getStatus().equals(Status.AUTHORIZED),
tx.getScope(), null, tx.getCallbackUri(), null, null);
OAuth2Authentication o2a = new OAuth2Authentication(o2r, null);
AuthenticationHolderEntity ah = new AuthenticationHolderEntity();
ah.setAuthentication(o2a);
ah.setApproved(false);
tx.setAuthenticationHolder(ah);
// look back at the request to process the interaction parameters
JsonObject interact = json.get("interact").getAsJsonObject();
if (interact == null) {
if (tx.getClient().getGrantTypes().contains("client_credentials")) {
// TODO client can do credentials-only grant, issue a token
}
} else {
Map<String, Object> map = new HashMap<>();
// we support "redirect" and "callback" here
if (interact.has("redirect")) {
// generate an interaction URL
String interactPage = UUID.randomUUID().toString();
tx.setInteraction(interactPage);
String interactUrl = config.getIssuer() + "interact/" + interactPage;
map.put("interaction_url", interactUrl);
}
if (interact.has("callback")) {
JsonObject callback = interact.get("callback").getAsJsonObject();
String callbackString = callback.get("uri").getAsString();
Path callbackPath = Paths.get(URI.create(callbackString).getPath());
// we do sub-path matching for the callback
// FIXME: this is a really simplistic filter that definitely has holes in it
boolean callbackMatches = tx.getClient().getRedirectUris().stream()
.filter(s -> callbackString.startsWith(s))
.map(URI::create)
.map(URI::getPath)
.map(Paths::get)
.anyMatch(path ->
callbackPath.startsWith(path)
);
if (!callbackMatches) {
m.addAttribute(JsonErrorView.ERROR, "invalid_callback_uri");
m.addAttribute(JsonErrorView.ERROR_MESSAGE, "The client presented a callback URI that did not match one registered.");
m.addAttribute(HttpCodeView.CODE, HttpStatus.BAD_REQUEST);
return JsonErrorView.VIEWNAME;
}
tx.setCallbackUri(callbackString);
tx.setClientNonce(callback.get("nonce").getAsString());
if (callback.has("hash_method")) {
tx.setHashMethod(Hash.Method.fromJson(callback.get("hash_method").getAsString()));
} else {
tx.setHashMethod(Method.SHA3);
}
String serverNonce = UUID.randomUUID().toString();
tx.setServerNonce(serverNonce);
map.put("server_nonce", serverNonce);
}
// rotate the handle
String handle = UUID.randomUUID().toString();
tx.setHandle(handle);
Map<String, String> h = new HashMap<>();
h.put("value", handle);
h.put("presentation", "bearer");
map.put("handle", h);
txService.save(tx);
m.addAttribute(JsonEntityView.ENTITY, map);
return JsonEntityView.VIEWNAME;
}
break;
case AUTHORIZED:
Map<String, Object> map = new HashMap<>();
OAuth2Authentication storedAuth = tx.getAuthenticationHolder().getAuthentication();
OAuth2AccessToken accessToken = tokenService.createAccessToken(storedAuth);
Map<String, String> at = new HashMap<>();
at.put("value", accessToken.getValue());
at.put("presentation", accessToken.getTokenType().toLowerCase());
at.put("expiration", accessToken.getExpiration().toInstant().toString());
map.put("access_token", at);
tx.setStatus(Status.ISSUED);
if (accessToken.getAdditionalInformation().containsKey("id_token")) {
// add in the ID token if it's included
Map<String, String> c = new HashMap<>();
c.put("oidc_id_token", (String) accessToken.getAdditionalInformation().get("id_token"));
map.put("claims", c);
// TODO: save the claims request and translate that directly
}
// rotate the handle
String handle = UUID.randomUUID().toString();
tx.setHandle(handle);
Map<String, String> h = new HashMap<>();
h.put("value", handle);
h.put("presentation", "bearer");
map.put("handle", h);
txService.save(tx);
m.addAttribute(JsonEntityView.ENTITY, map);
return JsonEntityView.VIEWNAME;
case DENIED:
break;
case ISSUED:
break;
case WAITING:
break;
default:
break;
}
m.addAttribute(JsonErrorView.ERROR, "transaction_error");
m.addAttribute(JsonErrorView.ERROR_MESSAGE, "There was an error processing the transaction.");
m.addAttribute(HttpCodeView.CODE, HttpStatus.INTERNAL_SERVER_ERROR);
return JsonErrorView.VIEWNAME;
}
private void checkDetachedJws(String jwsd, String requestBody, JWK clientKey) {
try {
Base64URL[] parts = JOSEObject.split(jwsd);
Payload payload = new Payload(requestBody.getBytes());
JWSObject jwsObject = new JWSObject(parts[0], payload, parts[2]);
JWSVerifier verifier = new DefaultJWSVerifierFactory().createJWSVerifier(jwsObject.getHeader(),
((RSAKey)clientKey).toRSAPublicKey());
if (!jwsObject.verify(verifier)) {
throw new RuntimeException("Unable to verify JWS");
}
} catch (ParseException | JOSEException e) {
throw new RuntimeException("Bad JWS", e);
}
}
}