diff --git a/openid-connect-common/src/main/java/org/mitre/openid/connect/service/StatsService.java b/openid-connect-common/src/main/java/org/mitre/openid/connect/service/StatsService.java index 8bf56bbb2..319987b82 100644 --- a/openid-connect-common/src/main/java/org/mitre/openid/connect/service/StatsService.java +++ b/openid-connect-common/src/main/java/org/mitre/openid/connect/service/StatsService.java @@ -35,14 +35,14 @@ public interface StatsService { * * @return */ - public Map calculateSummaryStats(); + public Map getSummaryStats(); /** * Calculate usage count for all clients * * @return a map of id of client object to number of approvals */ - public Map calculateByClientId(); + public Map getByClientId(); /** * Calculate the usage count for a single client @@ -50,6 +50,11 @@ public interface StatsService { * @param id the id of the client to search on * @return */ - public Integer countForClientId(Long id); + public Integer getCountForClientId(Long id); + + /** + * Trigger the stats to be recalculated upon next update. + */ + public void resetCache(); } diff --git a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ClientDetailsEntityService.java b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ClientDetailsEntityService.java index 3d538cead..9a7219f73 100644 --- a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ClientDetailsEntityService.java +++ b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ClientDetailsEntityService.java @@ -37,6 +37,7 @@ import org.mitre.oauth2.service.SystemScopeService; import org.mitre.openid.connect.model.WhitelistedSite; import org.mitre.openid.connect.service.ApprovedSiteService; import org.mitre.openid.connect.service.BlacklistedSiteService; +import org.mitre.openid.connect.service.StatsService; import org.mitre.openid.connect.service.WhitelistedSiteService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -76,6 +77,9 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt @Autowired private SystemScopeService scopeService; + + @Autowired + private StatsService statsService; // map of sector URI -> list of redirect URIs private LoadingCache> sectorRedirects = CacheBuilder.newBuilder() @@ -136,7 +140,11 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt // make sure a client doesn't get any special system scopes client.setScope(scopeService.removeRestrictedScopes(client.getScope())); - return clientRepository.saveClient(client); + ClientDetailsEntity c = clientRepository.saveClient(client); + + statsService.resetCache(); + + return c; } /** @@ -192,6 +200,8 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt // take care of the client itself clientRepository.deleteClient(client); + statsService.resetCache(); + } /** diff --git a/openid-connect-server/src/main/java/org/mitre/oauth2/web/OAuthConfirmationController.java b/openid-connect-server/src/main/java/org/mitre/oauth2/web/OAuthConfirmationController.java index 6021332db..a47781d3f 100644 --- a/openid-connect-server/src/main/java/org/mitre/oauth2/web/OAuthConfirmationController.java +++ b/openid-connect-server/src/main/java/org/mitre/oauth2/web/OAuthConfirmationController.java @@ -171,7 +171,7 @@ public class OAuthConfirmationController { model.put("claims", claimsForScopes); // client stats - Integer count = statsService.countForClientId(client.getId()); + Integer count = statsService.getCountForClientId(client.getId()); model.put("count", count); diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java index 358a1dcc8..c313f9e53 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java @@ -26,6 +26,7 @@ import org.mitre.openid.connect.model.ApprovedSite; import org.mitre.openid.connect.model.WhitelistedSite; import org.mitre.openid.connect.repository.ApprovedSiteRepository; import org.mitre.openid.connect.service.ApprovedSiteService; +import org.mitre.openid.connect.service.StatsService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -53,6 +54,9 @@ public class DefaultApprovedSiteService implements ApprovedSiteService { @Autowired private OAuth2TokenRepository tokenRepository; + + @Autowired + private StatsService statsService; @Override public Collection getAll() { @@ -62,7 +66,9 @@ public class DefaultApprovedSiteService implements ApprovedSiteService { @Override @Transactional public ApprovedSite save(ApprovedSite approvedSite) { - return approvedSiteRepository.save(approvedSite); + ApprovedSite a = approvedSiteRepository.save(approvedSite); + statsService.resetCache(); + return a; } @Override @@ -85,6 +91,8 @@ public class DefaultApprovedSiteService implements ApprovedSiteService { } approvedSiteRepository.remove(approvedSite); + + statsService.resetCache(); } @Override @@ -140,7 +148,7 @@ public class DefaultApprovedSiteService implements ApprovedSiteService { Collection approvedSites = approvedSiteRepository.getByClientId(client.getClientId()); if (approvedSites != null) { for (ApprovedSite approvedSite : approvedSites) { - approvedSiteRepository.remove(approvedSite); + remove(approvedSite); } } } @@ -153,7 +161,7 @@ public class DefaultApprovedSiteService implements ApprovedSiteService { Collection expiredSites = getExpired(); if (expiredSites != null) { for (ApprovedSite expired : expiredSites) { - approvedSiteRepository.remove(expired); + remove(expired); } } } diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultStatsService.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultStatsService.java index fc7e0e18a..d9dcc3173 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultStatsService.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultStatsService.java @@ -53,25 +53,32 @@ public class DefaultStatsService implements StatsService { private ClientDetailsEntityService clientService; // stats cache - private Supplier> summaryCache = Suppliers.memoizeWithExpiration(new Supplier>() { - @Override - public Map get() { - return computeSummaryStats(); - } + private Supplier> summaryCache = createSummaryCache(); + + private Supplier> createSummaryCache() { + return Suppliers.memoizeWithExpiration(new Supplier>() { + @Override + public Map get() { + return computeSummaryStats(); + } + + }, 10, TimeUnit.MINUTES); + } - }, 10, TimeUnit.MINUTES); - - private Supplier> byClientIdCache = Suppliers.memoizeWithExpiration(new Supplier>() { - - @Override - public Map get() { - return computeByClientId(); - } - - }, 10, TimeUnit.MINUTES); + private Supplier> byClientIdCache = createByClientIdCache(); + + private Supplier> createByClientIdCache() { + return Suppliers.memoizeWithExpiration(new Supplier>() { + @Override + public Map get() { + return computeByClientId(); + } + + }, 10, TimeUnit.MINUTES); + } @Override - public Map calculateSummaryStats() { + public Map getSummaryStats() { return summaryCache.get(); } @@ -100,7 +107,7 @@ public class DefaultStatsService implements StatsService { * @see org.mitre.openid.connect.service.StatsService#calculateByClientId() */ @Override - public Map calculateByClientId() { + public Map getByClientId() { return byClientIdCache.get(); } @@ -126,9 +133,9 @@ public class DefaultStatsService implements StatsService { * @see org.mitre.openid.connect.service.StatsService#countForClientId(java.lang.String) */ @Override - public Integer countForClientId(Long id) { + public Integer getCountForClientId(Long id) { - Map counts = calculateByClientId(); + Map counts = getByClientId(); return counts.get(id); } @@ -147,4 +154,13 @@ public class DefaultStatsService implements StatsService { return counts; } + /** + * Reset both stats caches on a trigger (before the timer runs out). Resets the timers. + */ + @Override + public void resetCache() { + summaryCache = createSummaryCache(); + byClientIdCache = createByClientIdCache(); + } + } diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/ManagerController.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/ManagerController.java index edc9a07c7..78f1285f5 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/ManagerController.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/ManagerController.java @@ -38,7 +38,7 @@ public class ManagerController { @RequestMapping({"", "home", "index"}) public String showHomePage(ModelMap m) { - Map summary = statsService.calculateSummaryStats(); + Map summary = statsService.getSummaryStats(); m.put("statsSummary", summary); return "home"; @@ -53,7 +53,7 @@ public class ManagerController { @RequestMapping({"stats", "stats/"}) public String showStatsPage(ModelMap m) { - Map summary = statsService.calculateSummaryStats(); + Map summary = statsService.getSummaryStats(); m.put("statsSummary", summary); return "stats"; diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/StatsAPI.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/StatsAPI.java index 5665d9afd..8de307ea8 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/StatsAPI.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/StatsAPI.java @@ -37,7 +37,7 @@ public class StatsAPI { @RequestMapping(value = "summary", produces = "application/json") public String statsSummary(ModelMap m) { - Map e = statsService.calculateSummaryStats(); + Map e = statsService.getSummaryStats(); m.put("entity", e); @@ -47,7 +47,7 @@ public class StatsAPI { @RequestMapping(value = "byclientid", produces = "application/json") public String statsByClient(ModelMap m) { - Map e = statsService.calculateByClientId(); + Map e = statsService.getByClientId(); m.put("entity", e); @@ -56,7 +56,7 @@ public class StatsAPI { @RequestMapping(value = "byclientid/{id}", produces = "application/json") public String statsByClientId(@PathVariable("id") Long id, ModelMap m) { - Integer e = statsService.countForClientId(id); + Integer e = statsService.getCountForClientId(id); m.put("entity", e); diff --git a/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestDefaultStatsService.java b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestDefaultStatsService.java index d75f772e0..8e0338f53 100644 --- a/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestDefaultStatsService.java +++ b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestDefaultStatsService.java @@ -59,6 +59,8 @@ public class TestDefaultStatsService { private ApprovedSite ap2 = Mockito.mock(ApprovedSite.class); private ApprovedSite ap3 = Mockito.mock(ApprovedSite.class); private ApprovedSite ap4 = Mockito.mock(ApprovedSite.class); + private ApprovedSite ap5 = Mockito.mock(ApprovedSite.class); + private ApprovedSite ap6 = Mockito.mock(ApprovedSite.class); private ClientDetailsEntity client1 = Mockito.mock(ClientDetailsEntity.class); private ClientDetailsEntity client2 = Mockito.mock(ClientDetailsEntity.class); @@ -95,6 +97,12 @@ public class TestDefaultStatsService { Mockito.when(ap4.getUserId()).thenReturn(userId2); Mockito.when(ap4.getClientId()).thenReturn(clientId3); + Mockito.when(ap5.getUserId()).thenReturn(userId2); + Mockito.when(ap5.getClientId()).thenReturn(clientId1); + + Mockito.when(ap6.getUserId()).thenReturn(userId1); + Mockito.when(ap6.getClientId()).thenReturn(clientId4); + Mockito.when(approvedSiteService.getAll()).thenReturn(Sets.newHashSet(ap1, ap2, ap3, ap4)); Mockito.when(client1.getId()).thenReturn(1L); @@ -114,7 +122,7 @@ public class TestDefaultStatsService { Mockito.when(approvedSiteService.getAll()).thenReturn(new HashSet()); - Map stats = service.calculateSummaryStats(); + Map stats = service.getSummaryStats(); assertThat(stats.get("approvalCount"), is(0)); assertThat(stats.get("userCount"), is(0)); @@ -123,7 +131,7 @@ public class TestDefaultStatsService { @Test public void calculateSummaryStats() { - Map stats = service.calculateSummaryStats(); + Map stats = service.getSummaryStats(); assertThat(stats.get("approvalCount"), is(4)); assertThat(stats.get("userCount"), is(2)); @@ -135,7 +143,7 @@ public class TestDefaultStatsService { Mockito.when(approvedSiteService.getAll()).thenReturn(new HashSet()); - Map stats = service.calculateByClientId(); + Map stats = service.getByClientId(); assertThat(stats.get(1L), is(0)); assertThat(stats.get(2L), is(0)); @@ -146,7 +154,7 @@ public class TestDefaultStatsService { @Test public void calculateByClientId() { - Map stats = service.calculateByClientId(); + Map stats = service.getByClientId(); assertThat(stats.get(1L), is(2)); assertThat(stats.get(2L), is(1)); @@ -157,9 +165,38 @@ public class TestDefaultStatsService { @Test public void countForClientId() { - assertThat(service.countForClientId(1L), is(2)); - assertThat(service.countForClientId(2L), is(1)); - assertThat(service.countForClientId(3L), is(1)); - assertThat(service.countForClientId(4L), is(0)); + assertThat(service.getCountForClientId(1L), is(2)); + assertThat(service.getCountForClientId(2L), is(1)); + assertThat(service.getCountForClientId(3L), is(1)); + assertThat(service.getCountForClientId(4L), is(0)); + } + + @Test + public void cacheAndReset() { + + Map stats = service.getSummaryStats(); + + assertThat(stats.get("approvalCount"), is(4)); + assertThat(stats.get("userCount"), is(2)); + assertThat(stats.get("clientCount"), is(3)); + + Mockito.when(approvedSiteService.getAll()).thenReturn(Sets.newHashSet(ap1, ap2, ap3, ap4, ap5, ap6)); + + Map stats2 = service.getSummaryStats(); + + // cache should remain the same due to memoized functions + assertThat(stats2.get("approvalCount"), is(4)); + assertThat(stats2.get("userCount"), is(2)); + assertThat(stats2.get("clientCount"), is(3)); + + // reset the cache and make sure the count goes up + service.resetCache(); + + Map stats3 = service.getSummaryStats(); + + assertThat(stats3.get("approvalCount"), is(6)); + assertThat(stats3.get("userCount"), is(2)); + assertThat(stats3.get("clientCount"), is(4)); + } }