diff --git a/openid-connect-common/src/main/java/org/mitre/openid/connect/service/StatsService.java b/openid-connect-common/src/main/java/org/mitre/openid/connect/service/StatsService.java index 94c180a1a..d1f896ba3 100644 --- a/openid-connect-common/src/main/java/org/mitre/openid/connect/service/StatsService.java +++ b/openid-connect-common/src/main/java/org/mitre/openid/connect/service/StatsService.java @@ -35,14 +35,14 @@ public interface StatsService { * * @return */ - public Map calculateSummaryStats(); + public Map getSummaryStats(); /** * Calculate usage count for all clients * * @return a map of id of client object to number of approvals */ - public Map calculateByClientId(); + public Map getByClientId(); /** * Calculate the usage count for a single client @@ -50,6 +50,11 @@ public interface StatsService { * @param id the id of the client to search on * @return */ - public Integer countForClientId(Long id); + public Integer getCountForClientId(Long id); + + /** + * Trigger the stats to be recalculated upon next update. + */ + public void resetCache(); } diff --git a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ClientDetailsEntityService.java b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ClientDetailsEntityService.java index 1aac429e0..f8c77eac9 100644 --- a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ClientDetailsEntityService.java +++ b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ClientDetailsEntityService.java @@ -30,6 +30,7 @@ import org.mitre.oauth2.service.ClientDetailsEntityService; import org.mitre.openid.connect.model.WhitelistedSite; import org.mitre.openid.connect.service.ApprovedSiteService; import org.mitre.openid.connect.service.BlacklistedSiteService; +import org.mitre.openid.connect.service.StatsService; import org.mitre.openid.connect.service.WhitelistedSiteService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.oauth2.common.exceptions.InvalidClientException; @@ -56,6 +57,8 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt @Autowired private BlacklistedSiteService blacklistedSiteService; + @Autowired + private StatsService statsService; @Override public ClientDetailsEntity saveNewClient(ClientDetailsEntity client) { @@ -87,7 +90,11 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt // timestamp this to right now client.setCreatedAt(new Date()); - return clientRepository.saveClient(client); + ClientDetailsEntity c = clientRepository.saveClient(client); + + statsService.resetCache(); + + return c; } /** @@ -143,6 +150,8 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt // take care of the client itself clientRepository.deleteClient(client); + statsService.resetCache(); + } /** diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java index f3a58965e..e19b9423a 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java @@ -20,10 +20,12 @@ import java.util.Collection; import java.util.Date; import java.util.Set; +import org.mitre.oauth2.repository.OAuth2TokenRepository; import org.mitre.openid.connect.model.ApprovedSite; import org.mitre.openid.connect.model.WhitelistedSite; import org.mitre.openid.connect.repository.ApprovedSiteRepository; import org.mitre.openid.connect.service.ApprovedSiteService; +import org.mitre.openid.connect.service.StatsService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -49,6 +51,12 @@ public class DefaultApprovedSiteService implements ApprovedSiteService { @Autowired private ApprovedSiteRepository approvedSiteRepository; + @Autowired + private OAuth2TokenRepository tokenRepository; + + @Autowired + private StatsService statsService; + @Override public Collection getAll() { return approvedSiteRepository.getAll(); @@ -57,7 +65,9 @@ public class DefaultApprovedSiteService implements ApprovedSiteService { @Override @Transactional public ApprovedSite save(ApprovedSite approvedSite) { - return approvedSiteRepository.save(approvedSite); + ApprovedSite a = approvedSiteRepository.save(approvedSite); + statsService.resetCache(); + return a; } @Override @@ -69,6 +79,8 @@ public class DefaultApprovedSiteService implements ApprovedSiteService { @Transactional public void remove(ApprovedSite approvedSite) { approvedSiteRepository.remove(approvedSite); + + statsService.resetCache(); } @Override @@ -124,7 +136,7 @@ public class DefaultApprovedSiteService implements ApprovedSiteService { Collection approvedSites = approvedSiteRepository.getByClientId(client.getClientId()); if (approvedSites != null) { for (ApprovedSite approvedSite : approvedSites) { - approvedSiteRepository.remove(approvedSite); + remove(approvedSite); } } } @@ -137,7 +149,7 @@ public class DefaultApprovedSiteService implements ApprovedSiteService { Collection expiredSites = getExpired(); if (expiredSites != null) { for (ApprovedSite expired : expiredSites) { - approvedSiteRepository.remove(expired); + remove(expired); } } } diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultStatsService.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultStatsService.java index c9aa2007d..f5c91c2cc 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultStatsService.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultStatsService.java @@ -53,25 +53,32 @@ public class DefaultStatsService implements StatsService { private ClientDetailsEntityService clientService; // stats cache - private Supplier> summaryCache = Suppliers.memoizeWithExpiration(new Supplier>() { - @Override - public Map get() { - return computeSummaryStats(); - } + private Supplier> summaryCache = createSummaryCache(); + + private Supplier> createSummaryCache() { + return Suppliers.memoizeWithExpiration(new Supplier>() { + @Override + public Map get() { + return computeSummaryStats(); + } + + }, 10, TimeUnit.MINUTES); + } - }, 10, TimeUnit.MINUTES); - - private Supplier> byClientIdCache = Suppliers.memoizeWithExpiration(new Supplier>() { - - @Override - public Map get() { - return computeByClientId(); - } - - }, 10, TimeUnit.MINUTES); + private Supplier> byClientIdCache = createByClientIdCache(); + + private Supplier> createByClientIdCache() { + return Suppliers.memoizeWithExpiration(new Supplier>() { + @Override + public Map get() { + return computeByClientId(); + } + + }, 10, TimeUnit.MINUTES); + } @Override - public Map calculateSummaryStats() { + public Map getSummaryStats() { return summaryCache.get(); } @@ -100,7 +107,7 @@ public class DefaultStatsService implements StatsService { * @see org.mitre.openid.connect.service.StatsService#calculateByClientId() */ @Override - public Map calculateByClientId() { + public Map getByClientId() { return byClientIdCache.get(); } @@ -126,9 +133,9 @@ public class DefaultStatsService implements StatsService { * @see org.mitre.openid.connect.service.StatsService#countForClientId(java.lang.String) */ @Override - public Integer countForClientId(Long id) { + public Integer getCountForClientId(Long id) { - Map counts = calculateByClientId(); + Map counts = getByClientId(); return counts.get(id); } @@ -147,4 +154,13 @@ public class DefaultStatsService implements StatsService { return counts; } + /** + * Reset both stats caches on a trigger (before the timer runs out). Resets the timers. + */ + @Override + public void resetCache() { + summaryCache = createSummaryCache(); + byClientIdCache = createByClientIdCache(); + } + } diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/ManagerController.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/ManagerController.java index c7f4c06c7..170e8ebfb 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/ManagerController.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/ManagerController.java @@ -38,7 +38,7 @@ public class ManagerController { @RequestMapping({"", "home", "index"}) public String showHomePage(ModelMap m) { - Map summary = statsService.calculateSummaryStats(); + Map summary = statsService.getSummaryStats(); m.put("statsSummary", summary); return "home"; @@ -53,7 +53,7 @@ public class ManagerController { @RequestMapping({"stats", "stats/"}) public String showStatsPage(ModelMap m) { - Map summary = statsService.calculateSummaryStats(); + Map summary = statsService.getSummaryStats(); m.put("statsSummary", summary); return "stats"; diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/StatsAPI.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/StatsAPI.java index 47b37127f..aa133ed50 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/StatsAPI.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/StatsAPI.java @@ -37,7 +37,7 @@ public class StatsAPI { @RequestMapping(value = "summary", produces = "application/json") public String statsSummary(ModelMap m) { - Map e = statsService.calculateSummaryStats(); + Map e = statsService.getSummaryStats(); m.put("entity", e); @@ -47,7 +47,7 @@ public class StatsAPI { @RequestMapping(value = "byclientid", produces = "application/json") public String statsByClient(ModelMap m) { - Map e = statsService.calculateByClientId(); + Map e = statsService.getByClientId(); m.put("entity", e); @@ -56,7 +56,7 @@ public class StatsAPI { @RequestMapping(value = "byclientid/{id}", produces = "application/json") public String statsByClientId(@PathVariable("id") Long id, ModelMap m) { - Integer e = statsService.countForClientId(id); + Integer e = statsService.getCountForClientId(id); m.put("entity", e); diff --git a/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestDefaultStatsService.java b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestDefaultStatsService.java index 61c3a9f64..d19146cc9 100644 --- a/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestDefaultStatsService.java +++ b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestDefaultStatsService.java @@ -59,6 +59,8 @@ public class TestDefaultStatsService { private ApprovedSite ap2 = Mockito.mock(ApprovedSite.class); private ApprovedSite ap3 = Mockito.mock(ApprovedSite.class); private ApprovedSite ap4 = Mockito.mock(ApprovedSite.class); + private ApprovedSite ap5 = Mockito.mock(ApprovedSite.class); + private ApprovedSite ap6 = Mockito.mock(ApprovedSite.class); private ClientDetailsEntity client1 = Mockito.mock(ClientDetailsEntity.class); private ClientDetailsEntity client2 = Mockito.mock(ClientDetailsEntity.class); @@ -95,6 +97,12 @@ public class TestDefaultStatsService { Mockito.when(ap4.getUserId()).thenReturn(userId2); Mockito.when(ap4.getClientId()).thenReturn(clientId3); + Mockito.when(ap5.getUserId()).thenReturn(userId2); + Mockito.when(ap5.getClientId()).thenReturn(clientId1); + + Mockito.when(ap6.getUserId()).thenReturn(userId1); + Mockito.when(ap6.getClientId()).thenReturn(clientId4); + Mockito.when(approvedSiteService.getAll()).thenReturn(Sets.newHashSet(ap1, ap2, ap3, ap4)); Mockito.when(client1.getId()).thenReturn(1L); @@ -114,7 +122,7 @@ public class TestDefaultStatsService { Mockito.when(approvedSiteService.getAll()).thenReturn(new HashSet()); - Map stats = service.calculateSummaryStats(); + Map stats = service.getSummaryStats(); assertThat(stats.get("approvalCount"), is(0)); assertThat(stats.get("userCount"), is(0)); @@ -123,7 +131,7 @@ public class TestDefaultStatsService { @Test public void calculateSummaryStats() { - Map stats = service.calculateSummaryStats(); + Map stats = service.getSummaryStats(); assertThat(stats.get("approvalCount"), is(4)); assertThat(stats.get("userCount"), is(2)); @@ -135,7 +143,7 @@ public class TestDefaultStatsService { Mockito.when(approvedSiteService.getAll()).thenReturn(new HashSet()); - Map stats = service.calculateByClientId(); + Map stats = service.getByClientId(); assertThat(stats.get(1L), is(0)); assertThat(stats.get(2L), is(0)); @@ -146,7 +154,7 @@ public class TestDefaultStatsService { @Test public void calculateByClientId() { - Map stats = service.calculateByClientId(); + Map stats = service.getByClientId(); assertThat(stats.get(1L), is(2)); assertThat(stats.get(2L), is(1)); @@ -157,9 +165,38 @@ public class TestDefaultStatsService { @Test public void countForClientId() { - assertThat(service.countForClientId(1L), is(2)); - assertThat(service.countForClientId(2L), is(1)); - assertThat(service.countForClientId(3L), is(1)); - assertThat(service.countForClientId(4L), is(0)); + assertThat(service.getCountForClientId(1L), is(2)); + assertThat(service.getCountForClientId(2L), is(1)); + assertThat(service.getCountForClientId(3L), is(1)); + assertThat(service.getCountForClientId(4L), is(0)); + } + + @Test + public void cacheAndReset() { + + Map stats = service.getSummaryStats(); + + assertThat(stats.get("approvalCount"), is(4)); + assertThat(stats.get("userCount"), is(2)); + assertThat(stats.get("clientCount"), is(3)); + + Mockito.when(approvedSiteService.getAll()).thenReturn(Sets.newHashSet(ap1, ap2, ap3, ap4, ap5, ap6)); + + Map stats2 = service.getSummaryStats(); + + // cache should remain the same due to memoized functions + assertThat(stats2.get("approvalCount"), is(4)); + assertThat(stats2.get("userCount"), is(2)); + assertThat(stats2.get("clientCount"), is(3)); + + // reset the cache and make sure the count goes up + service.resetCache(); + + Map stats3 = service.getSummaryStats(); + + assertThat(stats3.get("approvalCount"), is(6)); + assertThat(stats3.get("userCount"), is(2)); + assertThat(stats3.get("clientCount"), is(4)); + } }