From 80eb57402c0fd9d645b10c7ece1e545911239300 Mon Sep 17 00:00:00 2001 From: Justin Richer Date: Mon, 11 May 2020 14:08:48 -0400 Subject: [PATCH] process transactions and redirect based interactions --- .../src/main/java/org/mitre/xyz/Hash.java | 112 ++++++ .../main/java/org/mitre/xyz/IxEndpoint.java | 247 +++++++++++++ .../main/java/org/mitre/xyz/TxEndpoint.java | 343 ++++++++++++++++++ 3 files changed, 702 insertions(+) create mode 100644 openid-connect-server/src/main/java/org/mitre/xyz/Hash.java create mode 100644 openid-connect-server/src/main/java/org/mitre/xyz/IxEndpoint.java create mode 100644 openid-connect-server/src/main/java/org/mitre/xyz/TxEndpoint.java diff --git a/openid-connect-server/src/main/java/org/mitre/xyz/Hash.java b/openid-connect-server/src/main/java/org/mitre/xyz/Hash.java new file mode 100644 index 000000000..e8443bfd8 --- /dev/null +++ b/openid-connect-server/src/main/java/org/mitre/xyz/Hash.java @@ -0,0 +1,112 @@ +package org.mitre.xyz; + +import java.security.MessageDigest; +import java.util.Base64; +import java.util.function.Function; + +import org.bouncycastle.jcajce.provider.digest.SHA1; +import org.bouncycastle.jcajce.provider.digest.SHA256; +import org.bouncycastle.jcajce.provider.digest.SHA3; +import org.bouncycastle.jcajce.provider.digest.SHA512; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Joiner; + +/** + * @author jricher + * + */ +public abstract class Hash { + + private static final Logger log = LoggerFactory.getLogger(Hash.class); + + public enum Method { + SHA3("sha3", Hash::SHA3_512_encode), + SHA2("sha2", Hash::SHA2_512_encode) + ; + + private String name; + private Function function; + + /** + * @param string + * @param object + */ + Method(String name, Function function) { + this.name = name; + this.function = function; + } + + public String getName() { + return name; + } + public Function getFunction() { + return function; + } + + public static Method fromJson(String key) { + return key == null ? null : + valueOf(key.toUpperCase()); + } + + public String toJson() { + return name().toLowerCase(); + } + + } + + public static String SHA3_512_encode(String input) { + MessageDigest digest = new SHA3.Digest512(); + byte[] output = digest.digest(input.getBytes()); + + byte[] encoded = Base64.getUrlEncoder().withoutPadding().encode(output); + + return new String(encoded); + + } + + public static String SHA2_512_encode(String input) { + MessageDigest digest = new SHA512.Digest(); + byte[] output = digest.digest(input.getBytes()); + + byte[] encoded = Base64.getUrlEncoder().withoutPadding().encode(output); + + return new String(encoded); + + } + + public static String CalculateInteractHash(String clientNonce, String serverNonce, String interact, Method method) { + return method.getFunction().apply( + Joiner.on('\n') + .join(clientNonce, + serverNonce, + interact)); + } + + public static String SHA256_encode(String input) { + if (input == null || input.isEmpty()) { + return null; + } + + MessageDigest digest = new SHA256.Digest(); + byte[] output = digest.digest(input.getBytes()); + + byte[] encoded = Base64.getUrlEncoder().withoutPadding().encode(output); + + return new String(encoded); + } + + public static String SHA1_digest(byte[] input) { + if (input == null || input.length == 0) { + return null; + } + + MessageDigest digest = new SHA1.Digest(); + byte[] output = digest.digest(input); + + byte[] encoded = Base64.getEncoder().encode(output); + + return new String(encoded); + } +} diff --git a/openid-connect-server/src/main/java/org/mitre/xyz/IxEndpoint.java b/openid-connect-server/src/main/java/org/mitre/xyz/IxEndpoint.java new file mode 100644 index 000000000..2bd0e3e6e --- /dev/null +++ b/openid-connect-server/src/main/java/org/mitre/xyz/IxEndpoint.java @@ -0,0 +1,247 @@ +package org.mitre.xyz; + +import java.net.URISyntaxException; +import java.util.Date; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; +import java.util.UUID; + +import org.apache.http.client.utils.URIBuilder; +import org.mitre.oauth2.model.AuthenticationHolderEntity; +import org.mitre.oauth2.model.SystemScope; +import org.mitre.oauth2.service.ClientDetailsEntityService; +import org.mitre.oauth2.service.SystemScopeService; +import org.mitre.openid.connect.model.UserInfo; +import org.mitre.openid.connect.service.ScopeClaimTranslationService; +import org.mitre.openid.connect.service.StatsService; +import org.mitre.openid.connect.service.UserInfoService; +import org.mitre.openid.connect.token.TofuUserApprovalHandler; +import org.mitre.openid.connect.view.HttpCodeView; +import org.mitre.xyz.TxEndpoint.Status; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpStatus; +import org.springframework.security.access.prepost.PreAuthorize; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.common.util.OAuth2Utils; +import org.springframework.security.oauth2.provider.AuthorizationRequest; +import org.springframework.security.oauth2.provider.OAuth2Authentication; +import org.springframework.security.oauth2.provider.OAuth2RequestFactory; +import org.springframework.security.oauth2.provider.endpoint.RedirectResolver; +import org.springframework.stereotype.Controller; +import org.springframework.ui.Model; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.servlet.ModelAndView; +import org.springframework.web.servlet.view.RedirectView; + +import com.google.common.base.Joiner; +import com.google.common.base.Strings; +import com.google.common.collect.Sets; +import com.google.gson.JsonObject; + +/** + * @author jricher + * + */ +@Controller +@RequestMapping("/interact") +public class IxEndpoint { + + @Autowired + private ClientDetailsEntityService clientService; + + @Autowired + private SystemScopeService scopeService; + + @Autowired + private ScopeClaimTranslationService scopeClaimTranslationService; + + @Autowired + private UserInfoService userInfoService; + + @Autowired + private StatsService statsService; + + @Autowired + private RedirectResolver redirectResolver; + + @Autowired + private TxService txService; + + @Autowired + private OAuth2RequestFactory oAuth2RequestFactory; + + @Autowired + private TofuUserApprovalHandler userApprovalHandler; + + + @PreAuthorize("hasRole('ROLE_USER')") + @RequestMapping(method = RequestMethod.GET, path = "/{id}") + public String interact(@PathVariable("id") String id, Model m, Authentication auth) { + + TxEntity tx = txService.loadByInteractUrl(id); + + if (tx == null) { + m.addAttribute(HttpCodeView.CODE, HttpStatus.BAD_REQUEST); + return HttpCodeView.VIEWNAME; + } + + m.addAttribute("client", tx.getClient()); + + m.addAttribute("redirect_uri", tx.getCallbackUri()); + + Set scopes = scopeService.fromStrings(tx.getScope()); + + Set sortedScopes = new LinkedHashSet<>(scopes.size()); + Set systemScopes = scopeService.getAll(); + + // sort scopes for display based on the inherent order of system scopes + for (SystemScope s : systemScopes) { + if (scopes.contains(s)) { + sortedScopes.add(s); + } + } + + // add in any scopes that aren't system scopes to the end of the list + sortedScopes.addAll(Sets.difference(scopes, systemScopes)); + + m.addAttribute("scopes", sortedScopes); + + // get the userinfo claims for each scope + UserInfo user = userInfoService.getByUsername(auth.getName()); + Map> claimsForScopes = new HashMap<>(); + if (user != null) { + JsonObject userJson = user.toJson(); + + for (SystemScope systemScope : sortedScopes) { + Map claimValues = new HashMap<>(); + + Set claims = scopeClaimTranslationService.getClaimsForScope(systemScope.getValue()); + for (String claim : claims) { + if (userJson.has(claim) && userJson.get(claim).isJsonPrimitive()) { + // TODO: this skips the address claim + claimValues.put(claim, userJson.get(claim).getAsString()); + } + } + + claimsForScopes.put(systemScope.getValue(), claimValues); + } + } + + m.addAttribute("claims", claimsForScopes); + + // client stats + Integer count = statsService.getCountForClientId(tx.getClient().getClientId()).getApprovedSiteCount(); + m.addAttribute("count", count); + + + // contacts + if (tx.getClient().getContacts() != null) { + String contacts = Joiner.on(", ").join(tx.getClient().getContacts()); + m.addAttribute("contacts", contacts); + } + + // if the client is over a week old and has more than one registration, don't give such a big warning + // instead, tag as "Generally Recognized As Safe" (gras) + Date lastWeek = new Date(System.currentTimeMillis() - (60 * 60 * 24 * 7 * 1000)); + if (count > 1 && tx.getClient().getCreatedAt() != null && tx.getClient().getCreatedAt().before(lastWeek)) { + m.addAttribute("gras", true); + } else { + m.addAttribute("gras", false); + } + + m.addAttribute("form_target", "interact/" + id); + + return "approve"; + + } + + @PreAuthorize("hasRole('ROLE_USER')") + @RequestMapping(method = RequestMethod.POST, path = "/{id}", params = OAuth2Utils.USER_OAUTH_APPROVAL) + public ModelAndView approveOrDeny(@PathVariable("id") String id, + @RequestParam Map approvalParameters, + Model m, Authentication auth) { + + TxEntity tx = txService.loadByInteractUrl(id); + + if (tx == null) { + m.addAttribute(HttpCodeView.CODE, HttpStatus.BAD_REQUEST); + new ModelAndView(HttpCodeView.VIEWNAME); + } + + // FIXME: this is using a simplified constructor + AuthorizationRequest ar = new AuthorizationRequest(tx.getClient().getClientId(), tx.getScope()); + ar.setRedirectUri(tx.getCallbackUri()); + + ar.setApprovalParameters(approvalParameters); + + userApprovalHandler.updateAfterApproval(ar, auth); + boolean approved = userApprovalHandler.isApproved(ar, auth); + ar.setApproved(approved); + + AuthenticationHolderEntity ah = tx.getAuthenticationHolder(); + + OAuth2Authentication o2a = new OAuth2Authentication(ah.getAuthentication().getOAuth2Request(), auth); + + ah.setAuthentication(o2a); + ah.setApproved(approved); + + tx.setAuthenticationHolder(ah); + + tx.setStatus(approved ? Status.AUTHORIZED : Status.DENIED); + + if (!Strings.isNullOrEmpty(tx.getCallbackUri())) { + + String interactRef = UUID.randomUUID().toString(); + + String hash = Hash.CalculateInteractHash(tx.getClientNonce(), tx.getServerNonce(), interactRef, tx.getHashMethod()); + + try { + String redirectTo = new URIBuilder(tx.getCallbackUri()) + .addParameter("interact", interactRef) + .addParameter("hash", hash).build().toString(); + + tx.setInteractionRef(interactRef); + + txService.save(tx); + + return new ModelAndView(new RedirectView(redirectTo)); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + + } else { + // no callback, show completion page + // pre-process the scopes + Set scopes = scopeService.fromStrings(tx.getScope()); + + Set sortedScopes = new LinkedHashSet<>(scopes.size()); + Set systemScopes = scopeService.getAll(); + + // sort scopes for display based on the inherent order of system scopes + for (SystemScope s : systemScopes) { + if (scopes.contains(s)) { + sortedScopes.add(s); + } + } + + // add in any scopes that aren't system scopes to the end of the list + sortedScopes.addAll(Sets.difference(scopes, systemScopes)); + + m.addAttribute("scopes", sortedScopes); + m.addAttribute("approved", true); + + txService.save(tx); + + // TODO: we are re-using the device approval page here + return new ModelAndView("deviceApproved"); + + } + } + + +} diff --git a/openid-connect-server/src/main/java/org/mitre/xyz/TxEndpoint.java b/openid-connect-server/src/main/java/org/mitre/xyz/TxEndpoint.java new file mode 100644 index 000000000..e14792107 --- /dev/null +++ b/openid-connect-server/src/main/java/org/mitre/xyz/TxEndpoint.java @@ -0,0 +1,343 @@ +package org.mitre.xyz; + +import java.net.URI; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.text.ParseException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import javax.servlet.http.HttpServletRequest; + +import org.mitre.oauth2.model.AuthenticationHolderEntity; +import org.mitre.oauth2.model.ClientDetailsEntity; +import org.mitre.oauth2.service.ClientDetailsEntityService; +import org.mitre.oauth2.service.OAuth2TokenEntityService; +import org.mitre.oauth2.service.SystemScopeService; +import org.mitre.openid.connect.config.ConfigurationPropertiesBean; +import org.mitre.openid.connect.view.HttpCodeView; +import org.mitre.openid.connect.view.JsonEntityView; +import org.mitre.openid.connect.view.JsonErrorView; +import org.mitre.xyz.Hash.Method; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.security.oauth2.common.OAuth2AccessToken; +import org.springframework.security.oauth2.provider.OAuth2Authentication; +import org.springframework.security.oauth2.provider.OAuth2Request; +import org.springframework.stereotype.Controller; +import org.springframework.ui.Model; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestHeader; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestMethod; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JOSEObject; +import com.nimbusds.jose.JWSObject; +import com.nimbusds.jose.JWSVerifier; +import com.nimbusds.jose.Payload; +import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.util.Base64URL; + +/** + * @author jricher + * + */ +@Controller +public class TxEndpoint { + + public enum Status { + + NEW, // newly created transaction, nothing's been done to it yet + ISSUED, // an access token has been issued + AUTHORIZED, // the user has authorized but a token has not been issued yet + WAITING, // we are waiting for the user + DENIED; // the user denied the transaction + } + + @Autowired + private ClientDetailsEntityService clientService; + + @Autowired + private TxService txService; + + @Autowired + private SystemScopeService scopeService; + + @Autowired + private OAuth2TokenEntityService tokenService; + + @Autowired + private ConfigurationPropertiesBean config; + + @RequestMapping(consumes = MediaType.APPLICATION_JSON_VALUE, + produces = MediaType.APPLICATION_JSON_VALUE, + path = "/transaction", + method = RequestMethod.POST) + public String transaction(@RequestBody String incoming, + @RequestHeader(name = HttpHeaders.AUTHORIZATION, required = false) String auth, + @RequestHeader(name = "Signature", required = false) String signature, + @RequestHeader(name = "Digest", required = false) String digest, + @RequestHeader(name = "Detached-JWS", required = false) String jwsd, + @RequestHeader(name = "DPoP", required = false) String dpop, + @RequestHeader(name = "PoP", required = false) String oauthPop, + Model m, + HttpServletRequest req) { + + JsonObject json = JsonParser.parseString(incoming).getAsJsonObject(); + + TxEntity tx; + + if (json.has("handle")) { + // it's a handle to an existing transaction, load it up and wire things in + tx = txService.loadByHandle(json.get("handle").getAsString()); + } else { + // otherwise we build one from the parts + + tx = new TxEntity(); + // client ID is passed in as the handle for the key object + // we don't support ephemeral keys (yet) + String clientId = json.get("keys").getAsString(); + + // first, load the client + ClientDetailsEntity client = clientService.loadClientByClientId(clientId); + + if (client == null) { + m.addAttribute(JsonErrorView.ERROR, "unknown_key"); + m.addAttribute(JsonErrorView.ERROR_MESSAGE, "The key handle presented does not match a client."); + m.addAttribute(HttpCodeView.CODE, HttpStatus.BAD_REQUEST); + return JsonErrorView.VIEWNAME; + } + + + tx.setClient(client); + + // scopes are passed in as handles for the resources + JsonArray resources = json.get("resources").getAsJsonArray(); + Set scopes = StreamSupport.stream(resources.spliterator(), false) + .filter( e -> e.isJsonPrimitive() ) // filter out anything that's not a handle + .map( e -> e.getAsString() ) + .collect(Collectors.toSet()); + tx.setScope(scopes); + + tx.setStatus(Status.NEW); + } + + // process transaction + + // check signatures + + // get the only key + if (tx.getClient().getJwks() == null || tx.getClient().getJwks().getKeys().size() == 0) { + m.addAttribute(JsonErrorView.ERROR, "unknown_key"); + m.addAttribute(JsonErrorView.ERROR_MESSAGE, "The key handle presented does not map to a key."); + m.addAttribute(HttpCodeView.CODE, HttpStatus.BAD_REQUEST); + return JsonErrorView.VIEWNAME; + } + + JWKSet clientJwks = tx.getClient().getJwks(); + + if (clientJwks.getKeys().size() != 1) { + m.addAttribute(JsonErrorView.ERROR, "unknown_key"); + m.addAttribute(JsonErrorView.ERROR_MESSAGE, "The key handle presented maps to multiple keys."); + m.addAttribute(HttpCodeView.CODE, HttpStatus.BAD_REQUEST); + return JsonErrorView.VIEWNAME; + } + + // TODO: this doesn't allow for multiple keys to be selected, this could be a client property + // TODO: this doesn't allow for jwks_uri loaded keys + JWK clientJwk = clientJwks.getKeys().get(0); + + // check the signature on the incoming request + // TODO: make this configurable on clients, for now assume JWSD + checkDetachedJws(jwsd, incoming, clientJwk); + + // process the transaction based on its current state + switch (tx.getStatus()) { + case NEW: + // now make sure the client is asking for scopes that it's allowed to + if (!scopeService.scopesMatch(tx.getClient().getScope(), tx.getScope())) { + m.addAttribute(JsonErrorView.ERROR, "resource_not_allowed"); + m.addAttribute(JsonErrorView.ERROR_MESSAGE, "The client requested resources it does not have access to."); + m.addAttribute(HttpCodeView.CODE, HttpStatus.BAD_REQUEST); + return JsonErrorView.VIEWNAME; + } + + OAuth2Request o2r = new OAuth2Request( + Collections.emptyMap(), tx.getClient().getClientId(), + tx.getClient().getAuthorities(), tx.getStatus().equals(Status.AUTHORIZED), + tx.getScope(), null, tx.getCallbackUri(), null, null); + + OAuth2Authentication o2a = new OAuth2Authentication(o2r, null); + + AuthenticationHolderEntity ah = new AuthenticationHolderEntity(); + ah.setAuthentication(o2a); + ah.setApproved(false); + + tx.setAuthenticationHolder(ah); + + // look back at the request to process the interaction parameters + JsonObject interact = json.get("interact").getAsJsonObject(); + if (interact == null) { + if (tx.getClient().getGrantTypes().contains("client_credentials")) { + // TODO client can do credentials-only grant, issue a token + + } + } else { + Map map = new HashMap<>(); + + // we support "redirect" and "callback" here + if (interact.has("redirect")) { + // generate an interaction URL + String interactPage = UUID.randomUUID().toString(); + tx.setInteraction(interactPage); + + String interactUrl = config.getIssuer() + "interact/" + interactPage; + map.put("interaction_url", interactUrl); + } + + if (interact.has("callback")) { + JsonObject callback = interact.get("callback").getAsJsonObject(); + + String callbackString = callback.get("uri").getAsString(); + Path callbackPath = Paths.get(URI.create(callbackString).getPath()); + + // we do sub-path matching for the callback + // FIXME: this is a really simplistic filter that definitely has holes in it + boolean callbackMatches = tx.getClient().getRedirectUris().stream() + .filter(s -> callbackString.startsWith(s)) + .map(URI::create) + .map(URI::getPath) + .map(Paths::get) + .anyMatch(path -> + callbackPath.startsWith(path) + ); + + if (!callbackMatches) { + m.addAttribute(JsonErrorView.ERROR, "invalid_callback_uri"); + m.addAttribute(JsonErrorView.ERROR_MESSAGE, "The client presented a callback URI that did not match one registered."); + m.addAttribute(HttpCodeView.CODE, HttpStatus.BAD_REQUEST); + return JsonErrorView.VIEWNAME; + } + + tx.setCallbackUri(callbackString); + + tx.setClientNonce(callback.get("nonce").getAsString()); + + if (callback.has("hash_method")) { + tx.setHashMethod(Hash.Method.fromJson(callback.get("hash_method").getAsString())); + } else { + tx.setHashMethod(Method.SHA3); + } + + String serverNonce = UUID.randomUUID().toString(); + tx.setServerNonce(serverNonce); + + map.put("server_nonce", serverNonce); + } + + // rotate the handle + String handle = UUID.randomUUID().toString(); + tx.setHandle(handle); + + Map h = new HashMap<>(); + h.put("value", handle); + h.put("presentation", "bearer"); + map.put("handle", h); + + txService.save(tx); + + m.addAttribute(JsonEntityView.ENTITY, map); + return JsonEntityView.VIEWNAME; + } + + break; + case AUTHORIZED: + + Map map = new HashMap<>(); + + OAuth2Authentication storedAuth = tx.getAuthenticationHolder().getAuthentication(); + + OAuth2AccessToken accessToken = tokenService.createAccessToken(storedAuth); + + Map at = new HashMap<>(); + at.put("value", accessToken.getValue()); + at.put("presentation", accessToken.getTokenType().toLowerCase()); + at.put("expiration", accessToken.getExpiration().toInstant().toString()); + map.put("access_token", at); + + tx.setStatus(Status.ISSUED); + + if (accessToken.getAdditionalInformation().containsKey("id_token")) { + // add in the ID token if it's included + Map c = new HashMap<>(); + c.put("oidc_id_token", (String) accessToken.getAdditionalInformation().get("id_token")); + map.put("claims", c); + + // TODO: save the claims request and translate that directly + } + + // rotate the handle + String handle = UUID.randomUUID().toString(); + tx.setHandle(handle); + + Map h = new HashMap<>(); + h.put("value", handle); + h.put("presentation", "bearer"); + map.put("handle", h); + + txService.save(tx); + m.addAttribute(JsonEntityView.ENTITY, map); + return JsonEntityView.VIEWNAME; + case DENIED: + break; + case ISSUED: + break; + case WAITING: + break; + default: + break; + + } + + m.addAttribute(JsonErrorView.ERROR, "transaction_error"); + m.addAttribute(JsonErrorView.ERROR_MESSAGE, "There was an error processing the transaction."); + m.addAttribute(HttpCodeView.CODE, HttpStatus.INTERNAL_SERVER_ERROR); + return JsonErrorView.VIEWNAME; + } + + + private void checkDetachedJws(String jwsd, String requestBody, JWK clientKey) { + try { + + Base64URL[] parts = JOSEObject.split(jwsd); + Payload payload = new Payload(requestBody.getBytes()); + + JWSObject jwsObject = new JWSObject(parts[0], payload, parts[2]); + + JWSVerifier verifier = new DefaultJWSVerifierFactory().createJWSVerifier(jwsObject.getHeader(), + ((RSAKey)clientKey).toRSAPublicKey()); + + if (!jwsObject.verify(verifier)) { + throw new RuntimeException("Unable to verify JWS"); + } + + } catch (ParseException | JOSEException e) { + throw new RuntimeException("Bad JWS", e); + } + } + +}